/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.xgboost.parser.visitor;

import java.util.HashMap;
import java.util.Map;
import org.antlr.v4.runtime.tree.TerminalNode;
import org.apache.ignite.ml.tree.DecisionTreeConditionalNode;
import org.apache.ignite.ml.tree.DecisionTreeLeafNode;
import org.apache.ignite.ml.tree.DecisionTreeNode;
import org.apache.ignite.ml.xgboost.parser.XGBoostModelBaseVisitor;
import org.apache.ignite.ml.xgboost.parser.XGBoostModelParser;

public class XGTreeVisitor
extends XGBoostModelBaseVisitor<DecisionTreeNode> {
    private static final int ROOT_NODE_IDX = 0;
    private final Map<String, Integer> dict;

    public XGTreeVisitor(Map<String, Integer> dict) {
        this.dict = dict;
    }

    @Override
    public DecisionTreeNode visitXgTree(XGBoostModelParser.XgTreeContext ctx) {
        int idx;
        HashMap<Integer, DecisionTreeConditionalNode> splitNodes = new HashMap<Integer, DecisionTreeConditionalNode>();
        HashMap<Integer, DecisionTreeLeafNode> leafNodes = new HashMap<Integer, DecisionTreeLeafNode>();
        for (XGBoostModelParser.XgNodeContext nodeCtx : ctx.xgNode()) {
            idx = Integer.valueOf(nodeCtx.INT(0).getText());
            String featureName = nodeCtx.STRING().getText();
            double threshold = this.parseXgValue(nodeCtx.xgValue());
            splitNodes.put(idx, new DecisionTreeConditionalNode(this.dict.get(featureName).intValue(), threshold, null, null, null));
        }
        for (XGBoostModelParser.XgLeafContext leafCtx : ctx.xgLeaf()) {
            idx = Integer.valueOf(leafCtx.INT().getText());
            double val = this.parseXgValue(leafCtx.xgValue());
            leafNodes.put(idx, new DecisionTreeLeafNode(val));
        }
        for (XGBoostModelParser.XgNodeContext nodeCtx : ctx.xgNode()) {
            idx = Integer.valueOf(nodeCtx.INT(0).getText());
            int yesIdx = Integer.valueOf(nodeCtx.INT(1).getText());
            int noIdx = Integer.valueOf(nodeCtx.INT(2).getText());
            int missIdx = Integer.valueOf(nodeCtx.INT(3).getText());
            DecisionTreeConditionalNode node = (DecisionTreeConditionalNode)splitNodes.get(idx);
            node.setElseNode(splitNodes.containsKey(yesIdx) ? (DecisionTreeNode)splitNodes.get(yesIdx) : (DecisionTreeNode)leafNodes.get(yesIdx));
            node.setThenNode(splitNodes.containsKey(noIdx) ? (DecisionTreeNode)splitNodes.get(noIdx) : (DecisionTreeNode)leafNodes.get(noIdx));
            node.setMissingNode(splitNodes.containsKey(missIdx) ? (DecisionTreeNode)splitNodes.get(missIdx) : (DecisionTreeNode)leafNodes.get(missIdx));
        }
        return splitNodes.containsKey(0) ? (DecisionTreeNode)splitNodes.get(0) : (DecisionTreeNode)leafNodes.get(0);
    }

    private double parseXgValue(XGBoostModelParser.XgValueContext valCtx) {
        TerminalNode terminalNode = valCtx.INT() != null ? valCtx.INT() : valCtx.DOUBLE();
        return Double.valueOf(terminalNode.getText());
    }
}

