package org.apache.ignite.ml.optimization.updatecalculators;

import java.lang.invoke.SerializedLambda;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.VectorUtils;
import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.util.MatrixUtil;
import org.apache.ignite.ml.optimization.SmoothParametrized;

/* loaded from: input_file:org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator.class */
public class RPropUpdateCalculator implements ParameterUpdateCalculator<SmoothParametrized, RPropParameterUpdate> {
    private static double DFLT_INIT_UPDATE = 0.1d;
    private static double DFLT_ACCELERATION_RATE = 1.2d;
    private static double DFLT_DEACCELERATION_RATE = 0.5d;
    private final double initUpdate;
    private final double accelerationRate;
    private final double deaccelerationRate;
    private static final double UPDATE_MAX = 50.0d;
    private static final double UPDATE_MIN = 1.0E-6d;
    protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;

    public RPropUpdateCalculator(double d, double d2, double d3) {
        this.initUpdate = d;
        this.accelerationRate = d2;
        this.deaccelerationRate = d3;
    }

    public RPropUpdateCalculator() {
        this(DFLT_INIT_UPDATE, DFLT_ACCELERATION_RATE, DFLT_DEACCELERATION_RATE);
    }

    @Override // org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator
    public RPropParameterUpdate calculateNewUpdate(SmoothParametrized smoothParametrized, RPropParameterUpdate rPropParameterUpdate, int i, Matrix matrix, Matrix matrix2) {
        Vector differentiateByParameters = smoothParametrized.differentiateByParameters(this.loss, matrix, matrix2);
        Vector prevIterationGradient = rPropParameterUpdate.prevIterationGradient();
        Vector zipWith = prevIterationGradient != null ? VectorUtils.zipWith(prevIterationGradient, differentiateByParameters, (d, d2) -> {
            return Double.valueOf(Math.signum(d.doubleValue() * d2.doubleValue()));
        }) : differentiateByParameters.like(differentiateByParameters.size()).assign(1.0d);
        Vector vector = zipWith;
        return new RPropParameterUpdate(MatrixUtil.zipWith(differentiateByParameters, rPropParameterUpdate.deltas(), (d3, d4, num) -> {
            return vector.getX(num.intValue()) >= 0.0d ? Double.valueOf((-Math.signum(d3.doubleValue())) * d4.doubleValue()) : Double.valueOf(rPropParameterUpdate.prevIterationUpdates().getX(num.intValue()));
        }), differentiateByParameters.copy(), rPropParameterUpdate.deltas().copy().map(zipWith, (d5, d6) -> {
            return d6.doubleValue() > 0.0d ? Double.valueOf(Math.min(d5.doubleValue() * this.accelerationRate, UPDATE_MAX)) : d6.doubleValue() < 0.0d ? Double.valueOf(Math.max(d5.doubleValue() * this.deaccelerationRate, UPDATE_MIN)) : d5;
        }), MatrixUtil.zipWith(zipWith, rPropParameterUpdate.prevIterationUpdates(), (d7, d8, num2) -> {
            if (d7.doubleValue() < 0.0d) {
                differentiateByParameters.setX(num2.intValue(), 0.0d);
            }
            return d7.doubleValue() >= 0.0d ? Double.valueOf(1.0d) : Double.valueOf(-1.0d);
        }));
    }

    /* renamed from: init, reason: avoid collision after fix types in other method */
    public RPropParameterUpdate init2(SmoothParametrized smoothParametrized, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> igniteFunction) {
        this.loss = igniteFunction;
        return new RPropParameterUpdate(smoothParametrized.parametersCount(), this.initUpdate);
    }

    @Override // org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator
    public <M1 extends SmoothParametrized> M1 update(M1 m1, RPropParameterUpdate rPropParameterUpdate) {
        return (M1) m1.setParameters(m1.parameters().plus(VectorUtils.elementWiseTimes(rPropParameterUpdate.updatesMask().copy(), rPropParameterUpdate.prevIterationUpdates())));
    }

    @Override // org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator
    public /* bridge */ /* synthetic */ RPropParameterUpdate init(SmoothParametrized smoothParametrized, IgniteFunction igniteFunction) {
        return init2(smoothParametrized, (IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction>) igniteFunction);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1281095079:
                if (implMethodName.equals("lambda$calculateNewUpdate$684d4fbf$1")) {
                    z = 2;
                    break;
                }
                break;
            case 1167559925:
                if (implMethodName.equals("lambda$calculateNewUpdate$f4156c16$1")) {
                    z = false;
                    break;
                }
                break;
            case 1193607371:
                if (implMethodName.equals("lambda$calculateNewUpdate$55bb8321$1")) {
                    z = 3;
                    break;
                }
                break;
            case 1288538761:
                if (implMethodName.equals("lambda$calculateNewUpdate$5153eddc$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteTriFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/Vector;Lorg/apache/ignite/ml/optimization/updatecalculators/RPropParameterUpdate;Ljava/lang/Double;Ljava/lang/Double;Ljava/lang/Integer;)Ljava/lang/Double;")) {
                    Vector vector = (Vector) serializedLambda.getCapturedArg(0);
                    RPropParameterUpdate rPropParameterUpdate = (RPropParameterUpdate) serializedLambda.getCapturedArg(1);
                    return (d3, d4, num) -> {
                        return vector.getX(num.intValue()) >= 0.0d ? Double.valueOf((-Math.signum(d3.doubleValue())) * d4.doubleValue()) : Double.valueOf(rPropParameterUpdate.prevIterationUpdates().getX(num.intValue()));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteTriFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/Vector;Ljava/lang/Double;Ljava/lang/Double;Ljava/lang/Integer;)Ljava/lang/Double;")) {
                    Vector vector2 = (Vector) serializedLambda.getCapturedArg(0);
                    return (d7, d8, num2) -> {
                        if (d7.doubleValue() < 0.0d) {
                            vector2.setX(num2.intValue(), 0.0d);
                        }
                        return d7.doubleValue() >= 0.0d ? Double.valueOf(1.0d) : Double.valueOf(-1.0d);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Double;Ljava/lang/Double;)Ljava/lang/Double;")) {
                    return (d, d2) -> {
                        return Double.valueOf(Math.signum(d.doubleValue() * d2.doubleValue()));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Double;Ljava/lang/Double;)Ljava/lang/Double;")) {
                    RPropUpdateCalculator rPropUpdateCalculator = (RPropUpdateCalculator) serializedLambda.getCapturedArg(0);
                    return (d5, d6) -> {
                        return d6.doubleValue() > 0.0d ? Double.valueOf(Math.min(d5.doubleValue() * this.accelerationRate, UPDATE_MAX)) : d6.doubleValue() < 0.0d ? Double.valueOf(Math.max(d5.doubleValue() * this.deaccelerationRate, UPDATE_MIN)) : d5;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
