package org.apache.ignite.ml.optimization.updatecalculators;

import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.optimization.SmoothParametrized;

/* loaded from: input_file:org/apache/ignite/ml/optimization/updatecalculators/NesterovUpdateCalculator.class */
public class NesterovUpdateCalculator<M extends SmoothParametrized<M>> implements ParameterUpdateCalculator<M, NesterovParameterUpdate> {
    private final double learningRate;
    private IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
    protected double momentum;

    public NesterovUpdateCalculator(double d, double d2) {
        this.learningRate = d;
        this.momentum = d2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v11, types: [org.apache.ignite.ml.optimization.SmoothParametrized] */
    @Override // org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator
    public NesterovParameterUpdate calculateNewUpdate(M m, NesterovParameterUpdate nesterovParameterUpdate, int i, Matrix matrix, Matrix matrix2) {
        Vector prevIterationUpdates = nesterovParameterUpdate.prevIterationUpdates();
        M m2 = m;
        if (i > 0) {
            m2 = (SmoothParametrized) m.withParameters(m.parameters().minus(prevIterationUpdates.times(this.momentum)));
        }
        return new NesterovParameterUpdate(prevIterationUpdates.plus(m2.differentiateByParameters(this.loss, matrix, matrix2).times(this.learningRate)));
    }

    public NesterovParameterUpdate init(M m, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> igniteFunction) {
        this.loss = igniteFunction;
        return new NesterovParameterUpdate(m.parametersCount());
    }

    /* JADX WARN: Incorrect return type in method signature: <M1:TM;>(TM1;Lorg/apache/ignite/ml/optimization/updatecalculators/NesterovParameterUpdate;)TM1; */
    @Override // org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator
    public SmoothParametrized update(SmoothParametrized smoothParametrized, NesterovParameterUpdate nesterovParameterUpdate) {
        return (SmoothParametrized) smoothParametrized.setParameters(smoothParametrized.parameters().minus(nesterovParameterUpdate.prevIterationUpdates()));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator
    public /* bridge */ /* synthetic */ NesterovParameterUpdate init(Object obj, IgniteFunction igniteFunction) {
        return init((NesterovUpdateCalculator<M>) obj, (IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction>) igniteFunction);
    }
}
