Java tutorial
/* * Copyright (C) 2016 University of Pittsburgh. * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation; either * version 2.1 of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, * MA 02110-1301 USA */ package edu.cmu.tetrad.cli.search; import edu.cmu.tetrad.cli.data.IKnowledgeFactory; import edu.cmu.tetrad.cli.util.Args; import edu.cmu.tetrad.cli.util.DateTime; import edu.cmu.tetrad.cli.util.FileIO; import edu.cmu.tetrad.cli.util.GraphmlSerializer; import edu.cmu.tetrad.cli.util.XmlPrint; import edu.cmu.tetrad.cli.validation.DataValidation; import edu.cmu.tetrad.cli.validation.LimitDiscreteCategory; import edu.cmu.tetrad.cli.validation.TabularDiscreteData; import edu.cmu.tetrad.cli.validation.UniqueVariableNames; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.io.DataReader; import edu.cmu.tetrad.io.VerticalTabularDiscreteDataReader; import edu.cmu.tetrad.search.BDeuScore; import edu.cmu.tetrad.search.Fgs; import java.io.BufferedOutputStream; import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintStream; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.util.Collections; import java.util.Formatter; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Set; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLineParser; import org.apache.commons.cli.DefaultParser; import org.apache.commons.cli.Option; import org.apache.commons.cli.Options; import org.apache.commons.cli.ParseException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * * Mar 28, 2016 4:10:20 PM * * @author Kevin V. Bui (kvb2@pitt.edu) */ public class FgsDiscrete { private static final Logger LOGGER = LoggerFactory.getLogger(FgsDiscrete.class); public static final int CATEGORY_LIMIT = 10; private static final Options MAIN_OPTIONS = new Options(); static { // added required inputs Option requiredOption = new Option("f", "data", true, "Data file."); requiredOption.setRequired(true); MAIN_OPTIONS.addOption(requiredOption); // data file options MAIN_OPTIONS.addOption("d", "delimiter", true, "Data delimiter either comma, semicolon, space, colon, or tab. Default: comma for *.csv, else tab."); // run options MAIN_OPTIONS.addOption(null, "verbose", false, "Print additional information."); MAIN_OPTIONS.addOption(null, "thread", true, "Number of threads."); // algorithm parameters MAIN_OPTIONS.addOption(null, "structure-prior", true, "Structure prior."); MAIN_OPTIONS.addOption(null, "sample-prior", true, "Sample prior."); MAIN_OPTIONS.addOption(null, "depth", true, "Search depth. Must be an integer >= -1 (-1 means unlimited). Default is -1."); // search options MAIN_OPTIONS.addOption(null, "heuristic-speedup", false, "Heuristic speedup. Default is false."); // filter options MAIN_OPTIONS.addOption(null, "knowledge", true, "A file containing prior knowledge."); MAIN_OPTIONS.addOption(null, "exclude-variables", true, "A file containing variables to exclude."); // output results MAIN_OPTIONS.addOption(null, "graphml", false, "Create graphML output."); // data validations MAIN_OPTIONS.addOption(null, "skip-unique-var-name", false, "Skip 'unique variable name' check."); MAIN_OPTIONS.addOption(null, "skip-category-limit", false, "Skip 'limit number of categories' check."); // output MAIN_OPTIONS.addOption("o", "out", true, "Output directory."); MAIN_OPTIONS.addOption(null, "output-prefix", true, "Prefix name of output files."); MAIN_OPTIONS.addOption(null, "no-validation-output", false, "No validation output files created."); MAIN_OPTIONS.addOption(null, "help", false, "Show help."); } private static Path dataFile; private static Path knowledgeFile; private static Path excludedVariableFile; private static char delimiter; private static double structurePrior; private static double samplePrior; private static int depth; private static boolean heuristicSpeedup; private static boolean graphML; private static boolean verbose; private static int numOfThreads; private static Path dirOut; private static String outputPrefix; private static boolean validationOutput; private static boolean skipUniqueVarName; private static boolean skipCategoryLimit; /** * @param args the command line arguments */ public static void main(String[] args) { if (args == null || args.length == 0 || Args.hasLongOption(args, "help")) { Args.showHelp("fgs-discrete", MAIN_OPTIONS); return; } parseArgs(args); System.out.println("================================================================================"); System.out.printf("FGS Discrete (%s)%n", DateTime.printNow()); System.out.println("================================================================================"); String argInfo = createArgsInfo(); System.out.println(argInfo); LOGGER.info("=== Starting FGS Discrete: " + Args.toString(args, ' ')); LOGGER.info(argInfo.trim().replaceAll("\n", ",").replaceAll(" = ", "=")); Set<String> excludedVariables = (excludedVariableFile == null) ? Collections.EMPTY_SET : getExcludedVariables(); runPreDataValidations(excludedVariables, System.err); DataSet dataSet = readInDataSet(excludedVariables); runOptionalDataValidations(dataSet, System.err); Path outputFile = Paths.get(dirOut.toString(), outputPrefix + ".txt"); try (PrintStream writer = new PrintStream( new BufferedOutputStream(Files.newOutputStream(outputFile, StandardOpenOption.CREATE)))) { String runInfo = createOutputRunInfo(excludedVariables, dataSet); writer.println(runInfo); String[] infos = runInfo.trim().replaceAll("\n\n", ";").split(";"); for (String s : infos) { LOGGER.info(s.trim().replaceAll("\n", ",").replaceAll(":,", ":").replaceAll(" = ", "=")); } Graph graph = runFgsDiscrete(dataSet, writer); writer.println(); writer.println(graph.toString()); if (graphML) { writeOutGraphML(graph, Paths.get(dirOut.toString(), outputPrefix + "_graph.txt")); } } catch (IOException exception) { LOGGER.error("FGS Discrete failed.", exception); System.err.printf("%s: FGS Discrete failed.%n", DateTime.printNow()); System.out.println("Please see log file for more information."); System.exit(-128); } System.out.printf("%s: FGS Discrete finished! Please see %s for details.%n", DateTime.printNow(), outputFile.getFileName().toString()); LOGGER.info(String.format("FGS Discrete finished! Please see %s for details.", outputFile.getFileName().toString())); } private static void writeOutGraphML(Graph graph, Path outputFile) { if (graph == null) { return; } try (PrintStream graphWriter = new PrintStream( new BufferedOutputStream(Files.newOutputStream(outputFile, StandardOpenOption.CREATE)))) { String fileName = outputFile.getFileName().toString(); String msg = String.format("Writing out GraphML file '%s'.", fileName); System.out.printf("%s: %s%n", DateTime.printNow(), msg); LOGGER.info(msg); XmlPrint.printPretty(GraphmlSerializer.serialize(graph, outputPrefix), graphWriter); msg = String.format("Finished writing out GraphML file '%s'.", fileName); System.out.printf("%s: %s%n", DateTime.printNow(), msg); LOGGER.info(msg); } catch (Throwable throwable) { String errMsg = String.format("Failed when writting out GraphML file '%s'.", outputFile.getFileName().toString()); System.err.println(errMsg); LOGGER.error(errMsg, throwable); } } private static Graph runFgsDiscrete(DataSet dataSet, PrintStream writer) throws IOException { BDeuScore score = new BDeuScore(dataSet); score.setSamplePrior(samplePrior); score.setStructurePrior(structurePrior); Fgs fgs = new Fgs(score); fgs.setParallelism(numOfThreads); fgs.setVerbose(verbose); fgs.setNumPatternsToStore(0); fgs.setOut(writer); fgs.setHeuristicSpeedup(heuristicSpeedup); fgs.setDepth(depth); if (knowledgeFile != null) { fgs.setKnowledge(IKnowledgeFactory.readInKnowledge(knowledgeFile)); } System.out.printf("%s: Start search.%n", DateTime.printNow()); LOGGER.info("Start search."); Graph graph = fgs.search(); System.out.printf("%s: End search.%n", DateTime.printNow()); LOGGER.info("End search."); return graph; } private static String createOutputRunInfo(Set<String> excludedVariables, DataSet dataSet) { Formatter fmt = new Formatter(); fmt.format("Runtime Parameters:%n"); fmt.format("verbose = %s%n", verbose); fmt.format("number of threads = %s%n", numOfThreads); fmt.format("%n"); fmt.format("Dataset:%n"); fmt.format("file = %s%n", dataFile.getFileName()); fmt.format("delimiter = %s%n", Args.getDelimiterName(delimiter)); fmt.format("cases read in = %s%n", dataSet.getNumColumns()); fmt.format("variables read in = %s%n", dataSet.getNumRows()); fmt.format("%n"); if (excludedVariableFile != null || knowledgeFile != null) { fmt.format("Filters:%n"); if (excludedVariableFile != null) { fmt.format("excluded variables (%d variables) = %s%n", excludedVariables.size(), excludedVariableFile.getFileName()); } if (knowledgeFile != null) { fmt.format("knowledge = %s%n", knowledgeFile.getFileName()); } fmt.format("%n"); } fmt.format("FGS Discrete Parameters:%n"); fmt.format("structure prior = %f%n", structurePrior); fmt.format("sample prior = %f%n", samplePrior); fmt.format("depth = %d%n", depth); fmt.format("%n"); fmt.format("Run Options:%n"); fmt.format("heuristic speedup = %s%n", heuristicSpeedup); fmt.format("%n"); fmt.format("Data Validations:%n"); fmt.format("skip unique variable name check = %s%n", skipUniqueVarName); fmt.format("skip limit number of category check = %s%n", skipCategoryLimit); fmt.format("%n"); return fmt.toString(); } private static void runOptionalDataValidations(DataSet dataSet, PrintStream writer) { String dir = dirOut.toString(); List<DataValidation> validations = new LinkedList<>(); if (!skipUniqueVarName) { validations.add(new UniqueVariableNames(dataSet, validationOutput ? Paths.get(dir, outputPrefix + "_duplicate_var_name.txt") : null)); } if (!skipCategoryLimit) { validations.add(new LimitDiscreteCategory(dataSet, CATEGORY_LIMIT)); } boolean isValid = true; for (DataValidation dataValidation : validations) { isValid = dataValidation.validate(writer, verbose) && isValid; } if (!isValid) { System.exit(-128); } } private static DataSet readInDataSet(Set<String> excludedVariables) { DataSet dataSet = null; DataReader dataReader = new VerticalTabularDiscreteDataReader(dataFile, delimiter); try { System.out.printf("%s: Start reading in data.%n", DateTime.printNow()); LOGGER.info("Start reading in data."); dataSet = dataReader.readInData(excludedVariables); System.out.printf("%s: End reading in data.%n", DateTime.printNow()); LOGGER.info("End reading in data."); } catch (IOException exception) { String errMsg = String.format("Failed when reading data file '%s'.", dataFile.getFileName()); System.err.println(errMsg); LOGGER.error(errMsg, exception); System.exit(-128); } return dataSet; } private static void runPreDataValidations(Set<String> excludedVariables, PrintStream stderr) { DataValidation dataValidation = new TabularDiscreteData(excludedVariables, dataFile, delimiter); if (!dataValidation.validate(stderr, verbose)) { System.exit(-128); } } private static Set<String> getExcludedVariables() { Set<String> variables = new HashSet<>(); try { System.out.printf("%s: Start reading in excluded variable file.%n", DateTime.printNow()); LOGGER.info("Start reading in excluded variable file."); variables.addAll(FileIO.extractUniqueLine(excludedVariableFile)); System.out.printf("%s: End reading in excluded variable file.%n", DateTime.printNow()); LOGGER.info("End reading in excluded variable file."); } catch (IOException exception) { String errMsg = String.format("Failed when reading excluded variable file '%s'.", excludedVariableFile.getFileName()); System.err.println(errMsg); LOGGER.error(errMsg, exception); System.exit(-128); } return variables; } private static String createArgsInfo() { Formatter fmt = new Formatter(); if (dataFile != null) { fmt.format("data = %s%n", dataFile.getFileName()); } if (excludedVariableFile != null) { fmt.format("exclude-variables = %s%n", excludedVariableFile.getFileName()); } if (knowledgeFile != null) { fmt.format("knowledge = %s%n", knowledgeFile.getFileName()); } fmt.format("delimiter = %s%n", Args.getDelimiterName(delimiter)); fmt.format("verbose = %s%n", verbose); fmt.format("thread = %s%n", numOfThreads); fmt.format("structure-prior = %f%n", structurePrior); fmt.format("sample-prior = %f%n", samplePrior); fmt.format("depth = %d%n", depth); fmt.format("heuristic-speedup = %s%n", heuristicSpeedup); fmt.format("graphml = %s%n", graphML); fmt.format("skip-unique-var-name = %s%n", skipUniqueVarName); fmt.format("skip-category-limit = %s%n", skipCategoryLimit); fmt.format("out = %s%n", dirOut.getFileName().toString()); fmt.format("output-prefix = %s%n", outputPrefix); fmt.format("no-validation-output = %s%n", !validationOutput); return fmt.toString(); } private static void parseArgs(String[] args) { try { CommandLineParser cmdParser = new DefaultParser(); CommandLine cmd = cmdParser.parse(MAIN_OPTIONS, args); dataFile = Args.getPathFile(cmd.getOptionValue("data"), true); knowledgeFile = Args.getPathFile(cmd.getOptionValue("knowledge", null), false); excludedVariableFile = Args.getPathFile(cmd.getOptionValue("exclude-variables", null), false); delimiter = Args.getDelimiterForName(cmd.getOptionValue("delimiter", dataFile.getFileName().toString().endsWith(".csv") ? "comma" : "tab")); structurePrior = Args.getDouble(cmd.getOptionValue("structure-prior", "1.0")); samplePrior = Args.getDouble(cmd.getOptionValue("sample-prior", "1.0")); depth = Args.getIntegerMin(cmd.getOptionValue("depth", "-1"), -1); heuristicSpeedup = cmd.hasOption("heuristic-speedup"); graphML = cmd.hasOption("graphml"); verbose = cmd.hasOption("verbose"); numOfThreads = Args.getInteger( cmd.getOptionValue("thread", Integer.toString(Runtime.getRuntime().availableProcessors()))); dirOut = Args.getPathDir(cmd.getOptionValue("out", "."), false); outputPrefix = cmd.getOptionValue("output-prefix", String.format("fgs_%s_%d", dataFile.getFileName(), System.currentTimeMillis())); validationOutput = !cmd.hasOption("no-validation-output"); skipUniqueVarName = cmd.hasOption("skip-unique-var-name"); skipCategoryLimit = cmd.hasOption("skip-category-limit"); } catch (ParseException | FileNotFoundException exception) { System.err.println(exception.getLocalizedMessage()); Args.showHelp("fgs-discrete", MAIN_OPTIONS); System.exit(-127); } } }