/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.trainers;

import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.DatasetMapping;
import org.apache.ignite.ml.composition.combinators.sequential.TrainersSequentialComposition;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.AdaptableDatasetModel;
import org.apache.ignite.ml.trainers.DatasetTrainer;

public class AdaptableDatasetTrainer<I, O, IW, OW, M extends IgniteModel<IW, OW>, L>
extends DatasetTrainer<AdaptableDatasetModel<I, O, IW, OW, M>, L> {
    private final DatasetTrainer<M, L> wrapped;
    private final IgniteFunction<I, IW> before;
    private final IgniteFunction<OW, O> after;
    private final IgniteFunction<LabeledVector<L>, LabeledVector<L>> afterExtractor;
    private final UpstreamTransformerBuilder upstreamTransformerBuilder;

    public static <I, O, M extends IgniteModel<I, O>, L> AdaptableDatasetTrainer<I, O, I, O, M, L> of(DatasetTrainer<M, L> wrapped) {
        return new AdaptableDatasetTrainer(IgniteFunction.identity(), wrapped, IgniteFunction.identity(), IgniteFunction.identity(), UpstreamTransformerBuilder.identity());
    }

    private AdaptableDatasetTrainer(IgniteFunction<I, IW> before, DatasetTrainer<M, L> wrapped, IgniteFunction<OW, O> after, IgniteFunction<LabeledVector<L>, LabeledVector<L>> afterExtractor, UpstreamTransformerBuilder builder) {
        this.before = before;
        this.wrapped = wrapped;
        this.after = after;
        this.afterExtractor = afterExtractor;
        this.upstreamTransformerBuilder = builder;
    }

    @Override
    public <K, V> AdaptableDatasetModel<I, O, IW, OW, M> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        M fit = this.wrapped.withEnvironmentBuilder(this.envBuilder).fit(datasetBuilder.withUpstreamTransformer(this.upstreamTransformerBuilder), extractor.map(this.afterExtractor));
        return new AdaptableDatasetModel<I, O, IW, OW, M>(this.before, fit, this.after);
    }

    @Override
    public boolean isUpdateable(AdaptableDatasetModel<I, O, IW, OW, M> mdl) {
        return this.wrapped.isUpdateable(mdl.innerModel());
    }

    @Override
    protected <K, V> AdaptableDatasetModel<I, O, IW, OW, M> updateModel(AdaptableDatasetModel<I, O, IW, OW, M> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        M updated = this.wrapped.withEnvironmentBuilder(this.envBuilder).updateModel(mdl.innerModel(), datasetBuilder.withUpstreamTransformer(this.upstreamTransformerBuilder), extractor.map(this.afterExtractor));
        return mdl.withInnerModel(updated);
    }

    public <O1> AdaptableDatasetTrainer<I, O1, IW, OW, M, L> afterTrainedModel(IgniteFunction<O, O1> after) {
        return new AdaptableDatasetTrainer<I, O, IW, OW, M, L>(this.before, this.wrapped, i -> after.apply(this.after.apply(i)), this.afterExtractor, this.upstreamTransformerBuilder);
    }

    public <I1> AdaptableDatasetTrainer<I1, O, IW, OW, M, L> beforeTrainedModel(IgniteFunction<I1, I> before) {
        IgniteFunction function = i -> this.before.apply(before.apply(i));
        return new AdaptableDatasetTrainer<I, O, IW, OW, M, L>(function, this.wrapped, this.after, this.afterExtractor, this.upstreamTransformerBuilder);
    }

    public AdaptableDatasetTrainer<I, O, IW, OW, M, L> withDatasetMapping(final DatasetMapping<L, L> mapping) {
        return AdaptableDatasetTrainer.of(new DatasetTrainer<M, L>(){

            @Override
            public <K, V> M fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
                return AdaptableDatasetTrainer.this.wrapped.fit(datasetBuilder, extractor.map(lv -> new LabeledVector(mapping.mapFeatures((Vector)lv.features()), mapping.mapLabels(lv.label()))));
            }

            @Override
            public <K, V> M update(M mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
                return AdaptableDatasetTrainer.this.wrapped.update(mdl, datasetBuilder, vectorizer.map(lv -> new LabeledVector(mapping.mapFeatures((Vector)lv.features()), mapping.mapLabels(lv.label()))));
            }

            @Override
            public boolean isUpdateable(M mdl) {
                return false;
            }

            @Override
            protected <K, V> M updateModel(M mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
                return null;
            }
        }).beforeTrainedModel(this.before).afterTrainedModel(this.after);
    }

    public <O1, M1 extends IgniteModel<O, O1>> TrainersSequentialComposition<I, O, O1, L> andThen(DatasetTrainer<M1, L> tr, IgniteFunction<AdaptableDatasetModel<I, O, IW, OW, M>, IgniteFunction<LabeledVector<L>, LabeledVector<L>>> datasetMappingProducer) {
        IgniteFunction coercedMapping = mdl -> (IgniteFunction)datasetMappingProducer.apply((AdaptableDatasetModel)mdl);
        return new TrainersSequentialComposition(this, tr, coercedMapping);
    }

    public AdaptableDatasetTrainer<I, O, IW, OW, M, L> afterFeatureExtractor(IgniteFunction<Vector, Vector> after) {
        IgniteFunction<LabeledVector<L>, LabeledVector<L>> newExtractor = this.afterExtractor.andThen(slv -> new LabeledVector((Vector)after.apply((Vector)slv.features()), slv.label()));
        return new AdaptableDatasetTrainer<I, O, IW, OW, M, L>(this.before, this.wrapped, this.after, newExtractor, this.upstreamTransformerBuilder);
    }

    public AdaptableDatasetTrainer<I, O, IW, OW, M, L> afterLabelExtractor(IgniteFunction<L, L> after) {
        IgniteFunction<LabeledVector<L>, LabeledVector<L>> newExtractor = this.afterExtractor.andThen(slv -> new LabeledVector((Vector)slv.features(), after.apply(slv.label())));
        return new AdaptableDatasetTrainer<I, O, IW, OW, M, L>(this.before, this.wrapped, this.after, newExtractor, this.upstreamTransformerBuilder);
    }

    public AdaptableDatasetTrainer<I, O, IW, OW, M, L> withUpstreamTransformerBuilder(UpstreamTransformerBuilder upstreamTransformerBuilder) {
        return new AdaptableDatasetTrainer<I, O, IW, OW, M, L>(this.before, this.wrapped, this.after, this.afterExtractor, upstreamTransformerBuilder);
    }
}

