package org.apache.ignite.ml.clustering.kmeans;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.VectorUtils;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.math.util.MapUtil;
import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

/* loaded from: input_file:org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.class */
public class KMeansTrainer implements SingleLabelDatasetTrainer<KMeansModel> {
    private int k = 2;
    private int maxIterations = 10;
    private double epsilon = 1.0E-4d;
    private DistanceMeasure distance = new EuclideanDistance();
    private long seed;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/ignite/ml/clustering/kmeans/KMeansTrainer$TotalCostAndCounts.class */
    public static class TotalCostAndCounts {
        double totalCost;
        ConcurrentHashMap<Integer, Vector> sums;
        ConcurrentHashMap<Integer, Integer> counts;

        private TotalCostAndCounts() {
            this.sums = new ConcurrentHashMap<>();
            this.counts = new ConcurrentHashMap<>();
        }

        TotalCostAndCounts merge(TotalCostAndCounts totalCostAndCounts) {
            this.totalCost += this.totalCost;
            this.sums = (ConcurrentHashMap) MapUtil.mergeMaps(this.sums, totalCostAndCounts.sums, (v0, v1) -> {
                return v0.plus(v1);
            }, ConcurrentHashMap::new);
            this.counts = (ConcurrentHashMap) MapUtil.mergeMaps(this.counts, totalCostAndCounts.counts, (num, num2) -> {
                return Integer.valueOf(num.intValue() + num2.intValue());
            }, ConcurrentHashMap::new);
            return this;
        }
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> KMeansModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> igniteBiFunction, IgniteBiFunction<K, V, Double> igniteBiFunction2) {
        if (!$assertionsDisabled && datasetBuilder == null) {
            throw new AssertionError();
        }
        try {
            Dataset<C, D> build = datasetBuilder.build((it, j) -> {
                return new EmptyContext();
            }, new LabeledDatasetPartitionDataBuilderOnHeap(igniteBiFunction, igniteBiFunction2));
            Throwable th = null;
            try {
                try {
                    int intValue = ((Integer) build.compute((v0) -> {
                        return v0.colSize();
                    }, (num, num2) -> {
                        return num == null ? num2 : num;
                    })).intValue();
                    Vector[] initClusterCentersRandomly = initClusterCentersRandomly(build, this.k);
                    boolean z = false;
                    int i = 0;
                    while (i < this.maxIterations && !z) {
                        Vector[] vectorArr = new DenseLocalOnHeapVector[this.k];
                        TotalCostAndCounts calcDataForNewCentroids = calcDataForNewCentroids(initClusterCentersRandomly, build, intValue);
                        z = true;
                        Iterator<K> it2 = calcDataForNewCentroids.sums.keySet().iterator();
                        while (it2.hasNext()) {
                            Integer num3 = (Integer) it2.next();
                            Vector times = calcDataForNewCentroids.sums.get(num3).times(1.0d / calcDataForNewCentroids.counts.get(num3).intValue());
                            if (z && this.distance.compute(times, initClusterCentersRandomly[num3.intValue()]) > this.epsilon * this.epsilon) {
                                z = false;
                            }
                            vectorArr[num3.intValue()] = times;
                        }
                        i++;
                        initClusterCentersRandomly = vectorArr;
                    }
                    if (build != null) {
                        if (0 != 0) {
                            try {
                                build.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            build.close();
                        }
                    }
                    return new KMeansModel(initClusterCentersRandomly, this.distance);
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private TotalCostAndCounts calcDataForNewCentroids(Vector[] vectorArr, Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset, int i) {
        return (TotalCostAndCounts) dataset.compute(labeledDataset -> {
            TotalCostAndCounts totalCostAndCounts = new TotalCostAndCounts();
            for (int i2 = 0; i2 < labeledDataset.rowSize(); i2++) {
                IgniteBiTuple<Integer, Double> findClosestCentroid = findClosestCentroid(vectorArr, (LabeledVector) labeledDataset.getRow(i2));
                int intValue = ((Integer) findClosestCentroid.get1()).intValue();
                labeledDataset.setLabel(i2, intValue);
                totalCostAndCounts.totalCost += ((Double) findClosestCentroid.get2()).doubleValue();
                totalCostAndCounts.sums.putIfAbsent(Integer.valueOf(intValue), VectorUtils.zeroes(i));
                int i3 = i2;
                totalCostAndCounts.sums.compute(Integer.valueOf(intValue), (num, vector) -> {
                    return vector.plus(((LabeledVector) labeledDataset.getRow(i3)).features());
                });
                totalCostAndCounts.counts.merge(Integer.valueOf(intValue), 1, (num2, num3) -> {
                    return Integer.valueOf(num2.intValue() + num3.intValue());
                });
            }
            return totalCostAndCounts;
        }, (totalCostAndCounts, totalCostAndCounts2) -> {
            return totalCostAndCounts == null ? totalCostAndCounts2 : totalCostAndCounts.merge(totalCostAndCounts2);
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r2v4, types: [org.apache.ignite.ml.math.Vector] */
    private IgniteBiTuple<Integer, Double> findClosestCentroid(Vector[] vectorArr, LabeledVector labeledVector) {
        double d = Double.POSITIVE_INFINITY;
        int i = 0;
        for (int i2 = 0; i2 < vectorArr.length; i2++) {
            double compute = this.distance.compute(vectorArr[i2], (Vector) labeledVector.features());
            if (compute < d) {
                d = compute;
                i = i2;
            }
        }
        return new IgniteBiTuple<>(Integer.valueOf(i), Double.valueOf(d));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Vector[] initClusterCentersRandomly(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset, int i) {
        DenseLocalOnHeapVector[] denseLocalOnHeapVectorArr = new DenseLocalOnHeapVector[i];
        List list = (List) dataset.compute(labeledDataset -> {
            ArrayList arrayList = new ArrayList();
            arrayList.add(labeledDataset.getRow(new Random(this.seed).nextInt(labeledDataset.rowSize())));
            return arrayList;
        }, (list2, list3) -> {
            return list2 == null ? list3 : (List) Stream.concat(list2.stream(), list3.stream()).collect(Collectors.toList());
        });
        for (int i2 = 0; i2 < i; i2++) {
            LabeledVector labeledVector = (LabeledVector) list.get(new Random(this.seed).nextInt(list.size()));
            list.remove(labeledVector);
            denseLocalOnHeapVectorArr[i2] = labeledVector.features();
        }
        return denseLocalOnHeapVectorArr;
    }

    public int getK() {
        return this.k;
    }

    public KMeansTrainer withK(int i) {
        this.k = i;
        return this;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public KMeansTrainer withMaxIterations(int i) {
        this.maxIterations = i;
        return this;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public KMeansTrainer withEpsilon(double d) {
        this.epsilon = d;
        return this;
    }

    public DistanceMeasure getDistance() {
        return this.distance;
    }

    public KMeansTrainer withDistance(DistanceMeasure distanceMeasure) {
        this.distance = distanceMeasure;
        return this;
    }

    public long getSeed() {
        return this.seed;
    }

    public KMeansTrainer withSeed(long j) {
        this.seed = j;
        return this;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1523934578:
                if (implMethodName.equals("lambda$calcDataForNewCentroids$1e52db3a$1")) {
                    z = true;
                    break;
                }
                break;
            case -987229453:
                if (implMethodName.equals("lambda$initClusterCentersRandomly$4dba08e1$1")) {
                    z = 7;
                    break;
                }
                break;
            case -943736098:
                if (implMethodName.equals("lambda$initClusterCentersRandomly$1a004068$1")) {
                    z = false;
                    break;
                }
                break;
            case -407537119:
                if (implMethodName.equals("lambda$calcDataForNewCentroids$a375c1d6$1")) {
                    z = 2;
                    break;
                }
                break;
            case 437668609:
                if (implMethodName.equals("lambda$fit$7187901e$1")) {
                    z = 3;
                    break;
                }
                break;
            case 948704673:
                if (implMethodName.equals("colSize")) {
                    z = 8;
                    break;
                }
                break;
            case 1079737696:
                if (implMethodName.equals("lambda$null$dd43bf2b$1")) {
                    z = 6;
                    break;
                }
                break;
            case 1211021620:
                if (implMethodName.equals("lambda$fit$ae875f25$1")) {
                    z = 4;
                    break;
                }
                break;
            case 1496801165:
                if (implMethodName.equals("lambda$null$44d9a8b6$1")) {
                    z = 5;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/kmeans/KMeansTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/structures/LabeledDataset;)Ljava/util/List;")) {
                    KMeansTrainer kMeansTrainer = (KMeansTrainer) serializedLambda.getCapturedArg(0);
                    return labeledDataset -> {
                        ArrayList arrayList = new ArrayList();
                        arrayList.add(labeledDataset.getRow(new Random(this.seed).nextInt(labeledDataset.rowSize())));
                        return arrayList;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/kmeans/KMeansTrainer") && serializedLambda.getImplMethodSignature().equals("([Lorg/apache/ignite/ml/math/Vector;ILorg/apache/ignite/ml/structures/LabeledDataset;)Lorg/apache/ignite/ml/clustering/kmeans/KMeansTrainer$TotalCostAndCounts;")) {
                    KMeansTrainer kMeansTrainer2 = (KMeansTrainer) serializedLambda.getCapturedArg(0);
                    Vector[] vectorArr = (Vector[]) serializedLambda.getCapturedArg(1);
                    int intValue = ((Integer) serializedLambda.getCapturedArg(2)).intValue();
                    return labeledDataset2 -> {
                        TotalCostAndCounts totalCostAndCounts = new TotalCostAndCounts();
                        for (int i2 = 0; i2 < labeledDataset2.rowSize(); i2++) {
                            IgniteBiTuple<Integer, Double> findClosestCentroid = findClosestCentroid(vectorArr, (LabeledVector) labeledDataset2.getRow(i2));
                            int intValue2 = ((Integer) findClosestCentroid.get1()).intValue();
                            labeledDataset2.setLabel(i2, intValue2);
                            totalCostAndCounts.totalCost += ((Double) findClosestCentroid.get2()).doubleValue();
                            totalCostAndCounts.sums.putIfAbsent(Integer.valueOf(intValue2), VectorUtils.zeroes(intValue));
                            int i3 = i2;
                            totalCostAndCounts.sums.compute(Integer.valueOf(intValue2), (num, vector) -> {
                                return vector.plus(((LabeledVector) labeledDataset2.getRow(i3)).features());
                            });
                            totalCostAndCounts.counts.merge(Integer.valueOf(intValue2), 1, (num2, num3) -> {
                                return Integer.valueOf(num2.intValue() + num3.intValue());
                            });
                        }
                        return totalCostAndCounts;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/kmeans/KMeansTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/clustering/kmeans/KMeansTrainer$TotalCostAndCounts;Lorg/apache/ignite/ml/clustering/kmeans/KMeansTrainer$TotalCostAndCounts;)Lorg/apache/ignite/ml/clustering/kmeans/KMeansTrainer$TotalCostAndCounts;")) {
                    return (totalCostAndCounts, totalCostAndCounts2) -> {
                        return totalCostAndCounts == null ? totalCostAndCounts2 : totalCostAndCounts.merge(totalCostAndCounts2);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/kmeans/KMeansTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Integer;Ljava/lang/Integer;)Ljava/lang/Integer;")) {
                    return (num, num2) -> {
                        return num == null ? num2 : num;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/dataset/PartitionContextBuilder") && serializedLambda.getFunctionalInterfaceMethodName().equals("build") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/util/Iterator;J)Ljava/io/Serializable;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/kmeans/KMeansTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Iterator;J)Lorg/apache/ignite/ml/dataset/primitive/context/EmptyContext;")) {
                    return (it, j) -> {
                        return new EmptyContext();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/kmeans/KMeansTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/structures/LabeledDataset;ILjava/lang/Integer;Lorg/apache/ignite/ml/math/Vector;)Lorg/apache/ignite/ml/math/Vector;")) {
                    LabeledDataset labeledDataset3 = (LabeledDataset) serializedLambda.getCapturedArg(0);
                    int intValue2 = ((Integer) serializedLambda.getCapturedArg(1)).intValue();
                    return (num3, vector) -> {
                        return vector.plus(((LabeledVector) labeledDataset3.getRow(intValue2)).features());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/kmeans/KMeansTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Integer;Ljava/lang/Integer;)Ljava/lang/Integer;")) {
                    return (num22, num32) -> {
                        return Integer.valueOf(num22.intValue() + num32.intValue());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/kmeans/KMeansTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Ljava/util/List;)Ljava/util/List;")) {
                    return (list2, list3) -> {
                        return list2 == null ? list3 : (List) Stream.concat(list2.stream(), list3.stream()).collect(Collectors.toList());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/structures/Dataset") && serializedLambda.getImplMethodSignature().equals("()I")) {
                    return (v0) -> {
                        return v0.colSize();
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }

    static {
        $assertionsDisabled = !KMeansTrainer.class.desiredAssertionStatus();
    }
}
