package org.apache.ignite.ml.composition.boosting;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.GDBTrainer;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.functions.IgniteSupplier;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.jetbrains.annotations.NotNull;

/* loaded from: input_file:org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.class */
public class GDBLearningStrategy {
    protected LearningEnvironment environment;
    protected int cntOfIterations;
    protected Loss loss;
    protected IgniteFunction<Double, Double> externalLbToInternalMapping;
    protected IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, Double>> baseMdlTrainerBuilder;
    protected double meanLbVal;
    protected long sampleSize;
    protected double[] compositionWeights;
    protected ConvergenceCheckerFactory checkConvergenceStgyFactory = new MeanAbsValueConvergenceCheckerFactory(0.001d);
    private double defaultGradStepSize;

    public <K, V> List<Model<Vector, Double>> learnModels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, Double> igniteBiFunction2) {
        return update(null, datasetBuilder, igniteBiFunction, igniteBiFunction2);
    }

    public <K, V> List<Model<Vector, Double>> update(GDBTrainer.GDBModel gDBModel, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, Double> igniteBiFunction2) {
        List<Model<Vector, Double>> initLearningState = initLearningState(gDBModel);
        ConvergenceChecker<K, V> create = this.checkConvergenceStgyFactory.create(this.sampleSize, this.externalLbToInternalMapping, this.loss, datasetBuilder, igniteBiFunction, igniteBiFunction2);
        DatasetTrainer<? extends Model<Vector, Double>, Double> datasetTrainer = this.baseMdlTrainerBuilder.get();
        for (int i = 0; i < this.cntOfIterations; i++) {
            ModelsComposition modelsComposition = new ModelsComposition(initLearningState, new WeightedPredictionsAggregator(Arrays.copyOf(this.compositionWeights, initLearningState.size()), this.meanLbVal));
            if (create.isConverged(datasetBuilder, modelsComposition)) {
                break;
            }
            IgniteBiFunction<K, V, Double> igniteBiFunction3 = (obj, obj2) -> {
                return Double.valueOf(-this.loss.gradient(this.sampleSize, this.externalLbToInternalMapping.apply(igniteBiFunction2.apply(obj, obj2)).doubleValue(), modelsComposition.apply((Vector) igniteBiFunction.apply(obj, obj2)).doubleValue()));
            };
            long currentTimeMillis = System.currentTimeMillis();
            initLearningState.add(datasetTrainer.fit(datasetBuilder, igniteBiFunction, igniteBiFunction3));
            this.environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
        }
        return initLearningState;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @NotNull
    public List<Model<Vector, Double>> initLearningState(GDBTrainer.GDBModel gDBModel) {
        ArrayList arrayList = new ArrayList();
        if (gDBModel != null) {
            arrayList.addAll(gDBModel.getModels());
            WeightedPredictionsAggregator weightedPredictionsAggregator = (WeightedPredictionsAggregator) gDBModel.getPredictionsAggregator();
            this.meanLbVal = weightedPredictionsAggregator.getBias();
            this.compositionWeights = new double[arrayList.size() + this.cntOfIterations];
            for (int i = 0; i < arrayList.size(); i++) {
                this.compositionWeights[i] = weightedPredictionsAggregator.getWeights()[i];
            }
        } else {
            this.compositionWeights = new double[this.cntOfIterations];
        }
        Arrays.fill(this.compositionWeights, arrayList.size(), this.compositionWeights.length, this.defaultGradStepSize);
        return arrayList;
    }

    public GDBLearningStrategy withEnvironment(LearningEnvironment learningEnvironment) {
        this.environment = learningEnvironment;
        return this;
    }

    public GDBLearningStrategy withCntOfIterations(int i) {
        this.cntOfIterations = i;
        return this;
    }

    public GDBLearningStrategy withLossGradient(Loss loss) {
        this.loss = loss;
        return this;
    }

    public GDBLearningStrategy withExternalLabelToInternal(IgniteFunction<Double, Double> igniteFunction) {
        this.externalLbToInternalMapping = igniteFunction;
        return this;
    }

    public GDBLearningStrategy withBaseModelTrainerBuilder(IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, Double>> igniteSupplier) {
        this.baseMdlTrainerBuilder = igniteSupplier;
        return this;
    }

    public GDBLearningStrategy withMeanLabelValue(double d) {
        this.meanLbVal = d;
        return this;
    }

    public GDBLearningStrategy withSampleSize(long j) {
        this.sampleSize = j;
        return this;
    }

    public GDBLearningStrategy withCompositionWeights(double[] dArr) {
        this.compositionWeights = dArr;
        return this;
    }

    public GDBLearningStrategy withCheckConvergenceStgyFactory(ConvergenceCheckerFactory convergenceCheckerFactory) {
        this.checkConvergenceStgyFactory = convergenceCheckerFactory;
        return this;
    }

    public GDBLearningStrategy withDefaultGradStepSize(double d) {
        this.defaultGradStepSize = d;
        return this;
    }

    public double[] getCompositionWeights() {
        return this.compositionWeights;
    }

    public double getMeanValue() {
        return this.meanLbVal;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1815837211:
                if (implMethodName.equals("lambda$update$ff348a6$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/composition/boosting/GDBLearningStrategy") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/functions/IgniteBiFunction;Lorg/apache/ignite/ml/composition/ModelsComposition;Lorg/apache/ignite/ml/math/functions/IgniteBiFunction;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Double;")) {
                    GDBLearningStrategy gDBLearningStrategy = (GDBLearningStrategy) serializedLambda.getCapturedArg(0);
                    IgniteBiFunction igniteBiFunction = (IgniteBiFunction) serializedLambda.getCapturedArg(1);
                    ModelsComposition modelsComposition = (ModelsComposition) serializedLambda.getCapturedArg(2);
                    IgniteBiFunction igniteBiFunction2 = (IgniteBiFunction) serializedLambda.getCapturedArg(3);
                    return (obj, obj2) -> {
                        return Double.valueOf(-this.loss.gradient(this.sampleSize, this.externalLbToInternalMapping.apply(igniteBiFunction.apply(obj, obj2)).doubleValue(), modelsComposition.apply((Vector) igniteBiFunction2.apply(obj, obj2)).doubleValue()));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
