package org.apache.ignite.ml.optimization.updatecalculators;

import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.optimization.SmoothParametrized;

/* loaded from: input_file:org/apache/ignite/ml/optimization/updatecalculators/SimpleGDUpdateCalculator.class */
public class SimpleGDUpdateCalculator implements ParameterUpdateCalculator<SmoothParametrized, SimpleGDParameterUpdate> {
    private double learningRate;
    protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
    private static final double DEFAULT_LEARNING_RATE = 0.1d;

    public SimpleGDUpdateCalculator() {
        this(DEFAULT_LEARNING_RATE);
    }

    public SimpleGDUpdateCalculator(double d) {
        this.learningRate = d;
    }

    /* renamed from: init, reason: avoid collision after fix types in other method */
    public SimpleGDParameterUpdate init2(SmoothParametrized smoothParametrized, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> igniteFunction) {
        this.loss = igniteFunction;
        return new SimpleGDParameterUpdate(smoothParametrized.parametersCount());
    }

    @Override // org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator
    public SimpleGDParameterUpdate calculateNewUpdate(SmoothParametrized smoothParametrized, SimpleGDParameterUpdate simpleGDParameterUpdate, int i, Matrix matrix, Matrix matrix2) {
        return new SimpleGDParameterUpdate(smoothParametrized.differentiateByParameters(this.loss, matrix, matrix2));
    }

    @Override // org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator
    public <M1 extends SmoothParametrized> M1 update(M1 m1, SimpleGDParameterUpdate simpleGDParameterUpdate) {
        return (M1) m1.setParameters(m1.parameters().minus(simpleGDParameterUpdate.gradient().times(this.learningRate)));
    }

    public SimpleGDUpdateCalculator withLearningRate(double d) {
        return new SimpleGDUpdateCalculator(d);
    }

    @Override // org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator
    public /* bridge */ /* synthetic */ SimpleGDParameterUpdate init(SmoothParametrized smoothParametrized, IgniteFunction igniteFunction) {
        return init2(smoothParametrized, (IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction>) igniteFunction);
    }
}
