/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.example.ml.custom;

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.util.Map;
import java.util.UUID;
import org.apache.ignite.example.ml.custom.CustomInput;
import org.apache.ignite.example.ml.custom.CustomOutput;

public class CustomTranslator
implements NoBatchifyTranslator<CustomInput, CustomOutput> {
    private final HuggingFaceTokenizer tokenizer;
    private final boolean int32;
    private Predictor<NDList, NDList> predictor;

    private CustomTranslator(HuggingFaceTokenizer tokenizer, boolean int32) {
        this.tokenizer = tokenizer;
        this.int32 = int32;
    }

    public void prepare(TranslatorContext ctx) throws IOException, ModelException {
        Model model = ctx.getModel();
        this.predictor = model.newPredictor((Translator)new NoopTranslator((Batchifier)null));
        ctx.getPredictorManager().attachInternal(UUID.randomUUID().toString(), new AutoCloseable[]{this.predictor});
    }

    public NDList processInput(TranslatorContext ctx, CustomInput input) {
        ctx.setAttachment("input", (Object)input);
        return new NDList();
    }

    public CustomOutput processOutput(TranslatorContext ctx, NDList list) throws TranslateException {
        CustomInput input = (CustomInput)ctx.getAttachment("input");
        String template = input.getHypothesisTemplate();
        String[] candidates = input.getCandidates();
        if (candidates != null && candidates.length != 0) {
            long lastColIndex;
            Shape shape;
            NDManager manager = ctx.getNDManager();
            NDList output = new NDList(candidates.length);
            for (String candidate : candidates) {
                String hypothesis = CustomTranslator.applyTemplate(template, candidate);
                Encoding encoding = this.tokenizer.encode(input.getText(), hypothesis);
                NDList in = encoding.toNDList(manager, false, this.int32);
                NDList batch = Batchifier.STACK.batchify(new NDList[]{in});
                output.add((Object)((NDArray)((NDList)this.predictor.predict((Object)batch)).get(0)));
            }
            NDArray logits = NDArrays.concat((NDList)output);
            if (input.isMultiLabel()) {
                shape = logits.getShape();
                lastColIndex = shape.get(shape.dimension() - 1) - 1L;
                logits = logits.get(":, {}", new Object[]{lastColIndex});
                int lastDim = logits.getShape().dimension() - 1;
                logits = logits.softmax(lastDim);
            } else {
                logits = logits.get(new NDIndex(":, {}", new Object[]{manager.create(new int[]{0, 2})}));
                logits = logits.softmax(1);
                shape = logits.getShape();
                lastColIndex = shape.get(shape.dimension() - 1) - 1L;
                logits = logits.get(":, {}", new Object[]{lastColIndex});
            }
            float[] probabilities = logits.toFloatArray();
            long[] indices = CustomTranslator.argSortDescending(probabilities);
            String[] labels = new String[candidates.length];
            double[] scores = new double[candidates.length];
            for (int i = 0; i < labels.length; ++i) {
                int index = (int)indices[i];
                labels[i] = candidates[index];
                scores[i] = probabilities[index];
            }
            return new CustomOutput(input.getText(), labels, scores);
        }
        throw new TranslateException("Missing candidates in input");
    }

    private static long[] argSortDescending(float[] values) {
        int i;
        int n = values.length;
        long[] indices = new long[n];
        for (i = 0; i < n; ++i) {
            indices[i] = i;
        }
        for (i = 0; i < n - 1; ++i) {
            int maxIdx = i;
            for (int j = i + 1; j < n; ++j) {
                if (!(values[(int)indices[j]] > values[(int)indices[maxIdx]])) continue;
                maxIdx = j;
            }
            if (maxIdx == i) continue;
            long temp = indices[i];
            indices[i] = indices[maxIdx];
            indices[maxIdx] = temp;
        }
        return indices;
    }

    private static String applyTemplate(String template, String arg) {
        int pos = template.indexOf("{}");
        if (pos == -1) {
            return template + arg;
        }
        int len = template.length();
        return template.substring(0, pos) + arg + template.substring(pos + 2, len);
    }

    public static Builder builder(HuggingFaceTokenizer tokenizer) {
        return new Builder(tokenizer);
    }

    public static Builder builder(HuggingFaceTokenizer tokenizer, Map<String, ?> arguments) {
        Builder builder = CustomTranslator.builder(tokenizer);
        builder.configure(arguments);
        return builder;
    }

    public static final class Builder {
        private final HuggingFaceTokenizer tokenizer;
        private boolean int32;

        Builder(HuggingFaceTokenizer tokenizer) {
            this.tokenizer = tokenizer;
        }

        void optInt32(boolean int32) {
            this.int32 = int32;
        }

        void configure(Map<String, ?> arguments) {
            this.optInt32(ArgumentsUtil.booleanValue(arguments, (String)"int32"));
        }

        public CustomTranslator build() throws IOException {
            return new CustomTranslator(this.tokenizer, this.int32);
        }
    }
}

