package org.apache.ignite.ml.nn.trainers.distributed;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.stream.Stream;
import org.apache.ignite.Ignite;
import org.apache.ignite.Ignition;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.functions.IgniteSupplier;
import org.apache.ignite.ml.math.util.MatrixUtil;
import org.apache.ignite.ml.nn.MultilayerPerceptron;
import org.apache.ignite.ml.optimization.LossFunctions;
import org.apache.ignite.ml.optimization.SmoothParametrized;
import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey;
import org.apache.ignite.ml.trainers.group.MetaoptimizerGroupTrainer;
import org.apache.ignite.ml.trainers.group.ResultAndUpdates;
import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
import org.apache.ignite.ml.trainers.group.chain.EntryAndContext;
import org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID;
import org.apache.ignite.ml.util.Utils;

/* loaded from: input_file:org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer.class */
public class MLPGroupUpdateTrainer<U extends Serializable> extends MetaoptimizerGroupTrainer<MLPGroupUpdateTrainerLocalContext, Void, MLPGroupTrainingCacheValue, U, MultilayerPerceptron, U, MultilayerPerceptron, AbstractMLPGroupUpdateTrainerInput, MLPGroupUpdateTrainingContext<U>, ArrayList<U>, MLPGroupUpdateTrainingLoopData<U>, U> {
    private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
    private final double tolerance;
    private final int maxGlobalSteps;
    private final int syncPeriod;
    private final IgniteFunction<List<U>, U> allUpdatesReducer;
    private final IgniteFunction<List<U>, U> locStepUpdatesReducer;
    private final ParameterUpdateCalculator<? super MultilayerPerceptron, U> updateCalculator;
    private static final int DEFAULT_MAX_GLOBAL_STEPS = 30;
    private static final int DEFAULT_SYNC_RATE = 5;
    private static final IgniteFunction<List<RPropParameterUpdate>, RPropParameterUpdate> DEFAULT_ALL_UPDATES_REDUCER = RPropParameterUpdate::avg;
    private static final IgniteFunction<List<RPropParameterUpdate>, RPropParameterUpdate> DEFAULT_LOCAL_STEP_UPDATES_REDUCER = RPropParameterUpdate::sumLocal;
    private static final ParameterUpdateCalculator<SmoothParametrized, RPropParameterUpdate> DEFAULT_UPDATE_CALCULATOR = new RPropUpdateCalculator();
    private static final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> DEFAULT_LOSS = LossFunctions.MSE;

    public MLPGroupUpdateTrainer(int i, int i2, IgniteFunction<List<U>, U> igniteFunction, IgniteFunction<List<U>, U> igniteFunction2, ParameterUpdateCalculator<? super MultilayerPerceptron, U> parameterUpdateCalculator, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> igniteFunction3, Ignite ignite, double d) {
        super(new MLPMetaoptimizer(igniteFunction), MLPCache.getOrCreate(ignite), ignite);
        this.maxGlobalSteps = i;
        this.syncPeriod = i2;
        this.allUpdatesReducer = igniteFunction;
        this.locStepUpdatesReducer = igniteFunction2;
        this.updateCalculator = parameterUpdateCalculator;
        this.loss = igniteFunction3;
        this.tolerance = d;
    }

    public static MLPGroupUpdateTrainer<RPropParameterUpdate> getDefault(Ignite ignite) {
        return new MLPGroupUpdateTrainer<>(DEFAULT_MAX_GLOBAL_STEPS, DEFAULT_SYNC_RATE, DEFAULT_ALL_UPDATES_REDUCER, DEFAULT_LOCAL_STEP_UPDATES_REDUCER, DEFAULT_UPDATE_CALCULATOR, DEFAULT_LOSS, ignite, 0.01d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.group.MetaoptimizerGroupTrainer, org.apache.ignite.ml.trainers.group.GroupTrainer
    public void init(AbstractMLPGroupUpdateTrainerInput abstractMLPGroupUpdateTrainerInput, UUID uuid) {
        super.init((MLPGroupUpdateTrainer<U>) abstractMLPGroupUpdateTrainerInput, uuid);
        MLPGroupUpdateTrainerDataCache.getOrCreate(this.ignite).put(uuid, new MLPGroupUpdateTrainingData(this.updateCalculator, this.syncPeriod, this.locStepUpdatesReducer, abstractMLPGroupUpdateTrainerInput.batchSupplier(), this.loss, this.tolerance));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.group.GroupTrainer
    public IgniteFunction<GroupTrainerCacheKey<Void>, ResultAndUpdates<U>> distributedInitializer(AbstractMLPGroupUpdateTrainerInput abstractMLPGroupUpdateTrainerInput) {
        MultilayerPerceptron mdl = abstractMLPGroupUpdateTrainerInput.mdl();
        return groupTrainerCacheKey -> {
            return ResultAndUpdates.of(this.updateCalculator.init(mdl, this.loss)).updateCache(MLPCache.getOrCreate(Ignition.localIgnite()), groupTrainerCacheKey, new MLPGroupTrainingCacheValue(mdl));
        };
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.group.MetaoptimizerGroupTrainer
    public IgniteFunction<EntryAndContext<Void, MLPGroupTrainingCacheValue, MLPGroupUpdateTrainingContext<U>>, MLPGroupUpdateTrainingLoopData<U>> trainingLoopStepDataExtractor() {
        return entryAndContext -> {
            MLPGroupUpdateTrainingContext mLPGroupUpdateTrainingContext = (MLPGroupUpdateTrainingContext) entryAndContext.context();
            Map.Entry entry = entryAndContext.entry();
            MLPGroupUpdateTrainingData data = mLPGroupUpdateTrainingContext.data();
            return new MLPGroupUpdateTrainingLoopData(((MLPGroupTrainingCacheValue) entry.getValue()).perceptron(), data.updateCalculator(), data.stepsCnt(), data.updateReducer(), mLPGroupUpdateTrainingContext.previousUpdate(), (GroupTrainerCacheKey) entry.getKey(), data.batchSupplier(), data.loss(), data.tolerance());
        };
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.group.MetaoptimizerGroupTrainer
    public IgniteSupplier<Stream<GroupTrainerCacheKey<Void>>> keysToProcessInTrainingLoop(MLPGroupUpdateTrainerLocalContext mLPGroupUpdateTrainerLocalContext) {
        int parallelTrainingsCnt = mLPGroupUpdateTrainerLocalContext.parallelTrainingsCnt();
        UUID trainingUUID = mLPGroupUpdateTrainerLocalContext.trainingUUID();
        return () -> {
            return MLPCache.allKeys(parallelTrainingsCnt, trainingUUID);
        };
    }

    /* renamed from: remoteContextExtractor, reason: avoid collision after fix types in other method */
    protected IgniteSupplier<MLPGroupUpdateTrainingContext<U>> remoteContextExtractor2(U u, MLPGroupUpdateTrainerLocalContext mLPGroupUpdateTrainerLocalContext) {
        UUID trainingUUID = mLPGroupUpdateTrainerLocalContext.trainingUUID();
        return () -> {
            return new MLPGroupUpdateTrainingContext((MLPGroupUpdateTrainingData) MLPGroupUpdateTrainerDataCache.getOrCreate(Ignition.localIgnite()).get(trainingUUID), u);
        };
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.group.MetaoptimizerGroupTrainer
    public IgniteFunction<MLPGroupUpdateTrainingLoopData<U>, ResultAndUpdates<U>> dataProcessor() {
        return mLPGroupUpdateTrainingLoopData -> {
            MultilayerPerceptron multilayerPerceptron = (MultilayerPerceptron) this.updateCalculator.update(mLPGroupUpdateTrainingLoopData.mlp(), mLPGroupUpdateTrainingLoopData.previousUpdate());
            MultilayerPerceptron multilayerPerceptron2 = (MultilayerPerceptron) Utils.copy(multilayerPerceptron);
            ParameterUpdateCalculator updateCalculator = mLPGroupUpdateTrainingLoopData.updateCalculator();
            IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss = mLPGroupUpdateTrainingLoopData.loss();
            updateCalculator.init(multilayerPerceptron2, loss);
            int stepsCnt = mLPGroupUpdateTrainingLoopData.stepsCnt();
            ArrayList arrayList = new ArrayList(stepsCnt);
            Serializable serializable = (Serializable) mLPGroupUpdateTrainingLoopData.previousUpdate();
            for (int i = 0; i < stepsCnt; i++) {
                IgniteBiTuple<Matrix, Matrix> igniteBiTuple = mLPGroupUpdateTrainingLoopData.batchSupplier().get();
                Matrix matrix = (Matrix) igniteBiTuple.get1();
                Matrix matrix2 = (Matrix) igniteBiTuple.get2();
                int columnSize = matrix2.columnSize();
                serializable = (Serializable) updateCalculator.calculateNewUpdate(multilayerPerceptron2, serializable, i, matrix, matrix2);
                multilayerPerceptron2 = (MultilayerPerceptron) updateCalculator.update(multilayerPerceptron2, serializable);
                arrayList.add(serializable);
                if (MatrixUtil.zipFoldByColumns(multilayerPerceptron2.apply(matrix), matrix2, (vector, vector2) -> {
                    return (Double) ((IgniteDifferentiableVectorToDoubleFunction) loss.apply(vector2)).apply(vector);
                }).sum() / columnSize < mLPGroupUpdateTrainingLoopData.tolerance()) {
                    break;
                }
            }
            return new ResultAndUpdates((Serializable) mLPGroupUpdateTrainingLoopData.getUpdateReducer().apply(arrayList)).updateCache(MLPCache.getOrCreate(Ignition.localIgnite()), mLPGroupUpdateTrainingLoopData.key(), new MLPGroupTrainingCacheValue(multilayerPerceptron));
        };
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.group.GroupTrainer
    public MLPGroupUpdateTrainerLocalContext<U> initialLocalContext(AbstractMLPGroupUpdateTrainerInput abstractMLPGroupUpdateTrainerInput, UUID uuid) {
        return new MLPGroupUpdateTrainerLocalContext<>(uuid, this.maxGlobalSteps, this.allUpdatesReducer, abstractMLPGroupUpdateTrainerInput.trainingsCount());
    }

    protected IgniteSupplier<Stream<GroupTrainerCacheKey<Void>>> finalResultKeys(U u, MLPGroupUpdateTrainerLocalContext mLPGroupUpdateTrainerLocalContext) {
        UUID trainingUUID = mLPGroupUpdateTrainerLocalContext.trainingUUID();
        int parallelTrainingsCnt = mLPGroupUpdateTrainerLocalContext.parallelTrainingsCnt();
        return () -> {
            return MLPCache.allKeys(parallelTrainingsCnt, trainingUUID);
        };
    }

    protected IgniteSupplier<MLPGroupUpdateTrainingContext<U>> extractContextForFinalResultCreation(U u, MLPGroupUpdateTrainerLocalContext mLPGroupUpdateTrainerLocalContext) {
        return () -> {
            return null;
        };
    }

    @Override // org.apache.ignite.ml.trainers.group.GroupTrainer
    protected IgniteFunction<EntryAndContext<Void, MLPGroupTrainingCacheValue, MLPGroupUpdateTrainingContext<U>>, ResultAndUpdates<MultilayerPerceptron>> finalResultsExtractor() {
        return entryAndContext -> {
            return ResultAndUpdates.of(((MLPGroupTrainingCacheValue) entryAndContext.entry().getValue()).perceptron());
        };
    }

    @Override // org.apache.ignite.ml.trainers.group.GroupTrainer
    protected IgniteFunction<List<MultilayerPerceptron>, MultilayerPerceptron> finalResultsReducer() {
        return list -> {
            return (MultilayerPerceptron) list.stream().filter((v0) -> {
                return Objects.nonNull(v0);
            }).findFirst().orElse(null);
        };
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.group.GroupTrainer
    public MultilayerPerceptron mapFinalResult(MultilayerPerceptron multilayerPerceptron, MLPGroupUpdateTrainerLocalContext mLPGroupUpdateTrainerLocalContext) {
        return multilayerPerceptron;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.group.GroupTrainer
    public void cleanup(MLPGroupUpdateTrainerLocalContext mLPGroupUpdateTrainerLocalContext) {
    }

    public MLPGroupUpdateTrainer<U> withMaxGlobalSteps(int i) {
        return new MLPGroupUpdateTrainer<>(i, this.syncPeriod, this.allUpdatesReducer, this.locStepUpdatesReducer, this.updateCalculator, this.loss, this.ignite, this.tolerance);
    }

    public MLPGroupUpdateTrainer<U> withSyncPeriod(int i) {
        return new MLPGroupUpdateTrainer<>(this.maxGlobalSteps, i, this.allUpdatesReducer, this.locStepUpdatesReducer, this.updateCalculator, this.loss, this.ignite, this.tolerance);
    }

    public MLPGroupUpdateTrainer<U> withTolerance(double d) {
        return new MLPGroupUpdateTrainer<>(this.maxGlobalSteps, this.syncPeriod, this.allUpdatesReducer, this.locStepUpdatesReducer, this.updateCalculator, this.loss, this.ignite, d);
    }

    public <U1 extends Serializable> MLPGroupUpdateTrainer<U1> withUpdateStrategy(UpdatesStrategy<? super MultilayerPerceptron, U1> updatesStrategy) {
        return new MLPGroupUpdateTrainer<>(this.maxGlobalSteps, this.syncPeriod, updatesStrategy.allUpdatesReducer(), updatesStrategy.locStepUpdatesReducer(), updatesStrategy.getUpdatesCalculator(), this.loss, this.ignite, this.tolerance);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.ignite.ml.trainers.group.MetaoptimizerGroupTrainer
    public /* bridge */ /* synthetic */ IgniteSupplier remoteContextExtractor(Serializable serializable, MLPGroupUpdateTrainerLocalContext mLPGroupUpdateTrainerLocalContext) {
        return remoteContextExtractor2((MLPGroupUpdateTrainer<U>) serializable, mLPGroupUpdateTrainerLocalContext);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.ignite.ml.trainers.group.GroupTrainer
    protected /* bridge */ /* synthetic */ IgniteSupplier finalResultKeys(Serializable serializable, HasTrainingUUID hasTrainingUUID) {
        return finalResultKeys((MLPGroupUpdateTrainer<U>) serializable, (MLPGroupUpdateTrainerLocalContext) hasTrainingUUID);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.ignite.ml.trainers.group.GroupTrainer
    protected /* bridge */ /* synthetic */ IgniteSupplier extractContextForFinalResultCreation(Serializable serializable, HasTrainingUUID hasTrainingUUID) {
        return extractContextForFinalResultCreation((MLPGroupUpdateTrainer<U>) serializable, (MLPGroupUpdateTrainerLocalContext) hasTrainingUUID);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1782356000:
                if (implMethodName.equals("sumLocal")) {
                    z = 3;
                    break;
                }
                break;
            case -1439052072:
                if (implMethodName.equals("lambda$remoteContextExtractor$d0c28925$1")) {
                    z = false;
                    break;
                }
                break;
            case -1371302900:
                if (implMethodName.equals("lambda$trainingLoopStepDataExtractor$d9f7e162$1")) {
                    z = 10;
                    break;
                }
                break;
            case -808471673:
                if (implMethodName.equals("lambda$finalResultsReducer$12c602f5$1")) {
                    z = 6;
                    break;
                }
                break;
            case -417386741:
                if (implMethodName.equals("lambda$dataProcessor$7144798a$1")) {
                    z = DEFAULT_SYNC_RATE;
                    break;
                }
                break;
            case 96978:
                if (implMethodName.equals("avg")) {
                    z = 2;
                    break;
                }
                break;
            case 338400528:
                if (implMethodName.equals("lambda$null$67eca87d$1")) {
                    z = 8;
                    break;
                }
                break;
            case 692458569:
                if (implMethodName.equals("lambda$finalResultKeys$b50b55eb$1")) {
                    z = 11;
                    break;
                }
                break;
            case 712097352:
                if (implMethodName.equals("lambda$extractContextForFinalResultCreation$9dd0a940$1")) {
                    z = 7;
                    break;
                }
                break;
            case 1619526612:
                if (implMethodName.equals("lambda$distributedInitializer$64671d7d$1")) {
                    z = 9;
                    break;
                }
                break;
            case 1842124766:
                if (implMethodName.equals("lambda$finalResultsExtractor$c6fb2ee7$1")) {
                    z = true;
                    break;
                }
                break;
            case 2089815507:
                if (implMethodName.equals("lambda$keysToProcessInTrainingLoop$d8281d61$1")) {
                    z = 4;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/UUID;Ljava/io/Serializable;)Lorg/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingContext;")) {
                    UUID uuid = (UUID) serializedLambda.getCapturedArg(0);
                    Serializable serializable = (Serializable) serializedLambda.getCapturedArg(1);
                    return () -> {
                        return new MLPGroupUpdateTrainingContext((MLPGroupUpdateTrainingData) MLPGroupUpdateTrainerDataCache.getOrCreate(Ignition.localIgnite()).get(uuid), serializable);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/trainers/group/chain/EntryAndContext;)Lorg/apache/ignite/ml/trainers/group/ResultAndUpdates;")) {
                    return entryAndContext -> {
                        return ResultAndUpdates.of(((MLPGroupTrainingCacheValue) entryAndContext.entry().getValue()).perceptron());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/optimization/updatecalculators/RPropParameterUpdate") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;)Lorg/apache/ignite/ml/optimization/updatecalculators/RPropParameterUpdate;")) {
                    return RPropParameterUpdate::avg;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/optimization/updatecalculators/RPropParameterUpdate") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;)Lorg/apache/ignite/ml/optimization/updatecalculators/RPropParameterUpdate;")) {
                    return RPropParameterUpdate::sumLocal;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer") && serializedLambda.getImplMethodSignature().equals("(ILjava/util/UUID;)Ljava/util/stream/Stream;")) {
                    int intValue = ((Integer) serializedLambda.getCapturedArg(0)).intValue();
                    UUID uuid2 = (UUID) serializedLambda.getCapturedArg(1);
                    return () -> {
                        return MLPCache.allKeys(intValue, uuid2);
                    };
                }
                break;
            case DEFAULT_SYNC_RATE /* 5 */:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData;)Lorg/apache/ignite/ml/trainers/group/ResultAndUpdates;")) {
                    MLPGroupUpdateTrainer mLPGroupUpdateTrainer = (MLPGroupUpdateTrainer) serializedLambda.getCapturedArg(0);
                    return mLPGroupUpdateTrainingLoopData -> {
                        MultilayerPerceptron multilayerPerceptron = (MultilayerPerceptron) this.updateCalculator.update(mLPGroupUpdateTrainingLoopData.mlp(), mLPGroupUpdateTrainingLoopData.previousUpdate());
                        MultilayerPerceptron multilayerPerceptron2 = (MultilayerPerceptron) Utils.copy(multilayerPerceptron);
                        ParameterUpdateCalculator updateCalculator = mLPGroupUpdateTrainingLoopData.updateCalculator();
                        IgniteFunction loss = mLPGroupUpdateTrainingLoopData.loss();
                        updateCalculator.init(multilayerPerceptron2, loss);
                        int stepsCnt = mLPGroupUpdateTrainingLoopData.stepsCnt();
                        ArrayList arrayList = new ArrayList(stepsCnt);
                        Serializable serializable2 = (Serializable) mLPGroupUpdateTrainingLoopData.previousUpdate();
                        for (int i = 0; i < stepsCnt; i++) {
                            IgniteBiTuple<Matrix, Matrix> igniteBiTuple = mLPGroupUpdateTrainingLoopData.batchSupplier().get();
                            Matrix matrix = (Matrix) igniteBiTuple.get1();
                            Matrix matrix2 = (Matrix) igniteBiTuple.get2();
                            int columnSize = matrix2.columnSize();
                            serializable2 = (Serializable) updateCalculator.calculateNewUpdate(multilayerPerceptron2, serializable2, i, matrix, matrix2);
                            multilayerPerceptron2 = (MultilayerPerceptron) updateCalculator.update(multilayerPerceptron2, serializable2);
                            arrayList.add(serializable2);
                            if (MatrixUtil.zipFoldByColumns(multilayerPerceptron2.apply(matrix), matrix2, (vector, vector2) -> {
                                return (Double) ((IgniteDifferentiableVectorToDoubleFunction) loss.apply(vector2)).apply(vector);
                            }).sum() / columnSize < mLPGroupUpdateTrainingLoopData.tolerance()) {
                                break;
                            }
                        }
                        return new ResultAndUpdates((Serializable) mLPGroupUpdateTrainingLoopData.getUpdateReducer().apply(arrayList)).updateCache(MLPCache.getOrCreate(Ignition.localIgnite()), mLPGroupUpdateTrainingLoopData.key(), new MLPGroupTrainingCacheValue(multilayerPerceptron));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;)Lorg/apache/ignite/ml/nn/MultilayerPerceptron;")) {
                    return list -> {
                        return (MultilayerPerceptron) list.stream().filter((v0) -> {
                            return Objects.nonNull(v0);
                        }).findFirst().orElse(null);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer") && serializedLambda.getImplMethodSignature().equals("()Lorg/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingContext;")) {
                    return () -> {
                        return null;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/functions/IgniteFunction;Lorg/apache/ignite/ml/math/Vector;Lorg/apache/ignite/ml/math/Vector;)Ljava/lang/Double;")) {
                    IgniteFunction igniteFunction = (IgniteFunction) serializedLambda.getCapturedArg(0);
                    return (vector, vector2) -> {
                        return (Double) ((IgniteDifferentiableVectorToDoubleFunction) igniteFunction.apply(vector2)).apply(vector);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/nn/MultilayerPerceptron;Lorg/apache/ignite/ml/trainers/group/GroupTrainerCacheKey;)Lorg/apache/ignite/ml/trainers/group/ResultAndUpdates;")) {
                    MLPGroupUpdateTrainer mLPGroupUpdateTrainer2 = (MLPGroupUpdateTrainer) serializedLambda.getCapturedArg(0);
                    MultilayerPerceptron multilayerPerceptron = (MultilayerPerceptron) serializedLambda.getCapturedArg(1);
                    return groupTrainerCacheKey -> {
                        return ResultAndUpdates.of(this.updateCalculator.init(multilayerPerceptron, this.loss)).updateCache(MLPCache.getOrCreate(Ignition.localIgnite()), groupTrainerCacheKey, new MLPGroupTrainingCacheValue(multilayerPerceptron));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/trainers/group/chain/EntryAndContext;)Lorg/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData;")) {
                    return entryAndContext2 -> {
                        MLPGroupUpdateTrainingContext mLPGroupUpdateTrainingContext = (MLPGroupUpdateTrainingContext) entryAndContext2.context();
                        Map.Entry entry = entryAndContext2.entry();
                        MLPGroupUpdateTrainingData data = mLPGroupUpdateTrainingContext.data();
                        return new MLPGroupUpdateTrainingLoopData(((MLPGroupTrainingCacheValue) entry.getValue()).perceptron(), data.updateCalculator(), data.stepsCnt(), data.updateReducer(), mLPGroupUpdateTrainingContext.previousUpdate(), (GroupTrainerCacheKey) entry.getKey(), data.batchSupplier(), data.loss(), data.tolerance());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer") && serializedLambda.getImplMethodSignature().equals("(ILjava/util/UUID;)Ljava/util/stream/Stream;")) {
                    int intValue2 = ((Integer) serializedLambda.getCapturedArg(0)).intValue();
                    UUID uuid3 = (UUID) serializedLambda.getCapturedArg(1);
                    return () -> {
                        return MLPCache.allKeys(intValue2, uuid3);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
