package org.apache.ignite.ml.trainers;

import java.lang.invoke.SerializedLambda;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.DatasetMapping;
import org.apache.ignite.ml.composition.combinators.sequential.TrainersSequentialComposition;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.LabeledVector;

/* loaded from: input_file:org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.class */
public class AdaptableDatasetTrainer<I, O, IW, OW, M extends IgniteModel<IW, OW>, L> extends DatasetTrainer<AdaptableDatasetModel<I, O, IW, OW, M>, L> {
    private final DatasetTrainer<M, L> wrapped;
    private final IgniteFunction<I, IW> before;
    private final IgniteFunction<OW, O> after;
    private final IgniteFunction<LabeledVector<L>, LabeledVector<L>> afterExtractor;
    private final UpstreamTransformerBuilder upstreamTransformerBuilder;

    public static <I, O, M extends IgniteModel<I, O>, L> AdaptableDatasetTrainer<I, O, I, O, M, L> of(DatasetTrainer<M, L> datasetTrainer) {
        return new AdaptableDatasetTrainer<>(IgniteFunction.identity(), datasetTrainer, IgniteFunction.identity(), IgniteFunction.identity(), UpstreamTransformerBuilder.identity());
    }

    private AdaptableDatasetTrainer(IgniteFunction<I, IW> igniteFunction, DatasetTrainer<M, L> datasetTrainer, IgniteFunction<OW, O> igniteFunction2, IgniteFunction<LabeledVector<L>, LabeledVector<L>> igniteFunction3, UpstreamTransformerBuilder upstreamTransformerBuilder) {
        this.before = igniteFunction;
        this.wrapped = datasetTrainer;
        this.after = igniteFunction2;
        this.afterExtractor = igniteFunction3;
        this.upstreamTransformerBuilder = upstreamTransformerBuilder;
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> AdaptableDatasetModel<I, O, IW, OW, M> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        return new AdaptableDatasetModel<>(this.before, this.wrapped.withEnvironmentBuilder2(this.envBuilder).fit(datasetBuilder.withUpstreamTransformer(this.upstreamTransformerBuilder), preprocessor.map(this.afterExtractor)), this.after);
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public boolean isUpdateable(AdaptableDatasetModel<I, O, IW, OW, M> adaptableDatasetModel) {
        return this.wrapped.isUpdateable(adaptableDatasetModel.innerModel());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> AdaptableDatasetModel<I, O, IW, OW, M> updateModel(AdaptableDatasetModel<I, O, IW, OW, M> adaptableDatasetModel, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        return (AdaptableDatasetModel<I, O, IW, OW, M>) adaptableDatasetModel.withInnerModel(this.wrapped.withEnvironmentBuilder2(this.envBuilder).updateModel(adaptableDatasetModel.innerModel(), datasetBuilder.withUpstreamTransformer(this.upstreamTransformerBuilder), preprocessor.map(this.afterExtractor)));
    }

    public <O1> AdaptableDatasetTrainer<I, O1, IW, OW, M, L> afterTrainedModel(IgniteFunction<O, O1> igniteFunction) {
        return new AdaptableDatasetTrainer<>(this.before, this.wrapped, obj -> {
            return igniteFunction.apply(this.after.apply(obj));
        }, this.afterExtractor, this.upstreamTransformerBuilder);
    }

    public <I1> AdaptableDatasetTrainer<I1, O, IW, OW, M, L> beforeTrainedModel(IgniteFunction<I1, I> igniteFunction) {
        return new AdaptableDatasetTrainer<>(obj -> {
            return this.before.apply(igniteFunction.apply(obj));
        }, this.wrapped, this.after, this.afterExtractor, this.upstreamTransformerBuilder);
    }

    public AdaptableDatasetTrainer<I, O, IW, OW, M, L> withDatasetMapping(final DatasetMapping<L, L> datasetMapping) {
        return of(new DatasetTrainer<M, L>() { // from class: org.apache.ignite.ml.trainers.AdaptableDatasetTrainer.1
            @Override // org.apache.ignite.ml.trainers.DatasetTrainer
            public <K, V> M fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
                DatasetTrainer datasetTrainer = AdaptableDatasetTrainer.this.wrapped;
                DatasetMapping datasetMapping2 = datasetMapping;
                return (M) datasetTrainer.fit(datasetBuilder, preprocessor.map(labeledVector -> {
                    return new LabeledVector(datasetMapping2.mapFeatures(labeledVector.features()), datasetMapping2.mapLabels(labeledVector.label()));
                }));
            }

            @Override // org.apache.ignite.ml.trainers.DatasetTrainer
            public <K, V> M update(M m, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
                DatasetTrainer datasetTrainer = AdaptableDatasetTrainer.this.wrapped;
                DatasetMapping datasetMapping2 = datasetMapping;
                return (M) datasetTrainer.update(m, datasetBuilder, preprocessor.map(labeledVector -> {
                    return new LabeledVector(datasetMapping2.mapFeatures(labeledVector.features()), datasetMapping2.mapLabels(labeledVector.label()));
                }));
            }

            @Override // org.apache.ignite.ml.trainers.DatasetTrainer
            public boolean isUpdateable(M m) {
                return false;
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.ignite.ml.trainers.DatasetTrainer
            public <K, V> M updateModel(M m, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
                return null;
            }

            private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
                String implMethodName = serializedLambda.getImplMethodName();
                boolean z = -1;
                switch (implMethodName.hashCode()) {
                    case -605941965:
                        if (implMethodName.equals("lambda$update$f44a9a05$1")) {
                            z = false;
                            break;
                        }
                        break;
                    case 1746066791:
                        if (implMethodName.equals("lambda$fitWithInitializedDeployingContext$584d8f25$1")) {
                            z = true;
                            break;
                        }
                        break;
                }
                switch (z) {
                    case false:
                        if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/trainers/AdaptableDatasetTrainer$1") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/composition/DatasetMapping;Lorg/apache/ignite/ml/structures/LabeledVector;)Lorg/apache/ignite/ml/structures/LabeledVector;")) {
                            DatasetMapping datasetMapping2 = (DatasetMapping) serializedLambda.getCapturedArg(0);
                            return labeledVector -> {
                                return new LabeledVector(datasetMapping2.mapFeatures(labeledVector.features()), datasetMapping2.mapLabels(labeledVector.label()));
                            };
                        }
                        break;
                    case true:
                        if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/trainers/AdaptableDatasetTrainer$1") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/composition/DatasetMapping;Lorg/apache/ignite/ml/structures/LabeledVector;)Lorg/apache/ignite/ml/structures/LabeledVector;")) {
                            DatasetMapping datasetMapping3 = (DatasetMapping) serializedLambda.getCapturedArg(0);
                            return labeledVector2 -> {
                                return new LabeledVector(datasetMapping3.mapFeatures(labeledVector2.features()), datasetMapping3.mapLabels(labeledVector2.label()));
                            };
                        }
                        break;
                }
                throw new IllegalArgumentException("Invalid lambda deserialization");
            }
        }).beforeTrainedModel(this.before).afterTrainedModel(this.after);
    }

    public <O1, M1 extends IgniteModel<O, O1>> TrainersSequentialComposition<I, O, O1, L> andThen(DatasetTrainer<M1, L> datasetTrainer, IgniteFunction<AdaptableDatasetModel<I, O, IW, OW, M>, IgniteFunction<LabeledVector<L>, LabeledVector<L>>> igniteFunction) {
        return new TrainersSequentialComposition<>(this, datasetTrainer, igniteModel -> {
            return (IgniteFunction) igniteFunction.apply((AdaptableDatasetModel) igniteModel);
        });
    }

    public AdaptableDatasetTrainer<I, O, IW, OW, M, L> afterFeatureExtractor(IgniteFunction<Vector, Vector> igniteFunction) {
        return new AdaptableDatasetTrainer<>(this.before, this.wrapped, this.after, this.afterExtractor.andThen(labeledVector -> {
            return new LabeledVector((Vector) igniteFunction.apply(labeledVector.features()), labeledVector.label());
        }), this.upstreamTransformerBuilder);
    }

    public AdaptableDatasetTrainer<I, O, IW, OW, M, L> afterLabelExtractor(IgniteFunction<L, L> igniteFunction) {
        return new AdaptableDatasetTrainer<>(this.before, this.wrapped, this.after, this.afterExtractor.andThen(labeledVector -> {
            return new LabeledVector(labeledVector.features(), igniteFunction.apply(labeledVector.label()));
        }), this.upstreamTransformerBuilder);
    }

    public AdaptableDatasetTrainer<I, O, IW, OW, M, L> withUpstreamTransformerBuilder(UpstreamTransformerBuilder upstreamTransformerBuilder) {
        return new AdaptableDatasetTrainer<>(this.before, this.wrapped, this.after, this.afterExtractor, upstreamTransformerBuilder);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1192998851:
                if (implMethodName.equals("lambda$afterLabelExtractor$a38aa4f3$1")) {
                    z = 2;
                    break;
                }
                break;
            case -1131527605:
                if (implMethodName.equals("lambda$afterTrainedModel$a758b46f$1")) {
                    z = true;
                    break;
                }
                break;
            case -881280620:
                if (implMethodName.equals("lambda$afterFeatureExtractor$8106313f$1")) {
                    z = 3;
                    break;
                }
                break;
            case -843686510:
                if (implMethodName.equals("lambda$beforeTrainedModel$2a6d6877$1")) {
                    z = false;
                    break;
                }
                break;
            case 1413528846:
                if (implMethodName.equals("lambda$andThen$a335b11$1")) {
                    z = 4;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/trainers/AdaptableDatasetTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/functions/IgniteFunction;Ljava/lang/Object;)Ljava/lang/Object;")) {
                    AdaptableDatasetTrainer adaptableDatasetTrainer = (AdaptableDatasetTrainer) serializedLambda.getCapturedArg(0);
                    IgniteFunction igniteFunction = (IgniteFunction) serializedLambda.getCapturedArg(1);
                    return obj -> {
                        return this.before.apply(igniteFunction.apply(obj));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/trainers/AdaptableDatasetTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/functions/IgniteFunction;Ljava/lang/Object;)Ljava/lang/Object;")) {
                    AdaptableDatasetTrainer adaptableDatasetTrainer2 = (AdaptableDatasetTrainer) serializedLambda.getCapturedArg(0);
                    IgniteFunction igniteFunction2 = (IgniteFunction) serializedLambda.getCapturedArg(1);
                    return obj2 -> {
                        return igniteFunction2.apply(this.after.apply(obj2));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/trainers/AdaptableDatasetTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/functions/IgniteFunction;Lorg/apache/ignite/ml/structures/LabeledVector;)Lorg/apache/ignite/ml/structures/LabeledVector;")) {
                    IgniteFunction igniteFunction3 = (IgniteFunction) serializedLambda.getCapturedArg(0);
                    return labeledVector -> {
                        return new LabeledVector(labeledVector.features(), igniteFunction3.apply(labeledVector.label()));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/trainers/AdaptableDatasetTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/functions/IgniteFunction;Lorg/apache/ignite/ml/structures/LabeledVector;)Lorg/apache/ignite/ml/structures/LabeledVector;")) {
                    IgniteFunction igniteFunction4 = (IgniteFunction) serializedLambda.getCapturedArg(0);
                    return labeledVector2 -> {
                        return new LabeledVector((Vector) igniteFunction4.apply(labeledVector2.features()), labeledVector2.label());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/trainers/AdaptableDatasetTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/functions/IgniteFunction;Lorg/apache/ignite/ml/IgniteModel;)Lorg/apache/ignite/ml/math/functions/IgniteFunction;")) {
                    IgniteFunction igniteFunction5 = (IgniteFunction) serializedLambda.getCapturedArg(0);
                    return igniteModel -> {
                        return (IgniteFunction) igniteFunction5.apply((AdaptableDatasetModel) igniteModel);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
