package org.apache.ignite.ml.regressions.linear;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.nn.Activators;
import org.apache.ignite.ml.nn.MLPTrainer;
import org.apache.ignite.ml.nn.MultilayerPerceptron;
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.optimization.LossFunctions;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

/* loaded from: input_file:org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.class */
public class LinearRegressionSGDTrainer<P extends Serializable> implements SingleLabelDatasetTrainer<LinearRegressionModel> {
    private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
    private final int maxIterations;
    private final int batchSize;
    private final int locIterations;
    private final long seed;

    public LinearRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStrategy, int i, int i2, int i3, long j) {
        this.updatesStgy = updatesStrategy;
        this.maxIterations = i;
        this.batchSize = i2;
        this.locIterations = i3;
        this.seed = j;
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> igniteBiFunction, final IgniteBiFunction<K, V, Double> igniteBiFunction2) {
        double[] data = new MLPTrainer((IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture>) dataset -> {
            return new MLPArchitecture(((Integer) dataset.compute(simpleLabeledDatasetData -> {
                if (simpleLabeledDatasetData.getFeatures() == null) {
                    return null;
                }
                return Integer.valueOf(simpleLabeledDatasetData.getFeatures().length / simpleLabeledDatasetData.getRows());
            }, (num, num2) -> {
                return num == null ? num2 : num;
            })).intValue()).withAddedLayer(1, true, Activators.LINEAR);
        }, LossFunctions.MSE, this.updatesStgy, this.maxIterations, this.batchSize, this.locIterations, this.seed).fit((DatasetBuilder) datasetBuilder, (IgniteBiFunction) igniteBiFunction, (IgniteBiFunction) new IgniteBiFunction<K, V, double[]>() { // from class: org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer.1
            /* JADX WARN: Multi-variable type inference failed */
            @Override // java.util.function.BiFunction
            public double[] apply(K k, V v) {
                return new double[]{((Double) igniteBiFunction2.apply(k, v)).doubleValue()};
            }

            /* JADX WARN: Multi-variable type inference failed */
            @Override // java.util.function.BiFunction
            public /* bridge */ /* synthetic */ Object apply(Object obj, Object obj2) {
                return apply((AnonymousClass1<K, V>) obj, obj2);
            }
        }).parameters().getStorage().data();
        return new LinearRegressionModel(new DenseLocalOnHeapVector(Arrays.copyOf(data, data.length - 1)), data[data.length - 1]);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1582714501:
                if (implMethodName.equals("lambda$null$3143335f$1")) {
                    z = false;
                    break;
                }
                break;
            case -1280245635:
                if (implMethodName.equals("lambda$null$ed3460b8$1")) {
                    z = 2;
                    break;
                }
                break;
            case 930579235:
                if (implMethodName.equals("lambda$fit$fea8edd5$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Integer;Ljava/lang/Integer;)Ljava/lang/Integer;")) {
                    return (num, num2) -> {
                        return num == null ? num2 : num;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/dataset/Dataset;)Lorg/apache/ignite/ml/nn/architecture/MLPArchitecture;")) {
                    return dataset -> {
                        return new MLPArchitecture(((Integer) dataset.compute(simpleLabeledDatasetData -> {
                            if (simpleLabeledDatasetData.getFeatures() == null) {
                                return null;
                            }
                            return Integer.valueOf(simpleLabeledDatasetData.getFeatures().length / simpleLabeledDatasetData.getRows());
                        }, (num3, num22) -> {
                            return num3 == null ? num22 : num3;
                        })).intValue()).withAddedLayer(1, true, Activators.LINEAR);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/dataset/primitive/data/SimpleLabeledDatasetData;)Ljava/lang/Integer;")) {
                    return simpleLabeledDatasetData -> {
                        if (simpleLabeledDatasetData.getFeatures() == null) {
                            return null;
                        }
                        return Integer.valueOf(simpleLabeledDatasetData.getFeatures().length / simpleLabeledDatasetData.getRows());
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
