package org.apache.ignite.ml.knn.ann;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeMap;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.knn.NNClassificationModel;
import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
import org.apache.ignite.ml.knn.classification.KNNModelFormat;
import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.util.ModelTrace;
import org.jetbrains.annotations.NotNull;

/* loaded from: input_file:org/apache/ignite/ml/knn/ann/ANNClassificationModel.class */
public class ANNClassificationModel extends NNClassificationModel {
    private static final long serialVersionUID = -127312378991350345L;
    private final LabeledVectorSet<ProbableLabel, LabeledVector> candidates;
    private final ANNClassificationTrainer.CentroidStat centroindsStat;

    public ANNClassificationModel(LabeledVectorSet<ProbableLabel, LabeledVector> labeledVectorSet, ANNClassificationTrainer.CentroidStat centroidStat) {
        this.candidates = labeledVectorSet;
        this.centroindsStat = centroidStat;
    }

    public LabeledVectorSet<ProbableLabel, LabeledVector> getCandidates() {
        return this.candidates;
    }

    public ANNClassificationTrainer.CentroidStat getCentroindsStat() {
        return this.centroindsStat;
    }

    @Override // org.apache.ignite.ml.inference.Model
    public Double predict(Vector vector) {
        return Double.valueOf(classify(findKNearestNeighbors(vector), vector, this.stgy));
    }

    @Override // org.apache.ignite.ml.knn.NNClassificationModel, org.apache.ignite.ml.Exportable
    public <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P p) {
        exporter.save(new ANNModelFormat(this.k, this.distanceMeasure, this.stgy, this.candidates, this.centroindsStat), p);
    }

    private List<LabeledVector> findKNearestNeighbors(Vector vector) {
        return Arrays.asList(getKClosestVectors(getDistances(vector)));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @NotNull
    private LabeledVector[] getKClosestVectors(TreeMap<Double, Set<Integer>> treeMap) {
        LabeledVector[] labeledVectorArr;
        if (this.candidates.rowSize() > this.k) {
            labeledVectorArr = new LabeledVector[this.k];
            int i = 0;
            Iterator<Double> it = treeMap.keySet().iterator();
            while (i < this.k) {
                Iterator<Integer> it2 = treeMap.get(Double.valueOf(it.next().doubleValue())).iterator();
                while (it2.hasNext()) {
                    labeledVectorArr[i] = (LabeledVector) this.candidates.getRow(it2.next().intValue());
                    i++;
                    if (i >= this.k) {
                        break;
                    }
                }
            }
        } else {
            labeledVectorArr = new LabeledVector[this.candidates.rowSize()];
            for (int i2 = 0; i2 < this.candidates.rowSize(); i2++) {
                labeledVectorArr[i2] = (LabeledVector) this.candidates.getRow(i2);
            }
        }
        return labeledVectorArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @NotNull
    private TreeMap<Double, Set<Integer>> getDistances(Vector vector) {
        TreeMap<Double, Set<Integer>> treeMap = new TreeMap<>();
        for (int i = 0; i < this.candidates.rowSize(); i++) {
            LabeledVector labeledVector = (LabeledVector) this.candidates.getRow(i);
            if (labeledVector != null) {
                putDistanceIdxPair(treeMap, i, this.distanceMeasure.compute(vector, labeledVector.features()));
            }
        }
        return treeMap;
    }

    private double classify(List<LabeledVector> list, Vector vector, NNStrategy nNStrategy) {
        HashMap hashMap = new HashMap();
        for (LabeledVector labeledVector : list) {
            TreeMap<Double, Double> treeMap = ((ProbableLabel) labeledVector.label()).clsLbls;
            double compute = this.distanceMeasure.compute(vector, labeledVector.features());
            treeMap.forEach((d, d2) -> {
                hashMap.put(d, Double.valueOf((hashMap.containsKey(d) ? ((Double) hashMap.get(d)).doubleValue() : 0.0d) + (d2.doubleValue() * getClassVoteForVector(nNStrategy, compute))));
            });
        }
        return getClassWithMaxVotes(hashMap);
    }

    @Override // org.apache.ignite.ml.knn.NNClassificationModel
    public int hashCode() {
        return (((((((1 * 37) + this.k) * 37) + this.distanceMeasure.hashCode()) * 37) + this.stgy.hashCode()) * 37) + this.candidates.hashCode();
    }

    @Override // org.apache.ignite.ml.knn.NNClassificationModel
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        ANNClassificationModel aNNClassificationModel = (ANNClassificationModel) obj;
        return this.k == aNNClassificationModel.k && this.distanceMeasure.equals(aNNClassificationModel.distanceMeasure) && this.stgy.equals(aNNClassificationModel.stgy) && this.candidates.equals(aNNClassificationModel.candidates);
    }

    @Override // org.apache.ignite.ml.knn.NNClassificationModel
    public String toString() {
        return toString(false);
    }

    @Override // org.apache.ignite.ml.knn.NNClassificationModel, org.apache.ignite.ml.IgniteModel
    public String toString(boolean z) {
        return ModelTrace.builder("KNNClassificationModel", z).addField("k", String.valueOf(this.k)).addField("measure", this.distanceMeasure.getClass().getSimpleName()).addField("strategy", this.stgy.name()).addField("amount of candidates", String.valueOf(this.candidates.rowSize())).toString();
    }
}
