package org.apache.ignite.ml.trees.trainers.columnbased.vectors;

import com.zaxxer.sparsebits.SparseBitSet;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.trees.CategoricalRegionInfo;
import org.apache.ignite.ml.trees.CategoricalSplitInfo;
import org.apache.ignite.ml.trees.RegionInfo;
import org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection;

/* loaded from: input_file:org/apache/ignite/ml/trees/trainers/columnbased/vectors/CategoricalFeatureProcessor.class */
public class CategoricalFeatureProcessor implements FeatureProcessor<CategoricalRegionInfo, CategoricalSplitInfo<CategoricalRegionInfo>> {
    private final int catsCnt;
    private final IgniteFunction<DoubleStream, Double> calc;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/ignite/ml/trees/trainers/columnbased/vectors/CategoricalFeatureProcessor$PSI.class */
    public static class PSI implements Iterator<BitSet> {
        private int i = 1;
        final int size;

        PSI(int i) {
            this.size = 1 << (i - 1);
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.i < this.size;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public BitSet next() {
            BitSet valueOf = BitSet.valueOf(new long[]{this.i});
            this.i++;
            return valueOf;
        }
    }

    public CategoricalFeatureProcessor(IgniteFunction<DoubleStream, Double> igniteFunction, int i) {
        this.calc = igniteFunction;
        this.catsCnt = i;
    }

    private SplitInfo<CategoricalRegionInfo> split(BitSet bitSet, int i, Map<Integer, Integer> map, Integer[] numArr, double[] dArr, double[] dArr2, double d) {
        Map map2 = (Map) Arrays.stream(numArr).collect(Collectors.partitioningBy(num -> {
            return bitSet.get(((Integer) map.get(Integer.valueOf((int) dArr[num.intValue()]))).intValue());
        }));
        List list = (List) map2.get(true);
        int size = list.size();
        double doubleValue = this.calc.apply(list.stream().mapToDouble(num2 -> {
            return dArr2[num2.intValue()];
        })).doubleValue();
        List list2 = (List) map2.get(false);
        int size2 = list2.size();
        double doubleValue2 = this.calc.apply(list2.stream().mapToDouble(num3 -> {
            return dArr2[num3.intValue()];
        })).doubleValue();
        int i2 = size + size2;
        CategoricalSplitInfo categoricalSplitInfo = new CategoricalSplitInfo(i, new CategoricalRegionInfo(doubleValue, null), new CategoricalRegionInfo(doubleValue2, null), bitSet);
        categoricalSplitInfo.setInfoGain((d - ((size / i2) * doubleValue)) - ((size2 / i2) * doubleValue2));
        return categoricalSplitInfo;
    }

    private Stream<BitSet> powerSet(int i) {
        Iterable iterable = () -> {
            return new PSI(i);
        };
        return StreamSupport.stream(iterable.spliterator(), false);
    }

    @Override // org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor
    public SplitInfo findBestSplit(RegionProjection<CategoricalRegionInfo> regionProjection, double[] dArr, double[] dArr2, int i) {
        Map<Integer, Integer> mapping = mapping(regionProjection.data().cats());
        return (SplitInfo) powerSet(regionProjection.data().cats().length()).map(bitSet -> {
            return split(bitSet, i, mapping, regionProjection.sampleIndexes(), dArr, dArr2, ((CategoricalRegionInfo) regionProjection.data()).impurity());
        }).max(Comparator.comparingDouble((v0) -> {
            return v0.infoGain();
        })).orElse(null);
    }

    @Override // org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor
    public RegionProjection<CategoricalRegionInfo> createInitialRegion(Integer[] numArr, double[] dArr, double[] dArr2) {
        BitSet bitSet = new BitSet();
        bitSet.set(0, this.catsCnt);
        return new RegionProjection<>(numArr, new CategoricalRegionInfo(this.calc.apply(Arrays.stream(dArr2)).doubleValue(), bitSet), 0);
    }

    @Override // org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor
    public SparseBitSet calculateOwnershipBitSet(RegionProjection<CategoricalRegionInfo> regionProjection, double[] dArr, CategoricalSplitInfo<CategoricalRegionInfo> categoricalSplitInfo) {
        SparseBitSet sparseBitSet = new SparseBitSet();
        Arrays.stream(regionProjection.sampleIndexes()).forEach(num -> {
            sparseBitSet.set(num.intValue(), categoricalSplitInfo.bitSet().get((int) dArr[num.intValue()]));
        });
        return sparseBitSet;
    }

    @Override // org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor
    public IgniteBiTuple<RegionProjection, RegionProjection> performSplit(SparseBitSet sparseBitSet, RegionProjection<CategoricalRegionInfo> regionProjection, CategoricalRegionInfo categoricalRegionInfo, CategoricalRegionInfo categoricalRegionInfo2) {
        return performSplitGeneric(sparseBitSet, null, regionProjection, categoricalRegionInfo, categoricalRegionInfo2);
    }

    @Override // org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor
    public IgniteBiTuple<RegionProjection, RegionProjection> performSplitGeneric(SparseBitSet sparseBitSet, double[] dArr, RegionProjection<CategoricalRegionInfo> regionProjection, RegionInfo regionInfo, RegionInfo regionInfo2) {
        int depth = regionProjection.depth();
        int cardinality = sparseBitSet.cardinality();
        IgniteBiTuple<Integer[], Integer[]> splitByBitSet = FeatureVectorProcessorUtils.splitByBitSet(cardinality, regionProjection.sampleIndexes().length - cardinality, regionProjection.sampleIndexes(), sparseBitSet);
        CategoricalRegionInfo categoricalRegionInfo = new CategoricalRegionInfo(regionInfo.impurity(), calculateCats((Integer[]) splitByBitSet.get1(), dArr));
        return new IgniteBiTuple<>(new RegionProjection((Integer[]) splitByBitSet.get1(), categoricalRegionInfo, depth + 1), new RegionProjection((Integer[]) splitByBitSet.get2(), new CategoricalRegionInfo(regionInfo2.impurity(), calculateCats((Integer[]) splitByBitSet.get2(), dArr)), depth + 1));
    }

    private Map<Integer, Integer> mapping(BitSet bitSet) {
        int i = 0;
        HashMap hashMap = new HashMap();
        int i2 = 0;
        while (true) {
            int nextSetBit = bitSet.nextSetBit(i);
            if (nextSetBit == -1) {
                return hashMap;
            }
            hashMap.put(Integer.valueOf(nextSetBit), Integer.valueOf(i2));
            i2++;
            i = nextSetBit + 1;
        }
    }

    private BitSet calculateCats(Integer[] numArr, double[] dArr) {
        BitSet bitSet = new BitSet();
        for (Integer num : numArr) {
            bitSet.set((int) dArr[num.intValue()]);
        }
        return bitSet;
    }
}
