/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree.randomforest.data.statistics;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.feature.FeatureMeta;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.tree.randomforest.data.statistics.NormalDistributionStatistics;

public class NormalDistributionStatisticsComputer
implements Serializable {
    private static final long serialVersionUID = -3699071003012595743L;

    public List<NormalDistributionStatistics> computeStatistics(List<FeatureMeta> meta, Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
        return (List)dataset.compute(x -> this.computeStatsOnPartition((BootstrappedDatasetPartition)x, meta), (l, r) -> this.reduceStats((List<NormalDistributionStatistics>)l, (List<NormalDistributionStatistics>)r, meta));
    }

    public List<NormalDistributionStatistics> computeStatsOnPartition(BootstrappedDatasetPartition part, List<FeatureMeta> meta) {
        double[] sumOfValues = new double[meta.size()];
        double[] sumOfSquares = new double[sumOfValues.length];
        double[] min = new double[sumOfValues.length];
        double[] max = new double[sumOfValues.length];
        Arrays.fill(min, Double.POSITIVE_INFINITY);
        Arrays.fill(max, Double.NEGATIVE_INFINITY);
        for (int i = 0; i < part.getRowsCount(); ++i) {
            Object vec = part.getRow(i).features();
            for (int featureId = 0; featureId < vec.size(); ++featureId) {
                if (meta.get(featureId).isCategoricalFeature()) continue;
                double featureVal = vec.get(featureId);
                int n = featureId;
                sumOfValues[n] = sumOfValues[n] + featureVal;
                int n2 = featureId;
                sumOfSquares[n2] = sumOfSquares[n2] + Math.pow(featureVal, 2.0);
                min[featureId] = Math.min(min[featureId], featureVal);
                max[featureId] = Math.max(max[featureId], featureVal);
            }
        }
        ArrayList<NormalDistributionStatistics> res = new ArrayList<NormalDistributionStatistics>();
        for (int featureId = 0; featureId < sumOfSquares.length; ++featureId) {
            res.add(new NormalDistributionStatistics(min[featureId], max[featureId], sumOfSquares[featureId], sumOfValues[featureId], part.getRowsCount()));
        }
        return res;
    }

    public List<NormalDistributionStatistics> reduceStats(List<NormalDistributionStatistics> left, List<NormalDistributionStatistics> right, List<FeatureMeta> meta) {
        if (left == null) {
            return right;
        }
        if (right == null) {
            return left;
        }
        assert (meta.size() == left.size() && meta.size() == right.size());
        ArrayList<NormalDistributionStatistics> res = new ArrayList<NormalDistributionStatistics>();
        for (int featureId = 0; featureId < meta.size(); ++featureId) {
            NormalDistributionStatistics leftStat = left.get(featureId);
            NormalDistributionStatistics rightStat = right.get(featureId);
            res.add(leftStat.plus(rightStat));
        }
        return res;
    }
}

