/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.composition.combinators.sequential;

import java.io.Serializable;
import java.util.ArrayList;
import org.apache.ignite.lang.IgniteBiPredicate;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.composition.combinators.sequential.ModelsSequentialComposition;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.DatasetTrainer;

public class TrainersSequentialComposition<I, O1, O2, L>
extends DatasetTrainer<ModelsSequentialComposition<I, O1, O2>, L> {
    private DatasetTrainer<IgniteModel<I, O1>, L> tr1;
    private DatasetTrainer<IgniteModel<O1, O2>, L> tr2;
    protected IgniteBiFunction<Integer, ? super IgniteModel<I, O1>, IgniteFunction<LabeledVector<L>, LabeledVector<L>>> datasetMapping;

    public static <I, O, L> TrainersSequentialComposition<I, O, O, L> ofSame(DatasetTrainer<? extends IgniteModel<I, O>, L> tr, IgniteBiFunction<Integer, ? super IgniteModel<I, O>, IgniteFunction<LabeledVector<L>, LabeledVector<L>>> datasetMapping, IgniteBiPredicate<Integer, IgniteModel<I, O>> shouldStop, IgniteFunction<O, I> out2In) {
        return new SameTrainersSequentialComposition(CompositionUtils.unsafeCoerce(tr), datasetMapping, shouldStop, out2In);
    }

    @Override
    public <K, V> ModelsSequentialComposition<I, O1, O2> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        IgniteModel<I, O1> mdl1 = this.tr1.fit(datasetBuilder, preprocessor);
        IgniteFunction mapping = (IgniteFunction)this.datasetMapping.apply(0, mdl1);
        IgniteModel<O1, O2> mdl2 = this.tr2.fit(datasetBuilder, preprocessor.map(mapping));
        return new ModelsSequentialComposition<I, O1, O2>(mdl1, mdl2);
    }

    public TrainersSequentialComposition(DatasetTrainer<? extends IgniteModel<I, O1>, L> tr1, DatasetTrainer<? extends IgniteModel<O1, O2>, L> tr2, IgniteFunction<? super IgniteModel<I, O1>, IgniteFunction<LabeledVector<L>, LabeledVector<L>>> datasetMapping) {
        this.tr1 = CompositionUtils.unsafeCoerce(tr1);
        this.tr2 = CompositionUtils.unsafeCoerce(tr2);
        this.datasetMapping = (i, mdl) -> (IgniteFunction)datasetMapping.apply((Object)mdl);
    }

    public TrainersSequentialComposition(DatasetTrainer<? extends IgniteModel<I, O1>, L> tr1, DatasetTrainer<? extends IgniteModel<O1, O2>, L> tr2, IgniteBiFunction<Integer, ? super IgniteModel<I, O1>, IgniteFunction<LabeledVector<L>, LabeledVector<L>>> datasetMapping) {
        this.tr1 = CompositionUtils.unsafeCoerce(tr1);
        this.tr2 = CompositionUtils.unsafeCoerce(tr2);
        this.datasetMapping = datasetMapping;
    }

    @Override
    protected <K, V> ModelsSequentialComposition<I, O1, O2> updateModel(ModelsSequentialComposition<I, O1, O2> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        throw new IllegalStateException();
    }

    @Override
    public <K, V> ModelsSequentialComposition<I, O1, O2> update(ModelsSequentialComposition<I, O1, O2> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        IgniteModel<I, O1> firstUpdated = this.tr1.update(mdl.firstModel(), datasetBuilder, preprocessor);
        IgniteFunction mapping = (IgniteFunction)this.datasetMapping.apply(0, firstUpdated);
        IgniteModel<O1, O2> secondUpdated = this.tr2.update(mdl.secondModel(), datasetBuilder, preprocessor.map(mapping));
        return new ModelsSequentialComposition<I, O1, O2>(firstUpdated, secondUpdated);
    }

    @Override
    public boolean isUpdateable(ModelsSequentialComposition<I, O1, O2> mdl) {
        throw new IllegalStateException();
    }

    public DatasetTrainer<IgniteModel<I, O2>, L> unsafeSimplyTyped() {
        return CompositionUtils.unsafeCoerce(this);
    }

    private static class SameTrainersSequentialComposition<I, O, L>
    extends TrainersSequentialComposition<I, O, O, L> {
        private final DatasetTrainer<IgniteModel<I, O>, L> tr;
        private final IgniteBiPredicate<Integer, IgniteModel<I, O>> shouldStop;
        private final IgniteFunction<O, I> out2Input;

        public SameTrainersSequentialComposition(DatasetTrainer<IgniteModel<I, O>, L> tr, IgniteBiFunction<Integer, ? super IgniteModel<I, O>, IgniteFunction<LabeledVector<L>, LabeledVector<L>>> datasetMapping, IgniteBiPredicate<Integer, IgniteModel<I, O>> shouldStop, IgniteFunction<O, I> out2Input) {
            super(null, null, datasetMapping);
            this.tr = tr;
            this.shouldStop = (IgniteBiPredicate & Serializable)(iteration, model) -> iteration != 0 && shouldStop.apply(iteration, model);
            this.out2Input = out2Input;
        }

        @Override
        public <K, V> ModelsSequentialComposition<I, O, O> fit(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
            int i = 0;
            IgniteModel<I, O> currMdl = null;
            IgniteFunction mapping = IgniteFunction.identity();
            ArrayList<IgniteModel<I, O>> mdls = new ArrayList<IgniteModel<I, O>>();
            while (!this.shouldStop.apply((Object)i, currMdl)) {
                currMdl = this.tr.fit(datasetBuilder, preprocessor.map(mapping));
                mdls.add(currMdl);
                if (this.shouldStop.apply((Object)i, currMdl)) break;
                mapping = (IgniteFunction)this.datasetMapping.apply(i, currMdl);
                ++i;
            }
            return ModelsSequentialComposition.ofSame(mdls, this.out2Input);
        }
    }
}

