package org.apache.ignite.ml.nn;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Random;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
import org.apache.ignite.ml.trainers.MultiLabelDatasetTrainer;
import org.apache.ignite.ml.util.Utils;

/* loaded from: input_file:org/apache/ignite/ml/nn/MLPTrainer.class */
public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer<MultilayerPerceptron> {
    private IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier;
    private IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
    private UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
    private int maxIterations;
    private int batchSize;
    private int locIterations;
    private long seed;
    static final /* synthetic */ boolean $assertionsDisabled;

    public MLPTrainer(MLPArchitecture mLPArchitecture, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> igniteFunction, UpdatesStrategy<? super MultilayerPerceptron, P> updatesStrategy, int i, int i2, int i3, long j) {
        this((IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture>) dataset -> {
            return mLPArchitecture;
        }, igniteFunction, updatesStrategy, i, i2, i3, j);
    }

    public MLPTrainer(IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> igniteFunction, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> igniteFunction2, UpdatesStrategy<? super MultilayerPerceptron, P> updatesStrategy, int i, int i2, int i3, long j) {
        this.maxIterations = 100;
        this.batchSize = 100;
        this.locIterations = 100;
        this.seed = 1234L;
        this.archSupplier = igniteFunction;
        this.loss = igniteFunction2;
        this.updatesStgy = updatesStrategy;
        this.maxIterations = i;
        this.batchSize = i2;
        this.locIterations = i3;
        this.seed = j;
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> MultilayerPerceptron fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, double[]> igniteBiFunction2) {
        return updateModel((MultilayerPerceptron) null, (DatasetBuilder) datasetBuilder, (IgniteBiFunction) igniteBiFunction, (IgniteBiFunction) igniteBiFunction2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Code restructure failed: missing block: B:32:0x00ca, code lost:
    
        r0 = (org.apache.ignite.ml.nn.MultilayerPerceptron) getLastTrainedModelOrThrowEmptyDatasetException(r9);
     */
    /* JADX WARN: Code restructure failed: missing block: B:33:0x00d6, code lost:
    
        if (r0 == 0) goto L38;
     */
    /* JADX WARN: Code restructure failed: missing block: B:35:0x00db, code lost:
    
        if (0 == 0) goto L37;
     */
    /* JADX WARN: Code restructure failed: missing block: B:36:0x00f4, code lost:
    
        r0.close();
     */
    /* JADX WARN: Code restructure failed: missing block: B:38:0x00de, code lost:
    
        r0.close();
     */
    /* JADX WARN: Code restructure failed: missing block: B:40:0x00e8, code lost:
    
        r22 = move-exception;
     */
    /* JADX WARN: Code restructure failed: missing block: B:41:0x00ea, code lost:
    
        r14.addSuppressed(r22);
     */
    /* JADX WARN: Code restructure failed: missing block: B:63:0x012d, code lost:
    
        r0 = r15;
     */
    /* JADX WARN: Code restructure failed: missing block: B:64:0x0133, code lost:
    
        if (r0 == 0) goto L49;
     */
    /* JADX WARN: Code restructure failed: missing block: B:66:0x0138, code lost:
    
        if (0 == 0) goto L48;
     */
    /* JADX WARN: Code restructure failed: missing block: B:67:0x0151, code lost:
    
        r0.close();
     */
    /* JADX WARN: Code restructure failed: missing block: B:69:0x013b, code lost:
    
        r0.close();
     */
    /* JADX WARN: Code restructure failed: missing block: B:71:0x0145, code lost:
    
        r18 = move-exception;
     */
    /* JADX WARN: Code restructure failed: missing block: B:72:0x0147, code lost:
    
        r14.addSuppressed(r18);
     */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public <K, V> org.apache.ignite.ml.nn.MultilayerPerceptron updateModel(org.apache.ignite.ml.nn.MultilayerPerceptron r9, org.apache.ignite.ml.dataset.DatasetBuilder<K, V> r10, org.apache.ignite.ml.math.functions.IgniteBiFunction<K, V, org.apache.ignite.ml.math.primitives.vector.Vector> r11, org.apache.ignite.ml.math.functions.IgniteBiFunction<K, V, double[]> r12) {
        /*
            Method dump skipped, instructions count: 412
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.apache.ignite.ml.nn.MLPTrainer.updateModel(org.apache.ignite.ml.nn.MultilayerPerceptron, org.apache.ignite.ml.dataset.DatasetBuilder, org.apache.ignite.ml.math.functions.IgniteBiFunction, org.apache.ignite.ml.math.functions.IgniteBiFunction):org.apache.ignite.ml.nn.MultilayerPerceptron");
    }

    public IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> getArchSupplier() {
        return this.archSupplier;
    }

    public MLPTrainer<P> withArchSupplier(IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> igniteFunction) {
        this.archSupplier = igniteFunction;
        return this;
    }

    public IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> getLoss() {
        return this.loss;
    }

    public MLPTrainer<P> withLoss(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> igniteFunction) {
        this.loss = igniteFunction;
        return this;
    }

    public UpdatesStrategy<? super MultilayerPerceptron, P> getUpdatesStgy() {
        return this.updatesStgy;
    }

    public MLPTrainer<P> withUpdatesStgy(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStrategy) {
        this.updatesStgy = updatesStrategy;
        return this;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public MLPTrainer<P> withMaxIterations(int i) {
        this.maxIterations = i;
        return this;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public MLPTrainer<P> withBatchSize(int i) {
        this.batchSize = i;
        return this;
    }

    public int getLocIterations() {
        return this.locIterations;
    }

    public MLPTrainer<P> withLocIterations(int i) {
        this.locIterations = i;
        return this;
    }

    public long getSeed() {
        return this.seed;
    }

    public MLPTrainer<P> withSeed(long j) {
        this.seed = j;
        return this;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public boolean checkState(MultilayerPerceptron multilayerPerceptron) {
        return true;
    }

    static double[] batch(double[] dArr, int[] iArr, int i) {
        int length = dArr.length / i;
        double[] dArr2 = new double[length * iArr.length];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            for (int i3 = 0; i3 < length; i3++) {
                dArr2[(i3 * iArr.length) + i2] = dArr[(i3 * i) + iArr[i2]];
            }
        }
        return dArr2;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -953369490:
                if (implMethodName.equals("lambda$updateModel$754eb3ff$1")) {
                    z = false;
                    break;
                }
                break;
            case 654111696:
                if (implMethodName.equals("lambda$updateModel$94e705a1$1")) {
                    z = 2;
                    break;
                }
                break;
            case 1822821200:
                if (implMethodName.equals("lambda$new$d28cd7c9$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/MLPTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/optimization/updatecalculators/ParameterUpdateCalculator;Lorg/apache/ignite/ml/nn/MultilayerPerceptron;ILorg/apache/ignite/ml/dataset/primitive/data/SimpleLabeledDatasetData;)Ljava/util/List;")) {
                    MLPTrainer mLPTrainer = (MLPTrainer) serializedLambda.getCapturedArg(0);
                    ParameterUpdateCalculator parameterUpdateCalculator = (ParameterUpdateCalculator) serializedLambda.getCapturedArg(1);
                    MultilayerPerceptron multilayerPerceptron = (MultilayerPerceptron) serializedLambda.getCapturedArg(2);
                    int intValue = ((Integer) serializedLambda.getCapturedArg(3)).intValue();
                    return simpleLabeledDatasetData -> {
                        P init = parameterUpdateCalculator.init(multilayerPerceptron, this.loss);
                        MultilayerPerceptron multilayerPerceptron2 = (MultilayerPerceptron) Utils.copy(multilayerPerceptron);
                        if (simpleLabeledDatasetData.getFeatures() == null) {
                            return null;
                        }
                        ArrayList arrayList = new ArrayList();
                        for (int i = 0; i < this.locIterations; i++) {
                            int[] selectKDistinct = Utils.selectKDistinct(simpleLabeledDatasetData.getRows(), Math.min(this.batchSize, simpleLabeledDatasetData.getRows()), new Random(this.seed ^ (intValue * i)));
                            init = parameterUpdateCalculator.calculateNewUpdate(multilayerPerceptron2, init, i, new DenseMatrix(batch(simpleLabeledDatasetData.getFeatures(), selectKDistinct, simpleLabeledDatasetData.getRows()), selectKDistinct.length, 0).transpose(), new DenseMatrix(batch(simpleLabeledDatasetData.getLabels(), selectKDistinct, simpleLabeledDatasetData.getRows()), selectKDistinct.length, 0).transpose());
                            multilayerPerceptron2 = (MultilayerPerceptron) parameterUpdateCalculator.update(multilayerPerceptron2, init);
                            arrayList.add(init);
                        }
                        ArrayList arrayList2 = new ArrayList();
                        arrayList2.add(this.updatesStgy.locStepUpdatesReducer().apply(arrayList));
                        return arrayList2;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/MLPTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/nn/architecture/MLPArchitecture;Lorg/apache/ignite/ml/dataset/Dataset;)Lorg/apache/ignite/ml/nn/architecture/MLPArchitecture;")) {
                    MLPArchitecture mLPArchitecture = (MLPArchitecture) serializedLambda.getCapturedArg(0);
                    return dataset -> {
                        return mLPArchitecture;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/MLPTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Ljava/util/List;)Ljava/util/List;")) {
                    return (list, list2) -> {
                        if (list == null) {
                            return list2;
                        }
                        if (list2 == null) {
                            return list;
                        }
                        list.addAll(list2);
                        return list;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }

    static {
        $assertionsDisabled = !MLPTrainer.class.desiredAssertionStatus();
    }
}
