package org.apache.ignite.ml.math.isolve.lsqr;

import com.github.fommil.netlib.BLAS;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;

/* loaded from: input_file:org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.class */
public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable {
    private final Dataset<LSQRPartitionContext, SimpleLabeledDatasetData> dataset;

    public LSQROnHeap(DatasetBuilder<K, V> datasetBuilder, PartitionDataBuilder<K, V, LSQRPartitionContext, SimpleLabeledDatasetData> partitionDataBuilder) {
        this.dataset = datasetBuilder.build((it, j) -> {
            return new LSQRPartitionContext();
        }, partitionDataBuilder);
    }

    @Override // org.apache.ignite.ml.math.isolve.lsqr.AbstractLSQR
    protected double bnorm() {
        return ((Double) this.dataset.computeWithCtx((lSQRPartitionContext, simpleLabeledDatasetData) -> {
            lSQRPartitionContext.setU(Arrays.copyOf(simpleLabeledDatasetData.getLabels(), simpleLabeledDatasetData.getLabels().length));
            return Double.valueOf(BLAS.getInstance().dnrm2(simpleLabeledDatasetData.getLabels().length, simpleLabeledDatasetData.getLabels(), 1));
        }, (d, d2) -> {
            return Double.valueOf(d == null ? d2.doubleValue() : d2 == null ? d.doubleValue() : Math.sqrt((d.doubleValue() * d.doubleValue()) + (d2.doubleValue() * d2.doubleValue())));
        })).doubleValue();
    }

    @Override // org.apache.ignite.ml.math.isolve.lsqr.AbstractLSQR
    protected double beta(double[] dArr, double d, double d2) {
        return ((Double) this.dataset.computeWithCtx((lSQRPartitionContext, simpleLabeledDatasetData) -> {
            if (simpleLabeledDatasetData.getFeatures() == null) {
                return null;
            }
            BLAS.getInstance().dgemv("N", simpleLabeledDatasetData.getRows(), simpleLabeledDatasetData.getFeatures().length / simpleLabeledDatasetData.getRows(), d, simpleLabeledDatasetData.getFeatures(), Math.max(1, simpleLabeledDatasetData.getRows()), dArr, 1, d2, lSQRPartitionContext.getU(), 1);
            return Double.valueOf(BLAS.getInstance().dnrm2(lSQRPartitionContext.getU().length, lSQRPartitionContext.getU(), 1));
        }, (d3, d4) -> {
            return Double.valueOf(d3 == null ? d4.doubleValue() : d4 == null ? d3.doubleValue() : Math.sqrt((d3.doubleValue() * d3.doubleValue()) + (d4.doubleValue() * d4.doubleValue())));
        })).doubleValue();
    }

    @Override // org.apache.ignite.ml.math.isolve.lsqr.AbstractLSQR
    protected double[] iter(double d, double[] dArr) {
        double[] dArr2 = (double[]) this.dataset.computeWithCtx((lSQRPartitionContext, simpleLabeledDatasetData) -> {
            if (simpleLabeledDatasetData.getFeatures() == null) {
                return null;
            }
            int length = simpleLabeledDatasetData.getFeatures().length / simpleLabeledDatasetData.getRows();
            BLAS.getInstance().dscal(lSQRPartitionContext.getU().length, 1.0d / d, lSQRPartitionContext.getU(), 1);
            double[] dArr3 = new double[length];
            BLAS.getInstance().dgemv("T", simpleLabeledDatasetData.getRows(), length, 1.0d, simpleLabeledDatasetData.getFeatures(), Math.max(1, simpleLabeledDatasetData.getRows()), lSQRPartitionContext.getU(), 1, 0.0d, dArr3, 1);
            return dArr3;
        }, (dArr3, dArr4) -> {
            if (dArr3 == null) {
                return dArr4;
            }
            if (dArr4 == null) {
                return dArr3;
            }
            BLAS.getInstance().daxpy(dArr3.length, 1.0d, dArr3, 1, dArr4, 1);
            return dArr4;
        });
        BLAS.getInstance().daxpy(dArr2.length, 1.0d, dArr2, 1, dArr, 1);
        return dArr;
    }

    @Override // org.apache.ignite.ml.math.isolve.lsqr.AbstractLSQR
    protected Integer getColumns() {
        return (Integer) this.dataset.compute(simpleLabeledDatasetData -> {
            if (simpleLabeledDatasetData.getFeatures() == null) {
                return null;
            }
            return Integer.valueOf(simpleLabeledDatasetData.getFeatures().length / simpleLabeledDatasetData.getRows());
        }, (num, num2) -> {
            if (num == null) {
                return Integer.valueOf(num2 == null ? 0 : num2.intValue());
            }
            return num2 == null ? num : num2;
        });
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        this.dataset.close();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1760727479:
                if (implMethodName.equals("lambda$getColumns$86030b09$1")) {
                    z = 3;
                    break;
                }
                break;
            case -955377113:
                if (implMethodName.equals("lambda$iter$c6c8c55f$1")) {
                    z = 5;
                    break;
                }
                break;
            case -197746086:
                if (implMethodName.equals("lambda$getColumns$e3c8990$1")) {
                    z = 7;
                    break;
                }
                break;
            case 734718285:
                if (implMethodName.equals("lambda$bnorm$7b55d622$1")) {
                    z = 4;
                    break;
                }
                break;
            case 734718286:
                if (implMethodName.equals("lambda$bnorm$7b55d622$2")) {
                    z = 2;
                    break;
                }
                break;
            case 857086043:
                if (implMethodName.equals("lambda$beta$4d5fdea6$1")) {
                    z = 8;
                    break;
                }
                break;
            case 875330544:
                if (implMethodName.equals("lambda$iter$a8eb4235$1")) {
                    z = 6;
                    break;
                }
                break;
            case 1413555906:
                if (implMethodName.equals("lambda$new$df3bffcf$1")) {
                    z = true;
                    break;
                }
                break;
            case 1859988991:
                if (implMethodName.equals("lambda$beta$8ebdd9b9$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Double;Ljava/lang/Double;)Ljava/lang/Double;")) {
                    return (d3, d4) -> {
                        return Double.valueOf(d3 == null ? d4.doubleValue() : d4 == null ? d3.doubleValue() : Math.sqrt((d3.doubleValue() * d3.doubleValue()) + (d4.doubleValue() * d4.doubleValue())));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/dataset/PartitionContextBuilder") && serializedLambda.getFunctionalInterfaceMethodName().equals("build") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/util/Iterator;J)Ljava/io/Serializable;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Iterator;J)Lorg/apache/ignite/ml/math/isolve/lsqr/LSQRPartitionContext;")) {
                    return (it, j) -> {
                        return new LSQRPartitionContext();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Double;Ljava/lang/Double;)Ljava/lang/Double;")) {
                    return (d, d2) -> {
                        return Double.valueOf(d == null ? d2.doubleValue() : d2 == null ? d.doubleValue() : Math.sqrt((d.doubleValue() * d.doubleValue()) + (d2.doubleValue() * d2.doubleValue())));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Integer;Ljava/lang/Integer;)Ljava/lang/Integer;")) {
                    return (num, num2) -> {
                        if (num == null) {
                            return Integer.valueOf(num2 == null ? 0 : num2.intValue());
                        }
                        return num2 == null ? num : num2;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/isolve/lsqr/LSQRPartitionContext;Lorg/apache/ignite/ml/dataset/primitive/data/SimpleLabeledDatasetData;)Ljava/lang/Double;")) {
                    return (lSQRPartitionContext, simpleLabeledDatasetData) -> {
                        lSQRPartitionContext.setU(Arrays.copyOf(simpleLabeledDatasetData.getLabels(), simpleLabeledDatasetData.getLabels().length));
                        return Double.valueOf(BLAS.getInstance().dnrm2(simpleLabeledDatasetData.getLabels().length, simpleLabeledDatasetData.getLabels(), 1));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap") && serializedLambda.getImplMethodSignature().equals("(DLorg/apache/ignite/ml/math/isolve/lsqr/LSQRPartitionContext;Lorg/apache/ignite/ml/dataset/primitive/data/SimpleLabeledDatasetData;)[D")) {
                    double doubleValue = ((Double) serializedLambda.getCapturedArg(0)).doubleValue();
                    return (lSQRPartitionContext2, simpleLabeledDatasetData2) -> {
                        if (simpleLabeledDatasetData2.getFeatures() == null) {
                            return null;
                        }
                        int length = simpleLabeledDatasetData2.getFeatures().length / simpleLabeledDatasetData2.getRows();
                        BLAS.getInstance().dscal(lSQRPartitionContext2.getU().length, 1.0d / doubleValue, lSQRPartitionContext2.getU(), 1);
                        double[] dArr3 = new double[length];
                        BLAS.getInstance().dgemv("T", simpleLabeledDatasetData2.getRows(), length, 1.0d, simpleLabeledDatasetData2.getFeatures(), Math.max(1, simpleLabeledDatasetData2.getRows()), lSQRPartitionContext2.getU(), 1, 0.0d, dArr3, 1);
                        return dArr3;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap") && serializedLambda.getImplMethodSignature().equals("([D[D)[D")) {
                    return (dArr3, dArr4) -> {
                        if (dArr3 == null) {
                            return dArr4;
                        }
                        if (dArr4 == null) {
                            return dArr3;
                        }
                        BLAS.getInstance().daxpy(dArr3.length, 1.0d, dArr3, 1, dArr4, 1);
                        return dArr4;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/dataset/primitive/data/SimpleLabeledDatasetData;)Ljava/lang/Integer;")) {
                    return simpleLabeledDatasetData3 -> {
                        if (simpleLabeledDatasetData3.getFeatures() == null) {
                            return null;
                        }
                        return Integer.valueOf(simpleLabeledDatasetData3.getFeatures().length / simpleLabeledDatasetData3.getRows());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap") && serializedLambda.getImplMethodSignature().equals("(D[DDLorg/apache/ignite/ml/math/isolve/lsqr/LSQRPartitionContext;Lorg/apache/ignite/ml/dataset/primitive/data/SimpleLabeledDatasetData;)Ljava/lang/Double;")) {
                    double doubleValue2 = ((Double) serializedLambda.getCapturedArg(0)).doubleValue();
                    double[] dArr = (double[]) serializedLambda.getCapturedArg(1);
                    double doubleValue3 = ((Double) serializedLambda.getCapturedArg(2)).doubleValue();
                    return (lSQRPartitionContext3, simpleLabeledDatasetData4) -> {
                        if (simpleLabeledDatasetData4.getFeatures() == null) {
                            return null;
                        }
                        BLAS.getInstance().dgemv("N", simpleLabeledDatasetData4.getRows(), simpleLabeledDatasetData4.getFeatures().length / simpleLabeledDatasetData4.getRows(), doubleValue2, simpleLabeledDatasetData4.getFeatures(), Math.max(1, simpleLabeledDatasetData4.getRows()), dArr, 1, doubleValue3, lSQRPartitionContext3.getU(), 1);
                        return Double.valueOf(BLAS.getInstance().dnrm2(lSQRPartitionContext3.getU().length, lSQRPartitionContext3.getU(), 1));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
