/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.composition.stacking;

import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.stacking.SimpleStackedDatasetTrainer;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.matrix.Matrix;
import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer;
import org.apache.ignite.ml.trainers.DatasetTrainer;

public class StackedVectorDatasetTrainer<O, AM extends IgniteModel<Vector, O>, L>
extends SimpleStackedDatasetTrainer<Vector, O, AM, L> {
    public StackedVectorDatasetTrainer(DatasetTrainer<AM, L> aggregatingTrainer) {
        super(aggregatingTrainer, VectorUtils::concat, IgniteFunction.identity(), IgniteFunction.identity(), IgniteFunction.identity());
    }

    public StackedVectorDatasetTrainer() {
        this(null);
    }

    public <M1 extends IgniteModel<Vector, Vector>> StackedVectorDatasetTrainer<O, AM, L> addTrainer(DatasetTrainer<M1, L> trainer) {
        return (StackedVectorDatasetTrainer)super.addTrainer((DatasetTrainer)trainer);
    }

    public StackedVectorDatasetTrainer<O, AM, L> withAggregatorTrainer(DatasetTrainer<AM, L> aggregatorTrainer) {
        return (StackedVectorDatasetTrainer)super.withAggregatorTrainer((DatasetTrainer)aggregatorTrainer);
    }

    public StackedVectorDatasetTrainer<O, AM, L> withOriginalFeaturesKept() {
        return (StackedVectorDatasetTrainer)super.withOriginalFeaturesKept();
    }

    @Override
    public StackedVectorDatasetTrainer<O, AM, L> withOriginalFeaturesDropped() {
        return (StackedVectorDatasetTrainer)super.withOriginalFeaturesDropped();
    }

    public StackedVectorDatasetTrainer<O, AM, L> withOriginalFeaturesKept(IgniteFunction<Vector, Vector> submodelInput2AggregatingInputConverter) {
        return (StackedVectorDatasetTrainer)super.withOriginalFeaturesKept(submodelInput2AggregatingInputConverter);
    }

    public StackedVectorDatasetTrainer<O, AM, L> withSubmodelOutput2VectorConverter(IgniteFunction<Vector, Vector> submodelOutput2VectorConverter) {
        return (StackedVectorDatasetTrainer)super.withSubmodelOutput2VectorConverter(submodelOutput2VectorConverter);
    }

    public StackedVectorDatasetTrainer<O, AM, L> withVector2SubmodelInputConverter(IgniteFunction<Vector, Vector> vector2SubmodelInputConverter) {
        return (StackedVectorDatasetTrainer)super.withVector2SubmodelInputConverter(vector2SubmodelInputConverter);
    }

    public StackedVectorDatasetTrainer<O, AM, L> withAggregatorInputMerger(IgniteBinaryOperator<Vector> merger) {
        return (StackedVectorDatasetTrainer)super.withAggregatorInputMerger(merger);
    }

    @Override
    public StackedVectorDatasetTrainer<O, AM, L> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
        return (StackedVectorDatasetTrainer)super.withEnvironmentBuilder(envBuilder);
    }

    public <L1> StackedVectorDatasetTrainer<O, AM, L1> withConvertedLabels(IgniteFunction<L1, L> new2Old) {
        return (StackedVectorDatasetTrainer)super.withConvertedLabels((IgniteFunction)new2Old);
    }

    public <M1 extends IgniteModel<Vector, Double>> StackedVectorDatasetTrainer<O, AM, L> addTrainerWithDoubleOutput(DatasetTrainer<M1, L> trainer) {
        return this.addTrainer((DatasetTrainer)AdaptableDatasetTrainer.of(trainer).afterTrainedModel(VectorUtils::num2Vec));
    }

    public <M1 extends IgniteModel<Matrix, Matrix>> StackedVectorDatasetTrainer<O, AM, L> addMatrix2MatrixTrainer(DatasetTrainer<M1, L> trainer) {
        AdaptableDatasetTrainer adapted = AdaptableDatasetTrainer.of(trainer).beforeTrainedModel(v -> new DenseMatrix(v.asArray(), 1)).afterTrainedModel(mtx -> mtx.getRow(0));
        return this.addTrainer((DatasetTrainer)adapted);
    }
}

