/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree.randomforest.data.impurity;

import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import org.apache.ignite.ml.dataset.feature.BucketMeta;
import org.apache.ignite.ml.dataset.feature.ObjectHistogram;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.tree.randomforest.data.NodeSplit;
import org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityComputer;
import org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityHistogram;
import org.apache.ignite.ml.tree.randomforest.data.impurity.basic.BootstrappedVectorsHistogram;
import org.apache.ignite.ml.tree.randomforest.data.impurity.basic.CountersHistogram;

public class MSEHistogram
extends ImpurityHistogram
implements ImpurityComputer<BootstrappedVector, MSEHistogram> {
    private static final long serialVersionUID = 9175485616887867623L;
    private final BucketMeta bucketMeta;
    private final int sampleId;
    private ObjectHistogram<BootstrappedVector> counters;
    private ObjectHistogram<BootstrappedVector> sumOfLabels;
    private ObjectHistogram<BootstrappedVector> sumOfSquaredLabels;

    public MSEHistogram(int sampleId, BucketMeta bucketMeta) {
        super(bucketMeta.getFeatureMeta().getFeatureId());
        this.bucketMeta = bucketMeta;
        this.sampleId = sampleId;
        this.counters = new CountersHistogram(this.bucketIds, bucketMeta, this.featureId, sampleId);
        this.sumOfLabels = new SumOfLabelsHistogram(this.bucketIds, bucketMeta, this.featureId, sampleId, 1.0);
        this.sumOfSquaredLabels = new SumOfLabelsHistogram(this.bucketIds, bucketMeta, this.featureId, sampleId, 2.0);
    }

    @Override
    public void addElement(BootstrappedVector vector) {
        this.counters.addElement(vector);
        this.sumOfLabels.addElement(vector);
        this.sumOfSquaredLabels.addElement(vector);
    }

    @Override
    public MSEHistogram plus(MSEHistogram other) {
        MSEHistogram res = new MSEHistogram(this.sampleId, this.bucketMeta);
        res.counters = this.counters.plus(other.counters);
        res.sumOfLabels = this.sumOfLabels.plus(other.sumOfLabels);
        res.sumOfSquaredLabels = this.sumOfSquaredLabels.plus(other.sumOfSquaredLabels);
        res.bucketIds.addAll(this.bucketIds);
        res.bucketIds.addAll(this.bucketIds);
        return res;
    }

    @Override
    public Set<Integer> buckets() {
        return this.bucketIds;
    }

    @Override
    public Optional<Double> getValue(Integer bucketId) {
        throw new IllegalStateException("MSE histogram doesn't support 'getValue' method");
    }

    @Override
    public Optional<NodeSplit> findBestSplit() {
        double bestSplitVal = Double.NEGATIVE_INFINITY;
        int bestBucketId = -1;
        double bestGain = 0.0;
        TreeMap<Integer, Double> cntrDistrib = this.counters.computeDistributionFunction();
        TreeMap<Integer, Double> ysDistrib = this.sumOfLabels.computeDistributionFunction();
        TreeMap<Integer, Double> y2sDistrib = this.sumOfSquaredLabels.computeDistributionFunction();
        double cntrMax = cntrDistrib.lastEntry().getValue();
        double ysMax = ysDistrib.lastEntry().getValue();
        double y2sMax = y2sDistrib.lastEntry().getValue();
        double nodeImpurity = this.impurity(cntrMax, ysMax, y2sMax);
        double lastLeftCntrVal = 0.0;
        double lastLeftYVal = 0.0;
        double lastLeftY2Val = 0.0;
        for (Integer bucketId : this.bucketIds) {
            double gain;
            double leftCnt = cntrDistrib.getOrDefault(bucketId, lastLeftCntrVal);
            double leftY = ysDistrib.getOrDefault(bucketId, lastLeftYVal);
            double leftY2 = y2sDistrib.getOrDefault(bucketId, lastLeftY2Val);
            double rightCnt = cntrMax - leftCnt;
            double rightY = ysMax - leftY;
            double rightY2 = y2sMax - leftY2;
            double childrenImpurity = 0.0;
            if (leftCnt > 0.0) {
                childrenImpurity += this.impurity(leftCnt, leftY, leftY2);
            }
            if (rightCnt > 0.0) {
                childrenImpurity += this.impurity(rightCnt, rightY, rightY2);
            }
            if (!((gain = nodeImpurity - childrenImpurity) > bestGain)) continue;
            bestGain = gain;
            bestSplitVal = this.bucketMeta.bucketIdToValue(bucketId);
            bestBucketId = bucketId;
        }
        return this.checkAndReturnSplitValue(bestBucketId, bestSplitVal, bestGain, nodeImpurity);
    }

    private double impurity(double cnt, double ys, double y2s) {
        return y2s - 2.0 * ys / cnt * ys + Math.pow(ys / cnt, 2.0) * cnt;
    }

    private Integer bucketMap(BootstrappedVector vec) {
        int bucketId = this.bucketMeta.getBucketId(vec.features().get(this.featureId));
        this.bucketIds.add(bucketId);
        return bucketId;
    }

    private Double counterMap(BootstrappedVector vec) {
        return vec.counters()[this.sampleId];
    }

    private Double ysMap(BootstrappedVector vec) {
        return (double)vec.counters()[this.sampleId] * (Double)vec.label();
    }

    private Double y2sMap(BootstrappedVector vec) {
        return (double)vec.counters()[this.sampleId] * Math.pow((Double)vec.label(), 2.0);
    }

    ObjectHistogram<BootstrappedVector> getCounters() {
        return this.counters;
    }

    ObjectHistogram<BootstrappedVector> getSumOfLabels() {
        return this.sumOfLabels;
    }

    ObjectHistogram<BootstrappedVector> getSumOfSquaredLabels() {
        return this.sumOfSquaredLabels;
    }

    @Override
    public boolean isEqualTo(MSEHistogram other) {
        HashSet<Integer> unionBuckets = new HashSet<Integer>(this.buckets());
        unionBuckets.addAll(other.bucketIds);
        if (unionBuckets.size() != this.bucketIds.size()) {
            return false;
        }
        if (!this.counters.isEqualTo(other.counters)) {
            return false;
        }
        if (!this.sumOfLabels.isEqualTo(other.sumOfLabels)) {
            return false;
        }
        return this.sumOfSquaredLabels.isEqualTo(other.sumOfSquaredLabels);
    }

    private static class SumOfLabelsHistogram
    extends BootstrappedVectorsHistogram {
        private static final long serialVersionUID = -3846156279667677800L;
        private final int sampleId;
        private final double labelPower;

        public SumOfLabelsHistogram(Set<Integer> bucketIds, BucketMeta bucketMeta, int featureId, int sampleId, double labelPower) {
            super(bucketIds, bucketMeta, featureId);
            this.sampleId = sampleId;
            this.labelPower = labelPower;
        }

        @Override
        public Integer mapToBucket(BootstrappedVector vec) {
            int bucketId = this.bucketMeta.getBucketId(vec.features().get(this.featureId));
            this.bucketIds.add(bucketId);
            return bucketId;
        }

        @Override
        public Double mapToCounter(BootstrappedVector vec) {
            return (double)vec.counters()[this.sampleId] * Math.pow((Double)vec.label(), this.labelPower);
        }

        @Override
        public ObjectHistogram<BootstrappedVector> newInstance() {
            return new SumOfLabelsHistogram(this.bucketIds, this.bucketMeta, this.featureId, this.sampleId, this.labelPower);
        }
    }
}

