package org.apache.ignite.ml.knn;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.knn.utils.indices.ArraySpatialIndex;
import org.apache.ignite.ml.knn.utils.indices.BallTreeSpatialIndex;
import org.apache.ignite.ml.knn.utils.indices.KDTreeSpatialIndex;
import org.apache.ignite.ml.knn.utils.indices.SpatialIndex;
import org.apache.ignite.ml.knn.utils.indices.SpatialIndexType;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.preprocessing.Preprocessor;

/* loaded from: input_file:org/apache/ignite/ml/knn/KNNPartitionDataBuilder.class */
public class KNNPartitionDataBuilder<K, V> implements PartitionDataBuilder<K, V, EmptyContext, SpatialIndex<Double>> {
    private final Preprocessor<K, V> preprocessor;
    private final SpatialIndexType spatialIdxType;
    private final DistanceMeasure distanceMeasure;

    public KNNPartitionDataBuilder(Preprocessor<K, V> preprocessor, SpatialIndexType spatialIndexType, DistanceMeasure distanceMeasure) {
        this.preprocessor = preprocessor;
        this.spatialIdxType = spatialIndexType;
        this.distanceMeasure = distanceMeasure;
    }

    @Override // org.apache.ignite.ml.dataset.PartitionDataBuilder
    public SpatialIndex<Double> build(LearningEnvironment learningEnvironment, Iterator<UpstreamEntry<K, V>> it, long j, EmptyContext emptyContext) {
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            UpstreamEntry<K, V> next = it.next();
            arrayList.add(this.preprocessor.apply(next.getKey(), next.getValue()));
        }
        switch (this.spatialIdxType) {
            case ARRAY:
                return new ArraySpatialIndex(arrayList, this.distanceMeasure);
            case KD_TREE:
                return new KDTreeSpatialIndex(arrayList, this.distanceMeasure);
            case BALL_TREE:
                return new BallTreeSpatialIndex(arrayList, this.distanceMeasure);
            default:
                throw new IllegalArgumentException("Unknown spatial index type [type=" + this.spatialIdxType + "]");
        }
    }
}
