package org.apache.ignite.ml.knn.classification;

import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.structures.LabeledVector;
import org.jetbrains.annotations.NotNull;

/* loaded from: input_file:org/apache/ignite/ml/knn/classification/KNNClassificationModel.class */
public class KNNClassificationModel<K, V> implements Model<Vector, Double>, Exportable<KNNModelFormat> {
    private static final long serialVersionUID = -127386523291350345L;
    protected int k = 5;
    protected DistanceMeasure distanceMeasure = new EuclideanDistance();
    protected KNNStrategy stgy = KNNStrategy.SIMPLE;
    private Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset;

    public KNNClassificationModel(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset) {
        this.dataset = dataset;
    }

    @Override // java.util.function.Function
    public Double apply(Vector vector) {
        if (this.dataset != null) {
            return Double.valueOf(classify(findKNearestNeighbors(vector), vector, this.stgy));
        }
        throw new IllegalStateException("The train kNN dataset is null");
    }

    @Override // org.apache.ignite.ml.Exportable
    public <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P p) {
        exporter.save(new KNNModelFormat(this.k, this.distanceMeasure, this.stgy), p);
    }

    public KNNClassificationModel<K, V> withK(int i) {
        this.k = i;
        return this;
    }

    public KNNClassificationModel<K, V> withStrategy(KNNStrategy kNNStrategy) {
        this.stgy = kNNStrategy;
        return this;
    }

    public KNNClassificationModel<K, V> withDistanceMeasure(DistanceMeasure distanceMeasure) {
        this.distanceMeasure = distanceMeasure;
        return this;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<LabeledVector> findKNearestNeighbors(Vector vector) {
        LabeledDataset<Double, LabeledVector> buildLabeledDatasetOnListOfVectors = buildLabeledDatasetOnListOfVectors((List) this.dataset.compute(labeledDataset -> {
            return Arrays.asList(getKClosestVectors(labeledDataset, getDistances(vector, labeledDataset)));
        }, (list, list2) -> {
            return list == null ? list2 : (List) Stream.concat(list.stream(), list2.stream()).collect(Collectors.toList());
        }));
        return Arrays.asList(getKClosestVectors(buildLabeledDatasetOnListOfVectors, getDistances(vector, buildLabeledDatasetOnListOfVectors)));
    }

    private LabeledDataset<Double, LabeledVector> buildLabeledDatasetOnListOfVectors(List<LabeledVector> list) {
        LabeledVector[] labeledVectorArr = new LabeledVector[list.size()];
        for (int i = 0; i < labeledVectorArr.length; i++) {
            labeledVectorArr[i] = list.get(i);
        }
        return new LabeledDataset<>(labeledVectorArr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @NotNull
    private LabeledVector[] getKClosestVectors(LabeledDataset<Double, LabeledVector> labeledDataset, TreeMap<Double, Set<Integer>> treeMap) {
        LabeledVector[] labeledVectorArr;
        if (labeledDataset.rowSize() > this.k) {
            labeledVectorArr = new LabeledVector[this.k];
            int i = 0;
            Iterator<Double> it = treeMap.keySet().iterator();
            while (i < this.k) {
                Iterator<Integer> it2 = treeMap.get(Double.valueOf(it.next().doubleValue())).iterator();
                while (it2.hasNext()) {
                    labeledVectorArr[i] = (LabeledVector) labeledDataset.getRow(it2.next().intValue());
                    i++;
                    if (i >= this.k) {
                        break;
                    }
                }
            }
        } else {
            labeledVectorArr = new LabeledVector[labeledDataset.rowSize()];
            for (int i2 = 0; i2 < labeledDataset.rowSize(); i2++) {
                labeledVectorArr[i2] = (LabeledVector) labeledDataset.getRow(i2);
            }
        }
        return labeledVectorArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @NotNull
    private TreeMap<Double, Set<Integer>> getDistances(Vector vector, LabeledDataset<Double, LabeledVector> labeledDataset) {
        TreeMap<Double, Set<Integer>> treeMap = new TreeMap<>();
        for (int i = 0; i < labeledDataset.rowSize(); i++) {
            LabeledVector labeledVector = (LabeledVector) labeledDataset.getRow(i);
            if (labeledVector != null) {
                putDistanceIdxPair(treeMap, i, this.distanceMeasure.compute(vector, (Vector) labeledVector.features()));
            }
        }
        return treeMap;
    }

    private void putDistanceIdxPair(Map<Double, Set<Integer>> map, int i, double d) {
        if (map.containsKey(Double.valueOf(d))) {
            map.get(Double.valueOf(d)).add(Integer.valueOf(i));
            return;
        }
        HashSet hashSet = new HashSet();
        hashSet.add(Integer.valueOf(i));
        map.put(Double.valueOf(d), hashSet);
    }

    private double classify(List<LabeledVector> list, Vector vector, KNNStrategy kNNStrategy) {
        HashMap hashMap = new HashMap();
        for (LabeledVector labeledVector : list) {
            double doubleValue = ((Double) labeledVector.label()).doubleValue();
            double compute = this.distanceMeasure.compute(vector, (Vector) labeledVector.features());
            if (hashMap.containsKey(Double.valueOf(doubleValue))) {
                hashMap.put(Double.valueOf(doubleValue), Double.valueOf(hashMap.get(Double.valueOf(doubleValue)).doubleValue() + getClassVoteForVector(kNNStrategy, compute)));
            } else {
                hashMap.put(Double.valueOf(doubleValue), Double.valueOf(getClassVoteForVector(kNNStrategy, compute)));
            }
        }
        return getClassWithMaxVotes(hashMap);
    }

    private double getClassWithMaxVotes(Map<Double, Double> map) {
        return ((Double) ((Map.Entry) Collections.max(map.entrySet(), Map.Entry.comparingByValue())).getKey()).doubleValue();
    }

    private double getClassVoteForVector(KNNStrategy kNNStrategy, double d) {
        if (kNNStrategy.equals(KNNStrategy.WEIGHTED)) {
            return 1.0d / d;
        }
        return 1.0d;
    }

    public int hashCode() {
        return (((((1 * 37) + this.k) * 37) + this.distanceMeasure.hashCode()) * 37) + this.stgy.hashCode();
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        KNNClassificationModel kNNClassificationModel = (KNNClassificationModel) obj;
        return this.k == kNNClassificationModel.k && this.distanceMeasure.equals(kNNClassificationModel.distanceMeasure) && this.stgy.equals(kNNClassificationModel.stgy);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -706068515:
                if (implMethodName.equals("lambda$findKNearestNeighbors$477f2ea$1")) {
                    z = false;
                    break;
                }
                break;
            case 1806377212:
                if (implMethodName.equals("lambda$findKNearestNeighbors$955cf29f$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/knn/classification/KNNClassificationModel") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/Vector;Lorg/apache/ignite/ml/structures/LabeledDataset;)Ljava/util/List;")) {
                    KNNClassificationModel kNNClassificationModel = (KNNClassificationModel) serializedLambda.getCapturedArg(0);
                    Vector vector = (Vector) serializedLambda.getCapturedArg(1);
                    return labeledDataset -> {
                        return Arrays.asList(getKClosestVectors(labeledDataset, getDistances(vector, labeledDataset)));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/knn/classification/KNNClassificationModel") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Ljava/util/List;)Ljava/util/List;")) {
                    return (list, list2) -> {
                        return list == null ? list2 : (List) Stream.concat(list.stream(), list2.stream()).collect(Collectors.toList());
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
