/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.example.ml;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import org.apache.ignite.example.ml.MlBaseExample;
import org.apache.ignite.example.ml.ModelUtils;
import org.apache.ignite.example.ml.custom.CustomBatchComputeJob;
import org.apache.ignite.example.ml.custom.CustomComputeJob;
import org.apache.ignite.example.ml.custom.CustomInput;
import org.apache.ignite.example.ml.custom.CustomInputMarshaller;
import org.apache.ignite.example.ml.custom.CustomOutput;
import org.apache.ignite.example.ml.custom.CustomOutputListMarshaller;
import org.apache.ignite.example.ml.custom.CustomOutputMarshaller;
import org.apache.ignite.example.ml.custom.CustomTranslatorFactory;
import org.gridgain.ml.model.MlBatchJobParameters;
import org.gridgain.ml.model.MlSimpleJobParameters;
import org.gridgain.ml.model.ModelConfig;

public class MlCustomTranslatorExample
extends MlBaseExample {
    private static final String MODEL_ID = "zeroshot";
    private static final String MODEL_VERSION = "1.0.0";
    private static final String CUSTOM_CLASSES_PACKAGE = "org/apache/ignite/example/ml/custom";
    private static Path tempModelDir;
    private static final List<String> urls;

    public static void main(String[] args) throws IOException {
        try {
            System.out.println("MlCustomTranslatorExample started!");
            MlCustomTranslatorExample.start(MlBaseExample.MODE.EMBEDDED);
            Path workingDir = Paths.get(System.getProperty("user.dir"), new String[0]);
            Path sourceDir = workingDir.resolve("java/src/main/java/org/apache/ignite/example/ml/custom");
            tempModelDir = ModelUtils.downloadAndDeployModelAndJar(urls, MODEL_ID, MODEL_VERSION, sourceDir);
            MlCustomTranslatorExample.simpleApiPrediction();
            MlCustomTranslatorExample.batchInputPrediction();
            System.out.println("MlCustomTranslatorExample completed successfully!");
        }
        catch (Throwable e) {
            System.err.println("MlCustomTranslatorExample failed: " + e.getMessage());
            throw new RuntimeException(e);
        }
        finally {
            MlCustomTranslatorExample.cleanup();
            MlCustomTranslatorExample.stop();
        }
    }

    private static void simpleApiPrediction() {
        System.out.println("Simple Prediction");
        String prompt = "one day I will see the world";
        String[] candidates = new String[]{"travel", "cooking", "dancing", "exploration"};
        boolean multiLabels = true;
        CustomInput input = new CustomInput(prompt, candidates, multiLabels);
        try {
            MlSimpleJobParameters jobParams = MlSimpleJobParameters.builder().id(MODEL_ID).version(MODEL_VERSION).name("model_quantized").config(ModelConfig.builder().build()).inputClass(CustomInput.class.getName()).outputClass(CustomOutput.class.getName()).translatorFactory(CustomTranslatorFactory.class.getName()).input((Object)input).customJobClass(CustomComputeJob.class).customInputMarshaller(new CustomInputMarshaller()).customOutputMarshaller(new CustomOutputMarshaller()).build();
            CustomOutput result = (CustomOutput)mlApi.predict(jobParams);
            System.out.println("Classification Results:");
            System.out.println("Text: " + prompt);
            System.out.println("Predictions: " + result.getLabels()[0]);
            System.out.println("Scores: " + result.getScores()[0]);
        }
        catch (Throwable e) {
            System.err.println("Error in simple API prediction: ");
            throw e;
        }
    }

    private static void batchInputPrediction() {
        System.out.println("Batch Prediction");
        List<CustomInput> batchInputs = Arrays.asList(new CustomInput("one day I will see the world", new String[]{"travel", "cooking", "dancing", "exploration"}, true), new CustomInput("I love preparing delicious meals", new String[]{"travel", "cooking", "dancing", "exploration"}, true), new CustomInput("moving to the rhythm makes me happy", new String[]{"travel", "cooking", "dancing", "exploration"}, true), new CustomInput("discovering new places is my passion", new String[]{"travel", "cooking", "dancing", "exploration"}, true));
        try {
            MlBatchJobParameters jobParameters = MlBatchJobParameters.builder().id(MODEL_ID).version(MODEL_VERSION).name("model_quantized").config(ModelConfig.builder().build()).inputClass(CustomInput.class.getName()).outputClass(CustomOutput.class.getName()).translatorFactory(CustomTranslatorFactory.class.getName()).batchInput(batchInputs).customJobClass(CustomBatchComputeJob.class).customInputMarshaller(new CustomInputMarshaller()).customOutputMarshaller(new CustomOutputListMarshaller()).build();
            List result = mlApi.batchPredict(jobParameters);
            System.out.println("Classifications:");
            System.out.println("Candidates : " + Arrays.toString(batchInputs.get(0).getCandidates()));
            for (int i = 0; i < result.size(); ++i) {
                CustomInput input = batchInputs.get(i);
                CustomOutput output = (CustomOutput)result.get(i);
                System.out.println("Input : " + input.getText());
                System.out.println("Output : " + output.getLabels()[0]);
                System.out.println("Score : " + output.getScores()[0] + "\n");
            }
        }
        catch (Throwable e) {
            System.err.println("Error in batch API prediction: ");
            throw e;
        }
    }

    private static void cleanup() {
        if (tempModelDir != null && Files.exists(tempModelDir, new LinkOption[0])) {
            try {
                ModelUtils.deleteDirectory(tempModelDir);
                System.out.println("Temporary model directory deleted: " + tempModelDir);
            }
            catch (IOException e) {
                System.err.println("Temporary model directory could NOT be deleted: " + tempModelDir);
                throw new RuntimeException(e);
            }
        }
    }

    static {
        urls = Arrays.asList("https://huggingface.co/MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33/resolve/main/onnx/model_quantized.onnx", "https://huggingface.co/MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33/resolve/main/tokenizer.json?download=true");
    }
}

