package org.apache.ignite.ml.clustering;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.apache.ignite.internal.util.GridArgumentCheck;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.VectorUtils;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.exceptions.ConvergenceException;
import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.math.util.MatrixUtil;

/* loaded from: input_file:org/apache/ignite/ml/clustering/KMeansLocalClusterer.class */
public class KMeansLocalClusterer extends BaseKMeansClusterer<DenseLocalOnHeapMatrix> implements WeightedClusterer<DenseLocalOnHeapMatrix, KMeansModel> {
    private int maxIterations;
    private Random rand;

    public KMeansLocalClusterer(DistanceMeasure distanceMeasure, int i, Long l) {
        super(distanceMeasure);
        this.maxIterations = i;
        this.rand = l != null ? new Random(l.longValue()) : new Random();
    }

    @Override // org.apache.ignite.ml.clustering.BaseKMeansClusterer, org.apache.ignite.ml.clustering.Clusterer
    public KMeansModel cluster(DenseLocalOnHeapMatrix denseLocalOnHeapMatrix, int i) throws MathIllegalArgumentException, ConvergenceException {
        return cluster2(denseLocalOnHeapMatrix, i, (List<Double>) new ArrayList(Collections.nCopies(denseLocalOnHeapMatrix.rowSize(), Double.valueOf(1.0d))));
    }

    /* renamed from: cluster, reason: avoid collision after fix types in other method */
    public KMeansModel cluster2(DenseLocalOnHeapMatrix denseLocalOnHeapMatrix, int i, List<Double> list) throws MathIllegalArgumentException, ConvergenceException {
        GridArgumentCheck.notNull(denseLocalOnHeapMatrix, "points");
        int columnSize = denseLocalOnHeapMatrix.columnSize();
        Vector[] vectorArr = new Vector[i];
        vectorArr[0] = pickWeighted(denseLocalOnHeapMatrix, list);
        Vector foldRows = denseLocalOnHeapMatrix.foldRows(vector -> {
            return Double.valueOf(distance(vector, vectorArr[0]));
        });
        for (int i2 = 0; i2 < i; i2++) {
            double nextDouble = this.rand.nextDouble() * weightedSum(foldRows, list);
            double d = 0.0d;
            int i3 = 0;
            while (i3 < denseLocalOnHeapMatrix.rowSize() && d < nextDouble) {
                d += list.get(i3).doubleValue() * foldRows.get(i3);
                i3++;
            }
            if (i3 == 0) {
                vectorArr[i2] = MatrixUtil.localCopyOf(denseLocalOnHeapMatrix.viewRow(0));
            } else {
                vectorArr[i2] = MatrixUtil.localCopyOf(denseLocalOnHeapMatrix.viewRow(i3 - 1));
            }
            for (int i4 = 0; i4 < denseLocalOnHeapMatrix.rowSize(); i4++) {
                foldRows.setX(i4, Math.min(getDistanceMeasure().compute(MatrixUtil.localCopyOf(denseLocalOnHeapMatrix.viewRow(i4)), vectorArr[i2]), foldRows.get(i4)));
            }
        }
        int[] iArr = new int[denseLocalOnHeapMatrix.rowSize()];
        Arrays.fill(iArr, -1);
        boolean z = true;
        for (int i5 = 0; z && i5 < this.maxIterations; i5++) {
            z = false;
            double[] dArr = new double[i];
            Arrays.fill(dArr, 0.0d);
            Vector[] vectorArr2 = new Vector[i];
            Arrays.fill(vectorArr2, VectorUtils.zeroes(columnSize));
            for (int i6 = 0; i6 < denseLocalOnHeapMatrix.rowSize(); i6++) {
                DenseLocalOnHeapVector localCopyOf = MatrixUtil.localCopyOf(denseLocalOnHeapMatrix.viewRow(i6));
                int intValue = ((Integer) findClosest(vectorArr, localCopyOf).get1()).intValue();
                vectorArr2[intValue] = vectorArr2[intValue].plus(localCopyOf.times(list.get(i6).doubleValue()));
                dArr[intValue] = dArr[intValue] + list.get(i6).doubleValue();
                if (intValue != iArr[i6]) {
                    z = true;
                    iArr[i6] = intValue;
                }
            }
            for (int i7 = 0; i7 < i; i7++) {
                if (dArr[i7] == 0.0d) {
                    vectorArr[i7] = denseLocalOnHeapMatrix.viewRow(this.rand.nextInt(denseLocalOnHeapMatrix.rowSize()));
                } else {
                    vectorArr2[i7] = vectorArr2[i7].times(1.0d / dArr[i7]);
                    vectorArr[i7] = vectorArr2[i7];
                }
            }
        }
        return new KMeansModel(vectorArr, getDistanceMeasure());
    }

    private Vector pickWeighted(Matrix matrix, List<Double> list) {
        double nextDouble = this.rand.nextDouble() * list.stream().mapToDouble((v0) -> {
            return Double.valueOf(v0);
        }).sum();
        int i = 0;
        double d = 0.0d;
        while (i < matrix.rowSize() && d < nextDouble) {
            d += list.get(i).doubleValue();
            i++;
        }
        return MatrixUtil.localCopyOf(matrix.viewRow(i - 1));
    }

    private double weightedSum(Vector vector, List<Double> list) {
        double d = 0.0d;
        for (int i = 0; i < vector.size(); i++) {
            d += vector.getX(i) * list.get(i).doubleValue();
        }
        return d;
    }

    @Override // org.apache.ignite.ml.clustering.WeightedClusterer
    public /* bridge */ /* synthetic */ KMeansModel cluster(DenseLocalOnHeapMatrix denseLocalOnHeapMatrix, int i, List list) throws MathIllegalArgumentException, ConvergenceException {
        return cluster2(denseLocalOnHeapMatrix, i, (List<Double>) list);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 415014198:
                if (implMethodName.equals("lambda$cluster$f9d34d2b$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/KMeansLocalClusterer") && serializedLambda.getImplMethodSignature().equals("([Lorg/apache/ignite/ml/math/Vector;Lorg/apache/ignite/ml/math/Vector;)Ljava/lang/Double;")) {
                    KMeansLocalClusterer kMeansLocalClusterer = (KMeansLocalClusterer) serializedLambda.getCapturedArg(0);
                    Vector[] vectorArr = (Vector[]) serializedLambda.getCapturedArg(1);
                    return vector -> {
                        return Double.valueOf(distance(vector, vectorArr[0]));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
