package org.apache.ignite.ml.composition.boosting;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.List;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
import org.jetbrains.annotations.NotNull;

/* loaded from: input_file:org/apache/ignite/ml/composition/boosting/GDBTrainer.class */
public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Double> {
    private final double gradientStep;
    private final int cntOfIterations;
    protected final Loss loss;
    protected ConvergenceCheckerFactory checkConvergenceStgyFactory = new MeanAbsValueConvergenceCheckerFactory(0.001d);

    /* loaded from: input_file:org/apache/ignite/ml/composition/boosting/GDBTrainer$GDBModel.class */
    public static class GDBModel extends ModelsComposition {
        private static final long serialVersionUID = 3476661240155508004L;
        private final IgniteFunction<Double, Double> internalToExternalLblMapping;

        public GDBModel(List<? extends IgniteModel<Vector, Double>> list, WeightedPredictionsAggregator weightedPredictionsAggregator, IgniteFunction<Double, Double> igniteFunction) {
            super(list, weightedPredictionsAggregator);
            this.internalToExternalLblMapping = igniteFunction;
        }

        @Override // org.apache.ignite.ml.composition.ModelsComposition, org.apache.ignite.ml.inference.Model
        public Double predict(Vector vector) {
            return this.internalToExternalLblMapping.apply(super.predict(vector));
        }
    }

    public GDBTrainer(double d, Integer num, Loss loss) {
        this.gradientStep = d;
        this.cntOfIterations = num.intValue();
        this.loss = loss;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> ModelsComposition fit(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        return updateModel((ModelsComposition) null, (DatasetBuilder) datasetBuilder, (Preprocessor) preprocessor);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> ModelsComposition updateModel(ModelsComposition modelsComposition, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        IgniteBiTuple<Double, Long> computeInitialValue;
        if (learnLabels(datasetBuilder, preprocessor) && (computeInitialValue = computeInitialValue(this.envBuilder, datasetBuilder, preprocessor)) != null) {
            Double d = (Double) computeInitialValue.get1();
            Long l = (Long) computeInitialValue.get2();
            long currentTimeMillis = System.currentTimeMillis();
            GDBLearningStrategy withCheckConvergenceStgyFactory = getLearningStrategy().withBaseModelTrainerBuilder(this::buildBaseModelTrainer).withExternalLabelToInternal((v1) -> {
                return externalLabelToInternal(v1);
            }).withCntOfIterations(this.cntOfIterations).withEnvironmentBuilder(this.envBuilder).withLossGradient(this.loss).withSampleSize(l.longValue()).withMeanLabelValue(d.doubleValue()).withDefaultGradStepSize(this.gradientStep).withCheckConvergenceStgyFactory(this.checkConvergenceStgyFactory);
            List<IgniteModel<Vector, Double>> update = modelsComposition != null ? withCheckConvergenceStgyFactory.update((GDBModel) modelsComposition, datasetBuilder, preprocessor) : withCheckConvergenceStgyFactory.learnModels(datasetBuilder, preprocessor);
            this.environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
            return new GDBModel(update, new WeightedPredictionsAggregator(withCheckConvergenceStgyFactory.getCompositionWeights(), withCheckConvergenceStgyFactory.getMeanValue()), (v1) -> {
                return internalLabelToExternal(v1);
            });
        }
        return getLastTrainedModelOrThrowEmptyDatasetException(modelsComposition);
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public boolean isUpdateable(ModelsComposition modelsComposition) {
        return modelsComposition instanceof GDBModel;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public DatasetTrainer<ModelsComposition, Double> withEnvironmentBuilder(LearningEnvironmentBuilder learningEnvironmentBuilder) {
        return (GDBTrainer) super.withEnvironmentBuilder(learningEnvironmentBuilder);
    }

    protected abstract <V, K> boolean learnLabels(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor);

    @NotNull
    protected abstract DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> buildBaseModelTrainer();

    protected abstract double externalLabelToInternal(double d);

    protected abstract double internalLabelToExternal(double d);

    protected <V, K, C extends Serializable> IgniteBiTuple<Double, Long> computeInitialValue(LearningEnvironmentBuilder learningEnvironmentBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        try {
            Dataset<C, D> build = datasetBuilder.build(learningEnvironmentBuilder, new EmptyContextBuilder(), new DecisionTreeDataBuilder(preprocessor, false));
            Throwable th = null;
            try {
                try {
                    IgniteBiTuple<Double, Long> igniteBiTuple = (IgniteBiTuple) build.compute(decisionTreeData -> {
                        return new IgniteBiTuple(Double.valueOf(Arrays.stream(decisionTreeData.getLabels()).map(this::externalLabelToInternal).sum()), Long.valueOf(decisionTreeData.getLabels().length));
                    }, (igniteBiTuple2, igniteBiTuple3) -> {
                        if (igniteBiTuple2 == null) {
                            return igniteBiTuple3;
                        }
                        if (igniteBiTuple3 == null) {
                            return igniteBiTuple2;
                        }
                        igniteBiTuple2.set1(Double.valueOf(((Double) igniteBiTuple2.get1()).doubleValue() + ((Double) igniteBiTuple3.get1()).doubleValue()));
                        igniteBiTuple2.set2(Long.valueOf(((Long) igniteBiTuple2.get2()).longValue() + ((Long) igniteBiTuple3.get2()).longValue()));
                        return igniteBiTuple2;
                    });
                    if (igniteBiTuple != null) {
                        igniteBiTuple.set1(Double.valueOf(((Double) igniteBiTuple.get1()).doubleValue() / ((Long) igniteBiTuple.get2()).longValue()));
                    }
                    if (build != 0) {
                        if (0 != 0) {
                            try {
                                build.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            build.close();
                        }
                    }
                    return igniteBiTuple;
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public GDBTrainer withCheckConvergenceStgyFactory(ConvergenceCheckerFactory convergenceCheckerFactory) {
        this.checkConvergenceStgyFactory = convergenceCheckerFactory;
        return this;
    }

    protected GDBLearningStrategy getLearningStrategy() {
        return new GDBLearningStrategy();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1922276099:
                if (implMethodName.equals("internalLabelToExternal")) {
                    z = false;
                    break;
                }
                break;
            case 4333899:
                if (implMethodName.equals("buildBaseModelTrainer")) {
                    z = 4;
                    break;
                }
                break;
            case 879078689:
                if (implMethodName.equals("externalLabelToInternal")) {
                    z = true;
                    break;
                }
                break;
            case 1173076143:
                if (implMethodName.equals("lambda$computeInitialValue$8cb16a88$1")) {
                    z = 2;
                    break;
                }
                break;
            case 1642143335:
                if (implMethodName.equals("lambda$computeInitialValue$549dc941$1")) {
                    z = 3;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/composition/boosting/GDBTrainer") && serializedLambda.getImplMethodSignature().equals("(D)D")) {
                    GDBTrainer gDBTrainer = (GDBTrainer) serializedLambda.getCapturedArg(0);
                    return (v1) -> {
                        return r0.internalLabelToExternal(v1);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/composition/boosting/GDBTrainer") && serializedLambda.getImplMethodSignature().equals("(D)D")) {
                    GDBTrainer gDBTrainer2 = (GDBTrainer) serializedLambda.getCapturedArg(0);
                    return (v1) -> {
                        return r0.externalLabelToInternal(v1);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/composition/boosting/GDBTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/tree/data/DecisionTreeData;)Lorg/apache/ignite/lang/IgniteBiTuple;")) {
                    GDBTrainer gDBTrainer3 = (GDBTrainer) serializedLambda.getCapturedArg(0);
                    return decisionTreeData -> {
                        return new IgniteBiTuple(Double.valueOf(Arrays.stream(decisionTreeData.getLabels()).map(this::externalLabelToInternal).sum()), Long.valueOf(decisionTreeData.getLabels().length));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/composition/boosting/GDBTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/lang/IgniteBiTuple;Lorg/apache/ignite/lang/IgniteBiTuple;)Lorg/apache/ignite/lang/IgniteBiTuple;")) {
                    return (igniteBiTuple2, igniteBiTuple3) -> {
                        if (igniteBiTuple2 == null) {
                            return igniteBiTuple3;
                        }
                        if (igniteBiTuple3 == null) {
                            return igniteBiTuple2;
                        }
                        igniteBiTuple2.set1(Double.valueOf(((Double) igniteBiTuple2.get1()).doubleValue() + ((Double) igniteBiTuple3.get1()).doubleValue()));
                        igniteBiTuple2.set2(Long.valueOf(((Long) igniteBiTuple2.get2()).longValue() + ((Long) igniteBiTuple3.get2()).longValue()));
                        return igniteBiTuple2;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/composition/boosting/GDBTrainer") && serializedLambda.getImplMethodSignature().equals("()Lorg/apache/ignite/ml/trainers/DatasetTrainer;")) {
                    GDBTrainer gDBTrainer4 = (GDBTrainer) serializedLambda.getCapturedArg(0);
                    return gDBTrainer4::buildBaseModelTrainer;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
