org.apache.sysml.utils.GenerateClassesForMLContext.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.sysml.utils.GenerateClassesForMLContext.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.sysml.utils;

import java.io.File;
import java.io.FileFilter;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FilenameFilter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.StringEscapeUtils;
import org.apache.log4j.BasicConfigurator;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.api.mlcontext.MLContext;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.api.mlcontext.ScriptExecutor;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.parser.FunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.LanguageException;
import org.apache.sysml.parser.Statement;

import javassist.CannotCompileException;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtConstructor;
import javassist.CtField;
import javassist.CtMethod;
import javassist.CtNewConstructor;
import javassist.CtNewMethod;
import javassist.Modifier;
import javassist.NotFoundException;

/**
 * Automatically generate classes and methods for interaction with DML scripts
 * and functions through the MLContext API.
 * 
 */
public class GenerateClassesForMLContext {

    public static final String SOURCE = "scripts";
    public static final String DESTINATION = "target/classes";
    public static final String BASE_DEST_PACKAGE = "org.apache.sysml";
    public static final String CONVENIENCE_BASE_DEST_PACKAGE = "org.apache.sysml.api.mlcontext.convenience";
    public static final String PATH_TO_MLCONTEXT_CLASS = "org/apache/sysml/api/mlcontext/MLContext.class";
    public static final String PATH_TO_MLRESULTS_CLASS = "org/apache/sysml/api/mlcontext/MLResults.class";
    public static final String PATH_TO_SCRIPT_CLASS = "org/apache/sysml/api/mlcontext/Script.class";
    public static final String PATH_TO_SCRIPTTYPE_CLASS = "org/apache/sysml/api/mlcontext/ScriptType.class";
    public static final String PATH_TO_MATRIX_CLASS = "org/apache/sysml/api/mlcontext/Matrix.class";
    public static final String PATH_TO_FRAME_CLASS = "org/apache/sysml/api/mlcontext/Frame.class";

    public static String source = SOURCE;
    public static String destination = DESTINATION;
    public static boolean skipStagingDir = true;
    public static boolean skipPerfTestDir = true;
    public static boolean skipObsoleteDir = true;
    public static boolean skipCompareBackendsDir = true;

    public static void main(String[] args) throws Throwable {
        if (args.length == 2) {
            source = args[0];
            destination = args[1];
        } else if (args.length == 1) {
            source = args[0];
        }
        try {
            BasicConfigurator.configure();
            Logger.getRootLogger().setLevel(Level.WARN);

            DMLScript.VALIDATOR_IGNORE_ISSUES = true;
            System.out.println("************************************");
            System.out.println("**** MLContext Class Generation ****");
            System.out.println("************************************");
            System.out.println("Source: " + source);
            System.out.println("Destination: " + destination);
            makeCtClasses();
            recurseDirectoriesForClassGeneration(source);
            String fullDirClassName = recurseDirectoriesForConvenienceClassGeneration(source);
            addConvenienceMethodsToMLContext(source, fullDirClassName);
        } finally {
            DMLScript.VALIDATOR_IGNORE_ISSUES = false;
        }
    }

    /**
     * Create compile-time classes required for later class generation.
     */
    public static void makeCtClasses() {
        try {
            ClassPool pool = ClassPool.getDefault();
            pool.makeClass(new FileInputStream(new File(destination + File.separator + PATH_TO_MLCONTEXT_CLASS)));
            pool.makeClass(new FileInputStream(new File(destination + File.separator + PATH_TO_MLRESULTS_CLASS)));
            pool.makeClass(new FileInputStream(new File(destination + File.separator + PATH_TO_SCRIPT_CLASS)));
            pool.makeClass(new FileInputStream(new File(destination + File.separator + PATH_TO_SCRIPTTYPE_CLASS)));
            pool.makeClass(new FileInputStream(new File(destination + File.separator + PATH_TO_MATRIX_CLASS)));
            pool.makeClass(new FileInputStream(new File(destination + File.separator + PATH_TO_FRAME_CLASS)));
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        } catch (RuntimeException e) {
            e.printStackTrace();
        }
    }

    /**
     * Add methods to MLContext to allow tab-completion to folders/packages
     * (such as {@code ml.scripts()} and {@code ml.nn()}).
     * 
     * @param source
     *            path to source directory (typically, the scripts directory)
     * @param fullDirClassName
     *            the full name of the class representing the source (scripts)
     *            directory
     */
    public static void addConvenienceMethodsToMLContext(String source, String fullDirClassName) {
        try {
            ClassPool pool = ClassPool.getDefault();
            CtClass ctMLContext = pool.get(MLContext.class.getName());

            CtClass dirClass = pool.get(fullDirClassName);
            String methodName = convertFullClassNameToConvenienceMethodName(fullDirClassName);
            System.out.println("Adding " + methodName + "() to " + ctMLContext.getName());

            String methodBody = "{ " + fullDirClassName + " z = new " + fullDirClassName + "(); return z; }";
            CtMethod ctMethod = CtNewMethod.make(Modifier.PUBLIC, dirClass, methodName, null, null, methodBody,
                    ctMLContext);
            ctMLContext.addMethod(ctMethod);

            addPackageConvenienceMethodsToMLContext(source, ctMLContext);

            ctMLContext.writeFile(destination);
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        } catch (RuntimeException e) {
            e.printStackTrace();
        } catch (NotFoundException e) {
            e.printStackTrace();
        } catch (CannotCompileException e) {
            e.printStackTrace();
        }
    }

    /**
     * Add methods to MLContext to allow tab-completion to packages contained
     * within the source directory (such as {@code ml.nn()}).
     *
     * @param dirPath
     *            path to source directory (typically, the scripts directory)
     * @param ctMLContext
     *            javassist compile-time class representation of MLContext
     */
    public static void addPackageConvenienceMethodsToMLContext(String dirPath, CtClass ctMLContext) {

        try {
            if (!SOURCE.equalsIgnoreCase(dirPath)) {
                return;
            }
            File dir = new File(dirPath);
            File[] subdirs = dir.listFiles(new FileFilter() {
                @Override
                public boolean accept(File f) {
                    return f.isDirectory();
                }
            });
            for (File subdir : subdirs) {
                String subDirPath = dirPath + File.separator + subdir.getName();
                if (skipDir(subdir, false)) {
                    continue;
                }

                String fullSubDirClassName = dirPathToFullDirClassName(subDirPath);

                ClassPool pool = ClassPool.getDefault();
                CtClass subDirClass = pool.get(fullSubDirClassName);
                String subDirName = subdir.getName();
                subDirName = subDirName.replaceAll("-", "_");
                subDirName = subDirName.toLowerCase();

                System.out.println("Adding " + subDirName + "() to " + ctMLContext.getName());

                String methodBody = "{ " + fullSubDirClassName + " z = new " + fullSubDirClassName
                        + "(); return z; }";
                CtMethod ctMethod = CtNewMethod.make(Modifier.PUBLIC, subDirClass, subDirName, null, null,
                        methodBody, ctMLContext);
                ctMLContext.addMethod(ctMethod);

            }
        } catch (NotFoundException e) {
            e.printStackTrace();
        } catch (CannotCompileException e) {
            e.printStackTrace();
        }

    }

    /**
     * Convert the full name of a class representing a directory to a method
     * name.
     * 
     * @param fullDirClassName
     *            the full name of the class representing a directory
     * @return method name
     */
    public static String convertFullClassNameToConvenienceMethodName(String fullDirClassName) {
        String m = fullDirClassName;
        m = m.substring(m.lastIndexOf(".") + 1);
        m = m.toLowerCase();
        return m;
    }

    /**
     * Generate convenience classes recursively. This allows for code such as
     * {@code ml.scripts.algorithms...}.
     * 
     * @param dirPath
     *            path to directory
     * @return the full name of the class representing the dirPath directory
     */
    public static String recurseDirectoriesForConvenienceClassGeneration(String dirPath) {
        try {
            File dir = new File(dirPath);

            String fullDirClassName = dirPathToFullDirClassName(dirPath);
            System.out.println("Generating Class: " + fullDirClassName);

            ClassPool pool = ClassPool.getDefault();
            CtClass ctDir = pool.makeClass(fullDirClassName);

            File[] subdirs = dir.listFiles(new FileFilter() {
                @Override
                public boolean accept(File f) {
                    return f.isDirectory();
                }
            });
            for (File subdir : subdirs) {
                String subDirPath = dirPath + File.separator + subdir.getName();
                if (skipDir(subdir, false)) {
                    continue;
                }
                String fullSubDirClassName = recurseDirectoriesForConvenienceClassGeneration(subDirPath);

                CtClass subDirClass = pool.get(fullSubDirClassName);
                String subDirName = subdir.getName();
                subDirName = subDirName.replaceAll("-", "_");
                subDirName = subDirName.toLowerCase();

                System.out.println("Adding " + subDirName + "() to " + fullDirClassName);

                String methodBody = "{ " + fullSubDirClassName + " z = new " + fullSubDirClassName
                        + "(); return z; }";
                CtMethod ctMethod = CtNewMethod.make(Modifier.PUBLIC, subDirClass, subDirName, null, null,
                        methodBody, ctDir);
                ctDir.addMethod(ctMethod);

            }

            File[] scriptFiles = dir.listFiles(new FilenameFilter() {
                @Override
                public boolean accept(File dir, String name) {
                    return (name.toLowerCase().endsWith(".dml") || name.toLowerCase().endsWith(".pydml"));
                }
            });
            for (File scriptFile : scriptFiles) {
                String scriptFilePath = scriptFile.getPath();
                String fullScriptClassName = BASE_DEST_PACKAGE + "."
                        + scriptFilePathToFullClassNameNoBase(scriptFilePath);
                CtClass scriptClass = pool.get(fullScriptClassName);
                String methodName = scriptFilePathToSimpleClassName(scriptFilePath);
                String methodBody = "{ " + fullScriptClassName + " z = new " + fullScriptClassName
                        + "(); return z; }";
                CtMethod ctMethod = CtNewMethod.make(Modifier.PUBLIC, scriptClass, methodName, null, null,
                        methodBody, ctDir);
                ctDir.addMethod(ctMethod);

            }

            ctDir.writeFile(destination);

            return fullDirClassName;
        } catch (RuntimeException e) {
            e.printStackTrace();
        } catch (CannotCompileException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        } catch (NotFoundException e) {
            e.printStackTrace();
        }
        return null;
    }

    /**
     * Convert a directory path to a full class name for a convenience class.
     * 
     * @param dirPath
     *            path to directory
     * @return the full name of the class representing the dirPath directory
     */
    public static String dirPathToFullDirClassName(String dirPath) {
        if (!dirPath.contains(File.separator)) {
            String c = dirPath;
            c = c.replace("-", "_");
            c = c.substring(0, 1).toUpperCase() + c.substring(1);
            c = CONVENIENCE_BASE_DEST_PACKAGE + "." + c;
            return c;
        }

        String p = dirPath;
        p = p.substring(0, p.lastIndexOf(File.separator));
        p = p.replace("-", "_");
        p = p.replace(File.separator, ".");
        p = p.toLowerCase();

        String c = dirPath;
        c = c.substring(c.lastIndexOf(File.separator) + 1);
        c = c.replace("-", "_");
        c = c.substring(0, 1).toUpperCase() + c.substring(1);

        return CONVENIENCE_BASE_DEST_PACKAGE + "." + p + "." + c;
    }

    /**
     * Whether or not the directory (and subdirectories of the directory) should
     * be skipped.
     * 
     * @param dir
     *            path to directory to check
     * @param displayMessage
     *            if {@code true}, display skip information to standard output
     * @return {@code true} if the directory should be skipped, {@code false}
     *         otherwise
     */
    public static boolean skipDir(File dir, boolean displayMessage) {
        if ("staging".equalsIgnoreCase(dir.getName()) && skipStagingDir) {
            if (displayMessage) {
                System.out.println("Skipping staging directory: " + dir.getPath());
            }
            return true;
        }
        if ("perftest".equalsIgnoreCase(dir.getName()) && skipPerfTestDir) {
            if (displayMessage) {
                System.out.println("Skipping perftest directory: " + dir.getPath());
            }
            return true;
        }
        if ("obsolete".equalsIgnoreCase(dir.getName()) && skipObsoleteDir) {
            if (displayMessage) {
                System.out.println("Skipping obsolete directory: " + dir.getPath());
            }
            return true;
        }
        if ("compare_backends".equalsIgnoreCase(dir.getName()) && skipCompareBackendsDir) {
            if (displayMessage) {
                System.out.println("Skipping compare_backends directory: " + dir.getPath());
            }
            return true;
        }
        return false;
    }

    /**
     * Recursively traverse the directories to create classes representing the
     * script files.
     * 
     * @param dirPath
     *            path to directory
     */
    public static void recurseDirectoriesForClassGeneration(String dirPath) {
        File dir = new File(dirPath);

        iterateScriptFilesInDirectory(dir);

        File[] subdirs = dir.listFiles(new FileFilter() {
            @Override
            public boolean accept(File f) {
                return f.isDirectory();
            }
        });
        for (File subdir : subdirs) {
            String subdirpath = dirPath + File.separator + subdir.getName();
            if (skipDir(subdir, true)) {
                continue;
            }
            recurseDirectoriesForClassGeneration(subdirpath);
        }
    }

    /**
     * Iterate through the script files in a directory and create a class for
     * each script file.
     * 
     * @param dir
     *            the directory to iterate through
     */
    public static void iterateScriptFilesInDirectory(File dir) {
        File[] scriptFiles = dir.listFiles(new FilenameFilter() {
            @Override
            public boolean accept(File dir, String name) {
                return (name.toLowerCase().endsWith(".dml") || name.toLowerCase().endsWith(".pydml"));
            }
        });
        for (File scriptFile : scriptFiles) {
            String scriptFilePath = scriptFile.getPath();
            createScriptClass(scriptFilePath);
        }
    }

    /**
     * Obtain the relative package for a script file. For example,
     * {@code scripts/algorithms/LinearRegCG.dml} resolves to
     * {@code scripts.algorithms}.
     * 
     * @param scriptFilePath
     *            the path to a script file
     * @return the relative package for a script file
     */
    public static String scriptFilePathToPackageNoBase(String scriptFilePath) {
        String p = scriptFilePath;
        p = p.substring(0, p.lastIndexOf(File.separator));
        p = p.replace("-", "_");
        p = p.replace(File.separator, ".");
        p = p.toLowerCase();
        return p;
    }

    /**
     * Obtain the simple class name for a script file. For example,
     * {@code scripts/algorithms/LinearRegCG} resolves to {@code LinearRegCG}.
     * 
     * @param scriptFilePath
     *            the path to a script file
     * @return the simple class name for a script file
     */
    public static String scriptFilePathToSimpleClassName(String scriptFilePath) {
        String c = scriptFilePath;
        c = c.substring(c.lastIndexOf(File.separator) + 1);
        c = c.replace("-", "_");
        c = c.substring(0, 1).toUpperCase() + c.substring(1);
        c = c.substring(0, c.indexOf("."));
        return c;
    }

    /**
     * Obtain the relative full class name for a script file. For example,
     * {@code scripts/algorithms/LinearRegCG.dml} resolves to
     * {@code scripts.algorithms.LinearRegCG}.
     * 
     * @param scriptFilePath
     *            the path to a script file
     * @return the relative full class name for a script file
     */
    public static String scriptFilePathToFullClassNameNoBase(String scriptFilePath) {
        String p = scriptFilePathToPackageNoBase(scriptFilePath);
        String c = scriptFilePathToSimpleClassName(scriptFilePath);
        return p + "." + c;
    }

    /**
     * Convert a script file to a Java class that extends the MLContext API's
     * Script class.
     * 
     * @param scriptFilePath
     *            the path to a script file
     */
    public static void createScriptClass(String scriptFilePath) {
        try {
            String fullScriptClassName = BASE_DEST_PACKAGE + "."
                    + scriptFilePathToFullClassNameNoBase(scriptFilePath);
            System.out.println("Generating Class: " + fullScriptClassName);
            ClassPool pool = ClassPool.getDefault();
            CtClass ctNewScript = pool.makeClass(fullScriptClassName);

            CtClass ctScript = pool.get(Script.class.getName());
            ctNewScript.setSuperclass(ctScript);

            CtConstructor ctCon = new CtConstructor(null, ctNewScript);
            ctCon.setBody(scriptConstructorBody(scriptFilePath));
            ctNewScript.addConstructor(ctCon);

            addFunctionMethods(scriptFilePath, ctNewScript);

            ctNewScript.writeFile(destination);
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (RuntimeException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        } catch (CannotCompileException e) {
            e.printStackTrace();
        } catch (NotFoundException e) {
            e.printStackTrace();
        }

    }

    /**
     * Create a DMLProgram from a script file.
     * 
     * @param scriptFilePath
     *            the path to a script file
     * @return the DMLProgram generated by the script file
     */
    public static DMLProgram dmlProgramFromScriptFilePath(String scriptFilePath) {
        String scriptString = fileToString(scriptFilePath);
        Script script = new Script(scriptString);
        ScriptExecutor se = new ScriptExecutor() {
            @Override
            public MLResults execute(Script script) {
                setup(script);
                parseScript();
                return null;
            }
        };
        se.execute(script);
        DMLProgram dmlProgram = se.getDmlProgram();
        return dmlProgram;
    }

    /**
     * Add methods to a derived script class to allow invocation of script
     * functions.
     * 
     * @param scriptFilePath
     *            the path to a script file
     * @param ctNewScript
     *            the javassist compile-time class representation of a script
     */
    public static void addFunctionMethods(String scriptFilePath, CtClass ctNewScript) {
        try {
            DMLProgram dmlProgram = dmlProgramFromScriptFilePath(scriptFilePath);
            if (dmlProgram == null) {
                System.out.println("Could not generate DML Program for: " + scriptFilePath);
                return;
            }
            Map<String, FunctionStatementBlock> defaultNsFsbsMap = dmlProgram
                    .getFunctionStatementBlocks(DMLProgram.DEFAULT_NAMESPACE);
            List<FunctionStatementBlock> fsbs = new ArrayList<>();
            fsbs.addAll(defaultNsFsbsMap.values());
            for (FunctionStatementBlock fsb : fsbs) {
                ArrayList<Statement> sts = fsb.getStatements();
                for (Statement st : sts) {
                    if (!(st instanceof FunctionStatement)) {
                        continue;
                    }
                    FunctionStatement fs = (FunctionStatement) st;

                    String dmlFunctionCall = generateDmlFunctionCall(scriptFilePath, fs);
                    String functionCallMethod = generateFunctionCallMethod(scriptFilePath, fs, dmlFunctionCall);

                    CtMethod m = CtNewMethod.make(functionCallMethod, ctNewScript);
                    ctNewScript.addMethod(m);

                    addDescriptionFunctionCallMethod(fs, scriptFilePath, ctNewScript, false);
                    addDescriptionFunctionCallMethod(fs, scriptFilePath, ctNewScript, true);
                }
            }
        } catch (LanguageException e) {
            System.out.println("Could not add function methods for " + ctNewScript.getName());
        } catch (CannotCompileException e) {
            System.out.println("Could not add function methods for " + ctNewScript.getName());
        } catch (RuntimeException e) {
            System.out.println("Could not add function methods for " + ctNewScript.getName());
        }
    }

    /**
     * Create a method that returns either: (1) the full function body, or (2)
     * the function body up to the end of the documentation comment for the
     * function. If (1) is generated, the method name will be followed
     * "__source". If (2) is generated, the method name will be followed by
     * "__docs". If (2) is generated but no end of documentation comment is
     * detected, the full function body will be displayed.
     * 
     * @param fs
     *            a SystemML function statement
     * @param scriptFilePath
     *            the path to a script file
     * @param ctNewScript
     *            the javassist compile-time class representation of a script
     * @param full
     *            if {@code true}, create method to return full function body;
     *            if {@code false}, create method to return the function body up
     *            to the end of the documentation comment
     */
    public static void addDescriptionFunctionCallMethod(FunctionStatement fs, String scriptFilePath,
            CtClass ctNewScript, boolean full) {

        try {
            int bl = fs.getBeginLine();
            int el = fs.getEndLine();
            File f = new File(scriptFilePath);
            List<String> lines = FileUtils.readLines(f);

            int end = el;
            if (!full) {
                for (int i = bl - 1; i < el; i++) {
                    String line = lines.get(i);
                    if (line.contains("*/")) {
                        end = i + 1;
                        break;
                    }
                }
            }
            List<String> sub = lines.subList(bl - 1, end);
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < sub.size(); i++) {
                String line = sub.get(i);
                String escapeLine = StringEscapeUtils.escapeJava(line);
                sb.append(escapeLine);
                sb.append("\\n");
            }

            String functionString = sb.toString();
            String docFunctionCallMethod = generateDescriptionFunctionCallMethod(fs, functionString, full);
            CtMethod m = CtNewMethod.make(docFunctionCallMethod, ctNewScript);
            ctNewScript.addMethod(m);

        } catch (IOException e) {
            e.printStackTrace();
        } catch (CannotCompileException e) {
            e.printStackTrace();
        }

    }

    /**
     * Generate method for returning (1) the full function body, or (2) the
     * function body up to the end of the documentation comment. (1) will have
     * "__source" appended to the end of the function name. (2) will have
     * "__docs" appended to the end of the function name.
     * 
     * @param fs
     *            a SystemML function statement
     * @param functionString
     *            either the full function body or the function body up to the
     *            end of the documentation comment
     * @param full
     *            if {@code true}, append "__source" to the end of the function
     *            name; if {@code false}, append "__docs" to the end of the
     *            function name
     * @return string representation of the function description method
     */
    public static String generateDescriptionFunctionCallMethod(FunctionStatement fs, String functionString,
            boolean full) {
        StringBuilder sb = new StringBuilder();
        sb.append("public String ");
        sb.append(fs.getName());
        if (full) {
            sb.append("__source");
        } else {
            sb.append("__docs");
        }
        sb.append("() {\n");
        sb.append("String docString = \"" + functionString + "\";\n");
        sb.append("return docString;\n");
        sb.append("}\n");
        return sb.toString();
    }

    /**
     * Obtain a string representation of a parameter type, where a Matrix or
     * Frame is represented by its full class name.
     * 
     * @param param
     *            the function parameter
     * @return string representation of a parameter type
     */
    public static String getParamTypeAsString(DataIdentifier param) {
        DataType dt = param.getDataType();
        ValueType vt = param.getValueType();
        if ((dt == DataType.SCALAR) && (vt == ValueType.INT)) {
            return "long";
        } else if ((dt == DataType.SCALAR) && (vt == ValueType.DOUBLE)) {
            return "double";
        } else if ((dt == DataType.SCALAR) && (vt == ValueType.BOOLEAN)) {
            return "boolean";
        } else if ((dt == DataType.SCALAR) && (vt == ValueType.STRING)) {
            return "String";
        } else if (dt == DataType.MATRIX) {
            return "org.apache.sysml.api.mlcontext.Matrix";
        } else if (dt == DataType.FRAME) {
            return "org.apache.sysml.api.mlcontext.Frame";
        }
        return null;
    }

    /**
     * Obtain a string representation of a parameter type, where a Matrix or
     * Frame is represented by its simple class name.
     * 
     * @param param
     *            the function parameter
     * @return string representation of a parameter type
     */
    public static String getSimpleParamTypeAsString(DataIdentifier param) {
        DataType dt = param.getDataType();
        ValueType vt = param.getValueType();
        if ((dt == DataType.SCALAR) && (vt == ValueType.INT)) {
            return "long";
        } else if ((dt == DataType.SCALAR) && (vt == ValueType.DOUBLE)) {
            return "double";
        } else if ((dt == DataType.SCALAR) && (vt == ValueType.BOOLEAN)) {
            return "boolean";
        } else if ((dt == DataType.SCALAR) && (vt == ValueType.STRING)) {
            return "String";
        } else if (dt == DataType.MATRIX) {
            return "Matrix";
        } else if (dt == DataType.FRAME) {
            return "Frame";
        }
        return null;
    }

    /**
     * Obtain the full class name for a class that encapsulates the outputs of a
     * function
     * 
     * @param scriptFilePath
     *            the path to a script file
     * @param fs
     *            a SystemML function statement
     * @return the full class name for a class that encapsulates the outputs of
     *         a function
     */
    public static String getFullFunctionOutputClassName(String scriptFilePath, FunctionStatement fs) {
        String p = scriptFilePath;
        p = p.replace("-", "_");
        p = p.replace(File.separator, ".");
        p = p.toLowerCase();
        p = p.substring(0, p.lastIndexOf("."));

        String c = fs.getName();
        c = c.substring(0, 1).toUpperCase() + c.substring(1);
        c = c + "_output";

        String functionOutputClassName = BASE_DEST_PACKAGE + "." + p + "." + c;
        return functionOutputClassName;
    }

    /**
     * Create a class that encapsulates the outputs of a function.
     * 
     * @param scriptFilePath
     *            the path to a script file
     * @param fs
     *            a SystemML function statement
     */
    public static void createFunctionOutputClass(String scriptFilePath, FunctionStatement fs) {

        try {
            ArrayList<DataIdentifier> oparams = fs.getOutputParams();
            // Note: if a function returns 1 output, simply output it rather
            // than encapsulating it in a function output class
            if ((oparams.size() == 0) || (oparams.size() == 1)) {
                return;
            }

            String fullFunctionOutputClassName = getFullFunctionOutputClassName(scriptFilePath, fs);
            System.out.println("Generating Class: " + fullFunctionOutputClassName);
            ClassPool pool = ClassPool.getDefault();
            CtClass ctFuncOut = pool.makeClass(fullFunctionOutputClassName);

            // add fields
            for (int i = 0; i < oparams.size(); i++) {
                DataIdentifier oparam = oparams.get(i);
                String type = getParamTypeAsString(oparam);
                String name = oparam.getName();
                String fstring = "public " + type + " " + name + ";";
                CtField field = CtField.make(fstring, ctFuncOut);
                ctFuncOut.addField(field);
            }

            // add constructor
            String simpleFuncOutClassName = fullFunctionOutputClassName
                    .substring(fullFunctionOutputClassName.lastIndexOf(".") + 1);
            StringBuilder con = new StringBuilder();
            con.append("public " + simpleFuncOutClassName + "(");
            for (int i = 0; i < oparams.size(); i++) {
                if (i > 0) {
                    con.append(", ");
                }
                DataIdentifier oparam = oparams.get(i);
                String type = getParamTypeAsString(oparam);
                String name = oparam.getName();
                con.append(type + " " + name);
            }
            con.append(") {\n");
            for (int i = 0; i < oparams.size(); i++) {
                DataIdentifier oparam = oparams.get(i);
                String name = oparam.getName();
                con.append("this." + name + "=" + name + ";\n");
            }
            con.append("}\n");
            String cstring = con.toString();
            CtConstructor ctCon = CtNewConstructor.make(cstring, ctFuncOut);
            ctFuncOut.addConstructor(ctCon);

            // add toString
            StringBuilder s = new StringBuilder();
            s.append("public String toString(){\n");
            s.append("StringBuilder sb = new StringBuilder();\n");
            for (int i = 0; i < oparams.size(); i++) {
                DataIdentifier oparam = oparams.get(i);
                String name = oparam.getName();
                s.append("sb.append(\"" + name + " (" + getSimpleParamTypeAsString(oparam) + "): \" + " + name
                        + " + \"\\n\");\n");
            }
            s.append("String str = sb.toString();\nreturn str;\n");
            s.append("}\n");
            String toStr = s.toString();
            CtMethod toStrMethod = CtNewMethod.make(toStr, ctFuncOut);
            ctFuncOut.addMethod(toStrMethod);

            ctFuncOut.writeFile(destination);

        } catch (RuntimeException e) {
            e.printStackTrace();
        } catch (CannotCompileException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * Obtain method for invoking a script function.
     * 
     * @param scriptFilePath
     *            the path to a script file
     * @param fs
     *            a SystemML function statement
     * @param dmlFunctionCall
     *            a string representing the invocation of a script function
     * @return string representation of a method that performs a function call
     */
    public static String generateFunctionCallMethod(String scriptFilePath, FunctionStatement fs,
            String dmlFunctionCall) {

        createFunctionOutputClass(scriptFilePath, fs);

        StringBuilder sb = new StringBuilder();
        sb.append("public ");

        // begin return type
        ArrayList<DataIdentifier> oparams = fs.getOutputParams();
        if (oparams.size() == 0) {
            sb.append("void");
        } else if (oparams.size() == 1) {
            // if 1 output, no need to encapsulate it, so return the output
            // directly
            DataIdentifier oparam = oparams.get(0);
            String type = getParamTypeAsString(oparam);
            sb.append(type);
        } else {
            String fullFunctionOutputClassName = getFullFunctionOutputClassName(scriptFilePath, fs);
            sb.append(fullFunctionOutputClassName);
        }
        sb.append(" ");
        // end return type

        sb.append(fs.getName());
        sb.append("(");

        ArrayList<DataIdentifier> inputParams = fs.getInputParams();
        for (int i = 0; i < inputParams.size(); i++) {
            if (i > 0) {
                sb.append(", ");
            }
            DataIdentifier inputParam = inputParams.get(i);
            /*
             * Note: Using Object is currently preferrable to using
             * datatype/valuetype to explicitly set the input type to
             * Integer/Double/Boolean/String since Object allows the automatic
             * handling of things such as automatic conversions from longs to
             * ints.
             */
            sb.append("Object ");
            sb.append(inputParam.getName());
        }

        sb.append(") {\n");
        sb.append("String scriptString = \"" + dmlFunctionCall + "\";\n");
        sb.append(
                "org.apache.sysml.api.mlcontext.Script script = new org.apache.sysml.api.mlcontext.Script(scriptString);\n");

        if ((inputParams.size() > 0) || (oparams.size() > 0)) {
            sb.append("script");
        }
        for (int i = 0; i < inputParams.size(); i++) {
            DataIdentifier inputParam = inputParams.get(i);
            String name = inputParam.getName();
            sb.append(".in(\"" + name + "\", " + name + ")");
        }
        for (int i = 0; i < oparams.size(); i++) {
            DataIdentifier outputParam = oparams.get(i);
            String name = outputParam.getName();
            sb.append(".out(\"" + name + "\")");
        }
        if ((inputParams.size() > 0) || (oparams.size() > 0)) {
            sb.append(";\n");
        }

        sb.append("org.apache.sysml.api.mlcontext.MLResults results = script.execute();\n");

        if (oparams.size() == 0) {
            sb.append("return;\n");
        } else if (oparams.size() == 1) {
            DataIdentifier o = oparams.get(0);
            DataType dt = o.getDataType();
            ValueType vt = o.getValueType();
            if ((dt == DataType.SCALAR) && (vt == ValueType.INT)) {
                sb.append("long res = results.getLong(\"" + o.getName() + "\");\nreturn res;\n");
            } else if ((dt == DataType.SCALAR) && (vt == ValueType.DOUBLE)) {
                sb.append("double res = results.getDouble(\"" + o.getName() + "\");\nreturn res;\n");
            } else if ((dt == DataType.SCALAR) && (vt == ValueType.BOOLEAN)) {
                sb.append("boolean res = results.getBoolean(\"" + o.getName() + "\");\nreturn res;\n");
            } else if ((dt == DataType.SCALAR) && (vt == ValueType.STRING)) {
                sb.append("String res = results.getString(\"" + o.getName() + "\");\nreturn res;\n");
            } else if (dt == DataType.MATRIX) {
                sb.append("org.apache.sysml.api.mlcontext.Matrix res = results.getMatrix(\"" + o.getName()
                        + "\");\nreturn res;\n");
            } else if (dt == DataType.FRAME) {
                sb.append("org.apache.sysml.api.mlcontext.Frame res = results.getFrame(\"" + o.getName()
                        + "\");\nreturn res;\n");
            }
        } else {

            for (int i = 0; i < oparams.size(); i++) {
                DataIdentifier outputParam = oparams.get(i);
                String name = outputParam.getName().toLowerCase();
                String type = getParamTypeAsString(outputParam);
                DataType dt = outputParam.getDataType();
                ValueType vt = outputParam.getValueType();
                if ((dt == DataType.SCALAR) && (vt == ValueType.INT)) {
                    sb.append(type + " " + name + " = results.getLong(\"" + outputParam.getName() + "\");\n");
                } else if ((dt == DataType.SCALAR) && (vt == ValueType.DOUBLE)) {
                    sb.append(type + " " + name + " = results.getDouble(\"" + outputParam.getName() + "\");\n");
                } else if ((dt == DataType.SCALAR) && (vt == ValueType.BOOLEAN)) {
                    sb.append(type + " " + name + " = results.getBoolean(\"" + outputParam.getName() + "\");\n");
                } else if ((dt == DataType.SCALAR) && (vt == ValueType.STRING)) {
                    sb.append(type + " " + name + " = results.getString(\"" + outputParam.getName() + "\");\n");
                } else if (dt == DataType.MATRIX) {
                    sb.append(type + " " + name + " = results.getMatrix(\"" + outputParam.getName() + "\");\n");
                } else if (dt == DataType.FRAME) {
                    sb.append(type + " " + name + " = results.getFrame(\"" + outputParam.getName() + "\");\n");
                }
            }
            String ffocn = getFullFunctionOutputClassName(scriptFilePath, fs);
            sb.append(ffocn + " res = new " + ffocn + "(");
            for (int i = 0; i < oparams.size(); i++) {
                if (i > 0) {
                    sb.append(", ");
                }
                DataIdentifier outputParam = oparams.get(i);
                String name = outputParam.getName().toLowerCase();
                sb.append(name);
            }
            sb.append(");\nreturn res;\n");
        }

        sb.append("}\n");
        return sb.toString();
    }

    /**
     * Obtain method for invoking a script function and returning the results as
     * an MLResults object. Currently this method is not used.
     * 
     * @param scriptFilePath
     *            the path to a script file
     * @param fs
     *            a SystemML function statement
     * @param dmlFunctionCall
     *            a string representing the invocation of a script function
     * @return string representation of a method that performs a function call
     */
    public static String generateFunctionCallMethodMLResults(String scriptFilePath, FunctionStatement fs,
            String dmlFunctionCall) {
        StringBuilder sb = new StringBuilder();

        sb.append("public org.apache.sysml.api.mlcontext.MLResults ");
        sb.append(fs.getName());
        sb.append("(");

        ArrayList<DataIdentifier> inputParams = fs.getInputParams();
        for (int i = 0; i < inputParams.size(); i++) {
            if (i > 0) {
                sb.append(", ");
            }
            DataIdentifier inputParam = inputParams.get(i);
            /*
             * Note: Using Object is currently preferrable to using
             * datatype/valuetype to explicitly set the input type to
             * Integer/Double/Boolean/String since Object allows the automatic
             * handling of things such as automatic conversions from longs to
             * ints.
             */
            sb.append("Object ");
            sb.append(inputParam.getName());
        }

        sb.append(") {\n");
        sb.append("String scriptString = \"" + dmlFunctionCall + "\";\n");
        sb.append(
                "org.apache.sysml.api.mlcontext.Script script = new org.apache.sysml.api.mlcontext.Script(scriptString);\n");

        ArrayList<DataIdentifier> outputParams = fs.getOutputParams();
        if ((inputParams.size() > 0) || (outputParams.size() > 0)) {
            sb.append("script");
        }
        for (int i = 0; i < inputParams.size(); i++) {
            DataIdentifier inputParam = inputParams.get(i);
            String name = inputParam.getName();
            sb.append(".in(\"" + name + "\", " + name + ")");
        }
        for (int i = 0; i < outputParams.size(); i++) {
            DataIdentifier outputParam = outputParams.get(i);
            String name = outputParam.getName();
            sb.append(".out(\"" + name + "\")");
        }
        if ((inputParams.size() > 0) || (outputParams.size() > 0)) {
            sb.append(";\n");
        }

        sb.append("org.apache.sysml.api.mlcontext.MLResults results = script.execute();\n");
        sb.append("return results;\n");
        sb.append("}\n");
        return sb.toString();
    }

    /**
     * Obtain the DML representing a function invocation.
     * 
     * @param scriptFilePath
     *            the path to a script file
     * @param fs
     *            a SystemML function statement
     * @return string representation of a DML function invocation
     */
    public static String generateDmlFunctionCall(String scriptFilePath, FunctionStatement fs) {
        StringBuilder sb = new StringBuilder();
        sb.append("source('" + scriptFilePath + "') as mlcontextns;");

        ArrayList<DataIdentifier> outputParams = fs.getOutputParams();
        if (outputParams.size() == 0) {
            sb.append("mlcontextns::");
        }
        if (outputParams.size() == 1) {
            DataIdentifier outputParam = outputParams.get(0);
            sb.append(outputParam.getName());
            sb.append(" = mlcontextns::");
        } else if (outputParams.size() > 1) {
            sb.append("[");
            for (int i = 0; i < outputParams.size(); i++) {
                if (i > 0) {
                    sb.append(", ");
                }
                sb.append(outputParams.get(i).getName());
            }
            sb.append("] = mlcontextns::");
        }
        sb.append(fs.getName());
        sb.append("(");
        ArrayList<DataIdentifier> inputParams = fs.getInputParams();
        for (int i = 0; i < inputParams.size(); i++) {
            if (i > 0) {
                sb.append(", ");
            }
            DataIdentifier inputParam = inputParams.get(i);
            sb.append(inputParam.getName());
        }
        sb.append(");");
        return sb.toString();
    }

    /**
     * Obtain the content of a file as a string.
     * 
     * @param filePath
     *            the path to a file
     * @return the file content as a string
     */
    public static String fileToString(String filePath) {
        try {
            File f = new File(filePath);
            FileReader fr = new FileReader(f);
            StringBuilder sb = new StringBuilder();
            int n;
            char[] charArray = new char[1024];
            while ((n = fr.read(charArray)) > 0) {
                sb.append(charArray, 0, n);
            }
            fr.close();
            return sb.toString();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return null;
    }

    /**
     * Obtain a constructor body for a Script subclass that sets the
     * scriptString based on the content of a script file.
     * 
     * @param scriptFilePath
     *            the path to a script file
     * @return constructor body for a Script subclass that sets the scriptString
     *         based on the content of a script file
     */
    public static String scriptConstructorBody(String scriptFilePath) {
        StringBuilder sb = new StringBuilder();
        sb.append("{");
        sb.append("String scriptFilePath = \"" + scriptFilePath + "\";");
        sb.append(
                "java.io.InputStream is = org.apache.sysml.api.mlcontext.Script.class.getResourceAsStream(\"/\"+scriptFilePath);");
        sb.append("java.io.InputStreamReader isr = new java.io.InputStreamReader(is);");
        sb.append("int n;");
        sb.append("char[] charArray = new char[1024];");
        sb.append("StringBuilder s = new StringBuilder();");
        sb.append("try {");
        sb.append("  while ((n = isr.read(charArray)) > 0) {");
        sb.append("    s.append(charArray, 0, n);");
        sb.append("  }");
        sb.append("} catch (java.io.IOException e) {");
        sb.append("  e.printStackTrace();");
        sb.append("}");
        sb.append("setScriptString(s.toString());");
        sb.append("}");
        return sb.toString();
    }

}