package org.apache.ignite.ml.knn.ann;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.List;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.stream.Collectors;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.util.MapUtil;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.jetbrains.annotations.NotNull;

/* loaded from: input_file:org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.class */
public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClassificationModel> {
    private int k = 2;
    private int maxIterations = 10;
    private double epsilon = 1.0E-4d;
    private DistanceMeasure distance = new EuclideanDistance();
    private long seed;

    /* loaded from: input_file:org/apache/ignite/ml/knn/ann/ANNClassificationTrainer$CentroidStat.class */
    public static class CentroidStat implements Serializable {
        private static final long serialVersionUID = 7624883170532045144L;
        ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> centroidStat = new ConcurrentHashMap<>();
        ConcurrentHashMap<Integer, Integer> counts = new ConcurrentHashMap<>();
        ConcurrentSkipListSet<Double> clsLblsSet = new ConcurrentSkipListSet<>();

        CentroidStat merge(CentroidStat centroidStat) {
            this.counts = (ConcurrentHashMap) MapUtil.mergeMaps(this.counts, centroidStat.counts, (num, num2) -> {
                return Integer.valueOf(num.intValue() + num2.intValue());
            }, ConcurrentHashMap::new);
            this.centroidStat = (ConcurrentHashMap) MapUtil.mergeMaps(this.centroidStat, centroidStat.centroidStat, (concurrentHashMap, concurrentHashMap2) -> {
                return (ConcurrentHashMap) MapUtil.mergeMaps(concurrentHashMap, concurrentHashMap2, (num3, num4) -> {
                    return Integer.valueOf(num3.intValue() + num4.intValue());
                }, ConcurrentHashMap::new);
            }, ConcurrentHashMap::new);
            this.clsLblsSet.addAll(centroidStat.clsLblsSet);
            return this;
        }

        public ConcurrentSkipListSet<Double> labels() {
            return this.clsLblsSet;
        }

        ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> centroidStat() {
            return this.centroidStat;
        }
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> ANNClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, Double> igniteBiFunction2) {
        return updateModel((ANNClassificationModel) null, (DatasetBuilder) datasetBuilder, (IgniteBiFunction) igniteBiFunction, (IgniteBiFunction) igniteBiFunction2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> ANNClassificationModel updateModel(ANNClassificationModel aNNClassificationModel, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, Double> igniteBiFunction2) {
        List<Vector> centroids;
        CentroidStat centroidStat;
        if (aNNClassificationModel != null) {
            centroids = (List) Arrays.stream(aNNClassificationModel.getCandidates().data()).map(datasetRow -> {
                return datasetRow.features();
            }).collect(Collectors.toList());
            CentroidStat centroidStat2 = getCentroidStat(datasetBuilder, igniteBiFunction, igniteBiFunction2, centroids);
            if (centroidStat2 == null) {
                return aNNClassificationModel;
            }
            centroidStat = centroidStat2.merge(aNNClassificationModel.getCentroindsStat());
        } else {
            centroids = getCentroids(igniteBiFunction, igniteBiFunction2, datasetBuilder);
            centroidStat = getCentroidStat(datasetBuilder, igniteBiFunction, igniteBiFunction2, centroids);
        }
        return new ANNClassificationModel(buildLabelsForCandidates(centroids, centroidStat), centroidStat);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public boolean checkState(ANNClassificationModel aNNClassificationModel) {
        return aNNClassificationModel.getDistanceMeasure().equals(this.distance) && aNNClassificationModel.getCandidates().rowSize() == this.k;
    }

    @NotNull
    private LabeledVectorSet<ProbableLabel, LabeledVector> buildLabelsForCandidates(List<Vector> list, CentroidStat centroidStat) {
        LabeledVector[] labeledVectorArr = new LabeledVector[list.size()];
        for (int i = 0; i < list.size(); i++) {
            labeledVectorArr[i] = new LabeledVector(list.get(i), fillProbableLabel(i, centroidStat));
        }
        return new LabeledVectorSet<>(labeledVectorArr);
    }

    private <K, V> List<Vector> getCentroids(IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, Double> igniteBiFunction2, DatasetBuilder<K, V> datasetBuilder) {
        return Arrays.asList(new KMeansTrainer().withAmountOfClusters(this.k).withMaxIterations(this.maxIterations).withSeed(this.seed).withDistance(this.distance).withEpsilon(this.epsilon).fit((DatasetBuilder) datasetBuilder, (IgniteBiFunction) igniteBiFunction, (IgniteBiFunction) igniteBiFunction2).getCenters());
    }

    private ProbableLabel fillProbableLabel(int i, CentroidStat centroidStat) {
        TreeMap treeMap = new TreeMap();
        centroidStat.clsLblsSet.forEach(d -> {
        });
        ConcurrentHashMap<Double, Integer> concurrentHashMap = centroidStat.centroidStat().get(Integer.valueOf(i));
        if (centroidStat.counts.containsKey(Integer.valueOf(i))) {
            int intValue = centroidStat.counts.get(Integer.valueOf(i)).intValue();
            treeMap.keySet().forEach(d2 -> {
            });
        }
        return new ProbableLabel(treeMap);
    }

    private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, Double> igniteBiFunction2, List<Vector> list) {
        try {
            Dataset<C, D> build = datasetBuilder.build((it, j) -> {
                return new EmptyContext();
            }, new LabeledDatasetPartitionDataBuilderOnHeap(igniteBiFunction, igniteBiFunction2));
            Throwable th = null;
            try {
                try {
                    CentroidStat centroidStat = (CentroidStat) build.compute(labeledVectorSet -> {
                        CentroidStat centroidStat2 = new CentroidStat();
                        for (int i = 0; i < labeledVectorSet.rowSize(); i++) {
                            int intValue = ((Integer) findClosestCentroid(list, (LabeledVector) labeledVectorSet.getRow(i)).get1()).intValue();
                            double label = labeledVectorSet.label(i);
                            centroidStat2.labels().add(Double.valueOf(label));
                            ConcurrentHashMap<Double, Integer> concurrentHashMap = centroidStat2.centroidStat.get(Integer.valueOf(intValue));
                            if (concurrentHashMap == null) {
                                ConcurrentHashMap<Double, Integer> concurrentHashMap2 = new ConcurrentHashMap<>();
                                concurrentHashMap2.put(Double.valueOf(label), 1);
                                centroidStat2.centroidStat.put(Integer.valueOf(intValue), concurrentHashMap2);
                            } else {
                                concurrentHashMap.put(Double.valueOf(label), Integer.valueOf(concurrentHashMap.getOrDefault(Double.valueOf(label), 0).intValue() + 1));
                            }
                            centroidStat2.counts.merge(Integer.valueOf(intValue), 1, (num, num2) -> {
                                return Integer.valueOf(num.intValue() + num2.intValue());
                            });
                        }
                        return centroidStat2;
                    }, (centroidStat2, centroidStat3) -> {
                        return centroidStat2 == null ? centroidStat3 == null ? new CentroidStat() : centroidStat3 : centroidStat3 == null ? centroidStat2 : centroidStat2.merge(centroidStat3);
                    });
                    if (build != 0) {
                        if (0 != 0) {
                            try {
                                build.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            build.close();
                        }
                    }
                    return centroidStat;
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r2v4, types: [org.apache.ignite.ml.math.primitives.vector.Vector] */
    private IgniteBiTuple<Integer, Double> findClosestCentroid(List<Vector> list, LabeledVector labeledVector) {
        double d = Double.POSITIVE_INFINITY;
        int i = 0;
        for (int i2 = 0; i2 < list.size(); i2++) {
            if (list.get(i2) != null) {
                double compute = this.distance.compute(list.get(i2), (Vector) labeledVector.features());
                if (compute < d) {
                    d = compute;
                    i = i2;
                }
            }
        }
        return new IgniteBiTuple<>(Integer.valueOf(i), Double.valueOf(d));
    }

    public int getK() {
        return this.k;
    }

    public ANNClassificationTrainer withK(int i) {
        this.k = i;
        return this;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public ANNClassificationTrainer withMaxIterations(int i) {
        this.maxIterations = i;
        return this;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public ANNClassificationTrainer withEpsilon(double d) {
        this.epsilon = d;
        return this;
    }

    public DistanceMeasure getDistance() {
        return this.distance;
    }

    public ANNClassificationTrainer withDistance(DistanceMeasure distanceMeasure) {
        this.distance = distanceMeasure;
        return this;
    }

    public long getSeed() {
        return this.seed;
    }

    public ANNClassificationTrainer withSeed(long j) {
        this.seed = j;
        return this;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1711520696:
                if (implMethodName.equals("lambda$getCentroidStat$31b8d543$1")) {
                    z = true;
                    break;
                }
                break;
            case -838376139:
                if (implMethodName.equals("lambda$getCentroidStat$5cb0aa0e$1")) {
                    z = false;
                    break;
                }
                break;
            case -26453417:
                if (implMethodName.equals("lambda$getCentroidStat$7bf6b318$1")) {
                    z = 3;
                    break;
                }
                break;
            case 1079737696:
                if (implMethodName.equals("lambda$null$dd43bf2b$1")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/dataset/PartitionContextBuilder") && serializedLambda.getFunctionalInterfaceMethodName().equals("build") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/util/Iterator;J)Ljava/io/Serializable;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/knn/ann/ANNClassificationTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Iterator;J)Lorg/apache/ignite/ml/dataset/primitive/context/EmptyContext;")) {
                    return (it, j) -> {
                        return new EmptyContext();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/knn/ann/ANNClassificationTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/knn/ann/ANNClassificationTrainer$CentroidStat;Lorg/apache/ignite/ml/knn/ann/ANNClassificationTrainer$CentroidStat;)Lorg/apache/ignite/ml/knn/ann/ANNClassificationTrainer$CentroidStat;")) {
                    return (centroidStat2, centroidStat3) -> {
                        return centroidStat2 == null ? centroidStat3 == null ? new CentroidStat() : centroidStat3 : centroidStat3 == null ? centroidStat2 : centroidStat2.merge(centroidStat3);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/knn/ann/ANNClassificationTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Integer;Ljava/lang/Integer;)Ljava/lang/Integer;")) {
                    return (num, num2) -> {
                        return Integer.valueOf(num.intValue() + num2.intValue());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/knn/ann/ANNClassificationTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Lorg/apache/ignite/ml/structures/LabeledVectorSet;)Lorg/apache/ignite/ml/knn/ann/ANNClassificationTrainer$CentroidStat;")) {
                    ANNClassificationTrainer aNNClassificationTrainer = (ANNClassificationTrainer) serializedLambda.getCapturedArg(0);
                    List list = (List) serializedLambda.getCapturedArg(1);
                    return labeledVectorSet -> {
                        CentroidStat centroidStat22 = new CentroidStat();
                        for (int i = 0; i < labeledVectorSet.rowSize(); i++) {
                            int intValue = ((Integer) findClosestCentroid(list, (LabeledVector) labeledVectorSet.getRow(i)).get1()).intValue();
                            double label = labeledVectorSet.label(i);
                            centroidStat22.labels().add(Double.valueOf(label));
                            ConcurrentHashMap<Double, Integer> concurrentHashMap = centroidStat22.centroidStat.get(Integer.valueOf(intValue));
                            if (concurrentHashMap == null) {
                                ConcurrentHashMap<Double, Integer> concurrentHashMap2 = new ConcurrentHashMap<>();
                                concurrentHashMap2.put(Double.valueOf(label), 1);
                                centroidStat22.centroidStat.put(Integer.valueOf(intValue), concurrentHashMap2);
                            } else {
                                concurrentHashMap.put(Double.valueOf(label), Integer.valueOf(concurrentHashMap.getOrDefault(Double.valueOf(label), 0).intValue() + 1));
                            }
                            centroidStat22.counts.merge(Integer.valueOf(intValue), 1, (num3, num22) -> {
                                return Integer.valueOf(num3.intValue() + num22.intValue());
                            });
                        }
                        return centroidStat22;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
