package org.apache.ignite.ml.clustering.gmm;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.primitives.matrix.Matrix;
import org.apache.ignite.ml.math.primitives.vector.Vector;

/* loaded from: input_file:org/apache/ignite/ml/clustering/gmm/CovarianceMatricesAggregator.class */
public class CovarianceMatricesAggregator implements Serializable {
    private static final long serialVersionUID = 4163253784526780812L;
    private final Vector mean;
    private Matrix weightedSum;
    private int rowCount;

    CovarianceMatricesAggregator(Vector vector) {
        this.mean = vector;
    }

    CovarianceMatricesAggregator(Vector vector, Matrix matrix, int i) {
        this.mean = vector;
        this.weightedSum = matrix;
        this.rowCount = i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static List<Matrix> computeCovariances(Dataset<EmptyContext, GmmPartitionData> dataset, Vector vector, Vector[] vectorArr) {
        List list = (List) dataset.compute(gmmPartitionData -> {
            return map(gmmPartitionData, vectorArr);
        }, CovarianceMatricesAggregator::reduce);
        if (list == null) {
            return Collections.emptyList();
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(((CovarianceMatricesAggregator) list.get(i)).covariance(vector.get(i)));
        }
        return arrayList;
    }

    void add(Vector vector, double d) {
        Matrix matrix = vector.minus(this.mean).toMatrix(false);
        Matrix times = matrix.times(matrix.transpose()).times(d);
        if (this.weightedSum == null) {
            this.weightedSum = times;
        } else {
            this.weightedSum = this.weightedSum.plus(times);
        }
        this.rowCount++;
    }

    CovarianceMatricesAggregator plus(CovarianceMatricesAggregator covarianceMatricesAggregator) {
        A.ensure(this.mean.equals(covarianceMatricesAggregator.mean), "this.mean == other.mean");
        return new CovarianceMatricesAggregator(this.mean, this.weightedSum.plus(covarianceMatricesAggregator.weightedSum), this.rowCount + covarianceMatricesAggregator.rowCount);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static List<CovarianceMatricesAggregator> map(GmmPartitionData gmmPartitionData, Vector[] vectorArr) {
        int length = vectorArr.length;
        ArrayList arrayList = new ArrayList();
        for (Vector vector : vectorArr) {
            arrayList.add(new CovarianceMatricesAggregator(vector));
        }
        for (int i = 0; i < gmmPartitionData.size(); i++) {
            for (int i2 = 0; i2 < length; i2++) {
                ((CovarianceMatricesAggregator) arrayList.get(i2)).add(gmmPartitionData.getX(i), gmmPartitionData.pcxi(i2, i));
            }
        }
        return arrayList;
    }

    private Matrix covariance(double d) {
        return this.weightedSum.divide(this.rowCount * d);
    }

    static List<CovarianceMatricesAggregator> reduce(List<CovarianceMatricesAggregator> list, List<CovarianceMatricesAggregator> list2) {
        A.ensure((list == null && list2 == null) ? false : true, "Both partitions cannot equal to null");
        if (list == null || list.isEmpty()) {
            return list2;
        }
        if (list2 == null || list2.isEmpty()) {
            return list;
        }
        A.ensure(list.size() == list2.size(), "l.size() == r.size()");
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(list.get(i).plus(list2.get(i)));
        }
        return arrayList;
    }

    Vector mean() {
        return this.mean.copy();
    }

    Matrix weightedSum() {
        return this.weightedSum.copy();
    }

    public int rowCount() {
        return this.rowCount;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -934873754:
                if (implMethodName.equals("reduce")) {
                    z = false;
                    break;
                }
                break;
            case 1472837427:
                if (implMethodName.equals("lambda$computeCovariances$ed5acb0d$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/gmm/CovarianceMatricesAggregator") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Ljava/util/List;)Ljava/util/List;")) {
                    return CovarianceMatricesAggregator::reduce;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/gmm/CovarianceMatricesAggregator") && serializedLambda.getImplMethodSignature().equals("([Lorg/apache/ignite/ml/math/primitives/vector/Vector;Lorg/apache/ignite/ml/clustering/gmm/GmmPartitionData;)Ljava/util/List;")) {
                    Vector[] vectorArr = (Vector[]) serializedLambda.getCapturedArg(0);
                    return gmmPartitionData -> {
                        return map(gmmPartitionData, vectorArr);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
