List of usage examples for java.lang Math sqrt
@HotSpotIntrinsicCandidate public static double sqrt(double a)
From source file:Timer.java
/** * Testing this class.//from ww w .ja v a2 s . c o m * * @param args Not used. */ public static void main(String[] args) { Timer timer = new Timer(); for (int i = 0; i < 100000000; i++) { double b = 998.43678; double c = Math.sqrt(b); } System.out.println(timer); }
From source file:com.genentech.chemistry.openEye.apps.SDFSubRMSD.java
public static void main(String... args) throws IOException { // create command line Options object Options options = new Options(); Option opt = new Option("in", true, "input file oe-supported"); opt.setRequired(true);// w ww . ja va 2 s.c o m options.addOption(opt); opt = new Option("out", true, "output file oe-supported"); opt.setRequired(false); options.addOption(opt); opt = new Option("fragFile", true, "file with single 3d substructure query"); opt.setRequired(false); options.addOption(opt); opt = new Option("isMDL", false, "if given the fragFile is suposed to be an mdl query file, query features are supported."); opt.setRequired(false); options.addOption(opt); CommandLineParser parser = new PosixParser(); CommandLine cmd = null; try { cmd = parser.parse(options, args); } catch (Exception e) { System.err.println(e.getMessage()); exitWithHelp(options); } args = cmd.getArgs(); if (cmd.hasOption("d")) { System.err.println("Start debugger and press return:"); new BufferedReader(new InputStreamReader(System.in)).readLine(); } String inFile = cmd.getOptionValue("in"); String outFile = cmd.getOptionValue("out"); String fragFile = cmd.getOptionValue("fragFile"); // read fragment OESubSearch ss; oemolistream ifs = new oemolistream(fragFile); OEMolBase mol; if (!cmd.hasOption("isMDL")) { mol = new OEGraphMol(); oechem.OEReadMolecule(ifs, mol); ss = new OESubSearch(mol, OEExprOpts.AtomicNumber, OEExprOpts.BondOrder); } else { int aromodel = OEIFlavor.Generic.OEAroModelOpenEye; int qflavor = ifs.GetFlavor(ifs.GetFormat()); ifs.SetFlavor(ifs.GetFormat(), (qflavor | aromodel)); int opts = OEMDLQueryOpts.Default | OEMDLQueryOpts.SuppressExplicitH; OEQMol qmol = new OEQMol(); oechem.OEReadMDLQueryFile(ifs, qmol, opts); ss = new OESubSearch(qmol); mol = qmol; } double nSSatoms = mol.NumAtoms(); double sssCoords[] = new double[mol.GetMaxAtomIdx() * 3]; mol.GetCoords(sssCoords); mol.Clear(); ifs.close(); if (!ss.IsValid()) throw new Error("Invalid query " + args[0]); ifs = new oemolistream(inFile); oemolostream ofs = new oemolostream(outFile); int count = 0; while (oechem.OEReadMolecule(ifs, mol)) { count++; double rmsd = Double.MAX_VALUE; double molCoords[] = new double[mol.GetMaxAtomIdx() * 3]; mol.GetCoords(molCoords); for (OEMatchBase mb : ss.Match(mol, false)) { double r = 0; for (OEMatchPairAtom mp : mb.GetAtoms()) { OEAtomBase asss = mp.getPattern(); double sx = sssCoords[asss.GetIdx() * 3]; double sy = sssCoords[asss.GetIdx() * 3]; double sz = sssCoords[asss.GetIdx() * 3]; OEAtomBase amol = mp.getTarget(); double mx = molCoords[amol.GetIdx() * 3]; double my = molCoords[amol.GetIdx() * 3]; double mz = molCoords[amol.GetIdx() * 3]; r += Math.sqrt((sx - mx) * (sx - mx) + (sy - my) * (sy - my) + (sz - mz) * (sz - mz)); } r /= nSSatoms; rmsd = Math.min(rmsd, r); } if (rmsd != Double.MAX_VALUE) oechem.OESetSDData(mol, "SSSrmsd", String.format("%.3f", rmsd)); oechem.OEWriteMolecule(ofs, mol); mol.Clear(); } ifs.close(); ofs.close(); mol.delete(); ss.delete(); }
From source file:com.linkedin.pinotdruidbenchmark.PinotResponseTime.java
public static void main(String[] args) throws Exception { if (args.length != 4 && args.length != 5) { System.err.println(//from w w w . j a va 2s.c om "4 or 5 arguments required: QUERY_DIR, RESOURCE_URL, WARM_UP_ROUNDS, TEST_ROUNDS, RESULT_DIR (optional)."); return; } File queryDir = new File(args[0]); String resourceUrl = args[1]; int warmUpRounds = Integer.parseInt(args[2]); int testRounds = Integer.parseInt(args[3]); File resultDir; if (args.length == 4) { resultDir = null; } else { resultDir = new File(args[4]); if (!resultDir.exists()) { if (!resultDir.mkdirs()) { throw new RuntimeException("Failed to create result directory: " + resultDir); } } } File[] queryFiles = queryDir.listFiles(); assert queryFiles != null; Arrays.sort(queryFiles); try (CloseableHttpClient httpClient = HttpClients.createDefault()) { HttpPost httpPost = new HttpPost(resourceUrl); for (File queryFile : queryFiles) { String query = new BufferedReader(new FileReader(queryFile)).readLine(); httpPost.setEntity(new StringEntity("{\"pql\":\"" + query + "\"}")); System.out.println( "--------------------------------------------------------------------------------"); System.out.println("Running query: " + query); System.out.println( "--------------------------------------------------------------------------------"); // Warm-up Rounds System.out.println("Run " + warmUpRounds + " times to warm up..."); for (int i = 0; i < warmUpRounds; i++) { CloseableHttpResponse httpResponse = httpClient.execute(httpPost); httpResponse.close(); System.out.print('*'); } System.out.println(); // Test Rounds System.out.println("Run " + testRounds + " times to get response time statistics..."); long[] responseTimes = new long[testRounds]; long totalResponseTime = 0L; for (int i = 0; i < testRounds; i++) { long startTime = System.currentTimeMillis(); CloseableHttpResponse httpResponse = httpClient.execute(httpPost); httpResponse.close(); long responseTime = System.currentTimeMillis() - startTime; responseTimes[i] = responseTime; totalResponseTime += responseTime; System.out.print(responseTime + "ms "); } System.out.println(); // Store result. if (resultDir != null) { File resultFile = new File(resultDir, queryFile.getName() + ".result"); CloseableHttpResponse httpResponse = httpClient.execute(httpPost); try (BufferedInputStream bufferedInputStream = new BufferedInputStream( httpResponse.getEntity().getContent()); BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(resultFile))) { int length; while ((length = bufferedInputStream.read(BYTE_BUFFER)) > 0) { bufferedWriter.write(new String(BYTE_BUFFER, 0, length)); } } httpResponse.close(); } // Process response times. double averageResponseTime = (double) totalResponseTime / testRounds; double temp = 0; for (long responseTime : responseTimes) { temp += (responseTime - averageResponseTime) * (responseTime - averageResponseTime); } double standardDeviation = Math.sqrt(temp / testRounds); System.out.println("Average response time: " + averageResponseTime + "ms"); System.out.println("Standard deviation: " + standardDeviation); } } }
From source file:edu.cmu.lti.oaqa.annographix.apps.SolrQueryApp.java
public static void main(String[] args) { Options options = new Options(); options.addOption("u", null, true, "Solr URI"); options.addOption("q", null, true, "Query"); options.addOption("n", null, true, "Max # of results"); options.addOption("o", null, true, "An optional TREC-style output file"); options.addOption("w", null, false, "Do a warm-up query call, before each query"); CommandLineParser parser = new org.apache.commons.cli.GnuParser(); BufferedWriter trecOutFile = null; try {//from w w w . j a v a 2 s .c o m CommandLine cmd = parser.parse(options, args); String queryFile = null, solrURI = null; if (cmd.hasOption("u")) { solrURI = cmd.getOptionValue("u"); } else { Usage("Specify Solr URI"); } SolrServerWrapper solr = new SolrServerWrapper(solrURI); if (cmd.hasOption("q")) { queryFile = cmd.getOptionValue("q"); } else { Usage("Specify Query file"); } int numRet = 100; if (cmd.hasOption("n")) { numRet = Integer.parseInt(cmd.getOptionValue("n")); } if (cmd.hasOption("o")) { trecOutFile = new BufferedWriter(new FileWriter(new File(cmd.getOptionValue("o")))); } List<String> fieldList = new ArrayList<String>(); fieldList.add(UtilConst.ID_FIELD); fieldList.add(UtilConst.SCORE_FIELD); double totalTime = 0; double retQty = 0; ArrayList<Double> queryTimes = new ArrayList<Double>(); boolean bDoWarmUp = cmd.hasOption("w"); if (bDoWarmUp) { System.out.println("Using a warmup step!"); } int queryQty = 0; for (String t : FileUtils.readLines(new File(queryFile))) { t = t.trim(); if (t.isEmpty()) continue; int ind = t.indexOf('|'); if (ind < 0) throw new Exception("Wrong format, line: '" + t + "'"); String qID = t.substring(0, ind); String q = t.substring(ind + 1); SolrDocumentList res = null; if (bDoWarmUp) { res = solr.runQuery(q, fieldList, numRet); } Long tm1 = System.currentTimeMillis(); res = solr.runQuery(q, fieldList, numRet); Long tm2 = System.currentTimeMillis(); retQty += res.getNumFound(); System.out.println(qID + " Obtained: " + res.getNumFound() + " entries in " + (tm2 - tm1) + " ms"); double delta = (tm2 - tm1); totalTime += delta; queryTimes.add(delta); ++queryQty; if (trecOutFile != null) { ArrayList<SolrRes> resArr = new ArrayList<SolrRes>(); for (SolrDocument doc : res) { String id = (String) doc.getFieldValue(UtilConst.ID_FIELD); float score = (Float) doc.getFieldValue(UtilConst.SCORE_FIELD); resArr.add(new SolrRes(id, "", score)); } SolrRes[] results = resArr.toArray(new SolrRes[resArr.size()]); Arrays.sort(results); SolrEvalUtils.saveTrecResults(qID, results, trecOutFile, TREC_RUN, results.length); } } double devTime = 0, meanTime = totalTime / queryQty; for (int i = 0; i < queryQty; ++i) { double d = queryTimes.get(i) - meanTime; devTime += d * d; } devTime = Math.sqrt(devTime / (queryQty - 1)); System.out.println(String.format("Query time, mean/standard dev: %.2f/%.2f (ms)", meanTime, devTime)); System.out.println(String.format("Avg # of docs returned: %.2f", retQty / queryQty)); solr.close(); trecOutFile.close(); } catch (ParseException e) { Usage("Cannot parse arguments"); } catch (Exception e) { System.err.println("Terminating due to an exception: " + e); System.exit(1); } }
From source file:dk.alexandra.fresco.demo.DistanceDemo.java
public static void main(String[] args) { CmdLineUtil cmdUtil = new CmdLineUtil(); SCEConfiguration sceConf = null;/*from w w w . j a va2s . c o m*/ int x, y; x = y = 0; try { cmdUtil.addOption(Option.builder("x").desc("The integer x coordinate of this party. " + "Note only party 1 and 2 should supply this input.").hasArg().build()); cmdUtil.addOption(Option.builder("y").desc( "The integer y coordinate of this party. " + "Note only party 1 and 2 should supply this input") .hasArg().build()); CommandLine cmd = cmdUtil.parse(args); sceConf = cmdUtil.getSCEConfiguration(); if (sceConf.getMyId() == 1 || sceConf.getMyId() == 2) { if (!cmd.hasOption("x") || !cmd.hasOption("y")) { throw new ParseException("Party 1 and 2 must submit input"); } else { x = Integer.parseInt(cmd.getOptionValue("x")); y = Integer.parseInt(cmd.getOptionValue("y")); } } else { if (cmd.hasOption("x") || cmd.hasOption("y")) throw new ParseException("Only party 1 and 2 should submit input"); } } catch (ParseException | IllegalArgumentException e) { System.out.println("Error: " + e); System.out.println(); cmdUtil.displayHelp(); System.exit(-1); } DistanceDemo distDemo = new DistanceDemo(sceConf.getMyId(), x, y); SCE sce = SCEFactory.getSCEFromConfiguration(sceConf); try { sce.runApplication(distDemo); } catch (Exception e) { System.out.println("Error while doing MPC: " + e.getMessage()); e.printStackTrace(); System.exit(-1); } double dist = distDemo.distance.getValue().doubleValue(); dist = Math.sqrt(dist); System.out.println("Distance between party 1 and 2 is " + dist); }
From source file:com.linkedin.pinotdruidbenchmark.DruidResponseTime.java
public static void main(String[] args) throws Exception { if (args.length != 4 && args.length != 5) { System.err.println(/*from www . j a v a 2s . com*/ "4 or 5 arguments required: QUERY_DIR, RESOURCE_URL, WARM_UP_ROUNDS, TEST_ROUNDS, RESULT_DIR (optional)."); return; } File queryDir = new File(args[0]); String resourceUrl = args[1]; int warmUpRounds = Integer.parseInt(args[2]); int testRounds = Integer.parseInt(args[3]); File resultDir; if (args.length == 4) { resultDir = null; } else { resultDir = new File(args[4]); if (!resultDir.exists()) { if (!resultDir.mkdirs()) { throw new RuntimeException("Failed to create result directory: " + resultDir); } } } File[] queryFiles = queryDir.listFiles(); assert queryFiles != null; Arrays.sort(queryFiles); try (CloseableHttpClient httpClient = HttpClients.createDefault()) { HttpPost httpPost = new HttpPost(resourceUrl); httpPost.addHeader("content-type", "application/json"); for (File queryFile : queryFiles) { StringBuilder stringBuilder = new StringBuilder(); try (BufferedReader bufferedReader = new BufferedReader(new FileReader(queryFile))) { int length; while ((length = bufferedReader.read(CHAR_BUFFER)) > 0) { stringBuilder.append(new String(CHAR_BUFFER, 0, length)); } } String query = stringBuilder.toString(); httpPost.setEntity(new StringEntity(query)); System.out.println( "--------------------------------------------------------------------------------"); System.out.println("Running query: " + query); System.out.println( "--------------------------------------------------------------------------------"); // Warm-up Rounds System.out.println("Run " + warmUpRounds + " times to warm up..."); for (int i = 0; i < warmUpRounds; i++) { CloseableHttpResponse httpResponse = httpClient.execute(httpPost); httpResponse.close(); System.out.print('*'); } System.out.println(); // Test Rounds System.out.println("Run " + testRounds + " times to get response time statistics..."); long[] responseTimes = new long[testRounds]; long totalResponseTime = 0L; for (int i = 0; i < testRounds; i++) { long startTime = System.currentTimeMillis(); CloseableHttpResponse httpResponse = httpClient.execute(httpPost); httpResponse.close(); long responseTime = System.currentTimeMillis() - startTime; responseTimes[i] = responseTime; totalResponseTime += responseTime; System.out.print(responseTime + "ms "); } System.out.println(); // Store result. if (resultDir != null) { File resultFile = new File(resultDir, queryFile.getName() + ".result"); CloseableHttpResponse httpResponse = httpClient.execute(httpPost); try (BufferedInputStream bufferedInputStream = new BufferedInputStream( httpResponse.getEntity().getContent()); BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(resultFile))) { int length; while ((length = bufferedInputStream.read(BYTE_BUFFER)) > 0) { bufferedWriter.write(new String(BYTE_BUFFER, 0, length)); } } httpResponse.close(); } // Process response times. double averageResponseTime = (double) totalResponseTime / testRounds; double temp = 0; for (long responseTime : responseTimes) { temp += (responseTime - averageResponseTime) * (responseTime - averageResponseTime); } double standardDeviation = Math.sqrt(temp / testRounds); System.out.println("Average response time: " + averageResponseTime + "ms"); System.out.println("Standard deviation: " + standardDeviation); } } }
From source file:PinotResponseTime.java
public static void main(String[] args) throws Exception { try (CloseableHttpClient client = HttpClients.createDefault()) { HttpPost post = new HttpPost("http://localhost:8099/query"); CloseableHttpResponse res;//from ww w .j a v a2s.c o m if (STORE_RESULT) { File dir = new File(RESULT_DIR); if (!dir.exists()) { dir.mkdirs(); } } int length; // Make sure all segments online System.out.println("Test if number of records is " + RECORD_NUMBER); post.setEntity(new StringEntity("{\"pql\":\"select count(*) from tpch_lineitem\"}")); while (true) { System.out.print('*'); res = client.execute(post); boolean valid; try (BufferedInputStream in = new BufferedInputStream(res.getEntity().getContent())) { length = in.read(BUFFER); valid = new String(BUFFER, 0, length, "UTF-8").contains("\"value\":\"" + RECORD_NUMBER + "\""); } res.close(); if (valid) { break; } else { Thread.sleep(5000); } } System.out.println("Number of Records Test Passed"); // Start Benchmark for (int i = 0; i < QUERIES.length; i++) { System.out.println( "--------------------------------------------------------------------------------"); System.out.println("Start running query: " + QUERIES[i]); post.setEntity(new StringEntity("{\"pql\":\"" + QUERIES[i] + "\"}")); // Warm-up Rounds System.out.println("Run " + WARMUP_ROUND + " times to warm up cache..."); for (int j = 0; j < WARMUP_ROUND; j++) { res = client.execute(post); if (!isValid(res, null)) { System.out.println("\nInvalid Response, Sleep 20 Seconds..."); Thread.sleep(20000); } res.close(); System.out.print('*'); } System.out.println(); // Test Rounds int[] time = new int[TEST_ROUND]; int totalTime = 0; int validIdx = 0; System.out.println("Run " + TEST_ROUND + " times to get average time..."); while (validIdx < TEST_ROUND) { long startTime = System.currentTimeMillis(); res = client.execute(post); long endTime = System.currentTimeMillis(); boolean valid; if (STORE_RESULT && validIdx == 0) { valid = isValid(res, RESULT_DIR + File.separator + i + ".json"); } else { valid = isValid(res, null); } if (!valid) { System.out.println("\nInvalid Response, Sleep 20 Seconds..."); Thread.sleep(20000); res.close(); continue; } res.close(); time[validIdx] = (int) (endTime - startTime); totalTime += time[validIdx]; System.out.print(time[validIdx] + "ms "); validIdx++; } System.out.println(); // Process Results double avgTime = (double) totalTime / TEST_ROUND; double stdDev = 0; for (int temp : time) { stdDev += (temp - avgTime) * (temp - avgTime) / TEST_ROUND; } stdDev = Math.sqrt(stdDev); System.out.println("The average response time for the query is: " + avgTime + "ms"); System.out.println("The standard deviation is: " + stdDev); } } }
From source file:DruidResponseTime.java
public static void main(String[] args) throws Exception { try (CloseableHttpClient client = HttpClients.createDefault()) { HttpPost post = new HttpPost("http://localhost:8082/druid/v2/?pretty"); post.addHeader("content-type", "application/json"); CloseableHttpResponse res;//from w w w .ja va 2 s. c o m if (STORE_RESULT) { File dir = new File(RESULT_DIR); if (!dir.exists()) { dir.mkdirs(); } } int length; // Make sure all segments online System.out.println("Test if number of records is " + RECORD_NUMBER); post.setEntity(new StringEntity("{" + "\"queryType\":\"timeseries\"," + "\"dataSource\":\"tpch_lineitem\"," + "\"intervals\":[\"1992-01-01/1999-01-01\"]," + "\"granularity\":\"all\"," + "\"aggregations\":[{\"type\":\"count\",\"name\":\"count\"}]}")); while (true) { System.out.print('*'); res = client.execute(post); boolean valid; try (BufferedInputStream in = new BufferedInputStream(res.getEntity().getContent())) { length = in.read(BYTE_BUFFER); valid = new String(BYTE_BUFFER, 0, length, "UTF-8").contains("\"count\" : 6001215"); } res.close(); if (valid) { break; } else { Thread.sleep(5000); } } System.out.println("Number of Records Test Passed"); for (int i = 0; i < QUERIES.length; i++) { System.out.println( "--------------------------------------------------------------------------------"); System.out.println("Start running query: " + QUERIES[i]); try (BufferedReader reader = new BufferedReader( new FileReader(QUERY_FILE_DIR + File.separator + i + ".json"))) { length = reader.read(CHAR_BUFFER); post.setEntity(new StringEntity(new String(CHAR_BUFFER, 0, length))); } // Warm-up Rounds System.out.println("Run " + WARMUP_ROUND + " times to warm up cache..."); for (int j = 0; j < WARMUP_ROUND; j++) { res = client.execute(post); res.close(); System.out.print('*'); } System.out.println(); // Test Rounds int[] time = new int[TEST_ROUND]; int totalTime = 0; System.out.println("Run " + TEST_ROUND + " times to get average time..."); for (int j = 0; j < TEST_ROUND; j++) { long startTime = System.currentTimeMillis(); res = client.execute(post); long endTime = System.currentTimeMillis(); if (STORE_RESULT && j == 0) { try (BufferedInputStream in = new BufferedInputStream(res.getEntity().getContent()); BufferedWriter writer = new BufferedWriter( new FileWriter(RESULT_DIR + File.separator + i + ".json", false))) { while ((length = in.read(BYTE_BUFFER)) > 0) { writer.write(new String(BYTE_BUFFER, 0, length, "UTF-8")); } } } res.close(); time[j] = (int) (endTime - startTime); totalTime += time[j]; System.out.print(time[j] + "ms "); } System.out.println(); // Process Results double avgTime = (double) totalTime / TEST_ROUND; double stdDev = 0; for (int temp : time) { stdDev += (temp - avgTime) * (temp - avgTime) / TEST_ROUND; } stdDev = Math.sqrt(stdDev); System.out.println("The average response time for the query is: " + avgTime + "ms"); System.out.println("The standard deviation is: " + stdDev); } } }
From source file:ch.epfl.lsir.xin.test.GlobalMeanTest.java
/** * @param args/* ww w . j a va 2 s . c o m*/ */ public static void main(String[] args) throws Exception { // TODO Auto-generated method stub PrintWriter logger = new PrintWriter(".//results//GlobalMean"); PropertiesConfiguration config = new PropertiesConfiguration(); config.setFile(new File("conf//GlobalMean.properties")); try { config.load(); } catch (ConfigurationException e) { // TODO Auto-generated catch block e.printStackTrace(); } logger.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()) + " Read rating data..."); DataLoaderFile loader = new DataLoaderFile(".//data//MoveLens100k.txt"); loader.readSimple(); DataSetNumeric dataset = loader.getDataset(); System.out.println("Number of ratings: " + dataset.getRatings().size() + " Number of users: " + dataset.getUserIDs().size() + " Number of items: " + dataset.getItemIDs().size()); logger.println("Number of ratings: " + dataset.getRatings().size() + ", Number of users: " + dataset.getUserIDs().size() + ", Number of items: " + dataset.getItemIDs().size()); double totalMAE = 0; double totalRMSE = 0; int F = 5; logger.println(F + "- folder cross validation."); logger.flush(); ArrayList<ArrayList<NumericRating>> folders = new ArrayList<ArrayList<NumericRating>>(); for (int i = 0; i < F; i++) { folders.add(new ArrayList<NumericRating>()); } while (dataset.getRatings().size() > 0) { int index = new Random().nextInt(dataset.getRatings().size()); int r = new Random().nextInt(F); folders.get(r).add(dataset.getRatings().get(index)); dataset.getRatings().remove(index); } for (int folder = 1; folder <= F; folder++) { System.out.println("Folder: " + folder); logger.println("Folder: " + folder); ArrayList<NumericRating> trainRatings = new ArrayList<NumericRating>(); ArrayList<NumericRating> testRatings = new ArrayList<NumericRating>(); for (int i = 0; i < folders.size(); i++) { if (i == folder - 1)//test data { testRatings.addAll(folders.get(i)); } else {//training data trainRatings.addAll(folders.get(i)); } } //create rating matrix HashMap<String, Integer> userIDIndexMapping = new HashMap<String, Integer>(); HashMap<String, Integer> itemIDIndexMapping = new HashMap<String, Integer>(); for (int i = 0; i < dataset.getUserIDs().size(); i++) { userIDIndexMapping.put(dataset.getUserIDs().get(i), i); } for (int i = 0; i < dataset.getItemIDs().size(); i++) { itemIDIndexMapping.put(dataset.getItemIDs().get(i), i); } RatingMatrix trainRatingMatrix = new RatingMatrix(dataset.getUserIDs().size(), dataset.getItemIDs().size()); for (int i = 0; i < trainRatings.size(); i++) { trainRatingMatrix.set(userIDIndexMapping.get(trainRatings.get(i).getUserID()), itemIDIndexMapping.get(trainRatings.get(i).getItemID()), trainRatings.get(i).getValue()); } RatingMatrix testRatingMatrix = new RatingMatrix(dataset.getUserIDs().size(), dataset.getItemIDs().size()); for (int i = 0; i < testRatings.size(); i++) { testRatingMatrix.set(userIDIndexMapping.get(testRatings.get(i).getUserID()), itemIDIndexMapping.get(testRatings.get(i).getItemID()), testRatings.get(i).getValue()); } System.out.println("Training: " + trainRatingMatrix.getTotalRatingNumber() + " vs Test: " + testRatingMatrix.getTotalRatingNumber()); logger.println("Initialize a recommendation model based on global average method."); GlobalAverage algo = new GlobalAverage(trainRatingMatrix); algo.setLogger(logger); algo.build(); algo.saveModel(".//localModels//" + config.getString("NAME")); logger.println("Save the model."); logger.flush(); System.out.println(trainRatings.size() + " vs. " + testRatings.size()); double RMSE = 0; double MAE = 0; int count = 0; for (int i = 0; i < testRatings.size(); i++) { NumericRating rating = testRatings.get(i); double prediction = algo.predict(rating.getUserID(), rating.getItemID()); if (Double.isNaN(prediction)) { System.out.println("no prediction"); continue; } MAE = MAE + Math.abs(rating.getValue() - prediction); RMSE = RMSE + Math.pow((rating.getValue() - prediction), 2); count++; } MAE = MAE / count; RMSE = Math.sqrt(RMSE / count); // System.out.println("MAE: " + MAE + " RMSE: " + RMSE); logger.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()) + " MAE: " + MAE + " RMSE: " + RMSE); logger.flush(); totalMAE = totalMAE + MAE; totalRMSE = totalRMSE + RMSE; } System.out.println("MAE: " + totalMAE / F + " RMSE: " + totalRMSE / F); logger.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()) + " Final results: MAE: " + totalMAE / F + " RMSE: " + totalRMSE / F); logger.flush(); logger.close(); //MAE: 0.9338607074893257 RMSE: 1.1170971131112037 (MovieLens1M) //MAE: 0.9446876509332618 RMSE: 1.1256517870920375 (MovieLens100K) }
From source file:ch.epfl.lsir.xin.test.UserAverageTest.java
/** * @param args//w w w.j a v a2 s .co m */ public static void main(String[] args) throws Exception { // TODO Auto-generated method stub PrintWriter logger = new PrintWriter(".//results//UserAverage"); PropertiesConfiguration config = new PropertiesConfiguration(); config.setFile(new File(".//conf//UserAverage.properties")); try { config.load(); } catch (ConfigurationException e) { // TODO Auto-generated catch block e.printStackTrace(); } logger.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()) + " Read rating data..."); DataLoaderFile loader = new DataLoaderFile(".//data//MoveLens100k.txt"); loader.readSimple(); DataSetNumeric dataset = loader.getDataset(); System.out.println("Number of ratings: " + dataset.getRatings().size() + " Number of users: " + dataset.getUserIDs().size() + " Number of items: " + dataset.getItemIDs().size()); logger.println("Number of ratings: " + dataset.getRatings().size() + " Number of users: " + dataset.getUserIDs().size() + " Number of items: " + dataset.getItemIDs().size()); logger.flush(); double totalMAE = 0; double totalRMSE = 0; int F = 5; logger.println(F + "- folder cross validation."); ArrayList<ArrayList<NumericRating>> folders = new ArrayList<ArrayList<NumericRating>>(); for (int i = 0; i < F; i++) { folders.add(new ArrayList<NumericRating>()); } while (dataset.getRatings().size() > 0) { int index = new Random().nextInt(dataset.getRatings().size()); int r = new Random().nextInt(F); folders.get(r).add(dataset.getRatings().get(index)); dataset.getRatings().remove(index); } for (int folder = 1; folder <= F; folder++) { logger.println("Folder: " + folder); System.out.println("Folder: " + folder); ArrayList<NumericRating> trainRatings = new ArrayList<NumericRating>(); ArrayList<NumericRating> testRatings = new ArrayList<NumericRating>(); for (int i = 0; i < folders.size(); i++) { if (i == folder - 1)//test data { testRatings.addAll(folders.get(i)); } else {//training data trainRatings.addAll(folders.get(i)); } } //create rating matrix HashMap<String, Integer> userIDIndexMapping = new HashMap<String, Integer>(); HashMap<String, Integer> itemIDIndexMapping = new HashMap<String, Integer>(); for (int i = 0; i < dataset.getUserIDs().size(); i++) { userIDIndexMapping.put(dataset.getUserIDs().get(i), i); } for (int i = 0; i < dataset.getItemIDs().size(); i++) { itemIDIndexMapping.put(dataset.getItemIDs().get(i), i); } RatingMatrix trainRatingMatrix = new RatingMatrix(dataset.getUserIDs().size(), dataset.getItemIDs().size()); for (int i = 0; i < trainRatings.size(); i++) { trainRatingMatrix.set(userIDIndexMapping.get(trainRatings.get(i).getUserID()), itemIDIndexMapping.get(trainRatings.get(i).getItemID()), trainRatings.get(i).getValue()); } trainRatingMatrix.calculateGlobalAverage(); RatingMatrix testRatingMatrix = new RatingMatrix(dataset.getUserIDs().size(), dataset.getItemIDs().size()); for (int i = 0; i < testRatings.size(); i++) { testRatingMatrix.set(userIDIndexMapping.get(testRatings.get(i).getUserID()), itemIDIndexMapping.get(testRatings.get(i).getItemID()), testRatings.get(i).getValue()); } System.out.println("Training: " + trainRatingMatrix.getTotalRatingNumber() + " vs Test: " + testRatingMatrix.getTotalRatingNumber()); logger.println("Initialize a recommendation model based on user average method."); UserAverage algo = new UserAverage(trainRatingMatrix); algo.setLogger(logger); algo.build(); algo.saveModel(".//localModels//" + config.getString("NAME")); logger.println("Save the model."); System.out.println(trainRatings.size() + " vs. " + testRatings.size()); double RMSE = 0; double MAE = 0; int count = 0; for (int i = 0; i < testRatings.size(); i++) { NumericRating rating = testRatings.get(i); double prediction = algo.predict(userIDIndexMapping.get(rating.getUserID()), itemIDIndexMapping.get(rating.getItemID())); if (Double.isNaN(prediction)) { System.out.println("no prediction"); continue; } MAE = MAE + Math.abs(rating.getValue() - prediction); RMSE = RMSE + Math.pow((rating.getValue() - prediction), 2); count++; } MAE = MAE / count; RMSE = Math.sqrt(RMSE / count); logger.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()) + " MAE: " + MAE + " RMSE: " + RMSE); logger.flush(); totalMAE = totalMAE + MAE; totalRMSE = totalRMSE + RMSE; } System.out.println("MAE: " + totalMAE / F + " RMSE: " + totalRMSE / F); logger.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()) + " Final results: MAE: " + totalMAE / F + " RMSE: " + totalRMSE / F); logger.flush(); logger.close(); //MAE: 0.8353035962363073 RMSE: 1.0422971886952053 (MovieLens 100k) }