package org.apache.ignite.ml.tree.randomforest.data.impurity;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.apache.commons.math3.util.Precision;
import org.apache.ignite.ml.dataset.feature.BucketMeta;
import org.apache.ignite.ml.dataset.feature.ObjectHistogram;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.tree.randomforest.data.NodeSplit;
import org.apache.ignite.ml.tree.randomforest.data.impurity.basic.CountersHistogram;

/* loaded from: input_file:org/apache/ignite/ml/tree/randomforest/data/impurity/GiniHistogram.class */
public class GiniHistogram extends ImpurityHistogram implements ImpurityComputer<BootstrappedVector, GiniHistogram> {
    private static final long serialVersionUID = 5780670356098827667L;
    private final BucketMeta bucketMeta;
    private final int sampleId;
    private final ArrayList<ObjectHistogram<BootstrappedVector>> hists;
    private final Map<Double, Integer> lblMapping;
    private final Set<Integer> bucketIds;

    public GiniHistogram(int i, Map<Double, Integer> map, BucketMeta bucketMeta) {
        super(bucketMeta.getFeatureMeta().getFeatureId());
        this.hists = new ArrayList<>(map.size());
        this.sampleId = i;
        this.bucketMeta = bucketMeta;
        this.lblMapping = map;
        this.bucketIds = new TreeSet();
        for (int i2 = 0; i2 < map.size(); i2++) {
            this.hists.add(new CountersHistogram(this.bucketIds, bucketMeta, this.featureId, i));
        }
    }

    @Override // org.apache.ignite.ml.dataset.feature.Histogram
    public void addElement(BootstrappedVector bootstrappedVector) {
        this.hists.get(this.lblMapping.get(bootstrappedVector.label()).intValue()).addElement(bootstrappedVector);
    }

    @Override // org.apache.ignite.ml.dataset.feature.Histogram
    public Optional<Double> getValue(Integer num) {
        throw new IllegalStateException("Gini histogram doesn't support 'getValue' method");
    }

    @Override // org.apache.ignite.ml.dataset.feature.Histogram
    public GiniHistogram plus(GiniHistogram giniHistogram) {
        GiniHistogram giniHistogram2 = new GiniHistogram(this.sampleId, this.lblMapping, this.bucketMeta);
        giniHistogram2.bucketIds.addAll(this.bucketIds);
        giniHistogram2.bucketIds.addAll(giniHistogram.bucketIds);
        for (int i = 0; i < this.hists.size(); i++) {
            giniHistogram2.hists.set(i, this.hists.get(i).plus(giniHistogram.hists.get(i)));
        }
        return giniHistogram2;
    }

    @Override // org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityComputer
    public Optional<NodeSplit> findBestSplit() {
        if (this.bucketIds.size() < 2) {
            return Optional.empty();
        }
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        double d2 = 0.0d;
        double d3 = 0.0d;
        List list = (List) this.hists.stream().map((v0) -> {
            return v0.computeDistributionFunction();
        }).collect(Collectors.toList());
        double[] array = list.stream().mapToDouble(treeMap -> {
            if (treeMap.isEmpty()) {
                return 0.0d;
            }
            return ((Double) treeMap.lastEntry().getValue()).doubleValue();
        }).toArray();
        double sum = Arrays.stream(array).sum();
        for (int i2 = 0; i2 < this.lblMapping.size(); i2++) {
            double d4 = array[i2] / sum;
            d3 += d4 * (1.0d - d4);
        }
        if (d3 < Precision.EPSILON) {
            return Optional.empty();
        }
        HashMap hashMap = new HashMap();
        for (int i3 = 0; i3 < this.lblMapping.size(); i3++) {
            hashMap.put(Integer.valueOf(i3), Double.valueOf(0.0d));
        }
        for (Integer num : this.bucketIds) {
            double d5 = 0.0d;
            double d6 = 0.0d;
            double d7 = 0.0d;
            double d8 = 0.0d;
            for (int i4 = 0; i4 < this.lblMapping.size(); i4++) {
                Double d9 = (Double) ((TreeMap) list.get(i4)).get(num);
                if (d9 == null) {
                    d9 = (Double) hashMap.get(Integer.valueOf(i4));
                }
                d5 += d9.doubleValue();
                d6 += array[i4] - d9.doubleValue();
                hashMap.put(Integer.valueOf(i4), d9);
            }
            for (int i5 = 0; i5 < this.lblMapping.size(); i5++) {
                Double d10 = (Double) ((TreeMap) list.get(i5)).getOrDefault(num, hashMap.get(Integer.valueOf(i5)));
                if (d10.doubleValue() > 0.0d) {
                    double doubleValue = d10.doubleValue() / d5;
                    d7 += doubleValue * (1.0d - doubleValue);
                }
                double doubleValue2 = array[i5] - d10.doubleValue();
                if (doubleValue2 > 0.0d) {
                    double d11 = doubleValue2 / d6;
                    d8 += d11 * (1.0d - d11);
                }
            }
            double d12 = d3 - ((d7 * (d5 / (d5 + d6))) + (d8 * (d6 / (d5 + d6))));
            if (d12 > d2) {
                d = this.bucketMeta.bucketIdToValue(num.intValue());
                i = num.intValue();
                d2 = d12;
            }
        }
        return checkAndReturnSplitValue(i, d, d2, d3);
    }

    @Override // org.apache.ignite.ml.dataset.feature.Histogram
    public Set<Integer> buckets() {
        return this.bucketIds;
    }

    ObjectHistogram<BootstrappedVector> getHistForLabel(Double d) {
        return this.hists.get(this.lblMapping.get(d).intValue());
    }

    @Override // org.apache.ignite.ml.dataset.feature.Histogram
    public boolean isEqualTo(GiniHistogram giniHistogram) {
        HashSet hashSet = new HashSet(buckets());
        hashSet.addAll(giniHistogram.bucketIds);
        if (hashSet.size() != this.bucketIds.size()) {
            return false;
        }
        HashSet hashSet2 = new HashSet(this.lblMapping.keySet());
        hashSet2.addAll(giniHistogram.lblMapping.keySet());
        if (hashSet2.size() != this.lblMapping.size()) {
            return false;
        }
        Iterator it = hashSet2.iterator();
        while (it.hasNext()) {
            Double d = (Double) it.next();
            if (this.lblMapping.get(d) != giniHistogram.lblMapping.get(d) || !getHistForLabel(d).isEqualTo(giniHistogram.getHistForLabel(d))) {
                return false;
            }
        }
        return true;
    }
}
