/*
 * Decompiled with CFR 0.152.
 */
package org.gridgain.internal.ml;

import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Supplier;
import org.apache.ignite3.compute.IgniteCompute;
import org.apache.ignite3.compute.JobDescriptor;
import org.apache.ignite3.compute.JobTarget;
import org.apache.ignite3.internal.ml.MlJobDescriptorUtils;
import org.apache.ignite3.network.ClusterNode;
import org.apache.ignite3.sql.IgniteSql;
import org.apache.ignite3.table.Tuple;
import org.gridgain.internal.ml.compute.PredictionUtils;
import org.gridgain.ml.IgniteMl;
import org.gridgain.ml.InferenceException;
import org.gridgain.ml.model.MlBatchJobParameters;
import org.gridgain.ml.model.MlColocatedJobParameters;
import org.gridgain.ml.model.MlSimpleJobParameters;
import org.gridgain.ml.model.MlSqlJobParameters;

public class IgniteMlImpl
implements IgniteMl {
    private final IgniteCompute compute;
    private final IgniteSql sql;
    private final Supplier<Collection<ClusterNode>> clusterNodesSupplier;
    private final Supplier<Executor> executorSupplier;

    public IgniteMlImpl(IgniteCompute compute, IgniteSql sql, Supplier<Collection<ClusterNode>> clusterNodesSupplier, Supplier<Executor> executorSupplier) {
        IgniteMlImpl.validateParameters(compute, clusterNodesSupplier, executorSupplier, sql);
        this.compute = compute;
        this.sql = sql;
        this.clusterNodesSupplier = clusterNodesSupplier;
        this.executorSupplier = executorSupplier;
    }

    private static void validateParameters(IgniteCompute computeApi, Supplier<Collection<ClusterNode>> clusterNodesSupplier, Supplier<Executor> executorSupplier, IgniteSql sql) {
        if (computeApi == null) {
            throw new IllegalArgumentException("computeApi cannot be null");
        }
        if (clusterNodesSupplier == null) {
            throw new IllegalArgumentException("clusterNodesSupplier cannot be null");
        }
        if (sql == null) {
            throw new IllegalArgumentException("sqlApi cannot be null");
        }
        if (executorSupplier == null) {
            throw new IllegalArgumentException("executorSupplier cannot be null");
        }
    }

    @Override
    public <I, O> O predict(MlSimpleJobParameters<I> jobParams) {
        try {
            return this.predictAsync(jobParams).get(jobParams.timeoutSeconds(), TimeUnit.SECONDS);
        }
        catch (Exception e) {
            throw IgniteMlImpl.handleSyncException(e, jobParams.timeoutSeconds(), "Error during prediction");
        }
    }

    @Override
    public <I, O> CompletableFuture<O> predictAsync(MlSimpleJobParameters<I> jobParams) {
        if (jobParams.hasUrl()) {
            return PredictionUtils.predictAsyncInternal(jobParams, this.executorSupplier.get());
        }
        JobDescriptor jobDescriptor = MlJobDescriptorUtils.createSimplePredictionDescriptor(jobParams);
        Collection<ClusterNode> nodes = this.clusterNodesSupplier.get();
        return this.compute.executeAsync(JobTarget.anyNode(Set.copyOf(nodes)), jobDescriptor, jobParams, null);
    }

    @Override
    public <I, O> List<O> batchPredict(MlBatchJobParameters<I> jobParams) {
        try {
            return this.batchPredictAsync(jobParams).get(jobParams.timeoutSeconds(), TimeUnit.SECONDS);
        }
        catch (Exception e) {
            throw IgniteMlImpl.handleSyncException(e, jobParams.timeoutSeconds(), "Error during batch prediction");
        }
    }

    @Override
    public <I, O> CompletableFuture<List<O>> batchPredictAsync(MlBatchJobParameters<I> jobParams) {
        if (jobParams.hasUrl()) {
            return PredictionUtils.batchPredictAsyncInternal(jobParams, this.executorSupplier.get());
        }
        JobDescriptor jobDescriptor = MlJobDescriptorUtils.createBatchPredictionDescriptor(jobParams);
        Collection<ClusterNode> nodes = this.clusterNodesSupplier.get();
        return this.compute.executeAsync(JobTarget.anyNode(Set.copyOf(nodes)), jobDescriptor, jobParams, null);
    }

    @Override
    public <O> List<O> predictFromSql(MlSqlJobParameters jobParams) {
        try {
            return this.predictFromSqlAsync(jobParams).get(jobParams.timeoutSeconds(), TimeUnit.SECONDS);
        }
        catch (Exception e) {
            throw IgniteMlImpl.handleSyncException(e, jobParams.timeoutSeconds(), "Error during SQL prediction");
        }
    }

    @Override
    public <O> CompletableFuture<List<O>> predictFromSqlAsync(MlSqlJobParameters jobParams) {
        if (jobParams.hasUrl()) {
            return PredictionUtils.predictFromSqlAsyncInternal(this.sql, jobParams, this.executorSupplier.get());
        }
        JobDescriptor jobDescriptor = MlJobDescriptorUtils.createSqlPredictionDescriptor(jobParams);
        return this.compute.executeAsync(this.createJobTarget(), jobDescriptor, jobParams, null);
    }

    @Override
    public <I, O> O predictColocated(MlColocatedJobParameters<I> jobParams) {
        try {
            return this.predictColocatedAsync(jobParams).get(jobParams.timeoutSeconds(), TimeUnit.SECONDS);
        }
        catch (Exception e) {
            throw IgniteMlImpl.handleSyncException(e, jobParams.timeoutSeconds(), "Error during SQL prediction");
        }
    }

    @Override
    public <I, O> CompletableFuture<O> predictColocatedAsync(MlColocatedJobParameters<I> jobParams) {
        if (jobParams.hasUrl()) {
            return PredictionUtils.predictAsyncInternal(jobParams, this.executorSupplier.get());
        }
        JobDescriptor jobDescriptor = MlJobDescriptorUtils.createColocatedPredictionDescriptor(jobParams);
        return this.compute.executeAsync(IgniteMlImpl.createColocatedJobTarget(jobParams.tableName(), jobParams.key()), jobDescriptor, jobParams, null);
    }

    private JobTarget createJobTarget() {
        Collection<ClusterNode> nodes = this.clusterNodesSupplier.get();
        if (nodes == null || nodes.isEmpty()) {
            throw new IllegalStateException("No cluster nodes available for ML execution");
        }
        return JobTarget.anyNode(Set.copyOf(nodes));
    }

    private static JobTarget createColocatedJobTarget(String tableName, Tuple key) {
        return JobTarget.colocated(tableName, key);
    }

    private static InferenceException handleSyncException(Exception e, long timeoutSeconds, String messagePrefix) {
        Throwable cause;
        if (e instanceof TimeoutException) {
            return new InferenceException(messagePrefix + " timed out after " + timeoutSeconds + " seconds", (Throwable)e);
        }
        Throwable throwable = cause = e.getCause() != null ? e.getCause() : e;
        if (cause instanceof InferenceException) {
            return (InferenceException)cause;
        }
        return new InferenceException(messagePrefix + ": " + cause.getMessage(), cause);
    }

    public Executor executor() {
        return this.executorSupplier.get();
    }
}

