package org.apache.ignite.ml.regressions;

import java.lang.invoke.SerializedLambda;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.decompositions.QRDecomposition;
import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException;
import org.apache.ignite.ml.math.functions.Functions;
import org.apache.ignite.ml.math.util.MatrixUtil;

/* loaded from: input_file:org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.class */
public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
    private QRDecomposition qr;
    private final double threshold;

    public OLSMultipleLinearRegression() {
        this(0.0d);
    }

    public OLSMultipleLinearRegression(double d) {
        this.qr = null;
        this.threshold = d;
    }

    public void newSampleData(Vector vector, Matrix matrix) throws MathIllegalArgumentException {
        validateSampleData(matrix, vector);
        newYSampleData(vector);
        newXSampleData(matrix);
    }

    @Override // org.apache.ignite.ml.regressions.AbstractMultipleLinearRegression
    public void newSampleData(double[] dArr, int i, int i2, Matrix matrix) {
        super.newSampleData(dArr, i, i2, matrix);
        this.qr = new QRDecomposition(getX(), this.threshold);
    }

    public Matrix calculateHat() {
        Matrix q = this.qr.getQ();
        Matrix like = MatrixUtil.like(q, q.columnSize(), q.columnSize());
        int columnSize = like.columnSize();
        int columnSize2 = this.qr.getR().columnSize();
        for (int i = 0; i < columnSize; i++) {
            for (int i2 = 0; i2 < columnSize; i2++) {
                if (i != i2 || i >= columnSize2) {
                    like.setX(i, i2, 0.0d);
                } else {
                    like.setX(i, i2, 1.0d);
                }
            }
        }
        return q.times(like).times(q.transpose());
    }

    public double calculateTotalSumOfSquares() {
        if (isNoIntercept()) {
            return ((Double) getY().foldMap(Functions.PLUS, Functions.SQUARE, Double.valueOf(0.0d))).doubleValue();
        }
        double sum = getY().sum() / getY().size();
        return ((Double) getY().foldMap(Functions.PLUS, d -> {
            return Double.valueOf((sum - d) * (sum - d));
        }, Double.valueOf(0.0d))).doubleValue();
    }

    public double calculateResidualSumOfSquares() {
        Vector calculateResiduals = calculateResiduals();
        return calculateResiduals.dot(calculateResiduals);
    }

    public double calculateRSquared() {
        return 1.0d - (calculateResidualSumOfSquares() / calculateTotalSumOfSquares());
    }

    public double calculateAdjustedRSquared() {
        double rowSize = getX().rowSize();
        return isNoIntercept() ? 1.0d - ((1.0d - calculateRSquared()) * (rowSize / (rowSize - getX().columnSize()))) : 1.0d - ((calculateResidualSumOfSquares() * (rowSize - 1.0d)) / (calculateTotalSumOfSquares() * (rowSize - getX().columnSize())));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.regressions.AbstractMultipleLinearRegression
    public void newXSampleData(Matrix matrix) {
        super.newXSampleData(matrix);
        this.qr = new QRDecomposition(getX());
    }

    @Override // org.apache.ignite.ml.regressions.AbstractMultipleLinearRegression
    protected Vector calculateBeta() {
        return this.qr.solve(getY());
    }

    @Override // org.apache.ignite.ml.regressions.AbstractMultipleLinearRegression
    protected Matrix calculateBetaVariance() {
        int columnSize = getX().columnSize();
        Matrix inverse = MatrixUtil.copy(this.qr.getR().viewPart(0, columnSize, 0, columnSize)).inverse();
        return inverse.times(inverse.transpose());
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -811547370:
                if (implMethodName.equals("lambda$calculateTotalSumOfSquares$d345d9d3$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteDoubleFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(D)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/regressions/OLSMultipleLinearRegression") && serializedLambda.getImplMethodSignature().equals("(DD)Ljava/lang/Double;")) {
                    double doubleValue = ((Double) serializedLambda.getCapturedArg(0)).doubleValue();
                    return d -> {
                        return Double.valueOf((doubleValue - d) * (doubleValue - d));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
