/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.selection.scoring.metric.classification;

import java.util.Comparator;
import java.util.Iterator;
import java.util.PriorityQueue;
import org.apache.commons.math3.util.Pair;
import org.apache.ignite.ml.selection.scoring.LabelPair;
import org.apache.ignite.ml.selection.scoring.metric.AbstractMetrics;
import org.apache.ignite.ml.selection.scoring.metric.classification.BinaryClassificationMetricValues;
import org.apache.ignite.ml.selection.scoring.metric.classification.ROCAUC;
import org.apache.ignite.ml.selection.scoring.metric.exceptions.UnknownClassLabelException;

public class BinaryClassificationMetrics
extends AbstractMetrics<BinaryClassificationMetricValues> {
    private double positiveClsLb = 1.0;
    private double negativeClsLb;
    private boolean enableROCAUC;

    public BinaryClassificationMetrics() {
        this.metric = BinaryClassificationMetricValues::accuracy;
    }

    @Override
    public BinaryClassificationMetricValues scoreAll(Iterator<LabelPair<Double>> iter) {
        long tp = 0L;
        long tn = 0L;
        long fp = 0L;
        long fn = 0L;
        double rocauc = Double.NaN;
        long pos = 0L;
        long neg = 0L;
        PriorityQueue<Pair<Double, Double>> queue = new PriorityQueue<Pair<Double, Double>>(Comparator.comparingDouble(Pair::getKey));
        while (iter.hasNext()) {
            LabelPair<Double> e = iter.next();
            double prediction = e.getPrediction();
            double truth = e.getTruth();
            if (prediction != this.negativeClsLb && prediction != this.positiveClsLb) {
                throw new UnknownClassLabelException(prediction, this.positiveClsLb, this.negativeClsLb);
            }
            if (truth != this.negativeClsLb && truth != this.positiveClsLb) {
                throw new UnknownClassLabelException(truth, this.positiveClsLb, this.negativeClsLb);
            }
            if (truth == this.positiveClsLb && prediction == this.positiveClsLb) {
                ++tp;
            } else if (truth == this.positiveClsLb && prediction == this.negativeClsLb) {
                ++fn;
            } else if (truth == this.negativeClsLb && prediction == this.negativeClsLb) {
                ++tn;
            } else if (truth == this.negativeClsLb && prediction == this.positiveClsLb) {
                ++fp;
            }
            if (!this.enableROCAUC) continue;
            queue.add((Pair<Double, Double>)new Pair((Object)prediction, (Object)truth));
            if (truth == this.positiveClsLb) {
                ++pos;
                continue;
            }
            if (truth == this.negativeClsLb) {
                ++neg;
                continue;
            }
            throw new UnknownClassLabelException(truth, this.positiveClsLb, this.negativeClsLb);
        }
        if (this.enableROCAUC) {
            rocauc = ROCAUC.calculateROCAUC(queue, pos, neg, this.positiveClsLb);
        }
        return new BinaryClassificationMetricValues(tp, tn, fp, fn, rocauc);
    }

    public double positiveClsLb() {
        return this.positiveClsLb;
    }

    public BinaryClassificationMetrics withPositiveClsLb(double positiveClsLb) {
        if (Double.isFinite(positiveClsLb)) {
            this.positiveClsLb = positiveClsLb;
        }
        return this;
    }

    public double negativeClsLb() {
        return this.negativeClsLb;
    }

    public BinaryClassificationMetrics withNegativeClsLb(double negativeClsLb) {
        if (Double.isFinite(negativeClsLb)) {
            this.negativeClsLb = negativeClsLb;
        }
        return this;
    }

    public BinaryClassificationMetrics withEnablingROCAUC(boolean enableROCAUC) {
        this.enableROCAUC = enableROCAUC;
        return this;
    }

    public boolean isROCAUCenabled() {
        return this.enableROCAUC;
    }

    @Override
    public String name() {
        return "Binary classification metrics";
    }
}

