package org.apache.ignite.ml.composition.boosting;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.loss.LogLoss;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer;

/* loaded from: input_file:org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.class */
public abstract class GDBBinaryClassifierTrainer extends GDBTrainer {
    private double externalFirstCls;
    private double externalSecondCls;

    public GDBBinaryClassifierTrainer(double d, Integer num) {
        super(d, num, new LogLoss());
    }

    public GDBBinaryClassifierTrainer(double d, Integer num, Loss loss) {
        super(d, num, loss);
    }

    @Override // org.apache.ignite.ml.composition.boosting.GDBTrainer
    protected <V, K> boolean learnLabels(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        Set set = (Set) datasetBuilder.build(this.envBuilder, new EmptyContextBuilder(), new LabeledDatasetPartitionDataBuilderOnHeap(preprocessor)).compute(labeledVectorSet -> {
            return (Set) Arrays.stream(labeledVectorSet.labels()).boxed().collect(Collectors.toSet());
        }, (set2, set3) -> {
            if (set2 == null) {
                return set3;
            }
            if (set3 == null) {
                return set2;
            }
            set2.addAll(set3);
            return set2;
        });
        if (set == null || set.size() != 2) {
            return false;
        }
        ArrayList arrayList = new ArrayList(set);
        this.externalFirstCls = ((Double) arrayList.get(0)).doubleValue();
        this.externalSecondCls = ((Double) arrayList.get(1)).doubleValue();
        return true;
    }

    @Override // org.apache.ignite.ml.composition.boosting.GDBTrainer
    protected double externalLabelToInternal(double d) {
        return d == this.externalFirstCls ? 0.0d : 1.0d;
    }

    @Override // org.apache.ignite.ml.composition.boosting.GDBTrainer
    protected double internalLabelToExternal(double d) {
        return (((1.0d / (1.0d + Math.exp(-d))) > 0.5d ? 1 : ((1.0d / (1.0d + Math.exp(-d))) == 0.5d ? 0 : -1)) < 0 ? 0.0d : 1.0d) == 0.0d ? this.externalFirstCls : this.externalSecondCls;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.ignite.ml.composition.boosting.GDBTrainer, org.apache.ignite.ml.trainers.DatasetTrainer
    /* renamed from: withEnvironmentBuilder */
    public DatasetTrainer<ModelsComposition, Double> withEnvironmentBuilder2(LearningEnvironmentBuilder learningEnvironmentBuilder) {
        return (GDBBinaryClassifierOnTreesTrainer) super.withEnvironmentBuilder2(learningEnvironmentBuilder);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -748449694:
                if (implMethodName.equals("lambda$learnLabels$22f84ea5$1")) {
                    z = false;
                    break;
                }
                break;
            case 135571973:
                if (implMethodName.equals("lambda$learnLabels$aaf4018c$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/structures/LabeledVectorSet;)Ljava/util/Set;")) {
                    return labeledVectorSet -> {
                        return (Set) Arrays.stream(labeledVectorSet.labels()).boxed().collect(Collectors.toSet());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Set;Ljava/util/Set;)Ljava/util/Set;")) {
                    return (set2, set3) -> {
                        if (set2 == null) {
                            return set3;
                        }
                        if (set3 == null) {
                            return set2;
                        }
                        set2.addAll(set3);
                        return set2;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
