package org.apache.ignite.ml.tree.randomforest.data.impurity;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.feature.BucketMeta;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.tree.randomforest.data.NodeId;
import org.apache.ignite.ml.tree.randomforest.data.NodeSplit;
import org.apache.ignite.ml.tree.randomforest.data.TreeNode;
import org.apache.ignite.ml.tree.randomforest.data.TreeRoot;
import org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityComputer;

/* loaded from: input_file:org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.class */
public abstract class ImpurityHistogramsComputer<S extends ImpurityComputer<BootstrappedVector, S>> implements Serializable {
    private static final long serialVersionUID = -4984067145908187508L;

    /* loaded from: input_file:org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer$NodeImpurityHistograms.class */
    public static class NodeImpurityHistograms<S extends ImpurityComputer<BootstrappedVector, S>> implements Serializable {
        private static final long serialVersionUID = 2700045747590421768L;
        private final NodeId nodeId;
        private final Map<Integer, S> perFeatureStatistics = new HashMap();
        static final /* synthetic */ boolean $assertionsDisabled;

        public NodeImpurityHistograms(NodeId nodeId) {
            this.nodeId = nodeId;
        }

        public NodeImpurityHistograms<S> plus(NodeImpurityHistograms<S> nodeImpurityHistograms) {
            if (!$assertionsDisabled && !this.nodeId.equals(nodeImpurityHistograms.nodeId)) {
                throw new AssertionError();
            }
            NodeImpurityHistograms<S> nodeImpurityHistograms2 = new NodeImpurityHistograms<>(this.nodeId);
            addTo(this.perFeatureStatistics, nodeImpurityHistograms2.perFeatureStatistics);
            addTo(nodeImpurityHistograms.perFeatureStatistics, nodeImpurityHistograms2.perFeatureStatistics);
            return nodeImpurityHistograms2;
        }

        private void addTo(Map<Integer, S> map, Map<Integer, S> map2) {
            map.forEach((num, impurityComputer) -> {
                if (map2.containsKey(num)) {
                    map2.put(num, (ImpurityComputer) ((ImpurityComputer) map2.get(num)).plus(impurityComputer));
                } else {
                    map2.put(num, impurityComputer);
                }
            });
        }

        public NodeId getNodeId() {
            return this.nodeId;
        }

        public Optional<NodeSplit> findBestSplit() {
            return this.perFeatureStatistics.values().stream().flatMap(impurityComputer -> {
                return (Stream) impurityComputer.findBestSplit().map((v0) -> {
                    return Stream.of(v0);
                }).orElse(Stream.empty());
            }).min(Comparator.comparingDouble((v0) -> {
                return v0.getImpurity();
            }));
        }

        static {
            $assertionsDisabled = !ImpurityHistogramsComputer.class.desiredAssertionStatus();
        }
    }

    public Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatistics(ArrayList<TreeRoot> arrayList, Map<Integer, BucketMeta> map, Map<NodeId, TreeNode> map2, Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
        return (Map) dataset.compute(bootstrappedDatasetPartition -> {
            return aggregateImpurityStatisticsOnPartition(bootstrappedDatasetPartition, arrayList, map, map2);
        }, this::reduceImpurityStatistics);
    }

    private Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatisticsOnPartition(BootstrappedDatasetPartition bootstrappedDatasetPartition, ArrayList<TreeRoot> arrayList, Map<Integer, BucketMeta> map, Map<NodeId, TreeNode> map2) {
        Map<NodeId, NodeImpurityHistograms<S>> map3 = (Map) map2.keySet().stream().collect(Collectors.toMap(nodeId -> {
            return nodeId;
        }, NodeImpurityHistograms::new));
        bootstrappedDatasetPartition.forEach(bootstrappedVector -> {
            for (int i = 0; i < bootstrappedVector.counters().length; i++) {
                if (bootstrappedVector.counters()[i] != 0) {
                    TreeRoot treeRoot = (TreeRoot) arrayList.get(i);
                    NodeId predictNextNodeKey = treeRoot.getRootNode().predictNextNodeKey(bootstrappedVector.features());
                    if (map2.containsKey(predictNextNodeKey)) {
                        NodeImpurityHistograms nodeImpurityHistograms = (NodeImpurityHistograms) map3.get(predictNextNodeKey);
                        for (Integer num : treeRoot.getUsedFeatures()) {
                            BucketMeta bucketMeta = (BucketMeta) map.get(num);
                            if (!nodeImpurityHistograms.perFeatureStatistics.containsKey(num)) {
                                nodeImpurityHistograms.perFeatureStatistics.put(num, createImpurityComputerForFeature(i, bucketMeta));
                            }
                            ((ImpurityComputer) nodeImpurityHistograms.perFeatureStatistics.get(num)).addElement(bootstrappedVector);
                        }
                    }
                }
            }
        });
        return map3;
    }

    private Map<NodeId, NodeImpurityHistograms<S>> reduceImpurityStatistics(Map<NodeId, NodeImpurityHistograms<S>> map, Map<NodeId, NodeImpurityHistograms<S>> map2) {
        if (map == null) {
            return map2;
        }
        if (map2 == null) {
            return map;
        }
        HashMap hashMap = new HashMap(map);
        for (NodeId nodeId : map2.keySet()) {
            NodeImpurityHistograms<S> nodeImpurityHistograms = map2.get(nodeId);
            if (hashMap.containsKey(nodeId)) {
                hashMap.put(nodeId, map.get(nodeId).plus(nodeImpurityHistograms));
            } else {
                hashMap.put(nodeId, nodeImpurityHistograms);
            }
        }
        return hashMap;
    }

    protected abstract S createImpurityComputerForFeature(int i, BucketMeta bucketMeta);

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1455082429:
                if (implMethodName.equals("lambda$aggregateImpurityStatistics$9c51d9fa$1")) {
                    z = false;
                    break;
                }
                break;
            case 1636022190:
                if (implMethodName.equals("reduceImpurityStatistics")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/ArrayList;Ljava/util/Map;Ljava/util/Map;Lorg/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetPartition;)Ljava/util/Map;")) {
                    ImpurityHistogramsComputer impurityHistogramsComputer = (ImpurityHistogramsComputer) serializedLambda.getCapturedArg(0);
                    ArrayList arrayList = (ArrayList) serializedLambda.getCapturedArg(1);
                    Map map = (Map) serializedLambda.getCapturedArg(2);
                    Map map2 = (Map) serializedLambda.getCapturedArg(3);
                    return bootstrappedDatasetPartition -> {
                        return aggregateImpurityStatisticsOnPartition(bootstrappedDatasetPartition, arrayList, map, map2);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Map;Ljava/util/Map;)Ljava/util/Map;")) {
                    ImpurityHistogramsComputer impurityHistogramsComputer2 = (ImpurityHistogramsComputer) serializedLambda.getCapturedArg(0);
                    return impurityHistogramsComputer2::reduceImpurityStatistics;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
