/*
 * Decompiled with CFR 0.152.
 */
package org.gridgain.shaded.org.apache.ignite.internal.client.ml;

import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;
import org.gridgain.ml.IgniteMl;
import org.gridgain.ml.InferenceException;
import org.gridgain.ml.ModelInitializationException;
import org.gridgain.ml.ModelNotFoundException;
import org.gridgain.ml.model.MlBatchJobParameters;
import org.gridgain.ml.model.MlColocatedJobParameters;
import org.gridgain.ml.model.MlSimpleJobParameters;
import org.gridgain.ml.model.MlSqlJobParameters;
import org.gridgain.shaded.org.apache.ignite.compute.JobDescriptor;
import org.gridgain.shaded.org.apache.ignite.compute.JobTarget;
import org.gridgain.shaded.org.apache.ignite.internal.client.compute.ClientCompute;
import org.gridgain.shaded.org.apache.ignite.internal.ml.MlJobDescriptorUtils;
import org.gridgain.shaded.org.apache.ignite.internal.util.ExceptionUtils;
import org.gridgain.shaded.org.apache.ignite.internal.util.ViewUtils;
import org.gridgain.shaded.org.apache.ignite.network.ClusterNode;
import org.gridgain.shaded.org.apache.ignite.table.Tuple;

public class ClientMl
implements IgniteMl {
    private final ClientCompute compute;
    private final Supplier<CompletableFuture<Collection<ClusterNode>>> clusterNodesSupplier;

    public ClientMl(ClientCompute compute, Supplier<CompletableFuture<Collection<ClusterNode>>> clusterNodesSupplier) {
        this.compute = compute;
        this.clusterNodesSupplier = clusterNodesSupplier;
    }

    @Override
    public <I, O> O predict(MlSimpleJobParameters<I> jobParams) {
        return ViewUtils.sync(this.predictAsync(jobParams));
    }

    @Override
    public <I, O> List<O> batchPredict(MlBatchJobParameters<I> jobParams) {
        return ViewUtils.sync(this.batchPredictAsync(jobParams));
    }

    @Override
    public <O> List<O> predictFromSql(MlSqlJobParameters jobParams) {
        return ViewUtils.sync(this.predictFromSqlAsync(jobParams));
    }

    @Override
    public <I, O> CompletableFuture<O> predictAsync(MlSimpleJobParameters<I> jobParams) {
        JobDescriptor descriptor = MlJobDescriptorUtils.createSimplePredictionDescriptor(jobParams);
        return ((CompletableFuture)this.createJobTarget().thenCompose(jobTarget -> this.compute.executeAsync((JobTarget)jobTarget, descriptor, jobParams))).exceptionally(throwable -> {
            Throwable cause = ExceptionUtils.unwrapCause(throwable);
            throw ClientMl.mapToMlException(cause, "Simple ML prediction failed");
        });
    }

    @Override
    public <I, O> CompletableFuture<List<O>> batchPredictAsync(MlBatchJobParameters<I> jobParams) {
        JobDescriptor descriptor = MlJobDescriptorUtils.createBatchPredictionDescriptor(jobParams);
        return ((CompletableFuture)this.createJobTarget().thenCompose(jobTarget -> this.compute.executeAsync((JobTarget)jobTarget, descriptor, jobParams))).exceptionally(throwable -> {
            Throwable cause = ExceptionUtils.unwrapCause(throwable);
            throw ClientMl.mapToMlException(cause, "Simple ML prediction failed");
        });
    }

    @Override
    public <O> CompletableFuture<List<O>> predictFromSqlAsync(MlSqlJobParameters jobParams) {
        JobDescriptor descriptor = MlJobDescriptorUtils.createSqlPredictionDescriptor(jobParams);
        return ((CompletableFuture)this.createJobTarget().thenCompose(jobTarget -> this.compute.executeAsync((JobTarget)jobTarget, descriptor, jobParams))).exceptionally(throwable -> {
            Throwable cause = ExceptionUtils.unwrapCause(throwable);
            throw ClientMl.mapToMlException(cause, "Simple ML prediction failed");
        });
    }

    @Override
    public <I, O> O predictColocated(MlColocatedJobParameters<I> jobParams) {
        try {
            return ViewUtils.sync(this.predictColocatedAsync(jobParams));
        }
        catch (Throwable e) {
            Throwable cause;
            Throwable throwable = cause = e.getCause() != null ? e.getCause() : e;
            if (cause instanceof InferenceException) {
                throw (InferenceException)cause;
            }
            throw new InferenceException("Error during SQL prediction: " + cause.getMessage(), cause);
        }
    }

    @Override
    public <I, O> CompletableFuture<O> predictColocatedAsync(MlColocatedJobParameters<I> jobParams) {
        JobDescriptor descriptor = MlJobDescriptorUtils.createColocatedPredictionDescriptor(jobParams);
        return this.compute.executeAsync(ClientMl.createColocatedJobTarget(jobParams.tableName(), jobParams.key()), descriptor, jobParams).exceptionally(throwable -> {
            Throwable cause = ExceptionUtils.unwrapCause(throwable);
            throw ClientMl.mapToMlException(cause, "Colocated ML prediction failed");
        });
    }

    private CompletableFuture<JobTarget> createJobTarget() {
        return this.clusterNodesSupplier.get().thenApply(nodes -> {
            if (nodes == null || nodes.isEmpty()) {
                throw new IllegalStateException("No cluster nodes available");
            }
            return JobTarget.anyNode(Set.copyOf(nodes));
        });
    }

    private static JobTarget createColocatedJobTarget(String tableName, Tuple key) {
        return JobTarget.colocated(tableName, key);
    }

    private static RuntimeException mapToMlException(Throwable cause, String context) {
        if (cause instanceof ModelNotFoundException) {
            return (ModelNotFoundException)cause;
        }
        if (cause instanceof InferenceException) {
            return (InferenceException)cause;
        }
        if (cause instanceof ModelInitializationException) {
            return (ModelInitializationException)cause;
        }
        Throwable rootCause = ExceptionUtils.unwrapCause(cause);
        if (rootCause instanceof ModelNotFoundException) {
            return (ModelNotFoundException)rootCause;
        }
        if (rootCause instanceof InferenceException) {
            return (InferenceException)rootCause;
        }
        if (rootCause instanceof ModelInitializationException) {
            return (ModelInitializationException)rootCause;
        }
        return new InferenceException(context + ": " + cause.getMessage(), cause);
    }
}

