List of usage examples for java.util.concurrent ExecutorService invokeAll
<T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) throws InterruptedException;
From source file:org.apache.sysml.runtime.compress.CompressedMatrixBlock.java
/** * Decompress block.//from www . ja v a 2 s . c o m * * @param k degree of parallelism * @return a new uncompressed matrix block containing the contents * of this block * @throws DMLRuntimeException if DMLRuntimeException occurs */ public MatrixBlock decompress(int k) throws DMLRuntimeException { //early abort for not yet compressed blocks if (!isCompressed()) return new MatrixBlock(this); if (k <= 1) return decompress(); Timing time = new Timing(true); MatrixBlock ret = new MatrixBlock(rlen, clen, sparse, nonZeros); ret.allocateDenseOrSparseBlock(); //multi-threaded decompression try { ExecutorService pool = Executors.newFixedThreadPool(k); int rlen = getNumRows(); int seqsz = BitmapEncoder.BITMAP_BLOCK_SZ; int blklen = (int) (Math.ceil((double) rlen / k)); blklen += (blklen % seqsz != 0) ? seqsz - blklen % seqsz : 0; ArrayList<DecompressTask> tasks = new ArrayList<DecompressTask>(); for (int i = 0; i < k & i * blklen < getNumRows(); i++) tasks.add(new DecompressTask(_colGroups, ret, i * blklen, Math.min((i + 1) * blklen, rlen))); List<Future<Object>> rtasks = pool.invokeAll(tasks); pool.shutdown(); for (Future<Object> rt : rtasks) rt.get(); //error handling } catch (Exception ex) { throw new DMLRuntimeException(ex); } //post-processing ret.setNonZeros(nonZeros); if (LOG.isDebugEnabled()) LOG.debug("decompressed block w/ k=" + k + " in " + time.stop() + "ms."); return ret; }
From source file:org.apache.sysml.runtime.compress.CompressedMatrixBlock.java
/** * Multi-threaded version of rightMultByVector. * // w w w .ja v a 2 s .c o m * @param vector matrix block vector * @param result matrix block result * @param k number of threads * @throws DMLRuntimeException if DMLRuntimeException occurs */ private void rightMultByVector(MatrixBlock vector, MatrixBlock result, int k) throws DMLRuntimeException { // initialize and allocate the result result.allocateDenseBlock(); //multi-threaded execution of all groups try { ExecutorService pool = Executors.newFixedThreadPool(k); int rlen = getNumRows(); int seqsz = BitmapEncoder.BITMAP_BLOCK_SZ; int blklen = (int) (Math.ceil((double) rlen / k)); blklen += (blklen % seqsz != 0) ? seqsz - blklen % seqsz : 0; ArrayList<RightMatrixMultTask> tasks = new ArrayList<RightMatrixMultTask>(); for (int i = 0; i < k & i * blklen < getNumRows(); i++) tasks.add(new RightMatrixMultTask(_colGroups, vector, result, i * blklen, Math.min((i + 1) * blklen, rlen))); pool.invokeAll(tasks); pool.shutdown(); } catch (Exception ex) { throw new DMLRuntimeException(ex); } // post-processing result.recomputeNonZeros(); }
From source file:org.apache.sysml.runtime.instructions.cp.ParamservBuiltinCPInstruction.java
@Override public void processInstruction(ExecutionContext ec) { Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null; PSModeType mode = getPSMode();//from w w w .ja v a 2 s .co m int workerNum = getWorkerNum(mode); BasicThreadFactory factory = new BasicThreadFactory.Builder().namingPattern("workers-pool-thread-%d") .build(); ExecutorService es = Executors.newFixedThreadPool(workerNum, factory); String updFunc = getParam(PS_UPDATE_FUN); String aggFunc = getParam(PS_AGGREGATION_FUN); int k = getParLevel(workerNum); // Get the compiled execution context // Create workers' execution context LocalVariableMap newVarsMap = createVarsMap(ec); List<ExecutionContext> newECs = ParamservUtils.createExecutionContexts(ec, newVarsMap, updFunc, aggFunc, workerNum, k); // Create workers' execution context List<ExecutionContext> workerECs = newECs.subList(0, newECs.size() - 1); // Create the agg service's execution context ExecutionContext aggServiceEC = newECs.get(newECs.size() - 1); PSFrequency freq = getFrequency(); PSUpdateType updateType = getUpdateType(); int epochs = getEpochs(); // Create the parameter server ListObject model = ec.getListObject(getParam(PS_MODEL)); ParamServer ps = createPS(mode, aggFunc, updateType, workerNum, model, aggServiceEC); // Create the local workers MatrixObject valFeatures = ec.getMatrixObject(getParam(PS_VAL_FEATURES)); MatrixObject valLabels = ec.getMatrixObject(getParam(PS_VAL_LABELS)); List<LocalPSWorker> workers = IntStream.range(0, workerNum).mapToObj(i -> new LocalPSWorker(i, updFunc, freq, epochs, getBatchSize(), valFeatures, valLabels, workerECs.get(i), ps)) .collect(Collectors.toList()); // Do data partition PSScheme scheme = getScheme(); doDataPartitioning(scheme, ec, workers); if (DMLScript.STATISTICS) Statistics.accPSSetupTime((long) tSetup.stop()); if (LOG.isDebugEnabled()) { LOG.debug(String.format( "\nConfiguration of paramserv func: " + "\nmode: %s \nworkerNum: %d \nupdate frequency: %s " + "\nstrategy: %s \ndata partitioner: %s", mode, workerNum, freq, updateType, scheme)); } try { // Launch the worker threads and wait for completion for (Future<Void> ret : es.invokeAll(workers)) ret.get(); //error handling // Fetch the final model from ps ListObject result = ps.getResult(); ec.setVariable(output.getName(), result); } catch (InterruptedException | ExecutionException e) { throw new DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e); } finally { es.shutdownNow(); // Should shutdown the thread pool in param server ps.shutdown(); } }
From source file:org.openspaces.rest.space.ReplicationRESTController.java
private void deploySpaces(final DeployRequest req) { ExecutorService esvc = Executors.newFixedThreadPool(req.endpoints.size()); List<Callable<DeployResult>> tasks = new ArrayList<Callable<DeployResult>>(); try {/*from ww w . jav a 2 s . c o m*/ for (final CloudifyRestEndpoint ep : req.endpoints) { tasks.add(new Callable<DeployResult>() { @Override public DeployResult call() throws Exception { HttpClient client = HttpClientBuilder.create().build(); HttpPost post = new HttpPost("http://" + ep.address + ":" + ep.port + "/" + CLOUDIFY_VERSION + "/deployments/applications/" + req.tspec.name + "/services/repl-management/invoke"); log.fine("invoking custom command via REST :" + "http://" + ep.address + ":" + ep.port + "/" + CLOUDIFY_VERSION + "/deployments/applications/" + req.tspec.name + "/services/repl-management/invoke"); post.setEntity(new StringEntity(String.format( "{\"commandName\":\"add-repl-space\",\"parameters\":[\"%s\",%d,%d,\"%s\"]}", req.tspec.name, 1, 0, ep.siteId), ContentType.APPLICATION_JSON)); log.fine(" post entity:" + String.format( "{\"commandName\":\"add-repl-space\",\"parameters\":[\"%s\",%d,%d,\"%s\"]}", req.tspec.name, 1, 0, ep.siteId)); HttpResponse resp = client.execute(post); return new DeployResult(resp, ep); } }); } List<Future<DeployResult>> results = new ArrayList<Future<DeployResult>>(); try { results = esvc.invokeAll(tasks); } catch (InterruptedException e) { throw new RuntimeException(e); } StringBuilder sb = new StringBuilder("Repl space deploy failed:"); int initlength = sb.length(); DeployResult dr = null; for (Future<DeployResult> future : results) { try { dr = future.get(); } catch (Exception e) { sb.append(dr.ep.address).append(" caught exception - ").append(e.getMessage()).append(","); } try { if (dr.response.getStatusLine().getStatusCode() != 200) { String body = IOUtils.toString(dr.response.getEntity().getContent()); sb.append(dr.ep.address).append(" returned status ") .append(dr.response.getStatusLine().getStatusCode()).append(",reason=") .append(dr.response.getStatusLine().getReasonPhrase()).append(",").append(body); } else { log.info("got response 200: " + dr.response.getStatusLine().toString()); } } catch (Exception e) { sb.append(dr.ep.address).append(" caught exception processing response body- ") .append(e.getMessage()).append(","); } } if (sb.length() > initlength) {//error caught throw new RuntimeException(sb.toString()); } } finally { esvc.shutdown(); } }
From source file:com.topsec.tsm.sim.report.model.ReportModel.java
/** * 1.mail 2.??//from w w w .j a v a 2 s. com * * @param RptMasterTbService * rptMasterTbImp DAO * @param ExpStruct * exp * @param HttpServletRequest * request HttpServletRequest * @return LinkedHashMap<String, List> exp? * @throws Exception * 2. */ public static LinkedHashMap<String, List> expMstReport(RptMasterTbService rptMasterTbImp, ExpStruct exp, HttpServletRequest request) throws Exception { String mstRptId = exp.getMstrptid();// ID Integer mstRptIdInt = 0; if (!GlobalUtil.isNullOrEmpty(mstRptId)) { mstRptIdInt = Integer.valueOf(mstRptId); } String mstSql = ReportUiConfig.MstSubSql; Object[] subParam = { mstRptIdInt }; // List<Map<String,Object>> subResult = rptMasterTbImp.queryTmpList(mstSql, subParam); List<Map<String, Object>> subResult = new ArrayList<Map<String, Object>>(); Map<Integer, Integer> rowColumns = new HashMap<Integer, Integer>(); List<Map<String, Object>> subResultTemp = rptMasterTbImp.queryTmpList(mstSql, subParam); if (subResultTemp.size() > 0) { Map subMap = subResultTemp.get(0); String viewItem = StringUtil.toString(subMap.get("viewItem"), ""); if (viewItem.indexOf("2") < 0) { exp.setRptType(ReportUiConfig.rptDirection); String[] time = ReportUiUtil.getExpTime("month"); exp.setRptTimeS(time[0]); exp.setRptTimeE(time[1]); } } int evtRptsize = subResultTemp.size(); if (!GlobalUtil.isNullOrEmpty(subResultTemp)) { subResult.addAll(subResultTemp); } ReportBean bean = new ReportBean(); if (!GlobalUtil.isNullOrEmpty(request)) { bean = ReportUiUtil.tidyFormBean(bean, request); } String nodeType = bean.getNodeType(); String dvcaddress = bean.getDvcaddress(); DataSourceService dataSourceService = (DataSourceService) SpringContextServlet.springCtx .getBean("dataSourceService"); if (!GlobalUtil.isNullOrEmpty(bean.getDvctype()) && bean.getDvctype().startsWith("Profession/Group") && !GlobalUtil.isNullOrEmpty(nodeType) && !GlobalUtil.isNullOrEmpty(dvcaddress)) { Map map = TopoUtil.getAssetEvtMstMap(); String mstIds = null; List<SimDatasource> simDatasources = dataSourceService.getByIp(dvcaddress); if (!GlobalUtil.isNullOrEmpty(simDatasources)) { mstIds = ""; for (SimDatasource simDatasource : simDatasources) { if (map.containsKey(simDatasource.getSecurityObjectType())) { mstIds += map.get(simDatasource.getSecurityObjectType()).toString() + ":::"; } else { String keyString = getStartStringKey(map, simDatasource.getSecurityObjectType()); if (!GlobalUtil.isNullOrEmpty(keyString)) { mstIds += map.get(keyString).toString() + ":::"; } } } if (mstIds.length() > 3) { mstIds = mstIds.substring(0, mstIds.length() - 3); } } else { if (map.containsKey(nodeType)) { mstIds = map.get(nodeType).toString(); } else { String keyString = getStartStringKey(map, nodeType); if (!GlobalUtil.isNullOrEmpty(keyString)) { mstIds = map.get(keyString).toString(); } } } if (!GlobalUtil.isNullOrEmpty(mstIds)) { String[] mstIdArr = mstIds.split(":::"); for (String string : mstIdArr) { List<Map<String, Object>> subTemp = rptMasterTbImp.queryTmpList(mstSql, new Object[] { StringUtil.toInt(string, 5) }); if (!GlobalUtil.isNullOrEmpty(subTemp)) { int maxCol = 0; if (!GlobalUtil.isNullOrEmpty(rowColumns)) { maxCol = getMaxOrMinKey(rowColumns, 1); } for (Map map2 : subTemp) { Integer row = (Integer) map2.get("subRow") + maxCol; map2.put("subRow", row); } subResult.addAll(subTemp); } } } } if (!GlobalUtil.isNullOrEmpty(bean.getDvctype()) && bean.getDvctype().startsWith("Comprehensive")) { List<String> dvcTypes = dvcTypes = new ArrayList<String>(); dvcTypes.add(bean.getDvctype().replace("Comprehensive", "")); List<String> mstrptidAndNodeTypeList = new ArrayList<String>(); setMstIdAndScanNodeType(dvcTypes, mstrptidAndNodeTypeList); subResultTemp = null; if (!GlobalUtil.isNullOrEmpty(mstrptidAndNodeTypeList)) { subResultTemp = rptMasterTbImp.queryTmpList(ReportUiConfig.MstSubSql, new Object[] { StringUtil.toInt((mstrptidAndNodeTypeList.get(0).split("IDandNODEtype"))[0], StringUtil.toInt(bean.getTalTop(), 5)) }); Map<Integer, Integer> rowColumnsTeMap = ReportModel.getRowColumns(subResultTemp); evtRptsize = subResultTemp.size(); if (!GlobalUtil.isNullOrEmpty(subResultTemp)) { for (Map map2 : subResultTemp) { map2.put("subject", (mstrptidAndNodeTypeList.get(0).split("IDandNODEtype"))[1]); } subResult.addAll(subResultTemp); rowColumns.putAll(rowColumnsTeMap); } int len = mstrptidAndNodeTypeList.size(); for (int i = 1; i < len; i++) { String mstrptidAndNodeType = mstrptidAndNodeTypeList.get(i); String string = mstrptidAndNodeType.split("IDandNODEtype")[0]; List<Map<String, Object>> subTemp = rptMasterTbImp.queryTmpList(ReportUiConfig.MstSubSql, new Object[] { StringUtil.toInt(string, StringUtil.toInt(bean.getTalTop(), 5)) }); if (!GlobalUtil.isNullOrEmpty(subTemp)) { int maxCol = 0; if (!GlobalUtil.isNullOrEmpty(rowColumns)) { maxCol = getMaxOrMinKey(rowColumns, 1); } for (Map map2 : subTemp) { Integer row = (Integer) map2.get("subRow") + maxCol; map2.put("subRow", row); map2.put("subject", mstrptidAndNodeType.split("IDandNODEtype")[1]); } subResult.addAll(subTemp); Map<Integer, Integer> rowColTemp = ReportModel.getRowColumns(subTemp); rowColumns.putAll(rowColTemp); } } } } List<ExpDateStruct> expList = new ArrayList<ExpDateStruct>(); // ? Map<ReportExecutor.SubjectKey, Map<Integer, ExpDateStruct>> exportMap = Collections .synchronizedMap(new LinkedHashMap()); ExecutorService threadPool = Executors.newFixedThreadPool(subResult.size(), new TsmThreadFactory("ReportSubjectExport")); LinkedHashMap<String, List> expMap = null; try { List<ReportExecutor> tasks = new ArrayList<ReportExecutor>(subResult.size()); int order = 0; for (Map sub : subResult) { order += 100; tasks.add(new ReportExecutor(order, rptMasterTbImp, exp, exportMap, expList, sub, request, SID.currentUser())); } threadPool.invokeAll(tasks); expMap = new LinkedHashMap<String, List>(exportMap.size()); for (Map.Entry<ReportExecutor.SubjectKey, Map<Integer, ExpDateStruct>> entry : exportMap.entrySet()) { expMap.put(entry.getKey().subject, new ArrayList(entry.getValue().values())); } } finally { threadPool.shutdownNow(); } return expMap; }
From source file:org.openspaces.rest.space.ReplicationRESTController.java
private void deployGateways(final DeployRequest req) { ExecutorService esvc = Executors.newFixedThreadPool(req.endpoints.size()); List<Callable<DeployResult>> tasks = new ArrayList<Callable<DeployResult>>(); try {// w w w . j a v a 2 s . co m final Map<String, List<String>> gwaddresses = makeGatewayAddressMap(req); final String natArg = addressesToNat(gwaddresses); for (final CloudifyRestEndpoint ep : req.endpoints) { tasks.add(new Callable<DeployResult>() { @Override public DeployResult call() throws Exception { HttpClient client = HttpClientBuilder.create().build(); HttpPost post = new HttpPost("http://" + ep.address + ":" + ep.port + "/" + CLOUDIFY_VERSION + "/deployments/applications/" + req.tspec.name + "/services/repl-gateway/invoke"); log.info("invoking custom command via REST :" + "http://" + ep.address + ":" + ep.port + "/" + CLOUDIFY_VERSION + "/deployments/applications/" + req.tspec.name + "/services/repl-gateway/invoke"); String entityString = String.format( "{\"commandName\":\"install-gateway\",\"parameters\":[\"%s\",\"%s\",\"%s\",\"%s\",\"%s\"]}", req.tspec.name, ep.siteId, tspecToPairsArg(req.tspec), tspecToLookups(ep, req.tspec, gwaddresses), natArg); post.setEntity(new StringEntity(entityString, ContentType.APPLICATION_JSON)); log.info(" post entity:" + entityString); HttpResponse resp = client.execute(post); return new DeployResult(resp, ep); } }); } List<Future<DeployResult>> results = new ArrayList<Future<DeployResult>>(); try { results = esvc.invokeAll(tasks); } catch (InterruptedException e) { throw new RuntimeException(e); } StringBuilder sb = new StringBuilder("Gateway deploy failed:"); int initlength = sb.length(); DeployResult dr = null; for (Future<DeployResult> future : results) { try { dr = future.get(); } catch (Exception e) { sb.append(dr.ep.address).append(" caught exception - ").append(e.getMessage()).append(","); } try { if (dr.response.getStatusLine().getStatusCode() != 200) { String body = IOUtils.toString(dr.response.getEntity().getContent()); sb.append(dr.ep.address).append(" returned status ") .append(dr.response.getStatusLine().getStatusCode()).append(",reason=") .append(dr.response.getStatusLine().getReasonPhrase()).append(",").append(body); } else { log.info("got response 200: " + dr.response.getStatusLine().toString()); } } catch (Exception e) { sb.append(dr.ep.address).append(" caught exception processing response body- ") .append(e.getMessage()).append(","); } } if (sb.length() > initlength) {//error caught throw new RuntimeException(sb.toString()); } } finally { esvc.shutdown(); } }
From source file:de.uzk.hki.da.pkg.MetsConsistencyChecker.java
/** * Checks the package consistency based on the File elements * in the METS file.//from w w w . ja v a2s .c om * * This assumes that a file must exist for every File element, * which is not the case for delta-packages. * * @return true, if a file could be found for every File element * and the checksums matched, otherwise false * @throws Exception the exception */ public boolean checkPackageBasedOnMets() throws Exception { boolean result = true; Namespace metsNS = Namespace.getNamespace("mets", "http://www.loc.gov/METS/"); Namespace xlinkNS = Namespace.getNamespace("xlink", "http://www.w3.org/1999/xlink"); String metsPath = packagePath + "/export_mets.xml"; SAXBuilder builder = new SAXBuilder(false); Document doc = builder.build(new File(metsPath)); XPath xpath = XPath.newInstance("//mets:file"); xpath.addNamespace(metsNS); @SuppressWarnings("rawtypes") List nodes = xpath.selectNodes(doc); logger.debug("Found {} mets:file elements", nodes.size()); ExecutorService executor = Executors.newFixedThreadPool(8); List<FileChecksumVerifierThread> threads = new ArrayList<FileChecksumVerifierThread>(); for (@SuppressWarnings("rawtypes") Iterator iterator = nodes.iterator(); iterator.hasNext();) { Element elem = (Element) iterator.next(); String checksum = elem.getAttributeValue("CHECKSUM"); String checksumType = elem.getAttributeValue("CHECKSUMTYPE"); String path = elem.getChild("FLocat", metsNS).getAttributeValue("href", xlinkNS); logger.debug("Verifying file: {}", path); // check if required attributes are set if (checksum == null) { logger.warn( "METS File Element in {} does not contain attribute CHECKSUM. File consistency can not be verified.", metsPath); continue; } if (checksumType == null) { logger.warn( "METS File Element in {} does not contain attribute CHECKSUM TYPE. File consistency can not be verified.", metsPath); continue; } if (path == null) { logger.warn("METS File Element in {} does not contain path.", metsPath); continue; } // check if file exists at path File file = new File(packagePath + "/" + path); if (!file.exists()) { result = false; String msg = "Could not find file referenced in METS metadata: " + path; logger.error(msg); messages.add(msg); continue; } logger.debug("Checking with algorithm: {}", checksumType); // calculate and verify checksum checksumType = checksumType.replaceAll("-", ""); try { MessageDigest algorithm = MessageDigest.getInstance(checksumType); threads.add(new FileChecksumVerifierThread(checksum, file, algorithm)); } catch (NoSuchAlgorithmException e) { logger.warn( "METS File Element in {} contains unknown CHECKSUM TYPE: {}. File consistency can not be verified.", metsPath, checksumType); continue; } } List<Future<ChecksumResult>> futures = executor.invokeAll(threads); for (Future<ChecksumResult> future : futures) { ChecksumResult cResult = future.get(); if (!cResult.isSuccess()) { result = false; logger.error(cResult.getMessage()); messages.add(cResult.getMessage()); } } return result; }
From source file:org.apache.sysml.runtime.compress.CompressedMatrixBlock.java
@Override public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result, int blockingFactorRow, int blockingFactorCol, MatrixIndexes indexesIn, boolean inCP) throws DMLRuntimeException { //call uncompressed matrix mult if necessary if (!isCompressed()) { return super.aggregateUnaryOperations(op, result, blockingFactorRow, blockingFactorCol, indexesIn, inCP);//from w ww.j av a 2 s . co m } //check for supported operations if (!(op.aggOp.increOp.fn instanceof KahanPlus || op.aggOp.increOp.fn instanceof KahanPlusSq || (op.aggOp.increOp.fn instanceof Builtin && (((Builtin) op.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN || ((Builtin) op.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)))) { throw new DMLRuntimeException("Unary aggregates other than sum/sumsq/min/max not supported yet."); } Timing time = LOG.isDebugEnabled() ? new Timing(true) : null; //prepare output dimensions CellIndex tempCellIndex = new CellIndex(-1, -1); op.indexFn.computeDimension(rlen, clen, tempCellIndex); if (op.aggOp.correctionExists) { switch (op.aggOp.correctionLocation) { case LASTROW: tempCellIndex.row++; break; case LASTCOLUMN: tempCellIndex.column++; break; case LASTTWOROWS: tempCellIndex.row += 2; break; case LASTTWOCOLUMNS: tempCellIndex.column += 2; break; default: throw new DMLRuntimeException("unrecognized correctionLocation: " + op.aggOp.correctionLocation); } } // initialize and allocate the result if (result == null) result = new MatrixBlock(tempCellIndex.row, tempCellIndex.column, false); else result.reset(tempCellIndex.row, tempCellIndex.column, false); MatrixBlock ret = (MatrixBlock) result; ret.allocateDenseBlock(); //special handling init value for rowmins/rowmax if (op.indexFn instanceof ReduceCol && op.aggOp.increOp.fn instanceof Builtin) { double val = Double.MAX_VALUE * ((((Builtin) op.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX) ? -1 : 1); Arrays.fill(ret.getDenseBlock(), val); } //core unary aggregate if (op.getNumThreads() > 1 && getExactSizeOnDisk() > MIN_PAR_AGG_THRESHOLD) { //multi-threaded execution of all groups ArrayList<ColGroup>[] grpParts = createStaticTaskPartitioning( (op.indexFn instanceof ReduceCol) ? 1 : op.getNumThreads(), false); ColGroupUncompressed uc = getUncompressedColGroup(); try { //compute uncompressed column group in parallel (otherwise bottleneck) if (uc != null) ret = (MatrixBlock) uc.getData().aggregateUnaryOperations(op, ret, blockingFactorRow, blockingFactorCol, indexesIn, false); //compute all compressed column groups ExecutorService pool = Executors.newFixedThreadPool(op.getNumThreads()); ArrayList<UnaryAggregateTask> tasks = new ArrayList<UnaryAggregateTask>(); if (op.indexFn instanceof ReduceCol && grpParts.length > 0) { int seqsz = BitmapEncoder.BITMAP_BLOCK_SZ; int blklen = (int) (Math.ceil((double) rlen / op.getNumThreads())); blklen += (blklen % seqsz != 0) ? seqsz - blklen % seqsz : 0; for (int i = 0; i < op.getNumThreads() & i * blklen < rlen; i++) tasks.add(new UnaryAggregateTask(grpParts[0], ret, i * blklen, Math.min((i + 1) * blklen, rlen), op)); } else for (ArrayList<ColGroup> grp : grpParts) tasks.add(new UnaryAggregateTask(grp, ret, 0, rlen, op)); List<Future<MatrixBlock>> rtasks = pool.invokeAll(tasks); pool.shutdown(); //aggregate partial results if (op.indexFn instanceof ReduceAll) { double val = ret.quickGetValue(0, 0); for (Future<MatrixBlock> rtask : rtasks) val = op.aggOp.increOp.fn.execute(val, rtask.get().quickGetValue(0, 0)); ret.quickSetValue(0, 0, val); } } catch (Exception ex) { throw new DMLRuntimeException(ex); } } else { //process UC column group for (ColGroup grp : _colGroups) if (grp instanceof ColGroupUncompressed) grp.unaryAggregateOperations(op, ret); //process OLE/RLE column groups for (ColGroup grp : _colGroups) if (!(grp instanceof ColGroupUncompressed)) grp.unaryAggregateOperations(op, ret); } //special handling zeros for rowmins/rowmax if (op.indexFn instanceof ReduceCol && op.aggOp.increOp.fn instanceof Builtin) { int[] rnnz = new int[rlen]; for (ColGroup grp : _colGroups) grp.countNonZerosPerRow(rnnz, 0, rlen); Builtin builtin = (Builtin) op.aggOp.increOp.fn; for (int i = 0; i < rlen; i++) if (rnnz[i] < clen) ret.quickSetValue(i, 0, builtin.execute2(ret.quickGetValue(i, 0), 0)); } //drop correction if necessary if (op.aggOp.correctionExists && inCP) ret.dropLastRowsOrColums(op.aggOp.correctionLocation); //post-processing ret.recomputeNonZeros(); if (LOG.isDebugEnabled()) LOG.debug("Compressed uagg k=" + op.getNumThreads() + " in " + time.stop()); return ret; }
From source file:com.ibm.bi.dml.runtime.matrix.data.LibMatrixMult.java
/** * /*from w w w .j ava 2 s.c o m*/ * @param mX * @param mU * @param mV * @param ret * @param wt * @param k * @throws DMLRuntimeException */ public static void matrixMultWCeMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WCeMMType wt, int k) throws DMLRuntimeException { //check for empty result if (mW.isEmptyBlock(false)) { ret.examSparsity(); //turn empty dense into sparse return; } //Timing time = new Timing(true); //pre-processing ret.sparse = false; ret.allocateDenseBlock(); try { ExecutorService pool = Executors.newFixedThreadPool(k); ArrayList<ScalarResultTask> tasks = new ArrayList<ScalarResultTask>(); int blklen = (int) (Math.ceil((double) mW.rlen / k)); for (int i = 0; i < k & i * blklen < mW.rlen; i++) tasks.add(new MatrixMultWCeTask(mW, mU, mV, wt, i * blklen, Math.min((i + 1) * blklen, mW.rlen))); pool.invokeAll(tasks); pool.shutdown(); //aggregate partial results sumScalarResults(tasks, ret); } catch (InterruptedException e) { throw new DMLRuntimeException(e); } //System.out.println("MMWCe "+wt.toString()+" k="+k+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" + // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop()); }
From source file:com.ibm.bi.dml.runtime.matrix.data.LibMatrixMult.java
/** * Performs a multi-threaded matrix multiplication and stores the result in the output matrix. * The parameter k (k>=1) determines the max parallelism k' with k'=min(k, vcores, m1.rlen). * //from w w w.j ava 2 s .c o m * @param m1 * @param m2 * @param ret * @param k * @throws DMLRuntimeException */ public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) throws DMLRuntimeException { //check inputs / outputs if (m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) { ret.examSparsity(); //turn empty dense into sparse return; } //check too high additional vector-matrix memory requirements (fallback to sequential) //check too small workload in terms of flops (fallback to sequential too) if (m1.rlen == 1 && (8L * m2.clen * k > MEM_OVERHEAD_THRESHOLD || !LOW_LEVEL_OPTIMIZATION || m2.clen == 1 || m1.isUltraSparse() || m2.isUltraSparse()) || 2L * m1.rlen * m1.clen * m2.clen < PAR_MINFLOP_THRESHOLD) { matrixMult(m1, m2, ret); return; } //Timing time = new Timing(true); //pre-processing: output allocation (in contrast to single-threaded, //we need to allocate sparse as well in order to prevent synchronization) boolean tm2 = checkPrepMatrixMultRightInput(m1, m2); m2 = prepMatrixMultRightInput(m1, m2); ret.sparse = (m1.isUltraSparse() || m2.isUltraSparse()); if (!ret.sparse) ret.allocateDenseBlock(); else ret.allocateSparseRowsBlock(); //prepare row-upper for special cases of vector-matrix / matrix-matrix boolean pm2 = checkParMatrixMultRightInput(m1, m2, k); int ru = pm2 ? m2.rlen : m1.rlen; //core multi-threaded matrix mult computation //(currently: always parallelization over number of rows) try { ExecutorService pool = Executors.newFixedThreadPool(k); ArrayList<MatrixMultTask> tasks = new ArrayList<MatrixMultTask>(); int blklen = (int) (Math.ceil((double) ru / k)); for (int i = 0; i < k & i * blklen < ru; i++) tasks.add(new MatrixMultTask(m1, m2, ret, tm2, pm2, i * blklen, Math.min((i + 1) * blklen, ru))); pool.invokeAll(tasks); pool.shutdown(); //aggregate partial results (nnz, ret for vector/matrix) ret.nonZeros = 0; //reset after execute for (MatrixMultTask task : tasks) { if (pm2) vectAdd(task.getResult().denseBlock, ret.denseBlock, 0, 0, ret.rlen * ret.clen); else ret.nonZeros += task.getPartialNnz(); } if (pm2) ret.recomputeNonZeros(); } catch (Exception ex) { throw new DMLRuntimeException(ex); } //post-processing (nnz maintained in parallel) ret.examSparsity(); //System.out.println("MM k="+k+" ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" + // "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop()); }