/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.selection.scoring.metric.regression;

import java.util.Iterator;
import org.apache.ignite.ml.selection.scoring.LabelPair;
import org.apache.ignite.ml.selection.scoring.metric.AbstractMetrics;
import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetricValues;

public class RegressionMetrics
extends AbstractMetrics<RegressionMetricValues> {
    private static final double EPS = 1.0E-5;

    public RegressionMetrics() {
        this.metric = RegressionMetricValues::rmse;
    }

    @Override
    public RegressionMetricValues scoreAll(Iterator<LabelPair<Double>> iter) {
        int totalAmount = 0;
        double rss = 0.0;
        double mae = 0.0;
        double sumOfLbls = 0.0;
        double sumOfSquaredLbls = 0.0;
        while (iter.hasNext()) {
            LabelPair<Double> e = iter.next();
            double prediction = e.getPrediction();
            double truth = e.getTruth();
            rss += Math.pow(prediction - truth, 2.0);
            mae += Math.abs(prediction - truth);
            ++totalAmount;
            sumOfLbls += truth;
            sumOfSquaredLbls += Math.pow(truth, 2.0);
        }
        double meanOfLbls = sumOfLbls / (double)totalAmount;
        double meanOfLblSquares = sumOfSquaredLbls / (double)totalAmount;
        double tss = (double)totalAmount * (meanOfLblSquares - Math.pow(meanOfLbls, 2.0));
        double r2 = 0.0;
        r2 = Math.abs(tss) < 1.0E-5 ? (Math.abs(rss) < 1.0E-5 ? 1.0 : 0.0) : 1.0 - rss / tss;
        return new RegressionMetricValues(totalAmount, rss, mae /= (double)totalAmount, r2);
    }

    @Override
    public String name() {
        return "Regression metrics";
    }
}

