package org.apache.ignite.ml.knn.utils.indices;

import java.util.Collections;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import org.apache.ignite.ml.knn.utils.PointWithDistance;
import org.apache.ignite.ml.knn.utils.PointWithDistanceUtil;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;

/* loaded from: input_file:org/apache/ignite/ml/knn/utils/indices/KDTreeSpatialIndex.class */
public class KDTreeSpatialIndex<L> implements SpatialIndex<L> {
    private final DistanceMeasure distanceMeasure;
    private KDTreeSpatialIndex<L>.TreeNode root;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/ignite/ml/knn/utils/indices/KDTreeSpatialIndex$TreeNode.class */
    public final class TreeNode {
        private final LabeledVector<L> val;
        private KDTreeSpatialIndex<L>.TreeNode left;
        private KDTreeSpatialIndex<L>.TreeNode right;

        TreeNode(LabeledVector<L> labeledVector) {
            this.val = labeledVector;
        }
    }

    public KDTreeSpatialIndex(List<LabeledVector<L>> list, DistanceMeasure distanceMeasure) {
        this.distanceMeasure = distanceMeasure;
        list.forEach(labeledVector -> {
            this.root = add(this.root, labeledVector);
        });
    }

    @Override // org.apache.ignite.ml.knn.utils.indices.SpatialIndex
    public List<LabeledVector<L>> findKClosest(int i, Vector vector) {
        if (i <= 0) {
            throw new IllegalArgumentException("Number of neighbours should be positive.");
        }
        PriorityQueue priorityQueue = new PriorityQueue(Collections.reverseOrder());
        findKClosest(vector, this.root, 0, priorityQueue, i);
        return PointWithDistanceUtil.transfomToListOrdered(priorityQueue);
    }

    private void findKClosest(Vector vector, KDTreeSpatialIndex<L>.TreeNode treeNode, int i, Queue<PointWithDistance<L>> queue, int i2) {
        if (treeNode == null) {
            return;
        }
        PointWithDistanceUtil.tryToAddIntoHeap(queue, i2, ((TreeNode) treeNode).val, this.distanceMeasure.compute(vector, ((TreeNode) treeNode).val.features()));
        double d = vector.get(i);
        double d2 = ((TreeNode) treeNode).val.get(i);
        KDTreeSpatialIndex<L>.TreeNode treeNode2 = d > d2 ? ((TreeNode) treeNode).right : ((TreeNode) treeNode).left;
        findKClosestInSplittedSpace(vector, treeNode2, treeNode2 == ((TreeNode) treeNode).right ? ((TreeNode) treeNode).left : ((TreeNode) treeNode).right, (i + 1) % vector.size(), Math.abs(d - d2), queue, i2);
    }

    private void findKClosestInSplittedSpace(Vector vector, KDTreeSpatialIndex<L>.TreeNode treeNode, KDTreeSpatialIndex<L>.TreeNode treeNode2, int i, double d, Queue<PointWithDistance<L>> queue, int i2) {
        findKClosest(vector, treeNode, i, queue, i2);
        if (queue.size() < i2 || d < queue.peek().getDistance()) {
            findKClosest(vector, treeNode2, i, queue, i2);
        }
    }

    private KDTreeSpatialIndex<L>.TreeNode add(KDTreeSpatialIndex<L>.TreeNode treeNode, LabeledVector<L> labeledVector) {
        if (treeNode == null) {
            return new TreeNode(labeledVector);
        }
        addIntoExistingTree(treeNode, labeledVector);
        return treeNode;
    }

    private void addIntoExistingTree(KDTreeSpatialIndex<L>.TreeNode treeNode, LabeledVector<L> labeledVector) {
        KDTreeSpatialIndex<L>.TreeNode treeNode2;
        int i = 0;
        while (true) {
            int i2 = i;
            if (labeledVector.get(i2) > ((TreeNode) treeNode).val.get(i2)) {
                if (((TreeNode) treeNode).right == null) {
                    ((TreeNode) treeNode).right = new TreeNode(labeledVector);
                    return;
                }
                treeNode2 = ((TreeNode) treeNode).right;
            } else {
                if (((TreeNode) treeNode).left == null) {
                    ((TreeNode) treeNode).left = new TreeNode(labeledVector);
                    return;
                }
                treeNode2 = ((TreeNode) treeNode).left;
            }
            treeNode = treeNode2;
            i = (i2 + 1) % labeledVector.size();
        }
    }
}
