package org.apache.ignite.ml.optimization;

import java.lang.invoke.SerializedLambda;
import java.util.Iterator;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.math.impls.vector.FunctionVector;
import org.apache.ignite.ml.optimization.util.SparseDistributedMatrixMapReducer;

/* loaded from: input_file:org/apache/ignite/ml/optimization/GradientDescent.class */
public class GradientDescent {
    private final GradientFunction lossGradient;
    private final Updater updater;
    private int maxIterations = 1000;
    private double convergenceTol = 1.0E-8d;
    static final /* synthetic */ boolean $assertionsDisabled;

    public GradientDescent(GradientFunction gradientFunction, Updater updater) {
        this.lossGradient = gradientFunction;
        this.updater = updater;
    }

    public GradientDescent withMaxIterations(int i) {
        if (!$assertionsDisabled && i < 0) {
            throw new AssertionError();
        }
        this.maxIterations = i;
        return this;
    }

    public GradientDescent withConvergenceTol(double d) {
        if (!$assertionsDisabled && d < 0.0d) {
            throw new AssertionError();
        }
        this.convergenceTol = d;
        return this;
    }

    public Vector optimize(Matrix matrix, Vector vector) {
        Vector vector2 = vector;
        Vector vector3 = null;
        Vector vector4 = null;
        IgniteFunction<Vector, Vector> lossGradientFunction = getLossGradientFunction(matrix);
        for (int i = 0; i < this.maxIterations; i++) {
            Vector apply = lossGradientFunction.apply(vector2);
            Vector compute = this.updater.compute(vector3, vector4, vector2, apply, i);
            if (isConverged(vector2, compute)) {
                return compute;
            }
            vector4 = apply;
            vector3 = vector2;
            vector2 = compute;
        }
        return vector2;
    }

    private Vector calculateDistributedGradient(SparseDistributedMatrix sparseDistributedMatrix, Vector vector) {
        return (Vector) new SparseDistributedMatrixMapReducer(sparseDistributedMatrix).mapReduce((matrix, vector2) -> {
            return this.lossGradient.compute(extractInputs(matrix), extractGroundTruth(matrix), vector2);
        }, collection -> {
            int i = 0;
            DenseLocalOnHeapVector denseLocalOnHeapVector = new DenseLocalOnHeapVector(sparseDistributedMatrix.columnSize());
            Iterator it = collection.iterator();
            while (it.hasNext()) {
                Vector vector3 = (Vector) it.next();
                if (vector3 != null) {
                    denseLocalOnHeapVector = denseLocalOnHeapVector.plus(vector3);
                    i++;
                }
            }
            return denseLocalOnHeapVector.divide(i);
        }, vector);
    }

    private boolean isConverged(Vector vector, Vector vector2) {
        return this.convergenceTol != 0.0d && vector.minus(vector2).kNorm(2.0d) < this.convergenceTol * Math.max(vector2.kNorm(2.0d), 1.0d);
    }

    private Vector extractGroundTruth(Matrix matrix) {
        return matrix.getCol(0);
    }

    private Matrix extractInputs(Matrix matrix) {
        Matrix copy = matrix.copy();
        copy.assignColumn(0, new FunctionVector(copy.rowSize(), num -> {
            return Double.valueOf(1.0d);
        }));
        return copy;
    }

    private IgniteFunction<Vector, Vector> getLossGradientFunction(Matrix matrix) {
        if (matrix instanceof SparseDistributedMatrix) {
            SparseDistributedMatrix sparseDistributedMatrix = (SparseDistributedMatrix) matrix;
            if (sparseDistributedMatrix.getStorage().storageMode() == 2001) {
                return vector -> {
                    return calculateDistributedGradient(sparseDistributedMatrix, vector);
                };
            }
        }
        Matrix extractInputs = extractInputs(matrix);
        Vector extractGroundTruth = extractGroundTruth(matrix);
        return vector2 -> {
            return this.lossGradient.compute(extractInputs, extractGroundTruth, vector2);
        };
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 707681768:
                if (implMethodName.equals("lambda$extractInputs$5988c54b$1")) {
                    z = 3;
                    break;
                }
                break;
            case 1638900357:
                if (implMethodName.equals("lambda$calculateDistributedGradient$8397739b$1")) {
                    z = 4;
                    break;
                }
                break;
            case 1981640156:
                if (implMethodName.equals("lambda$getLossGradientFunction$67165e81$1")) {
                    z = false;
                    break;
                }
                break;
            case 2097917561:
                if (implMethodName.equals("lambda$calculateDistributedGradient$9d7db17b$1")) {
                    z = true;
                    break;
                }
                break;
            case 2127356526:
                if (implMethodName.equals("lambda$getLossGradientFunction$37abfffb$1")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/optimization/GradientDescent") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/Matrix;Lorg/apache/ignite/ml/math/Vector;Lorg/apache/ignite/ml/math/Vector;)Lorg/apache/ignite/ml/math/Vector;")) {
                    GradientDescent gradientDescent = (GradientDescent) serializedLambda.getCapturedArg(0);
                    Matrix matrix = (Matrix) serializedLambda.getCapturedArg(1);
                    Vector vector = (Vector) serializedLambda.getCapturedArg(2);
                    return vector2 -> {
                        return this.lossGradient.compute(matrix, vector, vector2);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/optimization/GradientDescent") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/Matrix;Lorg/apache/ignite/ml/math/Vector;)Lorg/apache/ignite/ml/math/Vector;")) {
                    GradientDescent gradientDescent2 = (GradientDescent) serializedLambda.getCapturedArg(0);
                    return (matrix2, vector22) -> {
                        return this.lossGradient.compute(extractInputs(matrix2), extractGroundTruth(matrix2), vector22);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/optimization/GradientDescent") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrix;Lorg/apache/ignite/ml/math/Vector;)Lorg/apache/ignite/ml/math/Vector;")) {
                    GradientDescent gradientDescent3 = (GradientDescent) serializedLambda.getCapturedArg(0);
                    SparseDistributedMatrix sparseDistributedMatrix = (SparseDistributedMatrix) serializedLambda.getCapturedArg(1);
                    return vector3 -> {
                        return calculateDistributedGradient(sparseDistributedMatrix, vector3);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/optimization/GradientDescent") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Integer;)Ljava/lang/Double;")) {
                    return num -> {
                        return Double.valueOf(1.0d);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/optimization/GradientDescent") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrix;Ljava/util/Collection;)Lorg/apache/ignite/ml/math/Vector;")) {
                    SparseDistributedMatrix sparseDistributedMatrix2 = (SparseDistributedMatrix) serializedLambda.getCapturedArg(0);
                    return collection -> {
                        int i = 0;
                        DenseLocalOnHeapVector denseLocalOnHeapVector = new DenseLocalOnHeapVector(sparseDistributedMatrix2.columnSize());
                        Iterator it = collection.iterator();
                        while (it.hasNext()) {
                            Vector vector32 = (Vector) it.next();
                            if (vector32 != null) {
                                denseLocalOnHeapVector = denseLocalOnHeapVector.plus(vector32);
                                i++;
                            }
                        }
                        return denseLocalOnHeapVector.divide(i);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }

    static {
        $assertionsDisabled = !GradientDescent.class.desiredAssertionStatus();
    }
}
