package org.apache.ignite.ml.optimization.updatecalculators;

import java.io.Serializable;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;

/* loaded from: input_file:org/apache/ignite/ml/optimization/updatecalculators/RPropParameterUpdate.class */
public class RPropParameterUpdate implements Serializable {
    private static final long serialVersionUID = -165584242642323332L;
    protected Vector prevIterationUpdates;
    protected Vector prevIterationGradient;
    protected Vector deltas;
    protected Vector updatesMask;

    /* JADX INFO: Access modifiers changed from: package-private */
    public RPropParameterUpdate(int i, double d) {
        this.prevIterationUpdates = new DenseVector(i);
        this.prevIterationGradient = new DenseVector(i);
        this.deltas = new DenseVector(i).assign(d);
        this.updatesMask = new DenseVector(i);
    }

    public RPropParameterUpdate(Vector vector, Vector vector2, Vector vector3, Vector vector4) {
        this.prevIterationUpdates = vector;
        this.prevIterationGradient = vector2;
        this.deltas = vector3;
        this.updatesMask = vector4;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Vector deltas() {
        return this.deltas;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Vector prevIterationUpdates() {
        return this.prevIterationUpdates;
    }

    private RPropParameterUpdate setPrevIterationUpdates(Vector vector) {
        this.prevIterationUpdates = vector;
        return this;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Vector prevIterationGradient() {
        return this.prevIterationGradient;
    }

    private RPropParameterUpdate setPrevIterationGradient(Vector vector) {
        this.prevIterationGradient = vector;
        return this;
    }

    public Vector updatesMask() {
        return this.updatesMask;
    }

    public RPropParameterUpdate setUpdatesMask(Vector vector) {
        this.updatesMask = vector;
        return this;
    }

    public RPropParameterUpdate setDeltas(Vector vector) {
        this.deltas = vector;
        return this;
    }

    public static RPropParameterUpdate sumLocal(List<RPropParameterUpdate> list) {
        List list2 = (List) list.stream().filter((v0) -> {
            return Objects.nonNull(v0);
        }).collect(Collectors.toList());
        if (list2.isEmpty()) {
            return null;
        }
        Vector deltas = ((RPropParameterUpdate) list2.get(list2.size() - 1)).deltas();
        return new RPropParameterUpdate((Vector) list2.stream().map(rPropParameterUpdate -> {
            return VectorUtils.elementWiseTimes(rPropParameterUpdate.updatesMask().copy(), rPropParameterUpdate.prevIterationUpdates());
        }).reduce((v0, v1) -> {
            return v0.plus(v1);
        }).orElse(null), ((RPropParameterUpdate) list2.get(list2.size() - 1)).prevIterationGradient(), deltas, new DenseVector(deltas.size()).assign(1.0d));
    }

    public static RPropParameterUpdate sum(List<RPropParameterUpdate> list) {
        Vector vector = (Vector) list.stream().filter((v0) -> {
            return Objects.nonNull(v0);
        }).map(rPropParameterUpdate -> {
            return VectorUtils.elementWiseTimes(rPropParameterUpdate.updatesMask().copy(), rPropParameterUpdate.prevIterationUpdates());
        }).reduce((v0, v1) -> {
            return v0.plus(v1);
        }).orElse(null);
        Vector vector2 = (Vector) list.stream().filter((v0) -> {
            return Objects.nonNull(v0);
        }).map((v0) -> {
            return v0.deltas();
        }).reduce((v0, v1) -> {
            return v0.plus(v1);
        }).orElse(null);
        Vector vector3 = (Vector) list.stream().filter((v0) -> {
            return Objects.nonNull(v0);
        }).map((v0) -> {
            return v0.prevIterationGradient();
        }).reduce((v0, v1) -> {
            return v0.plus(v1);
        }).orElse(null);
        if (vector != null) {
            return new RPropParameterUpdate(vector, vector3, vector2, new DenseVector(((Vector) Objects.requireNonNull(vector2)).size()).assign(1.0d));
        }
        return null;
    }

    public static RPropParameterUpdate avg(List<RPropParameterUpdate> list) {
        int size = ((List) list.stream().filter((v0) -> {
            return Objects.nonNull(v0);
        }).collect(Collectors.toList())).size();
        RPropParameterUpdate sum = sum(list);
        if (sum != null) {
            return sum.setPrevIterationGradient(sum.prevIterationGradient().divide(size)).setPrevIterationUpdates(sum.prevIterationUpdates().divide(size)).setDeltas(sum.deltas().divide(size));
        }
        return null;
    }
}
