package org.apache.ignite.ml.svm;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

/* loaded from: input_file:org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.class */
public class SVMLinearMultiClassClassificationTrainer implements SingleLabelDatasetTrainer<SVMLinearMultiClassClassificationModel> {
    private int amountOfIterations = 20;
    private int amountOfLocIterations = 50;
    private double lambda = 0.2d;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> SVMLinearMultiClassClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> igniteBiFunction, IgniteBiFunction<K, V, Double> igniteBiFunction2) {
        List<Double> extractClassLabels = extractClassLabels(datasetBuilder, igniteBiFunction2);
        SVMLinearMultiClassClassificationModel sVMLinearMultiClassClassificationModel = new SVMLinearMultiClassClassificationModel();
        extractClassLabels.forEach(d -> {
            sVMLinearMultiClassClassificationModel.add(d.doubleValue(), new SVMLinearBinaryClassificationTrainer().withAmountOfIterations(amountOfIterations()).withAmountOfLocIterations(amountOfLocIterations()).withLambda(lambda()).fit(datasetBuilder, igniteBiFunction, (obj, obj2) -> {
                return ((Double) igniteBiFunction2.apply(obj, obj2)).equals(d) ? Double.valueOf(1.0d) : Double.valueOf(-1.0d);
            }));
        });
        return sVMLinearMultiClassClassificationModel;
    }

    private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> igniteBiFunction) {
        if (!$assertionsDisabled && datasetBuilder == null) {
            throw new AssertionError();
        }
        LabelPartitionDataBuilderOnHeap labelPartitionDataBuilderOnHeap = new LabelPartitionDataBuilderOnHeap(igniteBiFunction);
        ArrayList arrayList = new ArrayList();
        try {
            Dataset<C, D> build = datasetBuilder.build((it, j) -> {
                return new EmptyContext();
            }, labelPartitionDataBuilderOnHeap);
            Throwable th = null;
            try {
                try {
                    arrayList.addAll((Set) build.compute(labelPartitionDataOnHeap -> {
                        HashSet hashSet = new HashSet();
                        for (double d : labelPartitionDataOnHeap.getY()) {
                            hashSet.add(Double.valueOf(d));
                        }
                        return hashSet;
                    }, (set, set2) -> {
                        return set == null ? set2 : (Set) Stream.of((Object[]) new Set[]{set, set2}).flatMap((v0) -> {
                            return v0.stream();
                        }).collect(Collectors.toSet());
                    }));
                    if (build != 0) {
                        if (0 != 0) {
                            try {
                                build.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            build.close();
                        }
                    }
                    return arrayList;
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public SVMLinearMultiClassClassificationTrainer withLambda(double d) {
        if (!$assertionsDisabled && d <= 0.0d) {
            throw new AssertionError();
        }
        this.lambda = d;
        return this;
    }

    public double lambda() {
        return this.lambda;
    }

    public int amountOfIterations() {
        return this.amountOfIterations;
    }

    public SVMLinearMultiClassClassificationTrainer withAmountOfIterations(int i) {
        this.amountOfIterations = i;
        return this;
    }

    public int amountOfLocIterations() {
        return this.amountOfLocIterations;
    }

    public SVMLinearMultiClassClassificationTrainer withAmountOfLocIterations(int i) {
        this.amountOfLocIterations = i;
        return this;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -379739615:
                if (implMethodName.equals("lambda$extractClassLabels$5d30b6e$1")) {
                    z = 3;
                    break;
                }
                break;
            case 103243206:
                if (implMethodName.equals("lambda$extractClassLabels$14a51d75$1")) {
                    z = true;
                    break;
                }
                break;
            case 117398521:
                if (implMethodName.equals("lambda$extractClassLabels$c1969c17$1")) {
                    z = 2;
                    break;
                }
                break;
            case 145922659:
                if (implMethodName.equals("lambda$null$5ae24cd2$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/functions/IgniteBiFunction;Ljava/lang/Double;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Double;")) {
                    IgniteBiFunction igniteBiFunction = (IgniteBiFunction) serializedLambda.getCapturedArg(0);
                    Double d = (Double) serializedLambda.getCapturedArg(1);
                    return (obj, obj2) -> {
                        return ((Double) igniteBiFunction.apply(obj, obj2)).equals(d) ? Double.valueOf(1.0d) : Double.valueOf(-1.0d);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/structures/partition/LabelPartitionDataOnHeap;)Ljava/util/Set;")) {
                    return labelPartitionDataOnHeap -> {
                        HashSet hashSet = new HashSet();
                        for (double d2 : labelPartitionDataOnHeap.getY()) {
                            hashSet.add(Double.valueOf(d2));
                        }
                        return hashSet;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/dataset/PartitionContextBuilder") && serializedLambda.getFunctionalInterfaceMethodName().equals("build") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/util/Iterator;J)Ljava/io/Serializable;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Iterator;J)Lorg/apache/ignite/ml/dataset/primitive/context/EmptyContext;")) {
                    return (it, j) -> {
                        return new EmptyContext();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Set;Ljava/util/Set;)Ljava/util/Set;")) {
                    return (set, set2) -> {
                        return set == null ? set2 : (Set) Stream.of((Object[]) new Set[]{set, set2}).flatMap((v0) -> {
                            return v0.stream();
                        }).collect(Collectors.toSet());
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }

    static {
        $assertionsDisabled = !SVMLinearMultiClassClassificationTrainer.class.desiredAssertionStatus();
    }
}
