/*
 * Decompiled with CFR 0.152.
 */
package org.gridgain.internal.ml.handler;

import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import org.apache.ignite3.internal.logger.IgniteLogger;
import org.apache.ignite3.internal.logger.Loggers;
import org.gridgain.ml.ModelInitializationException;
import org.gridgain.ml.TranslatorNotFoundException;
import org.gridgain.ml.model.MlJobParameters;
import org.gridgain.ml.model.ModelInfo;
import org.gridgain.ml.model.ModelType;

class DjlAdapter {
    private static final IgniteLogger LOG = Loggers.forClass(DjlAdapter.class);

    DjlAdapter() {
    }

    private static Criteria<Object, Object> createCriteria(ModelInfo modelInfo, Path modelPath, ClassLoader classLoader) {
        String groupId;
        String artifactId;
        String application;
        String outputClassName;
        String inputClassName = modelInfo.inputClass();
        if (inputClassName == null || inputClassName.isEmpty()) {
            inputClassName = modelInfo.property("input_class", "java.lang.Object");
        }
        if ((outputClassName = modelInfo.outputClass()) == null || outputClassName.isEmpty()) {
            outputClassName = modelInfo.property("output_class", "java.lang.Object");
        }
        Class<?> inputClass = DjlAdapter.getClassForName(inputClassName, classLoader);
        Class<?> outputClass = DjlAdapter.getClassForName(outputClassName, classLoader);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Creating criteria for model: " + modelInfo.id() + "\n Input class: " + inputClass.getName() + "\n Output class: " + outputClass.getName(), new Object[0]);
        }
        Criteria.Builder builder = Criteria.builder().setTypes(inputClass, outputClass);
        if (modelPath != null) {
            builder.optModelPath(modelPath);
            if (modelInfo.name() != null && !modelInfo.name().isEmpty()) {
                String modelName = modelInfo.name();
                builder.optModelName(modelName);
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Using model with a specific name: " + modelName, new Object[0]);
                }
            }
        } else {
            throw new ModelInitializationException("Model path is null - ensure model is properly deployed as a deployment unit", null);
        }
        String engine = DjlAdapter.getEngineFromModelType(modelInfo.type());
        if (engine != null) {
            builder.optEngine(engine);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Using engine: " + engine, new Object[0]);
            }
        }
        if ((application = modelInfo.property("application")) != null) {
            try {
                if (application.contains(".")) {
                    application = application.replace('.', '/').toLowerCase();
                }
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Using application: " + application, new Object[0]);
                }
                builder.optApplication(Application.of((String)application));
            }
            catch (Exception e) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Warning: Invalid application format: " + e.getMessage(), new Object[0]);
                }
                builder.optOption("application", application);
            }
        }
        if ((artifactId = modelInfo.property("artifactId")) != null) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Using artifactId: " + artifactId, new Object[0]);
            }
            builder.optArtifactId(artifactId);
        }
        if ((groupId = modelInfo.property("groupId")) != null) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Using groupId: " + groupId, new Object[0]);
            }
            builder.optGroupId(groupId);
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("DjlAdapter classLoader is " + classLoader.hashCode(), new Object[0]);
        }
        DjlAdapter.configureTranslator((Criteria.Builder<Object, Object>)builder, modelInfo, classLoader);
        HashMap<String, String> arguments = new HashMap<String, String>();
        HashMap<String, String> options = new HashMap<String, String>();
        for (Map.Entry<String, String> entry : modelInfo.properties().entrySet()) {
            String key = entry.getKey();
            String value = entry.getValue();
            if (DjlAdapter.isSpecialProperty(key)) continue;
            if (key.startsWith("option.")) {
                String optionKey = key.substring(7);
                options.put(optionKey, value);
                if (!LOG.isDebugEnabled()) continue;
                LOG.debug("Adding to options: " + optionKey + " = " + value, new Object[0]);
                continue;
            }
            arguments.put(key, value);
            if (!LOG.isDebugEnabled()) continue;
            LOG.debug("Adding to arguments: " + key + " = " + value, new Object[0]);
        }
        Map<String, Object> engineOptions = modelInfo.config().engineOptions();
        for (Map.Entry<String, Object> entry : engineOptions.entrySet()) {
            String key = entry.getKey();
            Object value = entry.getValue();
            if (!(value instanceof String)) continue;
            options.put(key, (String)value);
            if (!LOG.isDebugEnabled()) continue;
            LOG.debug("Adding engine option: " + key + " = " + value, new Object[0]);
        }
        if (!arguments.isEmpty()) {
            builder.optArguments(arguments);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Set " + arguments.size() + " arguments for TranslatorFactory", new Object[0]);
            }
        }
        if (!options.isEmpty()) {
            builder.optOptions(options);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Set " + options.size() + " options for Model.load()", new Object[0]);
            }
        }
        return builder.build();
    }

    private static boolean isSpecialProperty(String key) {
        return "input_class".equals(key) || "output_class".equals(key) || "engine".equals(key) || "translator".equals(key) || "translatorFactory".equals(key) || "application".equals(key) || "groupId".equals(key) || "artifactId".equals(key);
    }

    static ZooModel<Object, Object> loadModel(ModelInfo modelInfo, Path modelPath) {
        try {
            System.setProperty("ai.djl.offline", "true");
            System.setProperty("ai.djl.onnxruntime.disable_alternative", "true");
            ClassLoader classLoader = DjlAdapter.getClassLoader(modelInfo);
            Criteria<Object, Object> criteria = DjlAdapter.createCriteria(modelInfo, modelPath, classLoader);
            return criteria.loadModel();
        }
        catch (MalformedModelException | ModelNotFoundException | IOException e) {
            throw new ModelInitializationException("Failed to load DJL model: " + modelInfo.id(), e);
        }
    }

    private static ClassLoader getClassLoader(ModelInfo modelInfo) {
        MlJobParameters params;
        if (modelInfo instanceof MlJobParameters && (params = (MlJobParameters)modelInfo).hasCustomComponents() && params.customJobClass() != null) {
            ClassLoader customClassLoader = params.customJobClass().getClassLoader();
            if (LOG.isDebugEnabled()) {
                LOG.debug("Using ClassLoader from custom job class: " + customClassLoader.hashCode(), new Object[0]);
            }
            return customClassLoader;
        }
        return modelInfo.getClass().getClassLoader();
    }

    private static void configureTranslator(Criteria.Builder<Object, Object> builder, ModelInfo modelInfo, ClassLoader classLoader) {
        String translatorClass;
        if (LOG.isDebugEnabled()) {
            LOG.debug("Configuring translator from properties: " + modelInfo.properties(), new Object[0]);
        }
        if ((translatorClass = modelInfo.translator()) == null || translatorClass.isEmpty()) {
            translatorClass = modelInfo.property("translator");
        }
        if (translatorClass != null && !translatorClass.isEmpty()) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Found translator class: " + translatorClass, new Object[0]);
            }
            Translator<?, ?> translator = DjlAdapter.createTranslator(translatorClass, classLoader);
            builder.optTranslator(translator);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Successfully configured translator: " + translator.getClass().getName(), new Object[0]);
            }
            return;
        }
        String translatorFactoryClassName = modelInfo.translatorFactory();
        if (translatorFactoryClassName == null || translatorFactoryClassName.isEmpty()) {
            translatorFactoryClassName = modelInfo.property("translatorFactory");
        }
        if (translatorFactoryClassName != null && !translatorFactoryClassName.isEmpty()) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Found translator factory class: " + translatorFactoryClassName, new Object[0]);
            }
            try {
                TranslatorFactory factory = DjlAdapter.createTranslatorFactory(translatorFactoryClassName, classLoader);
                builder.optTranslatorFactory(factory);
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Successfully configured translator factory: " + factory.getClass().getName(), new Object[0]);
                }
            }
            catch (Exception e) {
                throw new ModelInitializationException("Failed to initialize translator factory: " + translatorFactoryClassName, (Throwable)e);
            }
        }
    }

    private static Translator<?, ?> createTranslator(String className, ClassLoader classLoader) throws TranslatorNotFoundException {
        Class<?> clazz;
        try {
            clazz = classLoader.loadClass(className);
        }
        catch (ClassNotFoundException e) {
            throw new TranslatorNotFoundException(e);
        }
        if (!Translator.class.isAssignableFrom(clazz)) {
            throw new TranslatorNotFoundException("Class is not a Translator: " + className);
        }
        try {
            Constructor<?> constructor = clazz.getDeclaredConstructor(new Class[0]);
            constructor.setAccessible(true);
            return (Translator)constructor.newInstance(new Object[0]);
        }
        catch (Throwable e) {
            throw new TranslatorNotFoundException(e);
        }
    }

    private static TranslatorFactory createTranslatorFactory(String className, ClassLoader classLoader) {
        Class<?> clazz;
        try {
            clazz = classLoader.loadClass(className);
        }
        catch (ClassNotFoundException e) {
            throw new TranslatorNotFoundException(e);
        }
        if (!TranslatorFactory.class.isAssignableFrom(clazz)) {
            throw new TranslatorNotFoundException("Class is not a TranslatorFactory: " + className);
        }
        try {
            return (TranslatorFactory)clazz.getMethod("newInstance", new Class[0]).invoke(null, new Object[0]);
        }
        catch (Throwable e) {
            try {
                Constructor<?> constructor = clazz.getDeclaredConstructor(new Class[0]);
                constructor.setAccessible(true);
                return (TranslatorFactory)constructor.newInstance(new Object[0]);
            }
            catch (Throwable ex) {
                throw new TranslatorNotFoundException(ex);
            }
        }
    }

    private static String getEngineFromModelType(ModelType modelType) {
        switch (modelType) {
            case PYTORCH: {
                return "PyTorch";
            }
            case TENSORFLOW: {
                return "TensorFlow";
            }
            case ONNX: {
                return "OnnxRuntime";
            }
        }
        return null;
    }

    private static Class<?> getClassForName(String className, ClassLoader classLoader) {
        try {
            if ("[F".equals(className)) {
                return float[].class;
            }
            if ("[I".equals(className)) {
                return int[].class;
            }
            if ("[D".equals(className)) {
                return double[].class;
            }
            if ("[B".equals(className)) {
                return byte[].class;
            }
            if ("[J".equals(className)) {
                return long[].class;
            }
            if ("[Z".equals(className)) {
                return boolean[].class;
            }
            if ("[C".equals(className)) {
                return char[].class;
            }
            if ("[S".equals(className)) {
                return short[].class;
            }
            return classLoader.loadClass(className);
        }
        catch (ClassNotFoundException e) {
            LOG.error("Warning: Could not find class: " + className + " - " + e.getMessage(), new Object[0]);
            return Object.class;
        }
    }
}

