/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.preprocessing.standardscaling;

import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.preprocessing.standardscaling.StandardScalerData;
import org.apache.ignite.ml.preprocessing.standardscaling.StandardScalerPreprocessor;
import org.apache.ignite.ml.structures.LabeledVector;

public class StandardScalerTrainer<K, V>
implements PreprocessingTrainer<K, V> {
    @Override
    public StandardScalerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
        StandardScalerData standardScalerData = this.computeSum(envBuilder, datasetBuilder, basePreprocessor);
        int n = standardScalerData.sum.length;
        long cnt = standardScalerData.cnt;
        double[] mean = new double[n];
        double[] sigma = new double[n];
        for (int i = 0; i < n; ++i) {
            mean[i] = standardScalerData.sum[i] / (double)cnt;
            double variance = (standardScalerData.squaredSum[i] - Math.pow(standardScalerData.sum[i], 2.0) / (double)cnt) / (double)cnt;
            sigma[i] = Math.sqrt(variance);
        }
        return new StandardScalerPreprocessor<K, V>(mean, sigma, basePreprocessor);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private StandardScalerData computeSum(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
        try (Dataset<EmptyContext, StandardScalerData> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
            double[] sum = null;
            double[] squaredSum = null;
            long cnt = 0L;
            while (upstream.hasNext()) {
                UpstreamEntry entity = (UpstreamEntry)upstream.next();
                Object row = ((LabeledVector)basePreprocessor.apply(entity.getKey(), entity.getValue())).features();
                if (sum == null) {
                    sum = new double[row.size()];
                    squaredSum = new double[row.size()];
                } else assert (sum.length == row.size()) : "Base preprocessor must return exactly " + sum.length + " features";
                ++cnt;
                int i = 0;
                while (i < row.size()) {
                    double x = row.get(i);
                    int n = i;
                    sum[n] = sum[n] + x;
                    int n2 = i++;
                    squaredSum[n2] = squaredSum[n2] + x * x;
                }
            }
            return new StandardScalerData(sum, squaredSum, cnt);
        }, this.learningEnvironment(basePreprocessor));){
            StandardScalerData standardScalerData = (StandardScalerData)dataset.compute(data -> data, (a, b) -> {
                if (a == null) {
                    return b;
                }
                if (b == null) {
                    return a;
                }
                return a.merge((StandardScalerData)b);
            });
            return standardScalerData;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}

