package org.apache.ignite.ml.tree.randomforest.data.impurity;

import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import org.apache.ignite.ml.dataset.feature.BucketMeta;
import org.apache.ignite.ml.dataset.feature.ObjectHistogram;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.tree.randomforest.data.NodeSplit;
import org.apache.ignite.ml.tree.randomforest.data.impurity.basic.BootstrappedVectorsHistogram;
import org.apache.ignite.ml.tree.randomforest.data.impurity.basic.CountersHistogram;

/* loaded from: input_file:org/apache/ignite/ml/tree/randomforest/data/impurity/MSEHistogram.class */
public class MSEHistogram extends ImpurityHistogram implements ImpurityComputer<BootstrappedVector, MSEHistogram> {
    private static final long serialVersionUID = 9175485616887867623L;
    private final BucketMeta bucketMeta;
    private final int sampleId;
    private ObjectHistogram<BootstrappedVector> counters;
    private ObjectHistogram<BootstrappedVector> sumOfLabels;
    private ObjectHistogram<BootstrappedVector> sumOfSquaredLabels;

    /* loaded from: input_file:org/apache/ignite/ml/tree/randomforest/data/impurity/MSEHistogram$SumOfLabelsHistogram.class */
    private static class SumOfLabelsHistogram extends BootstrappedVectorsHistogram {
        private static final long serialVersionUID = -3846156279667677800L;
        private final int sampleId;
        private final double labelPower;

        public SumOfLabelsHistogram(Set<Integer> set, BucketMeta bucketMeta, int i, int i2, double d) {
            super(set, bucketMeta, i);
            this.sampleId = i2;
            this.labelPower = d;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.apache.ignite.ml.tree.randomforest.data.impurity.basic.BootstrappedVectorsHistogram, org.apache.ignite.ml.dataset.feature.ObjectHistogram
        public Integer mapToBucket(BootstrappedVector bootstrappedVector) {
            int bucketId = this.bucketMeta.getBucketId(Double.valueOf(bootstrappedVector.features().get(this.featureId)));
            this.bucketIds.add(Integer.valueOf(bucketId));
            return Integer.valueOf(bucketId);
        }

        @Override // org.apache.ignite.ml.dataset.feature.ObjectHistogram
        public Double mapToCounter(BootstrappedVector bootstrappedVector) {
            return Double.valueOf(bootstrappedVector.counters()[this.sampleId] * Math.pow(bootstrappedVector.label().doubleValue(), this.labelPower));
        }

        @Override // org.apache.ignite.ml.dataset.feature.ObjectHistogram
        public ObjectHistogram<BootstrappedVector> newInstance() {
            return new SumOfLabelsHistogram(this.bucketIds, this.bucketMeta, this.featureId, this.sampleId, this.labelPower);
        }
    }

    public MSEHistogram(int i, BucketMeta bucketMeta) {
        super(bucketMeta.getFeatureMeta().getFeatureId());
        this.bucketMeta = bucketMeta;
        this.sampleId = i;
        this.counters = new CountersHistogram(this.bucketIds, bucketMeta, this.featureId, i);
        this.sumOfLabels = new SumOfLabelsHistogram(this.bucketIds, bucketMeta, this.featureId, i, 1.0d);
        this.sumOfSquaredLabels = new SumOfLabelsHistogram(this.bucketIds, bucketMeta, this.featureId, i, 2.0d);
    }

    @Override // org.apache.ignite.ml.dataset.feature.Histogram
    public void addElement(BootstrappedVector bootstrappedVector) {
        this.counters.addElement(bootstrappedVector);
        this.sumOfLabels.addElement(bootstrappedVector);
        this.sumOfSquaredLabels.addElement(bootstrappedVector);
    }

    @Override // org.apache.ignite.ml.dataset.feature.Histogram
    public MSEHistogram plus(MSEHistogram mSEHistogram) {
        MSEHistogram mSEHistogram2 = new MSEHistogram(this.sampleId, this.bucketMeta);
        mSEHistogram2.counters = this.counters.plus(mSEHistogram.counters);
        mSEHistogram2.sumOfLabels = this.sumOfLabels.plus(mSEHistogram.sumOfLabels);
        mSEHistogram2.sumOfSquaredLabels = this.sumOfSquaredLabels.plus(mSEHistogram.sumOfSquaredLabels);
        mSEHistogram2.bucketIds.addAll(this.bucketIds);
        mSEHistogram2.bucketIds.addAll(this.bucketIds);
        return mSEHistogram2;
    }

    @Override // org.apache.ignite.ml.dataset.feature.Histogram
    public Set<Integer> buckets() {
        return this.bucketIds;
    }

    @Override // org.apache.ignite.ml.dataset.feature.Histogram
    public Optional<Double> getValue(Integer num) {
        throw new IllegalStateException("MSE histogram doesn't support 'getValue' method");
    }

    @Override // org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityComputer
    public Optional<NodeSplit> findBestSplit() {
        double d = Double.POSITIVE_INFINITY;
        double d2 = Double.NEGATIVE_INFINITY;
        int i = -1;
        TreeMap<Integer, Double> computeDistributionFunction = this.counters.computeDistributionFunction();
        TreeMap<Integer, Double> computeDistributionFunction2 = this.sumOfLabels.computeDistributionFunction();
        TreeMap<Integer, Double> computeDistributionFunction3 = this.sumOfSquaredLabels.computeDistributionFunction();
        double doubleValue = computeDistributionFunction.lastEntry().getValue().doubleValue();
        double doubleValue2 = computeDistributionFunction2.lastEntry().getValue().doubleValue();
        double doubleValue3 = computeDistributionFunction3.lastEntry().getValue().doubleValue();
        for (Integer num : this.bucketIds) {
            double doubleValue4 = ((Double) computeDistributionFunction.getOrDefault(num, Double.valueOf(0.0d))).doubleValue();
            double doubleValue5 = ((Double) computeDistributionFunction2.getOrDefault(num, Double.valueOf(0.0d))).doubleValue();
            double doubleValue6 = ((Double) computeDistributionFunction3.getOrDefault(num, Double.valueOf(0.0d))).doubleValue();
            double d3 = doubleValue - doubleValue4;
            double d4 = doubleValue2 - doubleValue5;
            double d5 = doubleValue3 - doubleValue6;
            double impurity = doubleValue4 > 0.0d ? 0.0d + impurity(doubleValue4, doubleValue5, doubleValue6) : 0.0d;
            if (d3 > 0.0d) {
                impurity += impurity(d3, d4, d5);
            }
            if (impurity < d) {
                d = impurity;
                d2 = this.bucketMeta.bucketIdToValue(num.intValue());
                i = num.intValue();
            }
        }
        return checkAndReturnSplitValue(i, d2, d);
    }

    private double impurity(double d, double d2, double d3) {
        return (d3 - (((2.0d * d2) / d) * d2)) + (Math.pow(d2 / d, 2.0d) * d);
    }

    private Integer bucketMap(BootstrappedVector bootstrappedVector) {
        int bucketId = this.bucketMeta.getBucketId(Double.valueOf(bootstrappedVector.features().get(this.featureId)));
        this.bucketIds.add(Integer.valueOf(bucketId));
        return Integer.valueOf(bucketId);
    }

    private Double counterMap(BootstrappedVector bootstrappedVector) {
        return Double.valueOf(bootstrappedVector.counters()[this.sampleId]);
    }

    private Double ysMap(BootstrappedVector bootstrappedVector) {
        return Double.valueOf(bootstrappedVector.counters()[this.sampleId] * bootstrappedVector.label().doubleValue());
    }

    private Double y2sMap(BootstrappedVector bootstrappedVector) {
        return Double.valueOf(bootstrappedVector.counters()[this.sampleId] * Math.pow(bootstrappedVector.label().doubleValue(), 2.0d));
    }

    ObjectHistogram<BootstrappedVector> getCounters() {
        return this.counters;
    }

    ObjectHistogram<BootstrappedVector> getSumOfLabels() {
        return this.sumOfLabels;
    }

    ObjectHistogram<BootstrappedVector> getSumOfSquaredLabels() {
        return this.sumOfSquaredLabels;
    }

    @Override // org.apache.ignite.ml.dataset.feature.Histogram
    public boolean isEqualTo(MSEHistogram mSEHistogram) {
        HashSet hashSet = new HashSet(buckets());
        hashSet.addAll(mSEHistogram.bucketIds);
        if (hashSet.size() == this.bucketIds.size() && this.counters.isEqualTo(mSEHistogram.counters) && this.sumOfLabels.isEqualTo(mSEHistogram.sumOfLabels)) {
            return this.sumOfSquaredLabels.isEqualTo(mSEHistogram.sumOfSquaredLabels);
        }
        return false;
    }
}
