List of usage examples for java.util.concurrent ExecutorCompletionService ExecutorCompletionService
public ExecutorCompletionService(Executor executor)
From source file:org.ocelotds.integration.AbstractOcelotTest.java
/** * * @param <T>/*from w w w. ja v a 2 s. c o m*/ * @param nb * @param client * @param returnClass * @param ds * @param methodName * @param params * @return */ protected <T> Collection<T> testCallMultiMethodsInClient(int nb, final Client client, final Class<T> returnClass, final Class ds, final String methodName, final String... params) { ExecutorCompletionService<ResultMonitored<T>> executorCompletionService = new ExecutorCompletionService( managedExecutor); Collection<T> results = new ArrayList<>(); long t0 = System.currentTimeMillis(); for (int i = 0; i < nb; i++) { final int num = i; Callable<ResultMonitored<T>> task = new Callable() { @Override public ResultMonitored<T> call() { Client cl = client; if (cl == null) { cl = getClient(); } long t0 = System.currentTimeMillis(); T result = getJava(returnClass, (String) testRSCallWithoutResult(cl, ds, methodName, params).getResponse()); ResultMonitored resultMonitored = new ResultMonitored(result, num); long t1 = System.currentTimeMillis(); resultMonitored.setTime(t1 - t0); return resultMonitored; } }; executorCompletionService.submit(task); } for (int i = 0; i < nb; i++) { try { Future<ResultMonitored<T>> fut = executorCompletionService.take(); ResultMonitored<T> res = fut.get(); // System.out.println("Time of execution of service " + res.getNum() + ": " + res.getTime() + " ms"); results.add(res.getResult()); } catch (InterruptedException | ExecutionException e) { } } long t1 = System.currentTimeMillis(); System.out.println("Time of execution of all services : " + (t1 - t0) + " ms"); assertThat(results).hasSize(nb); return results; }
From source file:org.apache.hadoop.hbase.regionserver.Store.java
/** * Close all the readers/*from www .j a va 2s .c o m*/ * * We don't need to worry about subsequent requests because the HRegion holds * a write lock that will prevent any more reads or writes. * * @throws IOException */ ImmutableList<StoreFile> close() throws IOException { this.lock.writeLock().lock(); try { ImmutableList<StoreFile> result = storefiles; // Clear so metrics doesn't find them. storefiles = ImmutableList.of(); if (!result.isEmpty()) { // initialize the thread pool for closing store files in parallel. ThreadPoolExecutor storeFileCloserThreadPool = this.region.getStoreFileOpenAndCloseThreadPool( "StoreFileCloserThread-" + this.family.getNameAsString()); // close each store file in parallel CompletionService<Void> completionService = new ExecutorCompletionService<Void>( storeFileCloserThreadPool); for (final StoreFile f : result) { completionService.submit(new Callable<Void>() { public Void call() throws IOException { f.closeReader(true); return null; } }); } try { for (int i = 0; i < result.size(); i++) { Future<Void> future = completionService.take(); future.get(); } } catch (InterruptedException e) { throw new IOException(e); } catch (ExecutionException e) { throw new IOException(e.getCause()); } finally { storeFileCloserThreadPool.shutdownNow(); } } LOG.info("Closed " + this); return result; } finally { this.lock.writeLock().unlock(); } }
From source file:org.apache.hadoop.hbase.regionserver.HFileReadWriteTest.java
public boolean runRandomReadWorkload() throws IOException { if (inputFileNames.size() != 1) { throw new IOException("Need exactly one input file for random reads: " + inputFileNames); }//ww w . j a v a 2 s. com Path inputPath = new Path(inputFileNames.get(0)); // Make sure we are using caching. StoreFile storeFile = openStoreFile(inputPath, true); StoreFile.Reader reader = storeFile.createReader(); LOG.info("First key: " + Bytes.toStringBinary(reader.getFirstKey())); LOG.info("Last key: " + Bytes.toStringBinary(reader.getLastKey())); KeyValue firstKV = KeyValue.createKeyValueFromKey(reader.getFirstKey()); firstRow = firstKV.getRow(); KeyValue lastKV = KeyValue.createKeyValueFromKey(reader.getLastKey()); lastRow = lastKV.getRow(); byte[] family = firstKV.getFamily(); if (!Bytes.equals(family, lastKV.getFamily())) { LOG.error("First and last key have different families: " + Bytes.toStringBinary(family) + " and " + Bytes.toStringBinary(lastKV.getFamily())); return false; } if (Bytes.equals(firstRow, lastRow)) { LOG.error("First and last row are the same, cannot run read workload: " + "firstRow=" + Bytes.toStringBinary(firstRow) + ", " + "lastRow=" + Bytes.toStringBinary(lastRow)); return false; } ExecutorService exec = Executors.newFixedThreadPool(numReadThreads + 1); int numCompleted = 0; int numFailed = 0; try { ExecutorCompletionService<Boolean> ecs = new ExecutorCompletionService<Boolean>(exec); endTime = System.currentTimeMillis() + 1000 * durationSec; boolean pread = true; for (int i = 0; i < numReadThreads; ++i) ecs.submit(new RandomReader(i, reader, pread)); ecs.submit(new StatisticsPrinter()); Future<Boolean> result; while (true) { try { result = ecs.poll(endTime + 1000 - System.currentTimeMillis(), TimeUnit.MILLISECONDS); if (result == null) break; try { if (result.get()) { ++numCompleted; } else { ++numFailed; } } catch (ExecutionException e) { LOG.error("Worker thread failure", e.getCause()); ++numFailed; } } catch (InterruptedException ex) { LOG.error("Interrupted after " + numCompleted + " workers completed"); Thread.currentThread().interrupt(); continue; } } } finally { storeFile.closeReader(true); exec.shutdown(); BlockCache c = cacheConf.getBlockCache(); if (c != null) { c.shutdown(); } } LOG.info("Worker threads completed: " + numCompleted); LOG.info("Worker threads failed: " + numFailed); return true; }
From source file:nl.systemsgenetics.eqtlinteractionanalyser.eqtlinteractionanalyser.TestEQTLDatasetForInteractions.java
public final String performInteractionAnalysis(String[] covsToCorrect, String[] covsToCorrect2, TextFile outputTopCovs, File snpsToSwapFile, HashMultimap<String, String> qtlProbeSnpMultiMap, String[] covariatesToTest, HashMap hashSamples, int numThreads, final TIntHashSet snpsToTest, boolean skipNormalization, boolean skipCovariateNormalization, HashMultimap<String, String> qtlProbeSnpMultiMapCovariates) throws IOException, Exception { //hashSamples = excludeOutliers(hashSamples); HashMap<String, Integer> covariatesToLoad = new HashMap(); if (covariatesToTest != null) { for (String c : covariatesToTest) { covariatesToLoad.put(c, null); }// w w w.ja v a 2 s . c om for (String c : covsToCorrect) { covariatesToLoad.put(c, null); } for (String c : covsToCorrect2) { covariatesToLoad.put(c, null); } for (int i = 1; i <= 50; ++i) { covariatesToLoad.put("Comp" + i, null); } } else { covariatesToLoad = null; } ExpressionDataset datasetExpression = new ExpressionDataset( inputDir + "/bigTableLude.txt.Expression.binary", '\t', null, hashSamples); ExpressionDataset datasetCovariates = new ExpressionDataset( inputDir + "/covariateTableLude.txt.Covariates.binary", '\t', covariatesToLoad, hashSamples); org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression regression = new org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression(); int nrSamples = datasetGenotypes.nrSamples; correctDosageDirectionForQtl(snpsToSwapFile, datasetGenotypes, datasetExpression); if (!skipNormalization) { correctExpressionData(covsToCorrect2, datasetGenotypes, datasetCovariates, datasetExpression); } ExpressionDataset datasetCovariatesPCAForceNormal = new ExpressionDataset( inputDir + "/covariateTableLude.txt.Covariates.binary", '\t', covariatesToLoad, hashSamples); if (!skipNormalization && !skipCovariateNormalization) { correctCovariateDataPCA(covsToCorrect2, covsToCorrect, datasetGenotypes, datasetCovariatesPCAForceNormal); } if (1 == 1) { if (!skipNormalization && !skipCovariateNormalization && covsToCorrect2.length != 0 && covsToCorrect.length != 0) { correctCovariateData(covsToCorrect2, covsToCorrect, datasetGenotypes, datasetCovariates); } if (!skipNormalization && !skipCovariateNormalization && !qtlProbeSnpMultiMapCovariates.isEmpty()) { correctCovariatesForQtls(datasetCovariates, datasetGenotypes, qtlProbeSnpMultiMapCovariates); } if (1 == 2) { saveCorrectedCovariates(datasetCovariates); } if (1 == 2) { icaCovariates(datasetCovariates); } if (!skipNormalization) { forceNormalCovariates(datasetCovariates, datasetGenotypes); } } ExpressionDataset datasetExpressionBeforeEQTLCorrection = new ExpressionDataset(datasetExpression.nrProbes, datasetExpression.nrSamples); for (int p = 0; p < datasetExpression.nrProbes; p++) { for (int s = 0; s < datasetExpression.nrSamples; s++) { datasetExpressionBeforeEQTLCorrection.rawData[p][s] = datasetExpression.rawData[p][s]; } } if (!skipNormalization && covsToCorrect.length != 0) { correctExpressionDataForInteractions(covsToCorrect, datasetCovariates, datasetGenotypes, nrSamples, datasetExpression, regression, qtlProbeSnpMultiMap); } if (!skipNormalization) { forceNormalExpressionData(datasetExpression); } datasetExpression.save(outputDir + "/expressionDataRound_" + covsToCorrect.length + ".txt"); datasetExpression.save(outputDir + "/expressionDataRound_" + covsToCorrect.length + ".binary"); datasetCovariates.save(outputDir + "/covariateData_" + covsToCorrect.length + ".binary"); if (1 == 1) { ExpressionDataset datasetZScores = new ExpressionDataset(datasetCovariates.nrProbes, datasetExpression.nrProbes); datasetZScores.probeNames = datasetCovariates.probeNames; datasetZScores.sampleNames = new String[datasetGenotypes.probeNames.length]; for (int i = 0; i < datasetGenotypes.probeNames.length; ++i) { datasetZScores.sampleNames[i] = datasetGenotypes.probeNames[i] + datasetExpression.probeNames[i] .substring(datasetExpression.probeNames[i].lastIndexOf('_')); } datasetZScores.recalculateHashMaps(); SkippedInteractionWriter skippedWriter = new SkippedInteractionWriter( new File(outputDir + "/skippedInteractionsRound_" + covsToCorrect.length + ".txt")); java.util.concurrent.ExecutorService threadPool = Executors.newFixedThreadPool(numThreads); CompletionService<DoubleArrayIntegerObject> pool = new ExecutorCompletionService<DoubleArrayIntegerObject>( threadPool); int nrTasks = 0; for (int cov = 0; cov < datasetCovariates.nrProbes; cov++) { double stdev = JSci.maths.ArrayMath.standardDeviation(datasetCovariates.rawData[cov]); if (stdev > 0) { PerformInteractionAnalysisPermutationTask task = new PerformInteractionAnalysisPermutationTask( datasetGenotypes, datasetExpression, datasetCovariates, datasetCovariatesPCAForceNormal, cov, skippedWriter, snpsToTest); pool.submit(task); nrTasks++; } } String maxChi2Cov = ""; int maxChi2CovI = 0; double maxChi2 = 0; try { // If gene annotation provided, for chi2sum calculation use only genes that are 1mb apart //if (geneDistanceMap != null) { for (int task = 0; task < nrTasks; task++) { try { //System.out.println("Waiting on thread for: " + datasetCovariates.probeNames[cov]); DoubleArrayIntegerObject result = pool.take().get(); int cov = result.intValue; double chi2Sum = 0; double[] covZ = datasetZScores.rawData[cov]; for (int snp = 0; snp < datasetGenotypes.nrProbes; snp++) { //if (genesFarAway(datasetZScores.sampleNames[snp], datasetZScores.probeNames[cov])) { double z = result.doubleArray[snp]; covZ[snp] = z; if (!Double.isNaN(z)) { chi2Sum += z * z; } //} } if (chi2Sum > maxChi2 && !datasetCovariates.probeNames[cov].startsWith("Comp") && !datasetCovariates.probeNames[cov].equals("LLS") && !datasetCovariates.probeNames[cov].equals("LLdeep") && !datasetCovariates.probeNames[cov].equals("RS") && !datasetCovariates.probeNames[cov].equals("CODAM")) { maxChi2 = chi2Sum; maxChi2CovI = cov; maxChi2Cov = datasetCovariates.probeNames[cov]; } //System.out.println(covsToCorrect.length + "\t" + cov + "\t" + datasetCovariates.probeNames[cov] + "\t" + chi2Sum); if ((task + 1) % 512 == 0) { System.out.println(task + 1 + " tasks processed"); } } catch (ExecutionException ex) { Logger.getLogger(PerformInteractionAnalysisPermutationTask.class.getName()) .log(Level.SEVERE, null, ex); } } /*} //If gene annotation not provided, use all gene pairs else { for (int task = 0; task < nrTasks; task++) { try { DoubleArrayIntegerObject result = pool.take().get(); int cov = result.intValue; double chi2Sum = 0; double[] covZ = datasetZScores.rawData[cov]; for (int snp = 0; snp < datasetGenotypes.nrProbes; snp++) { double z = result.doubleArray[snp]; covZ[snp] = z; if (!Double.isNaN(z)) { chi2Sum += z * z; } } if (chi2Sum > maxChi2) { maxChi2 = chi2Sum; maxChi2Cov = datasetCovariates.probeNames[cov]; } //System.out.println(covsToCorrect.length + "\t" + cov + "\t" + datasetCovariates.probeNames[cov] + "\t" + chi2Sum); if ((task + 1) % 512 == 0) { System.out.println(task + 1 + " tasks processed"); } } catch (ExecutionException ex) { Logger.getLogger(PerformInteractionAnalysisPermutationTask.class.getName()).log(Level.SEVERE, null, ex); } } }*/ threadPool.shutdown(); } catch (Exception e) { e.printStackTrace(); System.out.println(e.getMessage()); } System.out.println("Top covariate:\t" + maxChi2 + "\t" + maxChi2Cov); outputTopCovs.writeln("Top covariate:\t" + maxChi2 + "\t" + maxChi2Cov); outputTopCovs.flush(); skippedWriter.close(); datasetZScores.save(outputDir + "/InteractionZScoresMatrix-" + covsToCorrect.length + "Covariates.txt"); BufferedWriter writer = new BufferedWriter( new FileWriter(outputDir + "/" + "topCov" + maxChi2Cov + "_expression.txt")); double[] topCovExpression = datasetCovariates.rawData[maxChi2CovI]; for (int i = 0; i < topCovExpression.length; ++i) { writer.append(datasetCovariates.sampleNames[i]); writer.append('\t'); writer.append(String.valueOf(topCovExpression[i])); writer.append('\n'); } writer.close(); return maxChi2Cov; } return null; }
From source file:com.jolbox.benchmark.BenchmarkTests.java
/** * Helper function.//from ww w .j a va 2 s . co m * * @param threads * @param cpds * @param workDelay * @param doPreparedStatement * @return time taken * @throws InterruptedException */ public static long startThreadTest(int threads, DataSource cpds, int workDelay, boolean doPreparedStatement) throws InterruptedException { CountDownLatch startSignal = new CountDownLatch(1); CountDownLatch doneSignal = new CountDownLatch(threads); ExecutorService pool = Executors.newFixedThreadPool(threads); ExecutorCompletionService<Long> ecs = new ExecutorCompletionService<Long>(pool); for (int i = 0; i <= threads; i++) { // create and start threads ecs.submit(new ThreadTesterUtil(startSignal, doneSignal, cpds, workDelay, doPreparedStatement)); } startSignal.countDown(); // START TEST! doneSignal.await(); long time = 0; for (int i = 0; i <= threads; i++) { try { time = time + ecs.take().get(); } catch (ExecutionException e) { e.printStackTrace(); } } pool.shutdown(); return time; }
From source file:org.apache.hadoop.hbase.regionserver.HStore.java
@Override public ImmutableCollection<StoreFile> close() throws IOException { this.lock.writeLock().lock(); try {//from www . j a va2 s. c o m // Clear so metrics doesn't find them. ImmutableCollection<StoreFile> result = storeEngine.getStoreFileManager().clearFiles(); if (!result.isEmpty()) { // initialize the thread pool for closing store files in parallel. ThreadPoolExecutor storeFileCloserThreadPool = this.region .getStoreFileOpenAndCloseThreadPool("StoreFileCloserThread-" + this.getColumnFamilyName()); // close each store file in parallel CompletionService<Void> completionService = new ExecutorCompletionService<Void>( storeFileCloserThreadPool); for (final StoreFile f : result) { completionService.submit(new Callable<Void>() { @Override public Void call() throws IOException { f.closeReader(true); return null; } }); } IOException ioe = null; try { for (int i = 0; i < result.size(); i++) { try { Future<Void> future = completionService.take(); future.get(); } catch (InterruptedException e) { if (ioe == null) { ioe = new InterruptedIOException(); ioe.initCause(e); } } catch (ExecutionException e) { if (ioe == null) ioe = new IOException(e.getCause()); } } } finally { storeFileCloserThreadPool.shutdownNow(); } if (ioe != null) throw ioe; } LOG.info("Closed " + this); return result; } finally { this.lock.writeLock().unlock(); } }
From source file:org.paxle.filter.robots.impl.RobotsTxtManager.java
/** * Check a list of {@link URI URI} against the robots.txt file of the servers hosting the {@link URI}. * @param hostPort the web-server hosting the {@link URI URIs} * @param urlList a list of {@link URI}/*www. j a va 2 s . c o m*/ * * @return all {@link URI} that are blocked by the servers */ public List<URI> isDisallowed(Collection<URI> urlList) { if (urlList == null) throw new NullPointerException("The URI-list is null."); // group the URL list based on hostname:port HashMap<URI, List<URI>> uriBlocks = this.groupURI(urlList); ArrayList<URI> disallowedURI = new ArrayList<URI>(); /* * Asynchronous execution and parallel check of all blocks */ final CompletionService<Collection<URI>> execCompletionService = new ExecutorCompletionService<Collection<URI>>( this.execService); // loop through the blocks and start a worker for each block for (Entry<URI, List<URI>> uriBlock : uriBlocks.entrySet()) { URI baseUri = uriBlock.getKey(); List<URI> uriList = uriBlock.getValue(); execCompletionService.submit(new RobotsTxtManagerCallable(baseUri, uriList)); } // wait for the worker-threads to finish execution for (int i = 0; i < uriBlocks.size(); ++i) { try { Collection<URI> disallowedInGroup = execCompletionService.take().get(); if (disallowedInGroup != null) { disallowedURI.addAll(disallowedInGroup); } } catch (InterruptedException e) { this.logger.info(String.format("Interruption detected while waiting for robots.txt-check result.")); // XXX should we break here? } catch (ExecutionException e) { this.logger.error( String.format("Unexpected '%s' while performing robots.txt check.", e.getClass().getName()), e); } } return disallowedURI; }
From source file:org.apache.hadoop.hbase.io.hfile.TestHFileBlock.java
protected void testConcurrentReadingInternals() throws IOException, InterruptedException, ExecutionException { for (Compression.Algorithm compressAlgo : COMPRESSION_ALGORITHMS) { Path path = new Path(TEST_UTIL.getDataTestDir(), "concurrent_reading"); Random rand = defaultRandom(); List<Long> offsets = new ArrayList<Long>(); List<BlockType> types = new ArrayList<BlockType>(); writeBlocks(rand, compressAlgo, path, offsets, null, types, null); FSDataInputStream is = fs.open(path); long fileSize = fs.getFileStatus(path).getLen(); HFileContext meta = new HFileContextBuilder().withHBaseCheckSum(true) .withIncludesMvcc(includesMemstoreTS).withIncludesTags(includesTag) .withCompression(compressAlgo).build(); HFileBlock.FSReader hbr = new HFileBlock.FSReaderV2(is, fileSize, meta); Executor exec = Executors.newFixedThreadPool(NUM_READER_THREADS); ExecutorCompletionService<Boolean> ecs = new ExecutorCompletionService<Boolean>(exec); for (int i = 0; i < NUM_READER_THREADS; ++i) { ecs.submit(new BlockReaderThread("reader_" + (char) ('A' + i), hbr, offsets, types, fileSize)); }// w w w .ja va2 s . c o m for (int i = 0; i < NUM_READER_THREADS; ++i) { Future<Boolean> result = ecs.take(); assertTrue(result.get()); if (detailedLogging) { LOG.info(String.valueOf(i + 1) + " reader threads finished successfully (algo=" + compressAlgo + ")"); } } is.close(); } }
From source file:ml.shifu.shifu.core.dtrain.dt.DTWorker.java
@Override public DTWorkerParams doCompute(WorkerContext<DTMasterParams, DTWorkerParams> context) { if (context.isFirstIteration()) { return new DTWorkerParams(); }/* w w w .j a v a 2s .c om*/ DTMasterParams lastMasterResult = context.getLastMasterResult(); final List<TreeNode> trees = lastMasterResult.getTrees(); final Map<Integer, TreeNode> todoNodes = lastMasterResult.getTodoNodes(); if (todoNodes == null) { return new DTWorkerParams(); } LOG.info("Start to work: todoNodes size is {}", todoNodes.size()); Map<Integer, NodeStats> statistics = initTodoNodeStats(todoNodes); double trainError = 0d, validationError = 0d; double weightedTrainCount = 0d, weightedValidationCount = 0d; // renew random seed if (this.isGBDT && !this.gbdtSampleWithReplacement && lastMasterResult.isSwitchToNextTree()) { this.baggingRandomMap = new HashMap<Integer, Random>(); } long start = System.nanoTime(); for (Data data : this.trainingData) { if (this.isRF) { for (TreeNode treeNode : trees) { if (treeNode.getNode().getId() == Node.INVALID_INDEX) { continue; } Node predictNode = predictNodeIndex(treeNode.getNode(), data, true); if (predictNode.getPredict() != null) { // only update when not in first node, for treeNode, no predict statistics at that time float weight = data.subsampleWeights[treeNode.getTreeId()]; if (Float.compare(weight, 0f) == 0) { // oob data, no need to do weighting validationError += data.significance * loss .computeError((float) (predictNode.getPredict().getPredict()), data.label); weightedValidationCount += data.significance; } else { trainError += weight * data.significance * loss .computeError((float) (predictNode.getPredict().getPredict()), data.label); weightedTrainCount += weight * data.significance; } } } } if (this.isGBDT) { if (this.isContinuousEnabled && lastMasterResult.isContinuousRunningStart()) { recoverGBTData(context, data.output, data.predict, data, false); trainError += data.significance * loss.computeError(data.predict, data.label); weightedTrainCount += data.significance; } else { if (isNeedRecoverGBDTPredict) { if (this.recoverTrees == null) { this.recoverTrees = recoverCurrentTrees(); } // recover gbdt data for fail over recoverGBTData(context, data.output, data.predict, data, true); } int currTreeIndex = trees.size() - 1; if (lastMasterResult.isSwitchToNextTree()) { if (currTreeIndex >= 1) { Node node = trees.get(currTreeIndex - 1).getNode(); Node predictNode = predictNodeIndex(node, data, false); if (predictNode.getPredict() != null) { double predict = predictNode.getPredict().getPredict(); // first tree logic, master must set it to first tree even second tree with ROOT is // sending if (context.getLastMasterResult().isFirstTree()) { data.predict = (float) predict; } else { // random drop boolean drop = (this.dropOutRate > 0.0 && dropOutRandom.nextDouble() < this.dropOutRate); if (!drop) { data.predict += (float) (this.learningRate * predict); } } data.output = -1f * loss.computeGradient(data.predict, data.label); } // if not sampling with replacement in gbdt, renew bagging sample rate in next tree if (!this.gbdtSampleWithReplacement) { Random random = null; int classValue = (int) (data.label + 0.01f); if (this.isStratifiedSampling) { random = baggingRandomMap.get(classValue); if (random == null) { random = DTrainUtils.generateRandomBySampleSeed( modelConfig.getTrain().getBaggingSampleSeed(), CommonConstants.NOT_CONFIGURED_BAGGING_SEED); baggingRandomMap.put(classValue, random); } } else { random = baggingRandomMap.get(0); if (random == null) { random = DTrainUtils.generateRandomBySampleSeed( modelConfig.getTrain().getBaggingSampleSeed(), CommonConstants.NOT_CONFIGURED_BAGGING_SEED); baggingRandomMap.put(0, random); } } if (random.nextDouble() <= modelConfig.getTrain().getBaggingSampleRate()) { data.subsampleWeights[currTreeIndex % data.subsampleWeights.length] = 1f; } else { data.subsampleWeights[currTreeIndex % data.subsampleWeights.length] = 0f; } } } } if (context.getLastMasterResult().isFirstTree() && !lastMasterResult.isSwitchToNextTree()) { Node currTree = trees.get(currTreeIndex).getNode(); Node predictNode = predictNodeIndex(currTree, data, true); if (predictNode.getPredict() != null) { trainError += data.significance * loss .computeError((float) (predictNode.getPredict().getPredict()), data.label); weightedTrainCount += data.significance; } } else { trainError += data.significance * loss.computeError(data.predict, data.label); weightedTrainCount += data.significance; } } } } LOG.debug("Compute train error time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); if (validationData != null) { start = System.nanoTime(); for (Data data : this.validationData) { if (this.isRF) { for (TreeNode treeNode : trees) { if (treeNode.getNode().getId() == Node.INVALID_INDEX) { continue; } Node predictNode = predictNodeIndex(treeNode.getNode(), data, true); if (predictNode.getPredict() != null) { // only update when not in first node, for treeNode, no predict statistics at that time validationError += data.significance * loss .computeError((float) (predictNode.getPredict().getPredict()), data.label); weightedValidationCount += data.significance; } } } if (this.isGBDT) { if (this.isContinuousEnabled && lastMasterResult.isContinuousRunningStart()) { recoverGBTData(context, data.output, data.predict, data, false); validationError += data.significance * loss.computeError(data.predict, data.label); weightedValidationCount += data.significance; } else { if (isNeedRecoverGBDTPredict) { if (this.recoverTrees == null) { this.recoverTrees = recoverCurrentTrees(); } // recover gbdt data for fail over recoverGBTData(context, data.output, data.predict, data, true); } int currTreeIndex = trees.size() - 1; if (lastMasterResult.isSwitchToNextTree()) { if (currTreeIndex >= 1) { Node node = trees.get(currTreeIndex - 1).getNode(); Node predictNode = predictNodeIndex(node, data, false); if (predictNode.getPredict() != null) { double predict = predictNode.getPredict().getPredict(); if (context.getLastMasterResult().isFirstTree()) { data.predict = (float) predict; } else { data.predict += (float) (this.learningRate * predict); } data.output = -1f * loss.computeGradient(data.predict, data.label); } } } if (context.getLastMasterResult().isFirstTree() && !lastMasterResult.isSwitchToNextTree()) { Node predictNode = predictNodeIndex(trees.get(currTreeIndex).getNode(), data, true); if (predictNode.getPredict() != null) { validationError += data.significance * loss .computeError((float) (predictNode.getPredict().getPredict()), data.label); weightedValidationCount += data.significance; } } else { validationError += data.significance * loss.computeError(data.predict, data.label); weightedValidationCount += data.significance; } } } } LOG.debug("Compute val error time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); } if (this.isGBDT) { // reset trees to null to save memory this.recoverTrees = null; if (this.isNeedRecoverGBDTPredict) { // no need recover again this.isNeedRecoverGBDTPredict = false; } } start = System.nanoTime(); CompletionService<Map<Integer, NodeStats>> completionService = new ExecutorCompletionService<Map<Integer, NodeStats>>( this.threadPool); int realThreadCount = 0; LOG.debug("while todo size {}", todoNodes.size()); int realRecords = this.trainingData.size(); int realThreads = this.workerThreadCount > realRecords ? realRecords : this.workerThreadCount; int[] trainLows = new int[realThreads]; int[] trainHighs = new int[realThreads]; int stepCount = realRecords / realThreads; if (realRecords % realThreads != 0) { // move step count to append last gap to avoid last thread worse 2*stepCount-1 stepCount += (realRecords % realThreads) / stepCount; } for (int i = 0; i < realThreads; i++) { trainLows[i] = i * stepCount; if (i != realThreads - 1) { trainHighs[i] = trainLows[i] + stepCount - 1; } else { trainHighs[i] = realRecords - 1; } } for (int i = 0; i < realThreads; i++) { final Map<Integer, TreeNode> localTodoNodes = new HashMap<Integer, TreeNode>(todoNodes); final Map<Integer, NodeStats> localStatistics = initTodoNodeStats(todoNodes); final int startIndex = trainLows[i]; final int endIndex = trainHighs[i]; LOG.info("Thread {} todo size {} stats size {} start index {} end index {}", i, localTodoNodes.size(), localStatistics.size(), startIndex, endIndex); if (localTodoNodes.size() == 0) { continue; } realThreadCount += 1; completionService.submit(new Callable<Map<Integer, NodeStats>>() { @Override public Map<Integer, NodeStats> call() throws Exception { long start = System.nanoTime(); List<Integer> nodeIndexes = new ArrayList<Integer>(trees.size()); for (int j = startIndex; j <= endIndex; j++) { Data data = DTWorker.this.trainingData.get(j); nodeIndexes.clear(); if (DTWorker.this.isRF) { for (TreeNode treeNode : trees) { if (treeNode.getNode().getId() == Node.INVALID_INDEX) { nodeIndexes.add(Node.INVALID_INDEX); } else { Node predictNode = predictNodeIndex(treeNode.getNode(), data, false); nodeIndexes.add(predictNode.getId()); } } } if (DTWorker.this.isGBDT) { int currTreeIndex = trees.size() - 1; Node predictNode = predictNodeIndex(trees.get(currTreeIndex).getNode(), data, false); // update node index nodeIndexes.add(predictNode.getId()); } for (Map.Entry<Integer, TreeNode> entry : localTodoNodes.entrySet()) { // only do statistics on effective data Node todoNode = entry.getValue().getNode(); int treeId = entry.getValue().getTreeId(); int currPredictIndex = 0; if (DTWorker.this.isRF) { currPredictIndex = nodeIndexes.get(entry.getValue().getTreeId()); } if (DTWorker.this.isGBDT) { currPredictIndex = nodeIndexes.get(0); } if (todoNode.getId() == currPredictIndex) { List<Integer> features = entry.getValue().getFeatures(); if (features.isEmpty()) { features = getAllValidFeatures(); } for (Integer columnNum : features) { double[] featuerStatistic = localStatistics.get(entry.getKey()) .getFeatureStatistics().get(columnNum); float weight = data.subsampleWeights[treeId % data.subsampleWeights.length]; if (Float.compare(weight, 0f) != 0) { // only compute weight is not 0 short binIndex = data.inputs[DTWorker.this.inputIndexMap.get(columnNum)]; DTWorker.this.impurity.featureUpdate(featuerStatistic, binIndex, data.output, data.significance, weight); } } } } } LOG.debug("Thread computing stats time is {}ms in thread {}", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start), Thread.currentThread().getName()); return localStatistics; } }); } int rCnt = 0; while (rCnt < realThreadCount) { try { Map<Integer, NodeStats> currNodeStatsmap = completionService.take().get(); if (rCnt == 0) { statistics = currNodeStatsmap; } else { for (Entry<Integer, NodeStats> entry : statistics.entrySet()) { NodeStats resultNodeStats = entry.getValue(); mergeNodeStats(resultNodeStats, currNodeStatsmap.get(entry.getKey())); } } } catch (ExecutionException e) { throw new RuntimeException(e); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } rCnt += 1; } LOG.debug("Compute stats time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); LOG.info( "worker count is {}, error is {}, and stats size is {}. weightedTrainCount {}, weightedValidationCount {}, trainError {}, validationError {}", count, trainError, statistics.size(), weightedTrainCount, weightedValidationCount, trainError, validationError); return new DTWorkerParams(weightedTrainCount, weightedValidationCount, trainError, validationError, statistics); }
From source file:test.java.com.spotify.docker.client.DefaultDockerClientTest.java
@Test(expected = DockerTimeoutException.class) public void testConnectionRequestTimeout() throws Exception { final int connectionPoolSize = 1; final int callableCount = connectionPoolSize * 100; final ExecutorService executor = Executors.newCachedThreadPool(); final CompletionService completion = new ExecutorCompletionService(executor); // Spawn and wait on many more containers than the connection pool size. // This should cause a timeout once the connection pool is exhausted. final DockerClient dockerClient = DefaultDockerClient.fromEnv().connectionPoolSize(connectionPoolSize) .build();//from ww w .jav a 2 s .c o m try { // Create container final ContainerConfig config = ContainerConfig.builder().image(BUSYBOX_LATEST) .cmd("sh", "-c", "while :; do sleep 1; done").build(); final String name = randomName(); final ContainerCreation creation = dockerClient.createContainer(config, name); final String id = creation.id(); // Start the container dockerClient.startContainer(id); // Submit a bunch of waitContainer requests for (int i = 0; i < callableCount; i++) { completion.submit(new Callable<ContainerExit>() { @Override public ContainerExit call() throws Exception { return dockerClient.waitContainer(id); } }); } // Wait for the requests to complete or throw expected exception for (int i = 0; i < callableCount; i++) { try { completion.take().get(); } catch (ExecutionException e) { Throwables.propagateIfInstanceOf(e.getCause(), DockerTimeoutException.class); throw e; } } } finally { executor.shutdown(); dockerClient.close(); } }