/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.clustering.kmeans;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.clustering.kmeans.ClusterizationModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
import org.apache.ignite.ml.math.Tracer;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.util.ModelTrace;

public final class KMeansModel
implements ClusterizationModel<Vector, Integer>,
Exportable<KMeansModelFormat>,
DeployableObject {
    private final Vector[] centers;
    private final DistanceMeasure distanceMeasure;

    public KMeansModel(Vector[] centers, DistanceMeasure distanceMeasure) {
        this.centers = centers;
        this.distanceMeasure = distanceMeasure;
    }

    public DistanceMeasure distanceMeasure() {
        return this.distanceMeasure;
    }

    @Override
    public int getAmountOfClusters() {
        return this.centers.length;
    }

    public Vector[] getCenters() {
        return Arrays.copyOf(this.centers, this.centers.length);
    }

    @Override
    public Integer predict(Vector vec) {
        int res = -1;
        double minDist = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.centers.length; ++i) {
            double curDist = this.distanceMeasure.compute(this.centers[i], vec);
            if (!(curDist < minDist)) continue;
            minDist = curDist;
            res = i;
        }
        return res;
    }

    @Override
    public <P> void saveModel(Exporter<KMeansModelFormat, P> exporter, P path) {
        KMeansModelFormat mdlData = new KMeansModelFormat(this.centers, this.distanceMeasure);
        exporter.save(mdlData, path);
    }

    public int hashCode() {
        int res = 1;
        res = res * 37 + this.distanceMeasure.hashCode();
        res = res * 37 + Arrays.hashCode(this.centers);
        return res;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || this.getClass() != obj.getClass()) {
            return false;
        }
        KMeansModel that = (KMeansModel)obj;
        return this.distanceMeasure.equals(that.distanceMeasure) && Arrays.deepEquals(this.centers, that.centers);
    }

    public String toString() {
        return this.toString(false);
    }

    @Override
    public String toString(boolean pretty) {
        String measureName = this.distanceMeasure.getClass().getSimpleName();
        List centersList = Arrays.stream(this.centers).map(x -> Tracer.asAscii(x, "%.4f", false)).collect(Collectors.toList());
        return ModelTrace.builder("KMeansModel", pretty).addField("distance measure", measureName).addField("centroids", centersList).toString();
    }

    @Override
    public List<Object> getDependencies() {
        return Collections.singletonList(this.distanceMeasure);
    }
}

