/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.preprocessing.imputing;

import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionContextBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.preprocessing.imputing.ImputerPartitionData;
import org.apache.ignite.ml.preprocessing.imputing.ImputerPreprocessor;
import org.apache.ignite.ml.preprocessing.imputing.ImputingStrategy;
import org.apache.ignite.ml.structures.LabeledVector;

public class ImputerTrainer<K, V>
implements PreprocessingTrainer<K, V> {
    private ImputingStrategy imputingStgy = ImputingStrategy.MEAN;

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    public ImputerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
        PartitionContextBuilder builder = (env, upstream, upstreamSize) -> new EmptyContext();
        try (Dataset<EmptyContext, ImputerPartitionData> dataset = datasetBuilder.build(envBuilder, builder, (env, upstream, upstreamSize, ctx) -> {
            ImputerPartitionData partData;
            double[] sums = null;
            int[] counts = null;
            Map<Double, Integer>[] valuesByFreq = null;
            block8: while (upstream.hasNext()) {
                UpstreamEntry entity = (UpstreamEntry)upstream.next();
                LabeledVector row = (LabeledVector)basePreprocessor.apply(entity.getKey(), entity.getValue());
                switch (this.imputingStgy) {
                    case MEAN: {
                        sums = this.calculateTheSums(row, sums);
                        counts = this.calculateTheCounts(row, counts);
                        continue block8;
                    }
                    case MOST_FREQUENT: {
                        valuesByFreq = this.calculateFrequencies(row, valuesByFreq);
                        continue block8;
                    }
                }
                throw new UnsupportedOperationException("The chosen strategy is not supported");
            }
            switch (this.imputingStgy) {
                case MEAN: {
                    partData = new ImputerPartitionData().withSums(sums).withCounts(counts);
                    break;
                }
                case MOST_FREQUENT: {
                    partData = new ImputerPartitionData().withValuesByFrequency(valuesByFreq);
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("The chosen strategy is not supported");
                }
            }
            return partData;
        }, this.learningEnvironment(basePreprocessor));){
            Vector imputingValues;
            switch (this.imputingStgy) {
                case MEAN: {
                    imputingValues = VectorUtils.of(this.calculateImputingValuesBySumsAndCounts(dataset));
                    break;
                }
                case MOST_FREQUENT: {
                    imputingValues = VectorUtils.of(this.calculateImputingValuesByFrequencies(dataset));
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("The chosen strategy is not supported");
                }
            }
            ImputerPreprocessor<K, V> imputerPreprocessor = new ImputerPreprocessor<K, V>(imputingValues, basePreprocessor);
            return imputerPreprocessor;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private double[] calculateImputingValuesByFrequencies(Dataset<EmptyContext, ImputerPartitionData> dataset) {
        Map[] frequencies = (Map[])dataset.compute(ImputerPartitionData::valuesByFrequency, (a, b) -> {
            if (a == null) {
                return b;
            }
            if (b == null) {
                return a;
            }
            assert (((Map[])a).length == ((Map[])b).length);
            for (int i = 0; i < ((Map[])a).length; ++i) {
                int finalI = i;
                a[i].forEach((k, v) -> b[finalI].merge(k, v, (f1, f2) -> f1 + f2));
            }
            return b;
        });
        double[] res = new double[frequencies.length];
        for (int i = 0; i < frequencies.length; ++i) {
            Optional<Map.Entry> max = frequencies[i].entrySet().stream().max(Comparator.comparingInt(Map.Entry::getValue));
            if (!max.isPresent()) continue;
            res[i] = (Double)max.get().getKey();
        }
        return res;
    }

    private double[] calculateImputingValuesBySumsAndCounts(Dataset<EmptyContext, ImputerPartitionData> dataset) {
        double[] sums = (double[])dataset.compute(ImputerPartitionData::sums, (a, b) -> {
            if (a == null) {
                return b;
            }
            if (b == null) {
                return a;
            }
            assert (((double[])a).length == ((double[])b).length);
            for (int i = 0; i < ((double[])a).length; ++i) {
                int n = i;
                a[n] = a[n] + b[i];
            }
            return a;
        });
        int[] counts = (int[])dataset.compute(ImputerPartitionData::counts, (a, b) -> {
            if (a == null) {
                return b;
            }
            if (b == null) {
                return a;
            }
            assert (((int[])a).length == ((int[])b).length);
            for (int i = 0; i < ((int[])a).length; ++i) {
                int n = i;
                a[n] = a[n] + b[i];
            }
            return a;
        });
        double[] means = new double[sums.length];
        for (int i = 0; i < means.length; ++i) {
            means[i] = sums[i] / (double)counts[i];
        }
        return means;
    }

    private Map<Double, Integer>[] calculateFrequencies(LabeledVector row, Map<Double, Integer>[] valuesByFreq) {
        int i;
        if (valuesByFreq == null) {
            valuesByFreq = new HashMap[row.size()];
            for (i = 0; i < valuesByFreq.length; ++i) {
                valuesByFreq[i] = new HashMap<Double, Integer>();
            }
        } else assert (valuesByFreq.length == row.size()) : "Base preprocessor must return exactly " + valuesByFreq.length + " features";
        for (i = 0; i < valuesByFreq.length; ++i) {
            double v = row.get(i);
            if (Double.valueOf(v).equals(Double.NaN)) continue;
            Map<Double, Integer> map = valuesByFreq[i];
            if (map.containsKey(v)) {
                map.put(v, map.get(v) + 1);
                continue;
            }
            map.put(v, 1);
        }
        return valuesByFreq;
    }

    private double[] calculateTheSums(LabeledVector row, double[] sums) {
        if (sums == null) {
            sums = new double[row.size()];
        } else assert (sums.length == row.size()) : "Base preprocessor must return exactly " + sums.length + " features";
        for (int i = 0; i < sums.length; ++i) {
            if (Double.valueOf(row.get(i)).equals(Double.NaN)) continue;
            int n = i;
            sums[n] = sums[n] + row.get(i);
        }
        return sums;
    }

    private int[] calculateTheCounts(LabeledVector row, int[] counts) {
        if (counts == null) {
            counts = new int[row.size()];
        } else assert (counts.length == row.size()) : "Base preprocessor must return exactly " + counts.length + " features";
        for (int i = 0; i < counts.length; ++i) {
            if (Double.valueOf(row.get(i)).equals(Double.NaN)) continue;
            int n = i;
            counts[n] = counts[n] + 1;
        }
        return counts;
    }

    public ImputerTrainer<K, V> withImputingStrategy(ImputingStrategy imputingStgy) {
        this.imputingStgy = imputingStgy;
        return this;
    }
}

