/*
 * Decompiled with CFR 0.152.
 */
package org.gridgain.internal.ml.compute;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.ignite3.sql.IgniteSql;
import org.apache.ignite3.sql.SqlRow;
import org.apache.ignite3.sql.async.AsyncResultSet;
import org.gridgain.internal.ml.handler.DjlModelHandler;
import org.gridgain.internal.ml.handler.ModelHandler;
import org.gridgain.internal.ml.storage.ModelSource;
import org.gridgain.internal.ml.storage.ModelStorageProviderImpl;
import org.gridgain.ml.InferenceException;
import org.gridgain.ml.ModelInitializationException;
import org.gridgain.ml.model.MlBatchJobParameters;
import org.gridgain.ml.model.MlJobParameters;
import org.gridgain.ml.model.MlSimpleJobParameters;
import org.gridgain.ml.model.MlSqlJobParameters;
import org.jetbrains.annotations.Nullable;

public class PredictionUtils {
    public static <I, O> CompletableFuture<O> predictAsyncInternal(MlSimpleJobParameters<I> jobParams, Executor executor) {
        try {
            return ((CompletableFuture)PredictionUtils.getOrCreateHandler(jobParams, executor).thenCompose(handler -> handler.predict(jobParams.input()))).exceptionally(throwable -> {
                throw new CompletionException(new InferenceException("Error during async prediction: " + throwable.getMessage(), (Throwable)throwable));
            });
        }
        catch (Exception e) {
            return CompletableFuture.failedFuture(new InferenceException("Error during async prediction: " + e.getMessage(), (Throwable)e));
        }
    }

    public static <I, O> CompletableFuture<List<O>> batchPredictAsyncInternal(MlBatchJobParameters<I> jobParams, Executor executor) {
        try {
            if (jobParams.batchInput() == null || jobParams.batchInput().isEmpty()) {
                return CompletableFuture.completedFuture(new ArrayList());
            }
            return ((CompletableFuture)PredictionUtils.getOrCreateHandler(jobParams, executor).thenCompose(handler -> handler.batchPredict(jobParams.batchInput()))).exceptionally(throwable -> {
                throw new CompletionException(new InferenceException("Error during async batch prediction: " + throwable.getMessage(), (Throwable)throwable));
            });
        }
        catch (Exception e) {
            return CompletableFuture.failedFuture(new InferenceException("Error during async batch prediction: " + e.getMessage(), (Throwable)e));
        }
    }

    public static <O> CompletableFuture<List<O>> predictFromSqlAsyncInternal(IgniteSql igniteSql, MlSqlJobParameters jobParams, Executor executor) {
        try {
            return PredictionUtils.executeSqlQueryAsync(igniteSql, jobParams.sqlQuery(), jobParams.sqlParams()).thenCompose(asyncResultSet -> {
                if (asyncResultSet.currentPageSize() == 0 && !asyncResultSet.hasMorePages()) {
                    return CompletableFuture.completedFuture(new ArrayList());
                }
                return PredictionUtils.getOrCreateHandler(jobParams, executor).thenCompose(handler -> PredictionUtils.processResultSetWithLimit(asyncResultSet, handler, jobParams.limit(), jobParams));
            });
        }
        catch (Exception e) {
            return CompletableFuture.failedFuture(new InferenceException("Error during async SQL prediction: " + e.getMessage(), (Throwable)e));
        }
    }

    public static CompletableFuture<ModelHandler> getOrCreateHandler(MlJobParameters jobParams, Executor executor) {
        try {
            DjlModelHandler handler = new DjlModelHandler(executor);
            ModelStorageProviderImpl modelStorageProvider = new ModelStorageProviderImpl();
            return ((CompletableFuture)modelStorageProvider.createModelSource(jobParams.url(), jobParams.id(), jobParams.version(), jobParams.properties()).thenCompose(modelSource -> handler.initialize(jobParams, (ModelSource)modelSource).thenApply(v -> handler))).exceptionally(throwable -> {
                handler.close();
                throw new CompletionException(new ModelInitializationException("Failed to initialize async model handler: " + throwable.getMessage(), (Throwable)throwable));
            });
        }
        catch (Exception e) {
            return CompletableFuture.failedFuture(new ModelInitializationException("Failed to initialize async model handler: " + e.getMessage(), (Throwable)e));
        }
    }

    private static CompletableFuture<AsyncResultSet<SqlRow>> executeSqlQueryAsync(IgniteSql igniteSql, String sqlQuery, Object[] params) {
        if (sqlQuery == null || sqlQuery.isEmpty()) {
            return CompletableFuture.failedFuture(new IllegalArgumentException("SQL query cannot be null or empty"));
        }
        try {
            CompletableFuture<AsyncResultSet<SqlRow>> asyncResultSetFuture = params != null ? igniteSql.executeAsync(null, sqlQuery, params) : igniteSql.executeAsync(null, sqlQuery, new Object[0]);
            return asyncResultSetFuture.exceptionally(throwable -> {
                throw new CompletionException(new InferenceException("Error executing async SQL query: " + throwable.getMessage(), (Throwable)throwable));
            });
        }
        catch (Exception e) {
            return CompletableFuture.failedFuture(new InferenceException("Error executing async SQL query: " + e.getMessage(), (Throwable)e));
        }
    }

    private static <O> CompletableFuture<List<O>> processResultSetWithLimit(AsyncResultSet<SqlRow> resultSet, ModelHandler handler, int limit, MlSqlJobParameters jobParams) {
        ArrayList results = new ArrayList();
        AtomicInteger processedRows = new AtomicInteger(0);
        String inputColumn = jobParams.inputColumn();
        return PredictionUtils.processPageRecursivelyWithLimit(resultSet, handler, results, limit, processedRows, inputColumn);
    }

    private static <O> CompletableFuture<List<O>> processPageRecursivelyWithLimit(AsyncResultSet<SqlRow> resultSet, ModelHandler handler, List<O> results, int limit, AtomicInteger processedRows, @Nullable String inputColumn) {
        if (processedRows.get() >= limit) {
            return CompletableFuture.completedFuture(results);
        }
        int remainingNeeded = limit - processedRows.get();
        int rowsToProcess = Math.min(resultSet.currentPageSize(), remainingNeeded);
        List<CompletableFuture> pagePredictions = StreamSupport.stream(resultSet.currentPage().spliterator(), false).limit(rowsToProcess).map(row -> {
            Object inputValue = inputColumn != null && !inputColumn.isEmpty() ? row.value(inputColumn) : row.value(0);
            return handler.predict(inputValue).thenApply(result -> {
                try {
                    return result;
                }
                catch (ClassCastException e) {
                    throw new InferenceException("Prediction result type mismatch: " + e.getMessage(), (Throwable)e);
                }
            });
        }).collect(Collectors.toList());
        return ((CompletableFuture)CompletableFuture.allOf(pagePredictions.toArray(new CompletableFuture[0])).thenCompose(v -> {
            List pageResults = pagePredictions.stream().map(CompletableFuture::join).collect(Collectors.toList());
            results.addAll(pageResults);
            processedRows.addAndGet(rowsToProcess);
            if (processedRows.get() >= limit || !resultSet.hasMorePages()) {
                return CompletableFuture.completedFuture(results);
            }
            return resultSet.fetchNextPage().thenCompose(nextPage -> PredictionUtils.processPageRecursivelyWithLimit(nextPage, handler, results, limit, processedRows, inputColumn));
        })).exceptionally(throwable -> {
            if (throwable instanceof RuntimeException) {
                throw (RuntimeException)throwable;
            }
            throw new InferenceException("Error processing predictions: " + throwable.getMessage(), (Throwable)throwable);
        });
    }
}

