package org.apache.ignite.ml.tree.impurity.gini;

import java.util.Arrays;
import java.util.Map;
import org.apache.ignite.ml.tree.data.DecisionTreeData;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasure;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
import org.apache.ignite.ml.tree.impurity.util.StepFunction;

/* loaded from: input_file:org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.class */
public class GiniImpurityMeasureCalculator implements ImpurityMeasureCalculator<GiniImpurityMeasure> {
    private static final long serialVersionUID = -522995134128519679L;
    private final Map<Double, Integer> lbEncoder;
    static final /* synthetic */ boolean $assertionsDisabled;

    public GiniImpurityMeasureCalculator(Map<Double, Integer> map) {
        this.lbEncoder = map;
    }

    @Override // org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator
    public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData decisionTreeData) {
        double[][] features = decisionTreeData.getFeatures();
        double[] labels = decisionTreeData.getLabels();
        if (features.length <= 0) {
            return null;
        }
        StepFunction<GiniImpurityMeasure>[] stepFunctionArr = new StepFunction[features[0].length];
        for (int i = 0; i < stepFunctionArr.length; i++) {
            decisionTreeData.sort(i);
            double[] dArr = new double[features.length + 1];
            GiniImpurityMeasure[] giniImpurityMeasureArr = new GiniImpurityMeasure[features.length + 1];
            long[] jArr = new long[this.lbEncoder.size()];
            long[] jArr2 = new long[this.lbEncoder.size()];
            for (double d : labels) {
                int labelCode = getLabelCode(d);
                jArr2[labelCode] = jArr2[labelCode] + 1;
            }
            int i2 = 0 + 1;
            dArr[0] = Double.NEGATIVE_INFINITY;
            int i3 = 0 + 1;
            giniImpurityMeasureArr[0] = new GiniImpurityMeasure(Arrays.copyOf(jArr, jArr.length), Arrays.copyOf(jArr2, jArr2.length));
            for (int i4 = 0; i4 < features.length; i4++) {
                int labelCode2 = getLabelCode(labels[i4]);
                jArr[labelCode2] = jArr[labelCode2] + 1;
                int labelCode3 = getLabelCode(labels[i4]);
                jArr2[labelCode3] = jArr2[labelCode3] - 1;
                if (i4 >= features.length - 1 || features[i4 + 1][i] != features[i4][i]) {
                    int i5 = i2;
                    i2++;
                    dArr[i5] = features[i4][i];
                    int i6 = i3;
                    i3++;
                    giniImpurityMeasureArr[i6] = new GiniImpurityMeasure(Arrays.copyOf(jArr, jArr.length), Arrays.copyOf(jArr2, jArr2.length));
                }
            }
            stepFunctionArr[i] = new StepFunction<>(Arrays.copyOf(dArr, i2), (ImpurityMeasure[]) Arrays.copyOf(giniImpurityMeasureArr, i3));
        }
        return stepFunctionArr;
    }

    int getLabelCode(double d) {
        Integer num = this.lbEncoder.get(Double.valueOf(d));
        if ($assertionsDisabled || num != null) {
            return num.intValue();
        }
        throw new AssertionError("Can't find code for label " + d);
    }

    static {
        $assertionsDisabled = !GiniImpurityMeasureCalculator.class.desiredAssertionStatus();
    }
}
