package org.apache.ignite.ml.trainers;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.bagging.BaggedTrainer;
import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.functions.IgniteSupplier;
import org.apache.ignite.ml.math.functions.IgniteTriFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
import org.apache.ignite.ml.util.Utils;

/* loaded from: input_file:org/apache/ignite/ml/trainers/TrainerTransformers.class */
public class TrainerTransformers {

    /* loaded from: input_file:org/apache/ignite/ml/trainers/TrainerTransformers$ModelWithMapping.class */
    private static class ModelWithMapping<X, Y, M extends IgniteModel<X, Y>> implements IgniteModel<X, Y> {
        private final M model;
        private IgniteFunction<X, X> mapping;

        public ModelWithMapping(M m) {
            this(m, obj -> {
                return obj;
            });
        }

        public ModelWithMapping(M m, IgniteFunction<X, X> igniteFunction) {
            this.model = m;
            this.mapping = igniteFunction;
        }

        public void setMapping(IgniteFunction<X, X> igniteFunction) {
            this.mapping = igniteFunction;
        }

        @Override // org.apache.ignite.ml.inference.Model
        public Y predict(X x) {
            return (Y) this.model.predict(this.mapping.apply(x));
        }

        public M model() {
            return this.model;
        }

        public IgniteFunction<X, X> mapping() {
            return this.mapping;
        }

        private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
            String implMethodName = serializedLambda.getImplMethodName();
            boolean z = -1;
            switch (implMethodName.hashCode()) {
                case -1798814624:
                    if (implMethodName.equals("lambda$new$55aff245$1")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/trainers/TrainerTransformers$ModelWithMapping") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;")) {
                        return obj -> {
                            return obj;
                        };
                    }
                    break;
            }
            throw new IllegalArgumentException("Invalid lambda deserialization");
        }
    }

    public static <L> BaggedTrainer<L> makeBagged(DatasetTrainer<? extends IgniteModel, L> datasetTrainer, int i, double d, PredictionsAggregator predictionsAggregator) {
        return makeBagged(datasetTrainer, i, d, -1, -1, predictionsAggregator);
    }

    public static <M extends IgniteModel<Vector, Double>, L> BaggedTrainer<L> makeBagged(DatasetTrainer<M, L> datasetTrainer, int i, double d, int i2, int i3, PredictionsAggregator predictionsAggregator) {
        return new BaggedTrainer<>(datasetTrainer, predictionsAggregator, i, d, i2, i3);
    }

    private static <K, V, M extends IgniteModel<Vector, Double>> ModelsComposition runOnEnsemble(IgniteTriFunction<DatasetBuilder<K, V>, Integer, IgniteBiFunction<K, V, Vector>, IgniteSupplier<M>> igniteTriFunction, DatasetBuilder<K, V> datasetBuilder, int i, double d, int i2, int i3, IgniteBiFunction<K, V, Vector> igniteBiFunction, PredictionsAggregator predictionsAggregator, LearningEnvironment learningEnvironment) {
        MLLogger logger = learningEnvironment.logger(datasetBuilder.getClass());
        logger.log(MLLogger.VerboseLevel.LOW, "Start learning.", new Object[0]);
        List list = null;
        if (i2 > 0 && i3 != i2) {
            list = (List) IntStream.range(0, i).mapToObj(i4 -> {
                return getMapping(i2, i3, learningEnvironment.randomNumbersGenerator().nextLong() + i4);
            }).collect(Collectors.toList());
        }
        Long valueOf = Long.valueOf(System.currentTimeMillis());
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        if (list != null) {
            Iterator it = list.iterator();
            while (it.hasNext()) {
                arrayList2.add(wrapExtractor(igniteBiFunction, (int[]) it.next()));
            }
        }
        for (int i5 = 0; i5 < i; i5++) {
            arrayList.add(igniteTriFunction.apply(datasetBuilder.withUpstreamTransformer(BaggingUpstreamTransformer.builder(d, i5)), Integer.valueOf(i5), list != null ? (IgniteBiFunction) arrayList2.get(i5) : igniteBiFunction));
        }
        List list2 = (List) learningEnvironment.parallelismStrategy().submit(arrayList).stream().map((v0) -> {
            return v0.unsafeGet();
        }).map(ModelWithMapping::new).collect(Collectors.toList());
        if (list != null) {
            for (int i6 = 0; i6 < list2.size(); i6++) {
                ((ModelWithMapping) list2.get(i6)).setMapping(VectorUtils.getProjector((int[]) list.get(i6)));
            }
        }
        logger.log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs.", Double.valueOf((System.currentTimeMillis() - valueOf.longValue()) / 1000.0d));
        logger.log(MLLogger.VerboseLevel.LOW, "Learning finished.", new Object[0]);
        return new ModelsComposition(list2, predictionsAggregator);
    }

    public static int[] getMapping(int i, int i2, long j) {
        return Utils.selectKDistinct(i, i2, new Random(j));
    }

    private static <K, V> IgniteBiFunction<K, V, Vector> wrapExtractor(IgniteBiFunction<K, V, Vector> igniteBiFunction, int[] iArr) {
        return igniteBiFunction.andThen(vector -> {
            double[] dArr = new double[iArr.length];
            for (int i = 0; i < iArr.length; i++) {
                dArr[i] = vector.get(iArr[i]);
            }
            return VectorUtils.of(dArr);
        });
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1124669261:
                if (implMethodName.equals("lambda$wrapExtractor$e041164$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/trainers/TrainerTransformers") && serializedLambda.getImplMethodSignature().equals("([ILorg/apache/ignite/ml/math/primitives/vector/Vector;)Lorg/apache/ignite/ml/math/primitives/vector/Vector;")) {
                    int[] iArr = (int[]) serializedLambda.getCapturedArg(0);
                    return vector -> {
                        double[] dArr = new double[iArr.length];
                        for (int i = 0; i < iArr.length; i++) {
                            dArr[i] = vector.get(iArr[i]);
                        }
                        return VectorUtils.of(dArr);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
