/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.nn;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.nn.MultilayerPerceptron;
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.nn.initializers.RandomInitializer;
import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.trainers.MultiLabelDatasetTrainer;
import org.apache.ignite.ml.util.Utils;

public class MLPTrainer<P extends Serializable>
extends MultiLabelDatasetTrainer<MultilayerPerceptron> {
    private IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier;
    private IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
    private UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
    private int maxIterations = 100;
    private int batchSize = 100;
    private int locIterations = 100;
    private long seed = 1234L;

    public MLPTrainer(MLPArchitecture arch, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, int batchSize, int locIterations, long seed) {
        this(dataset -> arch, loss, updatesStgy, maxIterations, batchSize, locIterations, seed);
    }

    public MLPTrainer(IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, int batchSize, int locIterations, long seed) {
        this.archSupplier = archSupplier;
        this.loss = loss;
        this.updatesStgy = updatesStgy;
        this.maxIterations = maxIterations;
        this.batchSize = batchSize;
        this.locIterations = locIterations;
        this.seed = seed;
    }

    @Override
    public <K, V> MultilayerPerceptron fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        return this.updateModel((MultilayerPerceptron)null, datasetBuilder, extractor);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    protected <K, V> MultilayerPerceptron updateModel(MultilayerPerceptron lastLearnedMdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        assert (this.archSupplier != null);
        assert (this.loss != null);
        assert (this.updatesStgy != null);
        try (Dataset dataset = datasetBuilder.build(this.envBuilder, new EmptyContextBuilder(), new SimpleLabeledDatasetDataBuilder(extractor), this.learningEnvironment());){
            MultilayerPerceptron mdl;
            if (lastLearnedMdl != null) {
                mdl = lastLearnedMdl;
            } else {
                MLPArchitecture arch = (MLPArchitecture)this.archSupplier.apply(dataset);
                mdl = new MultilayerPerceptron(arch, new RandomInitializer(this.seed));
            }
            ParameterUpdateCalculator<MultilayerPerceptron, Serializable> updater = this.updatesStgy.getUpdatesCalculator();
            for (int i = 0; i < this.maxIterations; i += this.locIterations) {
                MultilayerPerceptron finalMdl = mdl;
                int finalI = i;
                List totUp = (List)dataset.compute(data -> {
                    Object update = updater.init(finalMdl, this.loss);
                    MultilayerPerceptron mlp = Utils.copy(finalMdl);
                    if (data.getFeatures() != null) {
                        ArrayList updates = new ArrayList();
                        for (int locStep = 0; locStep < this.locIterations; ++locStep) {
                            int[] rows = Utils.selectKDistinct(data.getRows(), Math.min(this.batchSize, data.getRows()), new Random(this.seed ^ (long)(finalI * locStep)));
                            double[] inputsBatch = MLPTrainer.batch(data.getFeatures(), rows, data.getRows());
                            double[] groundTruthBatch = MLPTrainer.batch(data.getLabels(), rows, data.getRows());
                            DenseMatrix inputs = new DenseMatrix(inputsBatch, rows.length, 0);
                            DenseMatrix groundTruth = new DenseMatrix(groundTruthBatch, rows.length, 0);
                            update = updater.calculateNewUpdate(mlp, (Serializable)update, locStep, inputs.transpose(), groundTruth.transpose());
                            mlp = updater.update(mlp, (Serializable)update);
                            updates.add(update);
                        }
                        ArrayList res = new ArrayList();
                        res.add(this.updatesStgy.locStepUpdatesReducer().apply(updates));
                        return res;
                    }
                    return null;
                }, (a, b) -> {
                    if (a == null) {
                        return b;
                    }
                    if (b == null) {
                        return a;
                    }
                    a.addAll(b);
                    return a;
                });
                if (totUp == null) {
                    MultilayerPerceptron multilayerPerceptron = this.getLastTrainedModelOrThrowEmptyDatasetException(lastLearnedMdl);
                    return multilayerPerceptron;
                }
                Serializable update = (Serializable)this.updatesStgy.allUpdatesReducer().apply(totUp);
                mdl = updater.update(mdl, update);
            }
            MultilayerPerceptron multilayerPerceptron = mdl;
            return multilayerPerceptron;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> getArchSupplier() {
        return this.archSupplier;
    }

    public MLPTrainer<P> withArchSupplier(IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier) {
        this.archSupplier = archSupplier;
        return this;
    }

    public IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> getLoss() {
        return this.loss;
    }

    public MLPTrainer<P> withLoss(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) {
        this.loss = loss;
        return this;
    }

    public UpdatesStrategy<? super MultilayerPerceptron, P> getUpdatesStgy() {
        return this.updatesStgy;
    }

    public MLPTrainer<P> withUpdatesStgy(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy) {
        this.updatesStgy = updatesStgy;
        return this;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public MLPTrainer<P> withMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
        return this;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public MLPTrainer<P> withBatchSize(int batchSize) {
        this.batchSize = batchSize;
        return this;
    }

    public int getLocIterations() {
        return this.locIterations;
    }

    public MLPTrainer<P> withLocIterations(int locIterations) {
        this.locIterations = locIterations;
        return this;
    }

    public long getSeed() {
        return this.seed;
    }

    public MLPTrainer<P> withSeed(long seed) {
        this.seed = seed;
        return this;
    }

    @Override
    public boolean isUpdateable(MultilayerPerceptron mdl) {
        return true;
    }

    static double[] batch(double[] data, int[] rows, int totalRows) {
        int cols = data.length / totalRows;
        double[] res = new double[cols * rows.length];
        for (int i = 0; i < rows.length; ++i) {
            for (int j = 0; j < cols; ++j) {
                res[j * rows.length + i] = data[j * totalRows + rows[i]];
            }
        }
        return res;
    }

    public MLPTrainer<P> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
        return (MLPTrainer)super.withEnvironmentBuilder(envBuilder);
    }
}

