package org.apache.ignite.ml.composition;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.util.Utils;
import org.jetbrains.annotations.NotNull;

/* loaded from: input_file:org/apache/ignite/ml/composition/BaggingModelTrainer.class */
public abstract class BaggingModelTrainer extends DatasetTrainer<ModelsComposition, Double> {
    private final PredictionsAggregator predictionsAggregator;
    private final int maximumFeaturesCntPerMdl;
    private final int ensembleSize;
    private final double samplePartSizePerMdl;
    private final int featureVectorSize;

    public BaggingModelTrainer(PredictionsAggregator predictionsAggregator, int i, int i2, int i3, double d) {
        this.predictionsAggregator = predictionsAggregator;
        this.maximumFeaturesCntPerMdl = i2;
        this.ensembleSize = i3;
        this.samplePartSizePerMdl = d;
        this.featureVectorSize = i;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> ModelsComposition fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, Double> igniteBiFunction2) {
        MLLogger logger = this.environment.logger(getClass());
        logger.log(MLLogger.VerboseLevel.LOW, "Start learning", new Object[0]);
        Long valueOf = Long.valueOf(System.currentTimeMillis());
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.ensembleSize; i++) {
            arrayList.add(() -> {
                return learnModel(datasetBuilder, igniteBiFunction, igniteBiFunction2);
            });
        }
        List list = (List) this.environment.parallelismStrategy().submit(arrayList).stream().map((v0) -> {
            return v0.unsafeGet();
        }).collect(Collectors.toList());
        logger.log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", Double.valueOf((System.currentTimeMillis() - valueOf.longValue()) / 1000.0d));
        logger.log(MLLogger.VerboseLevel.LOW, "Learning finished", new Object[0]);
        return new ModelsComposition(list, this.predictionsAggregator);
    }

    @NotNull
    private <K, V> ModelOnFeaturesSubspace learnModel(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, Double> igniteBiFunction2) {
        Random random = new Random();
        SHA256UniformMapper sHA256UniformMapper = new SHA256UniformMapper(random);
        Map<Integer, Integer> createFeaturesMapping = createFeaturesMapping(random.nextLong(), this.featureVectorSize);
        Long valueOf = Long.valueOf(System.currentTimeMillis());
        Model<Vector, Double> fit = buildDatasetTrainerForModel().fit(datasetBuilder.withFilter((obj, obj2) -> {
            return sHA256UniformMapper.map(obj, obj2) < this.samplePartSizePerMdl;
        }), wrapFeatureExtractor(igniteBiFunction, createFeaturesMapping), igniteBiFunction2);
        this.environment.logger(getClass()).log(MLLogger.VerboseLevel.HIGH, "One model training time was %.2fs", Double.valueOf((System.currentTimeMillis() - valueOf.longValue()) / 1000.0d));
        return new ModelOnFeaturesSubspace(createFeaturesMapping, fit);
    }

    private Map<Integer, Integer> createFeaturesMapping(long j, int i) {
        int[] selectKDistinct = Utils.selectKDistinct(i, this.maximumFeaturesCntPerMdl, new Random(j));
        HashMap hashMap = new HashMap();
        IntStream.range(0, this.maximumFeaturesCntPerMdl).forEach(i2 -> {
        });
        return hashMap;
    }

    protected abstract DatasetTrainer<? extends Model<Vector, Double>, Double> buildDatasetTrainerForModel();

    private <K, V> IgniteBiFunction<K, V, Vector> wrapFeatureExtractor(IgniteBiFunction<K, V, Vector> igniteBiFunction, Map<Integer, Integer> map) {
        return igniteBiFunction.andThen(vector -> {
            double[] dArr = new double[map.size()];
            map.forEach((num, num2) -> {
                dArr[num.intValue()] = vector.get(num2.intValue());
            });
            return VectorUtils.of(dArr);
        });
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> ModelsComposition updateModel(ModelsComposition modelsComposition, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, Double> igniteBiFunction2) {
        ArrayList arrayList = new ArrayList(modelsComposition.getModels());
        arrayList.addAll(fit((DatasetBuilder) datasetBuilder, (IgniteBiFunction) igniteBiFunction, (IgniteBiFunction) igniteBiFunction2).getModels());
        return new ModelsComposition(arrayList, this.predictionsAggregator);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 566929805:
                if (implMethodName.equals("lambda$learnModel$81b8a671$1")) {
                    z = true;
                    break;
                }
                break;
            case 853783226:
                if (implMethodName.equals("lambda$wrapFeatureExtractor$1fe5d37a$1")) {
                    z = 2;
                    break;
                }
                break;
            case 1787730271:
                if (implMethodName.equals("lambda$fit$6d074a00$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/composition/BaggingModelTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/dataset/DatasetBuilder;Lorg/apache/ignite/ml/math/functions/IgniteBiFunction;Lorg/apache/ignite/ml/math/functions/IgniteBiFunction;)Lorg/apache/ignite/ml/composition/ModelOnFeaturesSubspace;")) {
                    BaggingModelTrainer baggingModelTrainer = (BaggingModelTrainer) serializedLambda.getCapturedArg(0);
                    DatasetBuilder datasetBuilder = (DatasetBuilder) serializedLambda.getCapturedArg(1);
                    IgniteBiFunction igniteBiFunction = (IgniteBiFunction) serializedLambda.getCapturedArg(2);
                    IgniteBiFunction igniteBiFunction2 = (IgniteBiFunction) serializedLambda.getCapturedArg(3);
                    return () -> {
                        return learnModel(datasetBuilder, igniteBiFunction, igniteBiFunction2);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/lang/IgniteBiPredicate") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/composition/BaggingModelTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/selection/split/mapper/SHA256UniformMapper;Ljava/lang/Object;Ljava/lang/Object;)Z")) {
                    BaggingModelTrainer baggingModelTrainer2 = (BaggingModelTrainer) serializedLambda.getCapturedArg(0);
                    SHA256UniformMapper sHA256UniformMapper = (SHA256UniformMapper) serializedLambda.getCapturedArg(1);
                    return (obj, obj2) -> {
                        return sHA256UniformMapper.map(obj, obj2) < this.samplePartSizePerMdl;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/composition/BaggingModelTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Map;Lorg/apache/ignite/ml/math/primitives/vector/Vector;)Lorg/apache/ignite/ml/math/primitives/vector/Vector;")) {
                    Map map = (Map) serializedLambda.getCapturedArg(0);
                    return vector -> {
                        double[] dArr = new double[map.size()];
                        map.forEach((num, num2) -> {
                            dArr[num.intValue()] = vector.get(num2.intValue());
                        });
                        return VectorUtils.of(dArr);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
