package org.apache.ignite.ml.selection.cv;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.BiFunction;
import java.util.function.DoubleConsumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.ignite.lang.IgniteBiPredicate;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.pipeline.Pipeline;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.selection.paramgrid.BruteForceStrategy;
import org.apache.ignite.ml.selection.paramgrid.EvolutionOptimizationStrategy;
import org.apache.ignite.ml.selection.paramgrid.HyperParameterTuningStrategy;
import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
import org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator;
import org.apache.ignite.ml.selection.paramgrid.RandomStrategy;
import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
import org.apache.ignite.ml.selection.scoring.metric.Metric;
import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
import org.apache.ignite.ml.selection.split.mapper.UniformMapper;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.util.genetic.Chromosome;
import org.apache.ignite.ml.util.genetic.GeneticAlgorithm;
import org.jetbrains.annotations.NotNull;

/* loaded from: input_file:org/apache/ignite/ml/selection/cv/AbstractCrossValidation.class */
public abstract class AbstractCrossValidation<M extends IgniteModel<Vector, L>, L, K, V> {
    protected DatasetTrainer<M, L> trainer;
    protected Pipeline<K, V, Integer, Double> pipeline;
    protected Metric<L> metric;
    protected Preprocessor<K, V> preprocessor;
    protected int amountOfFolds;
    protected int parts;
    protected ParamGrid paramGrid;
    protected LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder();
    protected LearningEnvironment environment = this.envBuilder.buildForTrainer();
    protected IgniteBiPredicate<K, V> filter = (obj, obj2) -> {
        return true;
    };
    protected boolean isRunningOnPipeline = true;
    protected UniformMapper<K, V> mapper = new SHA256UniformMapper();

    /* loaded from: input_file:org/apache/ignite/ml/selection/cv/AbstractCrossValidation$TaskResult.class */
    public static class TaskResult {
        private Map<String, Double> paramMap;
        private double[] locScores;

        public TaskResult(Map<String, Double> map, double[] dArr) {
            this.paramMap = map;
            this.locScores = dArr;
        }

        public void setParamMap(Map<String, Double> map) {
            this.paramMap = map;
        }

        public void setLocScores(double[] dArr) {
            this.locScores = dArr;
        }
    }

    public CrossValidationResult tuneHyperParamterers() {
        HyperParameterTuningStrategy hyperParameterTuningStrategy = this.paramGrid.getHyperParameterTuningStrategy();
        if (hyperParameterTuningStrategy instanceof BruteForceStrategy) {
            return scoreBruteForceHyperparameterOptimization();
        }
        if (hyperParameterTuningStrategy instanceof RandomStrategy) {
            return scoreRandomSearchHyperparameterOptimization();
        }
        if (hyperParameterTuningStrategy instanceof EvolutionOptimizationStrategy) {
            return scoreEvolutionAlgorithmSearchHyperparameterOptimization();
        }
        throw new UnsupportedOperationException("This strategy is not supported yet [strategy=" + this.paramGrid.getHyperParameterTuningStrategy().getName() + "]");
    }

    private CrossValidationResult scoreEvolutionAlgorithmSearchHyperparameterOptimization() {
        EvolutionOptimizationStrategy evolutionOptimizationStrategy = (EvolutionOptimizationStrategy) this.paramGrid.getHyperParameterTuningStrategy();
        ArrayList arrayList = new ArrayList(new ParameterSetGenerator(this.paramGrid.getParamValuesByParamIdx()).generate());
        Collections.shuffle(arrayList, new Random(evolutionOptimizationStrategy.getSeed()));
        List subList = arrayList.subList(0, 20);
        CrossValidationResult crossValidationResult = new CrossValidationResult();
        Function<Chromosome, Double> function = chromosome -> {
            TaskResult calculateScoresForFixedParamSet = calculateScoresForFixedParamSet(chromosome.toDoubleArray());
            crossValidationResult.addScores(calculateScoresForFixedParamSet.locScores, calculateScoresForFixedParamSet.paramMap);
            double orElse = Arrays.stream(calculateScoresForFixedParamSet.locScores).average().orElse(Double.MIN_VALUE);
            if (orElse >= crossValidationResult.getBestAvgScore()) {
                crossValidationResult.setBestScore(calculateScoresForFixedParamSet.locScores);
                crossValidationResult.setBestHyperParams(calculateScoresForFixedParamSet.paramMap);
            }
            return Double.valueOf(orElse);
        };
        Random random = new Random(evolutionOptimizationStrategy.getSeed());
        BiFunction<Integer, Double, Double> biFunction = (num, d) -> {
            Double[] dArr = this.paramGrid.getParamRawData().get(num.intValue());
            return dArr[random.nextInt(dArr.length)];
        };
        GeneticAlgorithm geneticAlgorithm = new GeneticAlgorithm(subList);
        geneticAlgorithm.withFitnessFunction(function).withMutationOperator(biFunction).withAmountOfEliteChromosomes(evolutionOptimizationStrategy.getAmountOfEliteChromosomes()).withCrossingoverProbability(evolutionOptimizationStrategy.getCrossingoverProbability()).withCrossoverStgy(evolutionOptimizationStrategy.getCrossoverStgy()).withAmountOfGenerations(evolutionOptimizationStrategy.getAmountOfGenerations()).withSelectionStgy(evolutionOptimizationStrategy.getSelectionStgy()).withMutationProbability(evolutionOptimizationStrategy.getMutationProbability());
        if (this.environment.parallelismStrategy().getParallelism() > 1) {
            geneticAlgorithm.runParallel(this.environment);
        } else {
            geneticAlgorithm.run();
        }
        return crossValidationResult;
    }

    private CrossValidationResult scoreRandomSearchHyperparameterOptimization() {
        RandomStrategy randomStrategy = (RandomStrategy) this.paramGrid.getHyperParameterTuningStrategy();
        ArrayList arrayList = new ArrayList(new ParameterSetGenerator(this.paramGrid.getParamValuesByParamIdx()).generate());
        Collections.shuffle(arrayList, new Random(randomStrategy.getSeed()));
        CrossValidationResult crossValidationResult = new CrossValidationResult();
        ((List) this.environment.parallelismStrategy().submit((List) arrayList.subList(0, randomStrategy.getMaxTries()).stream().map(dArr -> {
            return () -> {
                return calculateScoresForFixedParamSet(dArr);
            };
        }).collect(Collectors.toList())).stream().map((v0) -> {
            return v0.unsafeGet();
        }).collect(Collectors.toList())).forEach(taskResult -> {
            crossValidationResult.addScores(taskResult.locScores, taskResult.paramMap);
            if (Arrays.stream(taskResult.locScores).average().orElse(Double.MIN_VALUE) >= crossValidationResult.getBestAvgScore()) {
                crossValidationResult.setBestScore(taskResult.locScores);
                crossValidationResult.setBestHyperParams(taskResult.paramMap);
            }
        });
        return crossValidationResult;
    }

    private CrossValidationResult scoreBruteForceHyperparameterOptimization() {
        List<Double[]> generate = new ParameterSetGenerator(this.paramGrid.getParamValuesByParamIdx()).generate();
        CrossValidationResult crossValidationResult = new CrossValidationResult();
        ((List) this.environment.parallelismStrategy().submit((List) generate.stream().map(dArr -> {
            return () -> {
                return calculateScoresForFixedParamSet(dArr);
            };
        }).collect(Collectors.toList())).stream().map((v0) -> {
            return v0.unsafeGet();
        }).collect(Collectors.toList())).forEach(taskResult -> {
            crossValidationResult.addScores(taskResult.locScores, taskResult.paramMap);
            if (Arrays.stream(taskResult.locScores).average().orElse(Double.MIN_VALUE) > crossValidationResult.getBestAvgScore()) {
                crossValidationResult.setBestScore(taskResult.locScores);
                crossValidationResult.setBestHyperParams(taskResult.paramMap);
            }
        });
        return crossValidationResult;
    }

    private TaskResult calculateScoresForFixedParamSet(Double[] dArr) {
        return new TaskResult(injectAndGetParametersFromPipeline(this.paramGrid, dArr), scoreByFolds());
    }

    public abstract double[] scoreByFolds();

    @NotNull
    private Map<String, Double> injectAndGetParametersFromPipeline(ParamGrid paramGrid, Double[] dArr) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < dArr.length; i++) {
            DoubleConsumer setterByIndex = paramGrid.getSetterByIndex(i);
            Double d = dArr[i];
            setterByIndex.accept(d.doubleValue());
            hashMap.put(paramGrid.getParamNameByIndex(i), d);
        }
        return hashMap;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] score(Function<IgniteBiPredicate<K, V>, DatasetBuilder<K, V>> function, BiFunction<IgniteBiPredicate<K, V>, M, LabelPairCursor<L>> biFunction) {
        double[] dArr = new double[this.amountOfFolds];
        double d = 1.0d / this.amountOfFolds;
        for (int i = 0; i < this.amountOfFolds; i++) {
            double d2 = d * i;
            double d3 = d * (i + 1);
            IgniteBiPredicate<K, V> igniteBiPredicate = (obj, obj2) -> {
                double map = this.mapper.map(obj, obj2);
                return map < d2 || map > d3;
            };
            try {
                LabelPairCursor<L> apply = biFunction.apply(igniteBiPredicate, this.trainer.fit(function.apply(igniteBiPredicate), this.preprocessor));
                Throwable th = null;
                try {
                    try {
                        dArr[i] = this.metric.score(apply.iterator());
                        if (apply != null) {
                            if (0 != 0) {
                                try {
                                    apply.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                apply.close();
                            }
                        }
                    } finally {
                    }
                } finally {
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] scorePipeline(Function<IgniteBiPredicate<K, V>, DatasetBuilder<K, V>> function, BiFunction<IgniteBiPredicate<K, V>, M, LabelPairCursor<L>> biFunction) {
        double[] dArr = new double[this.amountOfFolds];
        double d = 1.0d / this.amountOfFolds;
        for (int i = 0; i < this.amountOfFolds; i++) {
            double d2 = d * i;
            double d3 = d * (i + 1);
            IgniteBiPredicate<K, V> igniteBiPredicate = (obj, obj2) -> {
                double map = this.mapper.map(obj, obj2);
                return map < d2 || map > d3;
            };
            try {
                LabelPairCursor<L> apply = biFunction.apply(igniteBiPredicate, this.pipeline.fit(function.apply(igniteBiPredicate)));
                Throwable th = null;
                try {
                    try {
                        dArr[i] = this.metric.score(apply.iterator());
                        if (apply != null) {
                            if (0 != 0) {
                                try {
                                    apply.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                apply.close();
                            }
                        }
                    } finally {
                    }
                } finally {
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        return dArr;
    }

    public AbstractCrossValidation<M, L, K, V> withTrainer(DatasetTrainer<M, L> datasetTrainer) {
        this.trainer = datasetTrainer;
        return this;
    }

    public AbstractCrossValidation<M, L, K, V> withMetric(Metric<L> metric) {
        this.metric = metric;
        return this;
    }

    public AbstractCrossValidation<M, L, K, V> withPreprocessor(Preprocessor<K, V> preprocessor) {
        this.preprocessor = preprocessor;
        return this;
    }

    public AbstractCrossValidation<M, L, K, V> withFilter(IgniteBiPredicate<K, V> igniteBiPredicate) {
        this.filter = igniteBiPredicate;
        return this;
    }

    public AbstractCrossValidation<M, L, K, V> withAmountOfFolds(int i) {
        this.amountOfFolds = i;
        return this;
    }

    public AbstractCrossValidation<M, L, K, V> withParamGrid(ParamGrid paramGrid) {
        this.paramGrid = paramGrid;
        return this;
    }

    public AbstractCrossValidation<M, L, K, V> isRunningOnPipeline(boolean z) {
        this.isRunningOnPipeline = z;
        return this;
    }

    public AbstractCrossValidation<M, L, K, V> withEnvironmentBuilder(LearningEnvironmentBuilder learningEnvironmentBuilder) {
        this.envBuilder = learningEnvironmentBuilder;
        this.environment = learningEnvironmentBuilder.buildForTrainer();
        return this;
    }

    public AbstractCrossValidation<M, L, K, V> withPipeline(Pipeline<K, V, Integer, Double> pipeline) {
        this.pipeline = pipeline;
        return this;
    }

    public AbstractCrossValidation<M, L, K, V> withMapper(UniformMapper<K, V> uniformMapper) {
        this.mapper = uniformMapper;
        return this;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1716523561:
                if (implMethodName.equals("lambda$null$3a39e8be$1")) {
                    z = true;
                    break;
                }
                break;
            case -1716523560:
                if (implMethodName.equals("lambda$null$3a39e8be$2")) {
                    z = false;
                    break;
                }
                break;
            case -616061458:
                if (implMethodName.equals("lambda$score$c79adcb9$1")) {
                    z = 2;
                    break;
                }
                break;
            case 554864364:
                if (implMethodName.equals("lambda$scorePipeline$c79adcb9$1")) {
                    z = 4;
                    break;
                }
                break;
            case 777597943:
                if (implMethodName.equals("lambda$new$3060ff3e$1")) {
                    z = 3;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/selection/cv/AbstractCrossValidation") && serializedLambda.getImplMethodSignature().equals("([Ljava/lang/Double;)Lorg/apache/ignite/ml/selection/cv/AbstractCrossValidation$TaskResult;")) {
                    AbstractCrossValidation abstractCrossValidation = (AbstractCrossValidation) serializedLambda.getCapturedArg(0);
                    Double[] dArr = (Double[]) serializedLambda.getCapturedArg(1);
                    return () -> {
                        return calculateScoresForFixedParamSet(dArr);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/selection/cv/AbstractCrossValidation") && serializedLambda.getImplMethodSignature().equals("([Ljava/lang/Double;)Lorg/apache/ignite/ml/selection/cv/AbstractCrossValidation$TaskResult;")) {
                    AbstractCrossValidation abstractCrossValidation2 = (AbstractCrossValidation) serializedLambda.getCapturedArg(0);
                    Double[] dArr2 = (Double[]) serializedLambda.getCapturedArg(1);
                    return () -> {
                        return calculateScoresForFixedParamSet(dArr2);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/lang/IgniteBiPredicate") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/selection/cv/AbstractCrossValidation") && serializedLambda.getImplMethodSignature().equals("(DDLjava/lang/Object;Ljava/lang/Object;)Z")) {
                    AbstractCrossValidation abstractCrossValidation3 = (AbstractCrossValidation) serializedLambda.getCapturedArg(0);
                    double doubleValue = ((Double) serializedLambda.getCapturedArg(1)).doubleValue();
                    double doubleValue2 = ((Double) serializedLambda.getCapturedArg(2)).doubleValue();
                    return (obj, obj2) -> {
                        double map = this.mapper.map(obj, obj2);
                        return map < doubleValue || map > doubleValue2;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/lang/IgniteBiPredicate") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/selection/cv/AbstractCrossValidation") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Z")) {
                    return (obj3, obj22) -> {
                        return true;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/lang/IgniteBiPredicate") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/selection/cv/AbstractCrossValidation") && serializedLambda.getImplMethodSignature().equals("(DDLjava/lang/Object;Ljava/lang/Object;)Z")) {
                    AbstractCrossValidation abstractCrossValidation4 = (AbstractCrossValidation) serializedLambda.getCapturedArg(0);
                    double doubleValue3 = ((Double) serializedLambda.getCapturedArg(1)).doubleValue();
                    double doubleValue4 = ((Double) serializedLambda.getCapturedArg(2)).doubleValue();
                    return (obj4, obj23) -> {
                        double map = this.mapper.map(obj4, obj23);
                        return map < doubleValue3 || map > doubleValue4;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
