package org.apache.ignite.ml.tree.boosting;

import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.List;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
import org.apache.ignite.ml.composition.boosting.GDBTrainer;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.DecisionTree;
import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;

/* loaded from: input_file:org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.class */
public class GDBOnTreesLearningStrategy extends GDBLearningStrategy {
    private boolean useIdx;
    static final /* synthetic */ boolean $assertionsDisabled;

    public GDBOnTreesLearningStrategy(boolean z) {
        this.useIdx = z;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.ignite.ml.composition.boosting.GDBLearningStrategy
    public <K, V> List<IgniteModel<Vector, Double>> update(GDBTrainer.GDBModel gDBModel, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> datasetTrainer = this.baseMdlTrainerBuilder.get();
        if (!$assertionsDisabled && !(datasetTrainer instanceof DecisionTree)) {
            throw new AssertionError();
        }
        DecisionTree decisionTree = (DecisionTree) datasetTrainer;
        List<IgniteModel<Vector, Double>> initLearningState = initLearningState(gDBModel);
        ConvergenceChecker create = this.checkConvergenceStgyFactory.create(this.sampleSize, this.externalLbToInternalMapping, this.loss, datasetBuilder, preprocessor);
        try {
            Dataset<C, D> build = datasetBuilder.build(this.envBuilder, new EmptyContextBuilder(), new DecisionTreeDataBuilder(preprocessor, this.useIdx));
            Throwable th = null;
            for (int i = 0; i < this.cntOfIterations; i++) {
                try {
                    try {
                        ModelsComposition modelsComposition = new ModelsComposition(initLearningState, new WeightedPredictionsAggregator(Arrays.copyOf(this.compositionWeights, initLearningState.size()), this.meanLbVal));
                        if (create.isConverged(build, modelsComposition)) {
                            break;
                        }
                        build.compute(decisionTreeData -> {
                            if (decisionTreeData.getCopiedOriginalLabels() == null) {
                                decisionTreeData.setCopiedOriginalLabels(Arrays.copyOf(decisionTreeData.getLabels(), decisionTreeData.getLabels().length));
                            }
                            for (int i2 = 0; i2 < decisionTreeData.getLabels().length; i2++) {
                                decisionTreeData.getLabels()[i2] = -this.loss.gradient(this.sampleSize, this.externalLbToInternalMapping.apply(Double.valueOf(decisionTreeData.getCopiedOriginalLabels()[i2])).doubleValue(), modelsComposition.predict(VectorUtils.of(decisionTreeData.getFeatures()[i2])).doubleValue());
                            }
                        });
                        long currentTimeMillis = System.currentTimeMillis();
                        initLearningState.add(decisionTree.fit(build));
                        this.trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
                    } finally {
                    }
                } finally {
                }
            }
            if (build != 0) {
                if (0 != 0) {
                    try {
                        build.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    build.close();
                }
            }
            this.compositionWeights = Arrays.copyOf(this.compositionWeights, initLearningState.size());
            return initLearningState;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1272831654:
                if (implMethodName.equals("lambda$update$d0075805$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/composition/ModelsComposition;Lorg/apache/ignite/ml/tree/data/DecisionTreeData;)V")) {
                    GDBOnTreesLearningStrategy gDBOnTreesLearningStrategy = (GDBOnTreesLearningStrategy) serializedLambda.getCapturedArg(0);
                    ModelsComposition modelsComposition = (ModelsComposition) serializedLambda.getCapturedArg(1);
                    return decisionTreeData -> {
                        if (decisionTreeData.getCopiedOriginalLabels() == null) {
                            decisionTreeData.setCopiedOriginalLabels(Arrays.copyOf(decisionTreeData.getLabels(), decisionTreeData.getLabels().length));
                        }
                        for (int i2 = 0; i2 < decisionTreeData.getLabels().length; i2++) {
                            decisionTreeData.getLabels()[i2] = -this.loss.gradient(this.sampleSize, this.externalLbToInternalMapping.apply(Double.valueOf(decisionTreeData.getCopiedOriginalLabels()[i2])).doubleValue(), modelsComposition.predict(VectorUtils.of(decisionTreeData.getFeatures()[i2])).doubleValue());
                        }
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }

    static {
        $assertionsDisabled = !GDBOnTreesLearningStrategy.class.desiredAssertionStatus();
    }
}
