package org.apache.ignite.ml.knn.classification;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.knn.NNClassificationModel;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;

/* loaded from: input_file:org/apache/ignite/ml/knn/classification/KNNClassificationModel.class */
public class KNNClassificationModel extends NNClassificationModel implements Exportable<KNNModelFormat> {
    private static final long serialVersionUID = -127386523291350345L;
    private List<Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>>> datasets = new ArrayList();

    public KNNClassificationModel(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
        if (dataset != null) {
            this.datasets.add(dataset);
        }
    }

    @Override // org.apache.ignite.ml.inference.Model
    public Double predict(Vector vector) {
        if (this.datasets.isEmpty()) {
            throw new IllegalStateException("The train kNN dataset is null");
        }
        return Double.valueOf(classify(findKNearestNeighbors(vector), vector, this.stgy));
    }

    @Override // org.apache.ignite.ml.knn.NNClassificationModel, org.apache.ignite.ml.Exportable
    public <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P p) {
        exporter.save(new KNNModelFormat(this.k, this.distanceMeasure, this.stgy), p);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<LabeledVector> findKNearestNeighbors(Vector vector) {
        LabeledVectorSet<Double, LabeledVector> buildLabeledDatasetOnListOfVectors = buildLabeledDatasetOnListOfVectors((List) this.datasets.stream().flatMap(dataset -> {
            return findKNearestNeighborsInDataset(vector, dataset).stream();
        }).collect(Collectors.toList()));
        return Arrays.asList(getKClosestVectors(buildLabeledDatasetOnListOfVectors, getDistances(vector, buildLabeledDatasetOnListOfVectors)));
    }

    private List<LabeledVector> findKNearestNeighborsInDataset(Vector vector, Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
        List<LabeledVector> list = (List) dataset.compute(labeledVectorSet -> {
            return Arrays.asList(getKClosestVectors(labeledVectorSet, getDistances(vector, labeledVectorSet)));
        }, (list2, list3) -> {
            return list2 == null ? list3 == null ? new ArrayList() : list3 : list3 == null ? list2 : (List) Stream.concat(list2.stream(), list3.stream()).collect(Collectors.toList());
        });
        if (list == null) {
            return Collections.emptyList();
        }
        LabeledVectorSet<Double, LabeledVector> buildLabeledDatasetOnListOfVectors = buildLabeledDatasetOnListOfVectors(list);
        return Arrays.asList(getKClosestVectors(buildLabeledDatasetOnListOfVectors, getDistances(vector, buildLabeledDatasetOnListOfVectors)));
    }

    private double classify(List<LabeledVector> list, Vector vector, NNStrategy nNStrategy) {
        HashMap hashMap = new HashMap();
        for (LabeledVector labeledVector : list) {
            double doubleValue = ((Double) labeledVector.label()).doubleValue();
            double compute = this.distanceMeasure.compute(vector, labeledVector.features());
            if (hashMap.containsKey(Double.valueOf(doubleValue))) {
                hashMap.put(Double.valueOf(doubleValue), Double.valueOf(hashMap.get(Double.valueOf(doubleValue)).doubleValue() + getClassVoteForVector(nNStrategy, compute)));
            } else {
                hashMap.put(Double.valueOf(doubleValue), Double.valueOf(getClassVoteForVector(nNStrategy, compute)));
            }
        }
        return getClassWithMaxVotes(hashMap);
    }

    public void copyStateFrom(KNNClassificationModel kNNClassificationModel) {
        copyParametersFrom(kNNClassificationModel);
        this.datasets.addAll(kNNClassificationModel.datasets);
    }

    @Override // org.apache.ignite.ml.IgniteModel, org.apache.ignite.ml.inference.Model, java.lang.AutoCloseable
    public void close() {
        for (int i = 0; i < this.datasets.size(); i++) {
            try {
                this.datasets.get(i).close();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1772858051:
                if (implMethodName.equals("lambda$findKNearestNeighborsInDataset$bfa1c864$1")) {
                    z = false;
                    break;
                }
                break;
            case 1650144724:
                if (implMethodName.equals("lambda$findKNearestNeighborsInDataset$bcc18cae$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/knn/classification/KNNClassificationModel") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Ljava/util/List;)Ljava/util/List;")) {
                    return (list2, list3) -> {
                        return list2 == null ? list3 == null ? new ArrayList() : list3 : list3 == null ? list2 : (List) Stream.concat(list2.stream(), list3.stream()).collect(Collectors.toList());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/knn/classification/KNNClassificationModel") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/primitives/vector/Vector;Lorg/apache/ignite/ml/structures/LabeledVectorSet;)Ljava/util/List;")) {
                    KNNClassificationModel kNNClassificationModel = (KNNClassificationModel) serializedLambda.getCapturedArg(0);
                    Vector vector = (Vector) serializedLambda.getCapturedArg(1);
                    return labeledVectorSet -> {
                        return Arrays.asList(getKClosestVectors(labeledVectorSet, getDistances(vector, labeledVectorSet)));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
