package org.apache.ignite.ml.nn.trainers.distributed;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.trainers.group.Metaoptimizer;

/* loaded from: input_file:org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.class */
public class MLPMetaoptimizer<P> implements Metaoptimizer<MLPGroupUpdateTrainerLocalContext, MLPGroupUpdateTrainingLoopData<P>, P, P, P, ArrayList<P>> {
    private final IgniteFunction<List<P>, P> allUpdatesReducer;

    public MLPMetaoptimizer(IgniteFunction<List<P>, P> igniteFunction) {
        this.allUpdatesReducer = igniteFunction;
    }

    @Override // org.apache.ignite.ml.trainers.group.Metaoptimizer
    public IgniteFunction<List<P>, P> initialReducer() {
        return this.allUpdatesReducer;
    }

    /* renamed from: locallyProcessInitData, reason: avoid collision after fix types in other method */
    public P locallyProcessInitData2(P p, MLPGroupUpdateTrainerLocalContext mLPGroupUpdateTrainerLocalContext) {
        return p;
    }

    @Override // org.apache.ignite.ml.trainers.group.Metaoptimizer
    public IgniteFunction<P, ArrayList<P>> distributedPostprocessor() {
        return obj -> {
            ArrayList arrayList = new ArrayList();
            arrayList.add(obj);
            return arrayList;
        };
    }

    @Override // org.apache.ignite.ml.trainers.group.Metaoptimizer
    public IgniteFunction<List<ArrayList<P>>, ArrayList<P>> postProcessReducer() {
        return list -> {
            return new ArrayList((Collection) list.stream().flatMap((v0) -> {
                return v0.stream();
            }).collect(Collectors.toList()));
        };
    }

    @Override // org.apache.ignite.ml.trainers.group.Metaoptimizer
    public P localProcessor(ArrayList<P> arrayList, MLPGroupUpdateTrainerLocalContext mLPGroupUpdateTrainerLocalContext) {
        mLPGroupUpdateTrainerLocalContext.incrementCurrentStep();
        return (P) this.allUpdatesReducer.apply(arrayList.stream().filter(Objects::nonNull).collect(Collectors.toList()));
    }

    /* renamed from: shouldContinue, reason: avoid collision after fix types in other method */
    public boolean shouldContinue2(P p, MLPGroupUpdateTrainerLocalContext mLPGroupUpdateTrainerLocalContext) {
        return p != null && mLPGroupUpdateTrainerLocalContext.currentStep() < mLPGroupUpdateTrainerLocalContext.globalStepsMaxCount();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.ignite.ml.trainers.group.Metaoptimizer
    public /* bridge */ /* synthetic */ boolean shouldContinue(Object obj, MLPGroupUpdateTrainerLocalContext mLPGroupUpdateTrainerLocalContext) {
        return shouldContinue2((MLPMetaoptimizer<P>) obj, mLPGroupUpdateTrainerLocalContext);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.ignite.ml.trainers.group.Metaoptimizer
    public /* bridge */ /* synthetic */ Object locallyProcessInitData(Object obj, MLPGroupUpdateTrainerLocalContext mLPGroupUpdateTrainerLocalContext) {
        return locallyProcessInitData2((MLPMetaoptimizer<P>) obj, mLPGroupUpdateTrainerLocalContext);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -812450675:
                if (implMethodName.equals("lambda$postProcessReducer$781b0cd5$1")) {
                    z = true;
                    break;
                }
                break;
            case 781414142:
                if (implMethodName.equals("lambda$distributedPostprocessor$7b45015a$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Ljava/util/ArrayList;")) {
                    return obj -> {
                        ArrayList arrayList = new ArrayList();
                        arrayList.add(obj);
                        return arrayList;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;)Ljava/util/ArrayList;")) {
                    return list -> {
                        return new ArrayList((Collection) list.stream().flatMap((v0) -> {
                            return v0.stream();
                        }).collect(Collectors.toList()));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
