package org.apache.ignite.ml.knn.utils.indices;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import org.apache.ignite.ml.knn.utils.PointWithDistance;
import org.apache.ignite.ml.knn.utils.PointWithDistanceUtil;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.structures.LabeledVector;

/* loaded from: input_file:org/apache/ignite/ml/knn/utils/indices/BallTreeSpatialIndex.class */
public class BallTreeSpatialIndex<L> implements SpatialIndex<L> {
    private static final int MAX_LEAF_SIZE = 42;
    private static final double SPLIT_BALL_MARGIN = 0.2d;
    private final DistanceMeasure distanceMeasure;
    private final BallTreeSpatialIndex<L>.TreeNode root;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/ignite/ml/knn/utils/indices/BallTreeSpatialIndex$TreeInnerNode.class */
    public final class TreeInnerNode extends BallTreeSpatialIndex<L>.TreeNode {
        private final BallTreeSpatialIndex<L>.TreeNode left;
        private final BallTreeSpatialIndex<L>.TreeNode right;

        TreeInnerNode(Vector vector, double d, BallTreeSpatialIndex<L>.TreeNode treeNode, BallTreeSpatialIndex<L>.TreeNode treeNode2) {
            super(vector, d);
            this.left = treeNode;
            this.right = treeNode2;
        }

        @Override // org.apache.ignite.ml.knn.utils.indices.BallTreeSpatialIndex.TreeNode
        void findKClosest(Vector vector, Queue<PointWithDistance<L>> queue, int i) {
            BallTreeSpatialIndex<L>.TreeNode treeNode = computeDistToCenter(vector, this.left) > computeDistToCenter(vector, this.right) ? this.right : this.left;
            BallTreeSpatialIndex<L>.TreeNode treeNode2 = treeNode == this.right ? this.left : this.right;
            if (treeNode != null) {
                treeNode.findKClosest(vector, queue, i);
            }
            if (treeNode2 != null) {
                double computeDistToCenter = computeDistToCenter(vector, treeNode2) - treeNode2.getRadius();
                if (queue.size() < i || computeDistToCenter < queue.peek().getDistance()) {
                    treeNode2.findKClosest(vector, queue, i);
                }
            }
        }

        private double computeDistToCenter(Vector vector, BallTreeSpatialIndex<L>.TreeNode treeNode) {
            if (treeNode == null) {
                return Double.MAX_VALUE;
            }
            return BallTreeSpatialIndex.this.distanceMeasure.compute(vector, treeNode.getCenter());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/ignite/ml/knn/utils/indices/BallTreeSpatialIndex$TreeLeafNode.class */
    public final class TreeLeafNode extends BallTreeSpatialIndex<L>.TreeNode {
        private final List<LabeledVector<L>> points;

        TreeLeafNode(Vector vector, double d, List<LabeledVector<L>> list) {
            super(vector, d);
            this.points = list;
        }

        @Override // org.apache.ignite.ml.knn.utils.indices.BallTreeSpatialIndex.TreeNode
        void findKClosest(Vector vector, Queue<PointWithDistance<L>> queue, int i) {
            for (LabeledVector<L> labeledVector : this.points) {
                PointWithDistanceUtil.tryToAddIntoHeap(queue, i, labeledVector, BallTreeSpatialIndex.this.distanceMeasure.compute(vector, labeledVector.features()));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/ignite/ml/knn/utils/indices/BallTreeSpatialIndex$TreeNode.class */
    public abstract class TreeNode {
        private final Vector center;
        private final double radius;

        TreeNode(Vector vector, double d) {
            this.center = vector;
            this.radius = d;
        }

        abstract void findKClosest(Vector vector, Queue<PointWithDistance<L>> queue, int i);

        public Vector getCenter() {
            return this.center;
        }

        public double getRadius() {
            return this.radius;
        }
    }

    public BallTreeSpatialIndex(List<LabeledVector<L>> list, DistanceMeasure distanceMeasure) {
        this.distanceMeasure = distanceMeasure;
        this.root = buildTree(list);
    }

    @Override // org.apache.ignite.ml.knn.utils.indices.SpatialIndex
    public List<LabeledVector<L>> findKClosest(int i, Vector vector) {
        PriorityQueue priorityQueue = new PriorityQueue(Collections.reverseOrder());
        this.root.findKClosest(vector, priorityQueue, i);
        return PointWithDistanceUtil.transfomToListOrdered(priorityQueue);
    }

    private BallTreeSpatialIndex<L>.TreeNode buildTree(List<LabeledVector<L>> list) {
        Vector calculateCenter = calculateCenter(list);
        return buildTree(list, calculateCenter, calculateRadius(list, calculateCenter));
    }

    private BallTreeSpatialIndex<L>.TreeNode buildTree(List<LabeledVector<L>> list, Vector vector, double d) {
        if (list.size() <= MAX_LEAF_SIZE) {
            return new TreeLeafNode(vector, d, list);
        }
        Vector calculateCenter = calculateCenter(list);
        Vector copy = calculateCenter.copy();
        int calculateBestDimForSplit = calculateBestDimForSplit(list);
        double calculateMin = calculateMin(list, calculateBestDimForSplit);
        double calculateMax = calculateMax(list, calculateBestDimForSplit);
        calculateCenter.set(calculateBestDimForSplit, calculateMin + ((calculateMax - calculateMin) * SPLIT_BALL_MARGIN));
        copy.set(calculateBestDimForSplit, calculateMin + ((calculateMax - calculateMin) * 0.8d));
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        splitPoints(list, calculateCenter, copy, arrayList, arrayList2);
        list.clear();
        return new TreeInnerNode(vector, d, buildTree(arrayList, calculateCenter, calculateRadius(arrayList, calculateCenter)), buildTree(arrayList2, copy, calculateRadius(arrayList2, copy)));
    }

    private void splitPoints(List<LabeledVector<L>> list, Vector vector, Vector vector2, List<LabeledVector<L>> list2, List<LabeledVector<L>> list3) {
        for (LabeledVector<L> labeledVector : list) {
            (this.distanceMeasure.compute(vector, labeledVector.features()) < this.distanceMeasure.compute(vector2, labeledVector.features()) ? list2 : list3).add(labeledVector);
        }
    }

    private double calculateRadius(List<LabeledVector<L>> list, Vector vector) {
        double d = 0.0d;
        Iterator<LabeledVector<L>> it = list.iterator();
        while (it.hasNext()) {
            d = Math.max(d, this.distanceMeasure.compute(vector, it.next().features()));
        }
        return d;
    }

    private Vector calculateCenter(List<LabeledVector<L>> list) {
        if (list.isEmpty()) {
            return null;
        }
        double[] dArr = new double[list.get(0).size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = calculateMean(list, i);
        }
        return VectorUtils.of(dArr);
    }

    private int calculateBestDimForSplit(List<LabeledVector<L>> list) {
        if (list.isEmpty()) {
            return -1;
        }
        double d = 0.0d;
        int i = -1;
        for (int i2 = 0; i2 < list.get(0).size(); i2++) {
            double calculateStd = calculateStd(list, i2);
            if (calculateStd > d) {
                d = calculateStd;
                i = i2;
            }
        }
        return i;
    }

    private double calculateMax(List<LabeledVector<L>> list, int i) {
        double d = Double.NEGATIVE_INFINITY;
        Iterator<LabeledVector<L>> it = list.iterator();
        while (it.hasNext()) {
            d = Math.max(d, it.next().get(i));
        }
        return d;
    }

    private double calculateMin(List<LabeledVector<L>> list, int i) {
        double d = Double.POSITIVE_INFINITY;
        Iterator<LabeledVector<L>> it = list.iterator();
        while (it.hasNext()) {
            d = Math.min(d, it.next().get(i));
        }
        return d;
    }

    private double calculateStd(List<LabeledVector<L>> list, int i) {
        double d = 0.0d;
        double calculateMean = calculateMean(list, i);
        Iterator<LabeledVector<L>> it = list.iterator();
        while (it.hasNext()) {
            d += Math.pow(it.next().get(i) - calculateMean, 2.0d);
        }
        return Math.sqrt(d / list.size());
    }

    private double calculateMean(List<LabeledVector<L>> list, int i) {
        double d = 0.0d;
        Iterator<LabeledVector<L>> it = list.iterator();
        while (it.hasNext()) {
            d += it.next().get(i);
        }
        return d / list.size();
    }
}
