/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.sparkmodelparser;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Scanner;
import java.util.TreeMap;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.GDBTrainer;
import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
import org.apache.ignite.ml.inference.Model;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
import org.apache.ignite.ml.sparkmodelparser.SupportedSparkModels;
import org.apache.ignite.ml.sparkmodelparser.UnsupportedSparkModelException;
import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
import org.apache.ignite.ml.tree.DecisionTreeConditionalNode;
import org.apache.ignite.ml.tree.DecisionTreeLeafNode;
import org.apache.ignite.ml.tree.DecisionTreeNode;
import org.apache.parquet.column.page.PageReadStore;
import org.apache.parquet.example.data.Group;
import org.apache.parquet.example.data.simple.SimpleGroup;
import org.apache.parquet.example.data.simple.convert.GroupRecordConverter;
import org.apache.parquet.hadoop.ParquetFileReader;
import org.apache.parquet.hadoop.util.HadoopInputFile;
import org.apache.parquet.io.ColumnIOFactory;
import org.apache.parquet.io.InputFile;
import org.apache.parquet.io.MessageColumnIO;
import org.apache.parquet.io.RecordReader;
import org.apache.parquet.io.api.RecordMaterializer;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.Type;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public class SparkModelParser {
    public static Model parse(String pathToMdl, SupportedSparkModels parsedSparkMdl) throws IllegalArgumentException {
        File mdlDir = IgniteUtils.resolveIgnitePath((String)pathToMdl);
        if (mdlDir == null) {
            throw new IllegalArgumentException("Directory not found or empty [directory_path=" + pathToMdl + "]");
        }
        if (!mdlDir.isDirectory()) {
            throw new IllegalArgumentException("Spark Model Parser supports loading from directory only. The specified path " + pathToMdl + " is not the path to directory.");
        }
        String[] files = mdlDir.list();
        if (files.length == 0) {
            throw new IllegalArgumentException("Directory contain 0 files and sub-directories [directory_path=" + pathToMdl + "]");
        }
        if (Arrays.stream(files).noneMatch("data"::equals)) {
            throw new IllegalArgumentException("Directory should contain data sub-directory [directory_path=" + pathToMdl + "]");
        }
        if (Arrays.stream(files).noneMatch("metadata"::equals)) {
            throw new IllegalArgumentException("Directory should contain metadata sub-directory [directory_path=" + pathToMdl + "]");
        }
        String pathToData = pathToMdl + File.separator + "data";
        File dataDir = IgniteUtils.resolveIgnitePath((String)pathToData);
        File[] dataParquetFiles = dataDir.listFiles((dir, name) -> name.matches("^part-.*\\.snappy\\.parquet$"));
        if (dataParquetFiles.length == 0) {
            throw new IllegalArgumentException("Directory should contain parquet file with model [directory_path=" + pathToData + "]");
        }
        if (dataParquetFiles.length > 1) {
            throw new IllegalArgumentException("Directory should contain only one parquet file with model [directory_path=" + pathToData + "]");
        }
        String pathToMdlFile = dataParquetFiles[0].getPath();
        String pathToMetadata = pathToMdl + File.separator + "metadata";
        File metadataDir = IgniteUtils.resolveIgnitePath((String)pathToMetadata);
        String[] metadataFiles = metadataDir.list();
        if (Arrays.stream(metadataFiles).noneMatch("part-00000"::equals)) {
            throw new IllegalArgumentException("Directory should contain json file with model metadata with name part-00000 [directory_path=" + pathToMetadata + "]");
        }
        try {
            SparkModelParser.validateMetadata(pathToMetadata, parsedSparkMdl);
        }
        catch (FileNotFoundException e) {
            throw new IllegalArgumentException("Directory should contain json file with model metadata with name part-00000 [directory_path=" + pathToMetadata + "]");
        }
        if (SparkModelParser.shouldContainTreeMetadataSubDirectory(parsedSparkMdl)) {
            if (Arrays.stream(files).noneMatch("treesMetadata"::equals)) {
                throw new IllegalArgumentException("Directory should contain treeMetadata sub-directory [directory_path=" + pathToMdl + "]");
            }
            String pathToTreesMetadata = pathToMdl + File.separator + "treesMetadata";
            File treesMetadataDir = IgniteUtils.resolveIgnitePath((String)pathToTreesMetadata);
            File[] treesMetadataParquetFiles = treesMetadataDir.listFiles((dir, name) -> name.matches("^part-.*\\.snappy\\.parquet$"));
            if (treesMetadataParquetFiles.length == 0) {
                throw new IllegalArgumentException("Directory should contain parquet file with model treesMetadata [directory_path=" + pathToTreesMetadata + "]");
            }
            if (treesMetadataParquetFiles.length > 1) {
                throw new IllegalArgumentException("Directory should contain only one parquet file with model [directory_path=" + pathToTreesMetadata + "]");
            }
            String pathToTreesMetadataFile = treesMetadataParquetFiles[0].getPath();
            return SparkModelParser.parseDataWithMetadata(pathToMdlFile, pathToTreesMetadataFile, parsedSparkMdl);
        }
        return SparkModelParser.parseData(pathToMdlFile, parsedSparkMdl);
    }

    private static void validateMetadata(String pathToMetadata, SupportedSparkModels parsedSparkMdl) throws FileNotFoundException {
        File metadataFile = IgniteUtils.resolveIgnitePath((String)(pathToMetadata + File.separator + "part-00000"));
        if (metadataFile != null) {
            Scanner sc = new Scanner(metadataFile);
            boolean isInvalid = true;
            while (sc.hasNextLine()) {
                String line = sc.nextLine();
                if (!line.contains(parsedSparkMdl.getMdlClsNameInSpark())) continue;
                isInvalid = false;
            }
            if (isInvalid) {
                throw new IllegalArgumentException("The metadata file contains incorrect model metadata. It should contain " + parsedSparkMdl.getMdlClsNameInSpark() + " model metadata.");
            }
        }
    }

    private static boolean shouldContainTreeMetadataSubDirectory(SupportedSparkModels parsedSparkMdl) {
        return parsedSparkMdl == SupportedSparkModels.GRADIENT_BOOSTED_TREES || parsedSparkMdl == SupportedSparkModels.GRADIENT_BOOSTED_TREES_REGRESSION;
    }

    private static Model parseData(String pathToMdl, SupportedSparkModels parsedSparkMdl) {
        File mdlRsrc = IgniteUtils.resolveIgnitePath((String)pathToMdl);
        if (mdlRsrc == null) {
            throw new IllegalArgumentException("Resource not found [resource_path=" + pathToMdl + "]");
        }
        String ignitePathToMdl = mdlRsrc.getPath();
        switch (parsedSparkMdl) {
            case LOG_REGRESSION: {
                return SparkModelParser.loadLogRegModel(ignitePathToMdl);
            }
            case LINEAR_REGRESSION: {
                return SparkModelParser.loadLinRegModel(ignitePathToMdl);
            }
            case LINEAR_SVM: {
                return SparkModelParser.loadLinearSVMModel(ignitePathToMdl);
            }
            case DECISION_TREE: {
                return SparkModelParser.loadDecisionTreeModel(ignitePathToMdl);
            }
            case RANDOM_FOREST: {
                return SparkModelParser.loadRandomForestModel(ignitePathToMdl);
            }
            case KMEANS: {
                return SparkModelParser.loadKMeansModel(ignitePathToMdl);
            }
            case DECISION_TREE_REGRESSION: {
                return SparkModelParser.loadDecisionTreeRegressionModel(ignitePathToMdl);
            }
            case RANDOM_FOREST_REGRESSION: {
                return SparkModelParser.loadRandomForestRegressionModel(ignitePathToMdl);
            }
        }
        throw new UnsupportedSparkModelException(ignitePathToMdl);
    }

    private static Model parseDataWithMetadata(String pathToMdl, String pathToMetaData, SupportedSparkModels parsedSparkMdl) {
        File mdlRsrc1 = IgniteUtils.resolveIgnitePath((String)pathToMdl);
        if (mdlRsrc1 == null) {
            throw new IllegalArgumentException("Resource not found [resource_path=" + pathToMdl + "]");
        }
        String ignitePathToMdl = mdlRsrc1.getPath();
        File mdlRsrc2 = IgniteUtils.resolveIgnitePath((String)pathToMetaData);
        if (mdlRsrc2 == null) {
            throw new IllegalArgumentException("Resource not found [resource_path=" + pathToMetaData + "]");
        }
        String ignitePathToMdlMetaData = mdlRsrc2.getPath();
        switch (parsedSparkMdl) {
            case GRADIENT_BOOSTED_TREES: {
                return SparkModelParser.loadGBTClassifierModel(ignitePathToMdl, ignitePathToMdlMetaData);
            }
            case GRADIENT_BOOSTED_TREES_REGRESSION: {
                return SparkModelParser.loadGBTRegressionModel(ignitePathToMdl, ignitePathToMdlMetaData);
            }
        }
        throw new UnsupportedSparkModelException(ignitePathToMdl);
    }

    private static Model loadRandomForestRegressionModel(String pathToMdl) {
        List<IgniteModel<Vector, Double>> models = SparkModelParser.parseTreesForRandomForestAlgorithm(pathToMdl);
        if (models == null) {
            return null;
        }
        return new ModelsComposition(models, (PredictionsAggregator)new MeanValuePredictionsAggregator());
    }

    private static Model loadDecisionTreeRegressionModel(String pathToMdl) {
        return SparkModelParser.loadDecisionTreeModel(pathToMdl);
    }

    private static Model loadKMeansModel(String pathToMdl) {
        DenseVector[] centers = null;
        try (ParquetFileReader r = ParquetFileReader.open((InputFile)HadoopInputFile.fromPath((Path)new Path(pathToMdl), (Configuration)new Configuration()));){
            PageReadStore pages;
            MessageType schema = r.getFooter().getFileMetaData().getSchema();
            MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
            while (null != (pages = r.readNextRowGroup())) {
                int rows = (int)pages.getRowCount();
                RecordReader recordReader = colIO.getRecordReader(pages, (RecordMaterializer)new GroupRecordConverter(schema));
                centers = new DenseVector[rows];
                for (int i = 0; i < rows; ++i) {
                    SimpleGroup g = (SimpleGroup)recordReader.read();
                    Group clusterCenterCoeff = g.getGroup(1, 0).getGroup(3, 0);
                    int amountOfCoefficients = clusterCenterCoeff.getFieldRepetitionCount(0);
                    centers[i] = new DenseVector(amountOfCoefficients);
                    for (int j = 0; j < amountOfCoefficients; ++j) {
                        double coefficient = clusterCenterCoeff.getGroup(0, j).getDouble(0, 0);
                        centers[i].set(j, coefficient);
                    }
                }
            }
        }
        catch (IOException e) {
            System.out.println("Error reading parquet file.");
            e.printStackTrace();
        }
        return new KMeansModel(centers, (DistanceMeasure)new EuclideanDistance());
    }

    private static Model loadGBTRegressionModel(String pathToMdl, String pathToMdlMetaData) {
        IgniteFunction & Serializable lbMapper = (IgniteFunction & Serializable)lb -> lb;
        return SparkModelParser.parseAndBuildGDBModel(pathToMdl, pathToMdlMetaData, (IgniteFunction<Double, Double>)lbMapper);
    }

    private static Model loadGBTClassifierModel(String pathToMdl, String pathToMdlMetaData) {
        IgniteFunction & Serializable lbMapper = (IgniteFunction & Serializable)lb -> lb > 0.5 ? 1.0 : 0.0;
        return SparkModelParser.parseAndBuildGDBModel(pathToMdl, pathToMdlMetaData, (IgniteFunction<Double, Double>)lbMapper);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Nullable
    private static Model parseAndBuildGDBModel(String pathToMdl, String pathToMdlMetaData, IgniteFunction<Double, Double> lbMapper) {
        MessageColumnIO colIO;
        MessageType schema;
        Throwable throwable;
        double[] treeWeights = null;
        HashMap<Integer, Double> treeWeightsByTreeID = new HashMap<Integer, Double>();
        try {
            throwable = null;
            try (ParquetFileReader r2 = ParquetFileReader.open((InputFile)HadoopInputFile.fromPath((Path)new Path(pathToMdlMetaData), (Configuration)new Configuration()));){
                PageReadStore pagesMetaData;
                schema = r2.getFooter().getFileMetaData().getSchema();
                colIO = new ColumnIOFactory().getColumnIO(schema);
                while (null != (pagesMetaData = r2.readNextRowGroup())) {
                    long rows = pagesMetaData.getRowCount();
                    RecordReader recordReader = colIO.getRecordReader(pagesMetaData, (RecordMaterializer)new GroupRecordConverter(schema));
                    int i = 0;
                    while ((long)i < rows) {
                        SimpleGroup g = (SimpleGroup)recordReader.read();
                        int treeId = g.getInteger(0, 0);
                        double treeWeight = g.getDouble(2, 0);
                        treeWeightsByTreeID.put(treeId, treeWeight);
                        ++i;
                    }
                }
            }
            catch (Throwable pagesMetaData) {
                throwable = pagesMetaData;
                throw pagesMetaData;
            }
        }
        catch (IOException e) {
            System.out.println("Error reading parquet file with MetaData by the path: " + pathToMdlMetaData);
            e.printStackTrace();
        }
        treeWeights = new double[treeWeightsByTreeID.size()];
        for (int i = 0; i < treeWeights.length; ++i) {
            treeWeights[i] = (Double)treeWeightsByTreeID.get(i);
        }
        try {
            throwable = null;
            try (ParquetFileReader r = ParquetFileReader.open((InputFile)HadoopInputFile.fromPath((Path)new Path(pathToMdl), (Configuration)new Configuration()));){
                schema = r.getFooter().getFileMetaData().getSchema();
                colIO = new ColumnIOFactory().getColumnIO(schema);
                TreeMap nodesByTreeId = new TreeMap();
                block25: while (true) {
                    PageReadStore pages;
                    if (null == (pages = r.readNextRowGroup())) {
                        ArrayList models = new ArrayList();
                        nodesByTreeId.forEach((key, nodes) -> models.add(SparkModelParser.buildDecisionTreeModel(nodes)));
                        GDBTrainer.GDBModel gDBModel = new GDBTrainer.GDBModel(models, new WeightedPredictionsAggregator(treeWeights), lbMapper);
                        return gDBModel;
                    }
                    long rows = pages.getRowCount();
                    RecordReader recordReader = colIO.getRecordReader(pages, (RecordMaterializer)new GroupRecordConverter(schema));
                    int i = 0;
                    while (true) {
                        Map<Integer, NodeData> nodesByNodeId;
                        if ((long)i >= rows) continue block25;
                        SimpleGroup g = (SimpleGroup)recordReader.read();
                        int treeID = g.getInteger(0, 0);
                        SimpleGroup nodeDataGroup = (SimpleGroup)g.getGroup(1, 0);
                        NodeData nodeData = SparkModelParser.extractNodeDataFromParquetRow(nodeDataGroup);
                        if (nodesByTreeId.containsKey(treeID)) {
                            nodesByNodeId = (Map)nodesByTreeId.get(treeID);
                            nodesByNodeId.put(nodeData.id, nodeData);
                        } else {
                            nodesByNodeId = new TreeMap<Integer, NodeData>();
                            ((TreeMap)nodesByNodeId).put(nodeData.id, nodeData);
                            nodesByTreeId.put(treeID, nodesByNodeId);
                        }
                        ++i;
                    }
                    break;
                }
            }
            catch (Throwable throwable3) {
                throwable = throwable3;
                throw throwable3;
            }
        }
        catch (IOException e) {
            System.out.println("Error reading parquet file.");
            e.printStackTrace();
            return null;
        }
    }

    private static Model loadRandomForestModel(String pathToMdl) {
        List<IgniteModel<Vector, Double>> models = SparkModelParser.parseTreesForRandomForestAlgorithm(pathToMdl);
        if (models == null) {
            return null;
        }
        return new ModelsComposition(models, (PredictionsAggregator)new OnMajorityPredictionsAggregator());
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private static List<IgniteModel<Vector, Double>> parseTreesForRandomForestAlgorithm(String pathToMdl) {
        try (ParquetFileReader r = ParquetFileReader.open((InputFile)HadoopInputFile.fromPath((Path)new Path(pathToMdl), (Configuration)new Configuration()));){
            MessageType schema = r.getFooter().getFileMetaData().getSchema();
            MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
            TreeMap nodesByTreeId = new TreeMap();
            block11: while (true) {
                PageReadStore pages;
                if (null == (pages = r.readNextRowGroup())) {
                    ArrayList<IgniteModel<Vector, Double>> models = new ArrayList<IgniteModel<Vector, Double>>();
                    nodesByTreeId.forEach((key, nodes) -> models.add((IgniteModel<Vector, Double>)SparkModelParser.buildDecisionTreeModel(nodes)));
                    ArrayList<IgniteModel<Vector, Double>> arrayList = models;
                    return arrayList;
                }
                long rows = pages.getRowCount();
                RecordReader recordReader = colIO.getRecordReader(pages, (RecordMaterializer)new GroupRecordConverter(schema));
                int i = 0;
                while (true) {
                    Map<Integer, NodeData> nodesByNodeId;
                    if ((long)i >= rows) continue block11;
                    SimpleGroup g = (SimpleGroup)recordReader.read();
                    int treeID = g.getInteger(0, 0);
                    SimpleGroup nodeDataGroup = (SimpleGroup)g.getGroup(1, 0);
                    NodeData nodeData = SparkModelParser.extractNodeDataFromParquetRow(nodeDataGroup);
                    if (nodesByTreeId.containsKey(treeID)) {
                        nodesByNodeId = (Map)nodesByTreeId.get(treeID);
                        nodesByNodeId.put(nodeData.id, nodeData);
                    } else {
                        nodesByNodeId = new TreeMap<Integer, NodeData>();
                        ((TreeMap)nodesByNodeId).put(nodeData.id, nodeData);
                        nodesByTreeId.put(treeID, nodesByNodeId);
                    }
                    ++i;
                }
                break;
            }
        }
        catch (IOException e) {
            System.out.println("Error reading parquet file.");
            e.printStackTrace();
            return null;
        }
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private static Model loadDecisionTreeModel(String pathToMdl) {
        try (ParquetFileReader r = ParquetFileReader.open((InputFile)HadoopInputFile.fromPath((Path)new Path(pathToMdl), (Configuration)new Configuration()));){
            PageReadStore pages;
            MessageType schema = r.getFooter().getFileMetaData().getSchema();
            MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
            TreeMap<Integer, NodeData> nodes = new TreeMap<Integer, NodeData>();
            while (null != (pages = r.readNextRowGroup())) {
                long rows = pages.getRowCount();
                RecordReader recordReader = colIO.getRecordReader(pages, (RecordMaterializer)new GroupRecordConverter(schema));
                int i = 0;
                while ((long)i < rows) {
                    SimpleGroup g = (SimpleGroup)recordReader.read();
                    NodeData nodeData = SparkModelParser.extractNodeDataFromParquetRow(g);
                    nodes.put(nodeData.id, nodeData);
                    ++i;
                }
            }
            DecisionTreeNode decisionTreeNode = SparkModelParser.buildDecisionTreeModel(nodes);
            return decisionTreeNode;
        }
        catch (IOException e) {
            System.out.println("Error reading parquet file.");
            e.printStackTrace();
            return null;
        }
    }

    private static DecisionTreeNode buildDecisionTreeModel(Map<Integer, NodeData> nodes) {
        DecisionTreeNode mdl = null;
        if (!nodes.isEmpty()) {
            NodeData rootNodeData = (NodeData)((NavigableMap)nodes).firstEntry().getValue();
            mdl = SparkModelParser.buildTree(nodes, rootNodeData);
            return mdl;
        }
        return mdl;
    }

    @NotNull
    private static DecisionTreeNode buildTree(Map<Integer, NodeData> nodes, NodeData rootNodeData) {
        return rootNodeData.isLeafNode ? new DecisionTreeLeafNode(rootNodeData.prediction) : new DecisionTreeConditionalNode(rootNodeData.featureIdx, rootNodeData.threshold, SparkModelParser.buildTree(nodes, nodes.get(rootNodeData.rightChildId)), SparkModelParser.buildTree(nodes, nodes.get(rootNodeData.leftChildId)), null);
    }

    @NotNull
    private static NodeData extractNodeDataFromParquetRow(SimpleGroup g) {
        NodeData nodeData = new NodeData();
        nodeData.id = g.getInteger(0, 0);
        nodeData.prediction = g.getDouble(1, 0);
        nodeData.leftChildId = g.getInteger(5, 0);
        nodeData.rightChildId = g.getInteger(6, 0);
        if (nodeData.leftChildId == -1 && nodeData.rightChildId == -1) {
            nodeData.featureIdx = -1;
            nodeData.threshold = -1.0;
            nodeData.isLeafNode = true;
        } else {
            SimpleGroup splitGrp = (SimpleGroup)g.getGroup(7, 0);
            nodeData.featureIdx = splitGrp.getInteger(0, 0);
            nodeData.threshold = splitGrp.getGroup(1, 0).getGroup(0, 0).getDouble(0, 0);
        }
        return nodeData;
    }

    private static void printGroup(Group g) {
        int fieldCnt = g.getType().getFieldCount();
        for (int field = 0; field < fieldCnt; ++field) {
            int valCnt = g.getFieldRepetitionCount(field);
            Type fieldType = g.getType().getType(field);
            String fieldName = fieldType.getName();
            for (int idx = 0; idx < valCnt; ++idx) {
                if (fieldType.isPrimitive()) {
                    System.out.println(fieldName + " " + g.getValueToString(field, idx));
                    continue;
                }
                SparkModelParser.printGroup(g.getGroup(field, idx));
            }
        }
        System.out.println();
    }

    private static Model loadLinearSVMModel(String pathToMdl) {
        Vector coefficients = null;
        double interceptor = 0.0;
        try (ParquetFileReader r = ParquetFileReader.open((InputFile)HadoopInputFile.fromPath((Path)new Path(pathToMdl), (Configuration)new Configuration()));){
            PageReadStore pages;
            MessageType schema = r.getFooter().getFileMetaData().getSchema();
            MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
            while (null != (pages = r.readNextRowGroup())) {
                long rows = pages.getRowCount();
                RecordReader recordReader = colIO.getRecordReader(pages, (RecordMaterializer)new GroupRecordConverter(schema));
                int i = 0;
                while ((long)i < rows) {
                    SimpleGroup g = (SimpleGroup)recordReader.read();
                    interceptor = SparkModelParser.readSVMInterceptor(g);
                    coefficients = SparkModelParser.readSVMCoefficients(g);
                    ++i;
                }
            }
        }
        catch (IOException e) {
            System.out.println("Error reading parquet file.");
            e.printStackTrace();
        }
        return new SVMLinearClassificationModel(coefficients, interceptor);
    }

    private static Model loadLinRegModel(String pathToMdl) {
        Vector coefficients = null;
        double interceptor = 0.0;
        try (ParquetFileReader r = ParquetFileReader.open((InputFile)HadoopInputFile.fromPath((Path)new Path(pathToMdl), (Configuration)new Configuration()));){
            PageReadStore pages;
            MessageType schema = r.getFooter().getFileMetaData().getSchema();
            MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
            while (null != (pages = r.readNextRowGroup())) {
                long rows = pages.getRowCount();
                RecordReader recordReader = colIO.getRecordReader(pages, (RecordMaterializer)new GroupRecordConverter(schema));
                int i = 0;
                while ((long)i < rows) {
                    SimpleGroup g = (SimpleGroup)recordReader.read();
                    interceptor = SparkModelParser.readLinRegInterceptor(g);
                    coefficients = SparkModelParser.readLinRegCoefficients(g);
                    ++i;
                }
            }
        }
        catch (IOException e) {
            System.out.println("Error reading parquet file.");
            e.printStackTrace();
        }
        return new LinearRegressionModel(coefficients, interceptor);
    }

    private static Model loadLogRegModel(String pathToMdl) {
        Vector coefficients = null;
        double interceptor = 0.0;
        try (ParquetFileReader r = ParquetFileReader.open((InputFile)HadoopInputFile.fromPath((Path)new Path(pathToMdl), (Configuration)new Configuration()));){
            PageReadStore pages;
            MessageType schema = r.getFooter().getFileMetaData().getSchema();
            MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
            while (null != (pages = r.readNextRowGroup())) {
                long rows = pages.getRowCount();
                RecordReader recordReader = colIO.getRecordReader(pages, (RecordMaterializer)new GroupRecordConverter(schema));
                int i = 0;
                while ((long)i < rows) {
                    SimpleGroup g = (SimpleGroup)recordReader.read();
                    interceptor = SparkModelParser.readInterceptor(g);
                    coefficients = SparkModelParser.readCoefficients(g);
                    ++i;
                }
            }
        }
        catch (IOException e) {
            System.out.println("Error reading parquet file.");
            e.printStackTrace();
        }
        return new LogisticRegressionModel(coefficients, interceptor);
    }

    private static double readSVMInterceptor(SimpleGroup g) {
        return g.getDouble(1, 0);
    }

    private static Vector readSVMCoefficients(SimpleGroup g) {
        Group coeffGroup = g.getGroup(0, 0).getGroup(3, 0);
        int amountOfCoefficients = coeffGroup.getFieldRepetitionCount(0);
        DenseVector coefficients = new DenseVector(amountOfCoefficients);
        for (int j = 0; j < amountOfCoefficients; ++j) {
            double coefficient = coeffGroup.getGroup(0, j).getDouble(0, 0);
            coefficients.set(j, coefficient);
        }
        return coefficients;
    }

    private static double readLinRegInterceptor(SimpleGroup g) {
        return g.getDouble(0, 0);
    }

    private static Vector readLinRegCoefficients(SimpleGroup g) {
        Group coeffGroup = g.getGroup(1, 0).getGroup(3, 0);
        int amountOfCoefficients = coeffGroup.getFieldRepetitionCount(0);
        DenseVector coefficients = new DenseVector(amountOfCoefficients);
        for (int j = 0; j < amountOfCoefficients; ++j) {
            double coefficient = coeffGroup.getGroup(0, j).getDouble(0, 0);
            coefficients.set(j, coefficient);
        }
        return coefficients;
    }

    private static double readInterceptor(SimpleGroup g) {
        SimpleGroup interceptVector = (SimpleGroup)g.getGroup(2, 0);
        SimpleGroup interceptVectorVal = (SimpleGroup)interceptVector.getGroup(3, 0);
        SimpleGroup interceptVectorValElement = (SimpleGroup)interceptVectorVal.getGroup(0, 0);
        double interceptor = interceptVectorValElement.getDouble(0, 0);
        return interceptor;
    }

    private static Vector readCoefficients(SimpleGroup g) {
        int amountOfCoefficients = g.getGroup(3, 0).getGroup(5, 0).getFieldRepetitionCount(0);
        DenseVector coefficients = new DenseVector(amountOfCoefficients);
        for (int j = 0; j < amountOfCoefficients; ++j) {
            double coefficient = g.getGroup(3, 0).getGroup(5, 0).getGroup(0, j).getDouble(0, 0);
            coefficients.set(j, coefficient);
        }
        return coefficients;
    }

    private static class NodeData {
        int id;
        double prediction;
        int leftChildId;
        int rightChildId;
        double threshold;
        int featureIdx;
        boolean isLeafNode;

        private NodeData() {
        }

        public String toString() {
            return "NodeData{id=" + this.id + ", prediction=" + this.prediction + ", leftChildId=" + this.leftChildId + ", rightChildId=" + this.rightChildId + ", threshold=" + this.threshold + ", featureIdx=" + this.featureIdx + ", isLeafNode=" + this.isLeafNode + '}';
        }
    }
}

