package org.apache.ignite.ml.composition.stacking;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.composition.combinators.parallel.ModelsParallelComposition;
import org.apache.ignite.ml.composition.combinators.parallel.TrainersParallelComposition;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer;
import org.apache.ignite.ml.trainers.DatasetTrainer;

/* loaded from: input_file:org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.class */
public class StackedDatasetTrainer<IS, IA, O, AM extends IgniteModel<IA, O>, L> extends DatasetTrainer<StackedModel<IS, IA, O, AM>, L> {
    private IgniteBinaryOperator<IA> aggregatingInputMerger;
    private IgniteFunction<IS, IA> submodelInput2AggregatingInputConverter;
    private List<DatasetTrainer<IgniteModel<IS, IA>, L>> submodelsTrainers;
    private DatasetTrainer<AM, L> aggregatorTrainer;
    private IgniteFunction<Vector, IS> vector2SubmodelInputConverter;
    private IgniteFunction<IA, Vector> submodelOutput2VectorConverter;

    public StackedDatasetTrainer(DatasetTrainer<AM, L> datasetTrainer, IgniteBinaryOperator<IA> igniteBinaryOperator, IgniteFunction<IS, IA> igniteFunction, List<DatasetTrainer<IgniteModel<IS, IA>, L>> list, IgniteFunction<Vector, IS> igniteFunction2, IgniteFunction<IA, Vector> igniteFunction3) {
        this.aggregatorTrainer = datasetTrainer;
        this.aggregatingInputMerger = igniteBinaryOperator;
        this.submodelInput2AggregatingInputConverter = igniteFunction;
        this.submodelsTrainers = new ArrayList(list);
        this.vector2SubmodelInputConverter = igniteFunction2;
        this.submodelOutput2VectorConverter = igniteFunction3;
    }

    public StackedDatasetTrainer(DatasetTrainer<AM, L> datasetTrainer, IgniteBinaryOperator<IA> igniteBinaryOperator, IgniteFunction<IS, IA> igniteFunction) {
        this(datasetTrainer, igniteBinaryOperator, igniteFunction, new ArrayList(), null, null);
    }

    public StackedDatasetTrainer() {
        this(null, null, null, new ArrayList(), null, null);
    }

    public StackedDatasetTrainer<IS, IA, O, AM, L> withOriginalFeaturesKept(IgniteFunction<IS, IA> igniteFunction) {
        this.submodelInput2AggregatingInputConverter = igniteFunction;
        return this;
    }

    public StackedDatasetTrainer<IS, IA, O, AM, L> withOriginalFeaturesDropped() {
        this.submodelInput2AggregatingInputConverter = null;
        return this;
    }

    public StackedDatasetTrainer<IS, IA, O, AM, L> withSubmodelOutput2VectorConverter(IgniteFunction<IA, Vector> igniteFunction) {
        this.submodelOutput2VectorConverter = igniteFunction;
        return this;
    }

    public StackedDatasetTrainer<IS, IA, O, AM, L> withVector2SubmodelInputConverter(IgniteFunction<Vector, IS> igniteFunction) {
        this.vector2SubmodelInputConverter = igniteFunction;
        return this;
    }

    public StackedDatasetTrainer<IS, IA, O, AM, L> withAggregatorTrainer(DatasetTrainer<AM, L> datasetTrainer) {
        this.aggregatorTrainer = datasetTrainer;
        return this;
    }

    public StackedDatasetTrainer<IS, IA, O, AM, L> withAggregatorInputMerger(IgniteBinaryOperator<IA> igniteBinaryOperator) {
        this.aggregatingInputMerger = igniteBinaryOperator;
        return this;
    }

    public <M1 extends IgniteModel<IS, IA>> StackedDatasetTrainer<IS, IA, O, AM, L> addTrainer(DatasetTrainer<M1, L> datasetTrainer) {
        this.submodelsTrainers.add(CompositionUtils.unsafeCoerce(datasetTrainer));
        return this;
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> StackedModel<IS, IA, O, AM> fit(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        return new StackedModel<>(getTrainer().fit(datasetBuilder, preprocessor));
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> StackedModel<IS, IA, O, AM> update(StackedModel<IS, IA, O, AM> stackedModel, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        return new StackedModel<>(getTrainer().update(stackedModel, datasetBuilder, preprocessor));
    }

    private DatasetTrainer<IgniteModel<IS, O>, L> getTrainer() {
        checkConsistency();
        ArrayList arrayList = new ArrayList();
        if (this.submodelInput2AggregatingInputConverter != null) {
            arrayList.add(CompositionUtils.unsafeCoerce(AdaptableDatasetTrainer.of(DatasetTrainer.identityTrainer()).afterTrainedModel(this.submodelInput2AggregatingInputConverter)));
        }
        arrayList.addAll(this.submodelsTrainers);
        TrainersParallelComposition trainersParallelComposition = new TrainersParallelComposition(arrayList);
        IgniteBiFunction featureExtractorForAggregator = getFeatureExtractorForAggregator(this.submodelOutput2VectorConverter, this.vector2SubmodelInputConverter);
        return AdaptableDatasetTrainer.of(trainersParallelComposition).afterTrainedModel(list -> {
            return list.stream().reduce(this.aggregatingInputMerger).get();
        }).andThen(this.aggregatorTrainer, adaptableDatasetModel -> {
            return new IgniteFunction<LabeledVector<L>, LabeledVector<L>>() { // from class: org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer.1
                @Override // java.util.function.Function
                public LabeledVector<L> apply(LabeledVector<L> labeledVector) {
                    return new LabeledVector<>((Vector) featureExtractorForAggregator.apply(((ModelsParallelComposition) adaptableDatasetModel.innerModel()).submodels(), labeledVector.features()), labeledVector.label());
                }
            };
        }).unsafeSimplyTyped();
    }

    private void checkConsistency() {
        if (this.submodelInput2AggregatingInputConverter == null && this.submodelsTrainers.isEmpty()) {
            throw new IllegalStateException("There should be at least one way for submodels input to be propageted to aggregator.");
        }
        if (this.submodelOutput2VectorConverter == null || this.vector2SubmodelInputConverter == null) {
            throw new IllegalStateException("There should be a specified way to convert vectors to submodels input and submodels output to vector");
        }
        if (this.aggregatingInputMerger == null) {
            throw new IllegalStateException("Binary operator used to convert outputs of submodels is not specified");
        }
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public StackedDatasetTrainer<IS, IA, O, AM, L> withEnvironmentBuilder(LearningEnvironmentBuilder learningEnvironmentBuilder) {
        this.submodelsTrainers = (List) this.submodelsTrainers.stream().map(datasetTrainer -> {
            return datasetTrainer.withEnvironmentBuilder(learningEnvironmentBuilder);
        }).collect(Collectors.toList());
        this.aggregatorTrainer = this.aggregatorTrainer.withEnvironmentBuilder(learningEnvironmentBuilder);
        return this;
    }

    private static <IS, IA, K, V> IgniteBiFunction<List<IgniteModel<IS, IA>>, Vector, Vector> getFeatureExtractorForAggregator(IgniteFunction<IA, Vector> igniteFunction, IgniteFunction<Vector, IS> igniteFunction2) {
        return (list, vector) -> {
            return VectorUtils.concat((Vector[]) list.stream().map(igniteModel -> {
                return applyToVector(igniteModel, igniteFunction, igniteFunction2, vector);
            }).toArray(i -> {
                return new Vector[i];
            }));
        };
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static <IS, IA> Vector applyToVector(IgniteModel<IS, IA> igniteModel, IgniteFunction<IA, Vector> igniteFunction, IgniteFunction<Vector, IS> igniteFunction2, Vector vector) {
        igniteModel.getClass();
        return (Vector) igniteFunction2.andThen(igniteModel::predict).andThen((IgniteFunction<? super V, ? extends V>) igniteFunction).apply(vector);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> StackedModel<IS, IA, O, AM> updateModel(StackedModel<IS, IA, O, AM> stackedModel, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        throw new IllegalStateException();
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public boolean isUpdateable(StackedModel<IS, IA, O, AM> stackedModel) {
        throw new IllegalStateException();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -318720807:
                if (implMethodName.equals("predict")) {
                    z = true;
                    break;
                }
                break;
            case -276055687:
                if (implMethodName.equals("lambda$getFeatureExtractorForAggregator$cc55c666$1")) {
                    z = false;
                    break;
                }
                break;
            case -147387587:
                if (implMethodName.equals("lambda$getTrainer$998fd916$1")) {
                    z = 2;
                    break;
                }
                break;
            case 1243339568:
                if (implMethodName.equals("lambda$getTrainer$93a379d1$1")) {
                    z = 3;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/functions/IgniteFunction;Lorg/apache/ignite/ml/math/functions/IgniteFunction;Ljava/util/List;Lorg/apache/ignite/ml/math/primitives/vector/Vector;)Lorg/apache/ignite/ml/math/primitives/vector/Vector;")) {
                    IgniteFunction igniteFunction = (IgniteFunction) serializedLambda.getCapturedArg(0);
                    IgniteFunction igniteFunction2 = (IgniteFunction) serializedLambda.getCapturedArg(1);
                    return (list, vector) -> {
                        return VectorUtils.concat((Vector[]) list.stream().map(igniteModel -> {
                            return applyToVector(igniteModel, igniteFunction, igniteFunction2, vector);
                        }).toArray(i -> {
                            return new Vector[i];
                        }));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 9 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/inference/Model") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;")) {
                    IgniteModel igniteModel = (IgniteModel) serializedLambda.getCapturedArg(0);
                    return igniteModel::predict;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;)Ljava/lang/Object;")) {
                    StackedDatasetTrainer stackedDatasetTrainer = (StackedDatasetTrainer) serializedLambda.getCapturedArg(0);
                    return list2 -> {
                        return list2.stream().reduce(this.aggregatingInputMerger).get();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/functions/IgniteBiFunction;Lorg/apache/ignite/ml/trainers/AdaptableDatasetModel;)Lorg/apache/ignite/ml/math/functions/IgniteFunction;")) {
                    StackedDatasetTrainer stackedDatasetTrainer2 = (StackedDatasetTrainer) serializedLambda.getCapturedArg(0);
                    IgniteBiFunction igniteBiFunction = (IgniteBiFunction) serializedLambda.getCapturedArg(1);
                    return adaptableDatasetModel -> {
                        return new IgniteFunction<LabeledVector<L>, LabeledVector<L>>() { // from class: org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer.1
                            @Override // java.util.function.Function
                            public LabeledVector<L> apply(LabeledVector<L> labeledVector) {
                                return new LabeledVector<>((Vector) igniteBiFunction.apply(((ModelsParallelComposition) adaptableDatasetModel.innerModel()).submodels(), labeledVector.features()), labeledVector.label());
                            }
                        };
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
