/*
 * Decompiled with CFR 0.152.
 */
package org.gridgain.internal.ml.handler;

import ai.djl.inference.Predictor;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.ignite3.internal.logger.IgniteLogger;
import org.apache.ignite3.internal.logger.Loggers;
import org.gridgain.internal.ml.handler.DjlAdapter;
import org.gridgain.internal.ml.handler.ModelHandler;
import org.gridgain.internal.ml.storage.ModelSource;
import org.gridgain.ml.InferenceException;
import org.gridgain.ml.ModelInitializationException;
import org.gridgain.ml.model.ModelInfo;
import org.jetbrains.annotations.Nullable;

public class DjlModelHandler
implements ModelHandler {
    private static final IgniteLogger LOG = Loggers.forClass(DjlModelHandler.class);
    private final Executor executor;
    private ModelInfo modelInfo;
    @Nullable
    private ZooModel<Object, Object> model;
    private final AtomicBoolean initialized = new AtomicBoolean(false);
    private String engineName;
    private final Map<String, Object> modelMetadata = new HashMap<String, Object>();
    @Nullable
    private Path modelPath;

    public DjlModelHandler(Executor executor) {
        this.executor = executor;
    }

    @Override
    public CompletableFuture<Void> initialize(ModelInfo modelInfo, ModelSource modelSource) {
        if (this.checkAndMarkInitialized()) {
            return CompletableFuture.failedFuture(new IllegalStateException("Handler already initialized or initialization in progress"));
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("Starting async initialization for model: " + modelInfo.id(), new Object[0]);
        }
        this.modelInfo = modelInfo;
        return modelSource.resolveToLocalPath().thenCompose(resolvedPath -> {
            this.modelPath = resolvedPath;
            if (LOG.isDebugEnabled()) {
                LOG.debug("Model path resolved to: " + this.modelPath, new Object[0]);
            }
            return CompletableFuture.runAsync(() -> {
                try {
                    this.model = DjlAdapter.loadModel(modelInfo, this.modelPath);
                    if (this.model == null) {
                        throw new ModelInitializationException("Model name is null for model: " + modelInfo.id(), null);
                    }
                    this.engineName = this.model.getName().split(":")[0];
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("Model loaded successfully with engine: " + this.engineName, new Object[0]);
                    }
                    this.collectModelMetadata();
                    this.initialized.set(true);
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("Model initialized successfully: " + modelInfo.id(), new Object[0]);
                    }
                }
                catch (Exception e) {
                    throw new ModelInitializationException("Failed to initialize DJL model: " + e.getMessage(), (Throwable)e);
                }
            }, this.executor);
        });
    }

    @Override
    public <I, O> CompletableFuture<O> predict(I input) {
        if (!this.initialized.get()) {
            return CompletableFuture.failedFuture(new IllegalStateException("Model handler not initialized"));
        }
        return CompletableFuture.supplyAsync(() -> {
            Object object;
            block9: {
                Thread.currentThread().setContextClassLoader(Thread.currentThread().getContextClassLoader());
                assert (this.model != null);
                Predictor predictor = this.model.newPredictor();
                try {
                    object = predictor.predict(input);
                    if (predictor == null) break block9;
                }
                catch (Throwable throwable) {
                    try {
                        if (predictor != null) {
                            try {
                                predictor.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (TranslateException e) {
                        String message = String.format("Error during model inference for %s (type %s): %s", new Object[]{this.modelInfo.id(), this.modelInfo.type(), e.getMessage()});
                        throw new InferenceException(message, (Throwable)e);
                    }
                }
                predictor.close();
            }
            return object;
        }, this.executor);
    }

    @Override
    public <I, O> CompletableFuture<List<O>> batchPredict(List<I> inputs) {
        if (!this.initialized.get()) {
            return CompletableFuture.failedFuture(new IllegalStateException("Model handler not initialized"));
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("Processing batch prediction with " + inputs.size() + " inputs", new Object[0]);
        }
        return CompletableFuture.supplyAsync(() -> {
            List list;
            block9: {
                assert (this.model != null);
                Thread.currentThread().setContextClassLoader(Thread.currentThread().getContextClassLoader());
                Predictor predictor = this.model.newPredictor();
                try {
                    List results;
                    List objectInputs = inputs;
                    list = results = predictor.batchPredict(objectInputs);
                    if (predictor == null) break block9;
                }
                catch (Throwable throwable) {
                    try {
                        if (predictor != null) {
                            try {
                                predictor.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (TranslateException e) {
                        throw new InferenceException("Error during batch inference: " + e.getMessage(), (Throwable)e);
                    }
                }
                predictor.close();
            }
            return list;
        }, this.executor);
    }

    @Override
    public void close() {
        if (this.model != null) {
            this.model.close();
            this.model = null;
        }
        this.initialized.set(false);
    }

    @Override
    public boolean checkAndMarkInitialized() {
        return !this.initialized.compareAndSet(false, true);
    }

    @Override
    public ModelInfo getModelInfo() {
        return this.modelInfo;
    }

    @Override
    public Map<String, String> getEngineProperties() {
        HashMap<String, String> properties = new HashMap<String, String>();
        for (Map.Entry<String, String> entry : this.modelInfo.properties().entrySet()) {
            if (!entry.getKey().startsWith("engine.")) continue;
            properties.put(entry.getKey().substring(7), entry.getValue());
        }
        return properties;
    }

    @Override
    public Map<String, Object> getModelMetadata() {
        return new HashMap<String, Object>(this.modelMetadata);
    }

    private void collectModelMetadata() {
        this.modelMetadata.clear();
        if (this.model != null) {
            this.modelMetadata.put("name", this.model.getName());
            this.modelMetadata.put("engineName", this.engineName);
            if (this.modelPath != null) {
                this.modelMetadata.put("modelPath", this.modelPath.toString());
            }
            try {
                HashMap<String, String> modelProps = new HashMap<String, String>();
                for (String propertyName : this.model.getProperties().keySet()) {
                    try {
                        modelProps.put(propertyName, this.model.getProperty(propertyName));
                    }
                    catch (Exception exception) {}
                }
                this.modelMetadata.put("properties", modelProps);
            }
            catch (Exception exception) {
                // empty catch block
            }
            try {
                this.modelMetadata.put("modelClass", this.model.getClass().getName());
                this.modelMetadata.put("blockType", this.model.getBlock().getClass().getName());
            }
            catch (Exception exception) {
                // empty catch block
            }
        }
    }
}

