/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.naivebayes.discrete;

import java.io.Serializable;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesSumsHolder;

public final class DiscreteNaiveBayesModel
implements IgniteModel<Vector, Double>,
Exportable<DiscreteNaiveBayesModel>,
Serializable {
    private static final long serialVersionUID = -127386523291350345L;
    private final double[][][] probabilities;
    private final double[] clsProbabilities;
    private final double[] labels;
    private final double[][] bucketThresholds;
    private final DiscreteNaiveBayesSumsHolder sumsHolder;

    public DiscreteNaiveBayesModel(double[][][] probabilities, double[] clsProbabilities, double[] labels, double[][] bucketThresholds, DiscreteNaiveBayesSumsHolder sumsHolder) {
        this.probabilities = probabilities;
        this.clsProbabilities = clsProbabilities;
        this.labels = labels;
        this.bucketThresholds = bucketThresholds;
        this.sumsHolder = sumsHolder;
    }

    @Override
    public <P> void saveModel(Exporter<DiscreteNaiveBayesModel, P> exporter, P path) {
        exporter.save(this, path);
    }

    @Override
    public Double predict(Vector vector) {
        double maxProbapilityPower = -1.7976931348623157E308;
        int maxLabelIndex = -1;
        for (int i = 0; i < this.clsProbabilities.length; ++i) {
            double probabilityPower = Math.log(this.clsProbabilities[i]);
            for (int j = 0; j < this.probabilities[0].length; ++j) {
                int x = this.toBucketNumber(vector.get(j), this.bucketThresholds[j]);
                double p = this.probabilities[i][j][x];
                probabilityPower += p > 0.0 ? Math.log(p) : 0.0;
            }
            if (!(probabilityPower > maxProbapilityPower)) continue;
            maxLabelIndex = i;
            maxProbapilityPower = probabilityPower;
        }
        return this.labels[maxLabelIndex];
    }

    public double[][][] getProbabilities() {
        return this.probabilities;
    }

    public double[] getClsProbabilities() {
        return this.clsProbabilities;
    }

    public double[][] getBucketThresholds() {
        return this.bucketThresholds;
    }

    public DiscreteNaiveBayesSumsHolder getSumsHolder() {
        return this.sumsHolder;
    }

    private int toBucketNumber(double val, double[] thresholds) {
        for (int i = 0; i < thresholds.length; ++i) {
            if (!(val < thresholds[i])) continue;
            return i;
        }
        return thresholds.length;
    }
}

