Example usage for java.util.concurrent ExecutorService invokeAll

List of usage examples for java.util.concurrent ExecutorService invokeAll

Introduction

In this page you can find the example usage for java.util.concurrent ExecutorService invokeAll.

Prototype

<T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) throws InterruptedException;

Source Link

Document

Executes the given tasks, returning a list of Futures holding their status and results when all complete.

Usage

From source file:org.apache.sysml.runtime.compress.CompressedMatrixBlock.java

/**
 * Decompress block.//from www  .  ja v  a  2 s  . c o m
 * 
 * @param k degree of parallelism
 * @return a new uncompressed matrix block containing the contents 
 * of this block
 * @throws DMLRuntimeException if DMLRuntimeException occurs
 */
public MatrixBlock decompress(int k) throws DMLRuntimeException {
    //early abort for not yet compressed blocks
    if (!isCompressed())
        return new MatrixBlock(this);
    if (k <= 1)
        return decompress();

    Timing time = new Timing(true);

    MatrixBlock ret = new MatrixBlock(rlen, clen, sparse, nonZeros);
    ret.allocateDenseOrSparseBlock();

    //multi-threaded decompression
    try {
        ExecutorService pool = Executors.newFixedThreadPool(k);
        int rlen = getNumRows();
        int seqsz = BitmapEncoder.BITMAP_BLOCK_SZ;
        int blklen = (int) (Math.ceil((double) rlen / k));
        blklen += (blklen % seqsz != 0) ? seqsz - blklen % seqsz : 0;
        ArrayList<DecompressTask> tasks = new ArrayList<DecompressTask>();
        for (int i = 0; i < k & i * blklen < getNumRows(); i++)
            tasks.add(new DecompressTask(_colGroups, ret, i * blklen, Math.min((i + 1) * blklen, rlen)));
        List<Future<Object>> rtasks = pool.invokeAll(tasks);
        pool.shutdown();
        for (Future<Object> rt : rtasks)
            rt.get(); //error handling
    } catch (Exception ex) {
        throw new DMLRuntimeException(ex);
    }

    //post-processing 
    ret.setNonZeros(nonZeros);

    if (LOG.isDebugEnabled())
        LOG.debug("decompressed block w/ k=" + k + " in " + time.stop() + "ms.");

    return ret;
}

From source file:org.apache.sysml.runtime.compress.CompressedMatrixBlock.java

/**
 * Multi-threaded version of rightMultByVector.
 * //  w  w w .ja v  a  2 s  .c o  m
 * @param vector matrix block vector
 * @param result matrix block result
 * @param k number of threads
 * @throws DMLRuntimeException if DMLRuntimeException occurs
 */
private void rightMultByVector(MatrixBlock vector, MatrixBlock result, int k) throws DMLRuntimeException {
    // initialize and allocate the result
    result.allocateDenseBlock();

    //multi-threaded execution of all groups
    try {
        ExecutorService pool = Executors.newFixedThreadPool(k);
        int rlen = getNumRows();
        int seqsz = BitmapEncoder.BITMAP_BLOCK_SZ;
        int blklen = (int) (Math.ceil((double) rlen / k));
        blklen += (blklen % seqsz != 0) ? seqsz - blklen % seqsz : 0;
        ArrayList<RightMatrixMultTask> tasks = new ArrayList<RightMatrixMultTask>();
        for (int i = 0; i < k & i * blklen < getNumRows(); i++)
            tasks.add(new RightMatrixMultTask(_colGroups, vector, result, i * blklen,
                    Math.min((i + 1) * blklen, rlen)));
        pool.invokeAll(tasks);
        pool.shutdown();
    } catch (Exception ex) {
        throw new DMLRuntimeException(ex);
    }

    // post-processing
    result.recomputeNonZeros();
}

From source file:org.apache.sysml.runtime.instructions.cp.ParamservBuiltinCPInstruction.java

@Override
public void processInstruction(ExecutionContext ec) {
    Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;

    PSModeType mode = getPSMode();//from   w  w  w .ja v  a 2  s .co  m
    int workerNum = getWorkerNum(mode);
    BasicThreadFactory factory = new BasicThreadFactory.Builder().namingPattern("workers-pool-thread-%d")
            .build();
    ExecutorService es = Executors.newFixedThreadPool(workerNum, factory);
    String updFunc = getParam(PS_UPDATE_FUN);
    String aggFunc = getParam(PS_AGGREGATION_FUN);

    int k = getParLevel(workerNum);

    // Get the compiled execution context
    // Create workers' execution context
    LocalVariableMap newVarsMap = createVarsMap(ec);
    List<ExecutionContext> newECs = ParamservUtils.createExecutionContexts(ec, newVarsMap, updFunc, aggFunc,
            workerNum, k);

    // Create workers' execution context
    List<ExecutionContext> workerECs = newECs.subList(0, newECs.size() - 1);

    // Create the agg service's execution context
    ExecutionContext aggServiceEC = newECs.get(newECs.size() - 1);

    PSFrequency freq = getFrequency();
    PSUpdateType updateType = getUpdateType();
    int epochs = getEpochs();

    // Create the parameter server
    ListObject model = ec.getListObject(getParam(PS_MODEL));
    ParamServer ps = createPS(mode, aggFunc, updateType, workerNum, model, aggServiceEC);

    // Create the local workers
    MatrixObject valFeatures = ec.getMatrixObject(getParam(PS_VAL_FEATURES));
    MatrixObject valLabels = ec.getMatrixObject(getParam(PS_VAL_LABELS));
    List<LocalPSWorker> workers = IntStream.range(0, workerNum).mapToObj(i -> new LocalPSWorker(i, updFunc,
            freq, epochs, getBatchSize(), valFeatures, valLabels, workerECs.get(i), ps))
            .collect(Collectors.toList());

    // Do data partition
    PSScheme scheme = getScheme();
    doDataPartitioning(scheme, ec, workers);

    if (DMLScript.STATISTICS)
        Statistics.accPSSetupTime((long) tSetup.stop());

    if (LOG.isDebugEnabled()) {
        LOG.debug(String.format(
                "\nConfiguration of paramserv func: " + "\nmode: %s \nworkerNum: %d \nupdate frequency: %s "
                        + "\nstrategy: %s \ndata partitioner: %s",
                mode, workerNum, freq, updateType, scheme));
    }

    try {
        // Launch the worker threads and wait for completion
        for (Future<Void> ret : es.invokeAll(workers))
            ret.get(); //error handling
        // Fetch the final model from ps
        ListObject result = ps.getResult();
        ec.setVariable(output.getName(), result);
    } catch (InterruptedException | ExecutionException e) {
        throw new DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e);
    } finally {
        es.shutdownNow();
        // Should shutdown the thread pool in param server
        ps.shutdown();
    }
}

From source file:org.openspaces.rest.space.ReplicationRESTController.java

private void deploySpaces(final DeployRequest req) {
    ExecutorService esvc = Executors.newFixedThreadPool(req.endpoints.size());
    List<Callable<DeployResult>> tasks = new ArrayList<Callable<DeployResult>>();

    try {/*from  ww  w . jav  a  2  s . c  o  m*/

        for (final CloudifyRestEndpoint ep : req.endpoints) {
            tasks.add(new Callable<DeployResult>() {
                @Override
                public DeployResult call() throws Exception {
                    HttpClient client = HttpClientBuilder.create().build();
                    HttpPost post = new HttpPost("http://" + ep.address + ":" + ep.port + "/" + CLOUDIFY_VERSION
                            + "/deployments/applications/" + req.tspec.name
                            + "/services/repl-management/invoke");
                    log.fine("invoking custom command via REST :" + "http://" + ep.address + ":" + ep.port + "/"
                            + CLOUDIFY_VERSION + "/deployments/applications/" + req.tspec.name
                            + "/services/repl-management/invoke");
                    post.setEntity(new StringEntity(String.format(
                            "{\"commandName\":\"add-repl-space\",\"parameters\":[\"%s\",%d,%d,\"%s\"]}",
                            req.tspec.name, 1, 0, ep.siteId), ContentType.APPLICATION_JSON));
                    log.fine("  post entity:" + String.format(
                            "{\"commandName\":\"add-repl-space\",\"parameters\":[\"%s\",%d,%d,\"%s\"]}",
                            req.tspec.name, 1, 0, ep.siteId));
                    HttpResponse resp = client.execute(post);
                    return new DeployResult(resp, ep);
                }
            });
        }

        List<Future<DeployResult>> results = new ArrayList<Future<DeployResult>>();
        try {
            results = esvc.invokeAll(tasks);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }

        StringBuilder sb = new StringBuilder("Repl space deploy failed:");
        int initlength = sb.length();
        DeployResult dr = null;
        for (Future<DeployResult> future : results) {
            try {
                dr = future.get();
            } catch (Exception e) {
                sb.append(dr.ep.address).append(" caught exception - ").append(e.getMessage()).append(",");
            }
            try {
                if (dr.response.getStatusLine().getStatusCode() != 200) {
                    String body = IOUtils.toString(dr.response.getEntity().getContent());
                    sb.append(dr.ep.address).append(" returned status ")
                            .append(dr.response.getStatusLine().getStatusCode()).append(",reason=")
                            .append(dr.response.getStatusLine().getReasonPhrase()).append(",").append(body);
                } else {
                    log.info("got response 200: " + dr.response.getStatusLine().toString());
                }
            } catch (Exception e) {
                sb.append(dr.ep.address).append(" caught exception processing response body- ")
                        .append(e.getMessage()).append(",");
            }
        }

        if (sb.length() > initlength) {//error caught
            throw new RuntimeException(sb.toString());
        }

    } finally {
        esvc.shutdown();
    }

}

From source file:com.topsec.tsm.sim.report.model.ReportModel.java

/**
 *  1.mail 2.??//from   w w  w  .j  a  v a 2 s.  com
 * 
 * @param RptMasterTbService
 *            rptMasterTbImp DAO
 * @param ExpStruct
 *            exp 
 * @param HttpServletRequest
 *            request HttpServletRequest
 * @return LinkedHashMap<String, List> exp?
 * @throws Exception
 *             2.
 */
public static LinkedHashMap<String, List> expMstReport(RptMasterTbService rptMasterTbImp, ExpStruct exp,
        HttpServletRequest request) throws Exception {
    String mstRptId = exp.getMstrptid();// ID
    Integer mstRptIdInt = 0;
    if (!GlobalUtil.isNullOrEmpty(mstRptId)) {
        mstRptIdInt = Integer.valueOf(mstRptId);
    }
    String mstSql = ReportUiConfig.MstSubSql;
    Object[] subParam = { mstRptIdInt };
    //      List<Map<String,Object>> subResult = rptMasterTbImp.queryTmpList(mstSql, subParam);

    List<Map<String, Object>> subResult = new ArrayList<Map<String, Object>>();
    Map<Integer, Integer> rowColumns = new HashMap<Integer, Integer>();

    List<Map<String, Object>> subResultTemp = rptMasterTbImp.queryTmpList(mstSql, subParam);
    if (subResultTemp.size() > 0) {
        Map subMap = subResultTemp.get(0);
        String viewItem = StringUtil.toString(subMap.get("viewItem"), "");
        if (viewItem.indexOf("2") < 0) {
            exp.setRptType(ReportUiConfig.rptDirection);
            String[] time = ReportUiUtil.getExpTime("month");
            exp.setRptTimeS(time[0]);
            exp.setRptTimeE(time[1]);
        }
    }
    int evtRptsize = subResultTemp.size();
    if (!GlobalUtil.isNullOrEmpty(subResultTemp)) {
        subResult.addAll(subResultTemp);
    }
    ReportBean bean = new ReportBean();
    if (!GlobalUtil.isNullOrEmpty(request)) {
        bean = ReportUiUtil.tidyFormBean(bean, request);
    }
    String nodeType = bean.getNodeType();
    String dvcaddress = bean.getDvcaddress();
    DataSourceService dataSourceService = (DataSourceService) SpringContextServlet.springCtx
            .getBean("dataSourceService");
    if (!GlobalUtil.isNullOrEmpty(bean.getDvctype()) && bean.getDvctype().startsWith("Profession/Group")
            && !GlobalUtil.isNullOrEmpty(nodeType) && !GlobalUtil.isNullOrEmpty(dvcaddress)) {
        Map map = TopoUtil.getAssetEvtMstMap();
        String mstIds = null;
        List<SimDatasource> simDatasources = dataSourceService.getByIp(dvcaddress);
        if (!GlobalUtil.isNullOrEmpty(simDatasources)) {
            mstIds = "";
            for (SimDatasource simDatasource : simDatasources) {
                if (map.containsKey(simDatasource.getSecurityObjectType())) {
                    mstIds += map.get(simDatasource.getSecurityObjectType()).toString() + ":::";
                } else {
                    String keyString = getStartStringKey(map, simDatasource.getSecurityObjectType());
                    if (!GlobalUtil.isNullOrEmpty(keyString)) {
                        mstIds += map.get(keyString).toString() + ":::";
                    }
                }
            }
            if (mstIds.length() > 3) {
                mstIds = mstIds.substring(0, mstIds.length() - 3);
            }
        } else {
            if (map.containsKey(nodeType)) {
                mstIds = map.get(nodeType).toString();
            } else {
                String keyString = getStartStringKey(map, nodeType);
                if (!GlobalUtil.isNullOrEmpty(keyString)) {
                    mstIds = map.get(keyString).toString();
                }
            }
        }
        if (!GlobalUtil.isNullOrEmpty(mstIds)) {
            String[] mstIdArr = mstIds.split(":::");
            for (String string : mstIdArr) {
                List<Map<String, Object>> subTemp = rptMasterTbImp.queryTmpList(mstSql,
                        new Object[] { StringUtil.toInt(string, 5) });
                if (!GlobalUtil.isNullOrEmpty(subTemp)) {
                    int maxCol = 0;
                    if (!GlobalUtil.isNullOrEmpty(rowColumns)) {
                        maxCol = getMaxOrMinKey(rowColumns, 1);
                    }
                    for (Map map2 : subTemp) {
                        Integer row = (Integer) map2.get("subRow") + maxCol;
                        map2.put("subRow", row);
                    }
                    subResult.addAll(subTemp);
                }
            }
        }
    }

    if (!GlobalUtil.isNullOrEmpty(bean.getDvctype()) && bean.getDvctype().startsWith("Comprehensive")) {
        List<String> dvcTypes = dvcTypes = new ArrayList<String>();
        dvcTypes.add(bean.getDvctype().replace("Comprehensive", ""));

        List<String> mstrptidAndNodeTypeList = new ArrayList<String>();
        setMstIdAndScanNodeType(dvcTypes, mstrptidAndNodeTypeList);
        subResultTemp = null;
        if (!GlobalUtil.isNullOrEmpty(mstrptidAndNodeTypeList)) {
            subResultTemp = rptMasterTbImp.queryTmpList(ReportUiConfig.MstSubSql,
                    new Object[] { StringUtil.toInt((mstrptidAndNodeTypeList.get(0).split("IDandNODEtype"))[0],
                            StringUtil.toInt(bean.getTalTop(), 5)) });
            Map<Integer, Integer> rowColumnsTeMap = ReportModel.getRowColumns(subResultTemp);
            evtRptsize = subResultTemp.size();
            if (!GlobalUtil.isNullOrEmpty(subResultTemp)) {
                for (Map map2 : subResultTemp) {
                    map2.put("subject", (mstrptidAndNodeTypeList.get(0).split("IDandNODEtype"))[1]);
                }
                subResult.addAll(subResultTemp);
                rowColumns.putAll(rowColumnsTeMap);
            }

            int len = mstrptidAndNodeTypeList.size();
            for (int i = 1; i < len; i++) {
                String mstrptidAndNodeType = mstrptidAndNodeTypeList.get(i);
                String string = mstrptidAndNodeType.split("IDandNODEtype")[0];
                List<Map<String, Object>> subTemp = rptMasterTbImp.queryTmpList(ReportUiConfig.MstSubSql,
                        new Object[] { StringUtil.toInt(string, StringUtil.toInt(bean.getTalTop(), 5)) });
                if (!GlobalUtil.isNullOrEmpty(subTemp)) {
                    int maxCol = 0;
                    if (!GlobalUtil.isNullOrEmpty(rowColumns)) {
                        maxCol = getMaxOrMinKey(rowColumns, 1);
                    }
                    for (Map map2 : subTemp) {
                        Integer row = (Integer) map2.get("subRow") + maxCol;
                        map2.put("subRow", row);
                        map2.put("subject", mstrptidAndNodeType.split("IDandNODEtype")[1]);
                    }
                    subResult.addAll(subTemp);
                    Map<Integer, Integer> rowColTemp = ReportModel.getRowColumns(subTemp);
                    rowColumns.putAll(rowColTemp);
                }
            }
        }
    }

    List<ExpDateStruct> expList = new ArrayList<ExpDateStruct>(); // ?
    Map<ReportExecutor.SubjectKey, Map<Integer, ExpDateStruct>> exportMap = Collections
            .synchronizedMap(new LinkedHashMap());
    ExecutorService threadPool = Executors.newFixedThreadPool(subResult.size(),
            new TsmThreadFactory("ReportSubjectExport"));
    LinkedHashMap<String, List> expMap = null;
    try {
        List<ReportExecutor> tasks = new ArrayList<ReportExecutor>(subResult.size());
        int order = 0;
        for (Map sub : subResult) {
            order += 100;
            tasks.add(new ReportExecutor(order, rptMasterTbImp, exp, exportMap, expList, sub, request,
                    SID.currentUser()));
        }
        threadPool.invokeAll(tasks);
        expMap = new LinkedHashMap<String, List>(exportMap.size());
        for (Map.Entry<ReportExecutor.SubjectKey, Map<Integer, ExpDateStruct>> entry : exportMap.entrySet()) {
            expMap.put(entry.getKey().subject, new ArrayList(entry.getValue().values()));
        }
    } finally {
        threadPool.shutdownNow();
    }
    return expMap;
}

From source file:org.openspaces.rest.space.ReplicationRESTController.java

private void deployGateways(final DeployRequest req) {
    ExecutorService esvc = Executors.newFixedThreadPool(req.endpoints.size());
    List<Callable<DeployResult>> tasks = new ArrayList<Callable<DeployResult>>();

    try {// w  w  w  . j a v a  2  s  . co  m
        final Map<String, List<String>> gwaddresses = makeGatewayAddressMap(req);
        final String natArg = addressesToNat(gwaddresses);
        for (final CloudifyRestEndpoint ep : req.endpoints) {
            tasks.add(new Callable<DeployResult>() {
                @Override
                public DeployResult call() throws Exception {
                    HttpClient client = HttpClientBuilder.create().build();
                    HttpPost post = new HttpPost("http://" + ep.address + ":" + ep.port + "/" + CLOUDIFY_VERSION
                            + "/deployments/applications/" + req.tspec.name + "/services/repl-gateway/invoke");
                    log.info("invoking custom command via REST :" + "http://" + ep.address + ":" + ep.port + "/"
                            + CLOUDIFY_VERSION + "/deployments/applications/" + req.tspec.name
                            + "/services/repl-gateway/invoke");
                    String entityString = String.format(
                            "{\"commandName\":\"install-gateway\",\"parameters\":[\"%s\",\"%s\",\"%s\",\"%s\",\"%s\"]}",
                            req.tspec.name, ep.siteId, tspecToPairsArg(req.tspec),
                            tspecToLookups(ep, req.tspec, gwaddresses), natArg);
                    post.setEntity(new StringEntity(entityString, ContentType.APPLICATION_JSON));
                    log.info("  post entity:" + entityString);
                    HttpResponse resp = client.execute(post);
                    return new DeployResult(resp, ep);
                }

            });
        }

        List<Future<DeployResult>> results = new ArrayList<Future<DeployResult>>();
        try {
            results = esvc.invokeAll(tasks);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }

        StringBuilder sb = new StringBuilder("Gateway deploy failed:");
        int initlength = sb.length();
        DeployResult dr = null;
        for (Future<DeployResult> future : results) {
            try {
                dr = future.get();
            } catch (Exception e) {
                sb.append(dr.ep.address).append(" caught exception - ").append(e.getMessage()).append(",");
            }
            try {
                if (dr.response.getStatusLine().getStatusCode() != 200) {
                    String body = IOUtils.toString(dr.response.getEntity().getContent());
                    sb.append(dr.ep.address).append(" returned status ")
                            .append(dr.response.getStatusLine().getStatusCode()).append(",reason=")
                            .append(dr.response.getStatusLine().getReasonPhrase()).append(",").append(body);
                } else {
                    log.info("got response 200: " + dr.response.getStatusLine().toString());
                }
            } catch (Exception e) {
                sb.append(dr.ep.address).append(" caught exception processing response body- ")
                        .append(e.getMessage()).append(",");
            }
        }

        if (sb.length() > initlength) {//error caught
            throw new RuntimeException(sb.toString());
        }

    } finally {
        esvc.shutdown();
    }

}

From source file:de.uzk.hki.da.pkg.MetsConsistencyChecker.java

/**
 * Checks the package consistency based on the File elements
 * in the METS file.//from   w  w w . ja v a2s .c  om
 * 
 * This assumes that a file must exist for every File element,
 * which is not the case for delta-packages.
 *
 * @return true, if a file could be found for every File element
 * and the checksums matched, otherwise false
 * @throws Exception the exception
 */
public boolean checkPackageBasedOnMets() throws Exception {

    boolean result = true;

    Namespace metsNS = Namespace.getNamespace("mets", "http://www.loc.gov/METS/");
    Namespace xlinkNS = Namespace.getNamespace("xlink", "http://www.w3.org/1999/xlink");
    String metsPath = packagePath + "/export_mets.xml";

    SAXBuilder builder = new SAXBuilder(false);
    Document doc = builder.build(new File(metsPath));
    XPath xpath = XPath.newInstance("//mets:file");
    xpath.addNamespace(metsNS);

    @SuppressWarnings("rawtypes")
    List nodes = xpath.selectNodes(doc);

    logger.debug("Found {} mets:file elements", nodes.size());

    ExecutorService executor = Executors.newFixedThreadPool(8);
    List<FileChecksumVerifierThread> threads = new ArrayList<FileChecksumVerifierThread>();

    for (@SuppressWarnings("rawtypes")
    Iterator iterator = nodes.iterator(); iterator.hasNext();) {

        Element elem = (Element) iterator.next();
        String checksum = elem.getAttributeValue("CHECKSUM");
        String checksumType = elem.getAttributeValue("CHECKSUMTYPE");
        String path = elem.getChild("FLocat", metsNS).getAttributeValue("href", xlinkNS);

        logger.debug("Verifying file: {}", path);

        // check if required attributes are set
        if (checksum == null) {
            logger.warn(
                    "METS File Element in {} does not contain attribute CHECKSUM. File consistency can not be verified.",
                    metsPath);
            continue;
        }
        if (checksumType == null) {
            logger.warn(
                    "METS File Element in {} does not contain attribute CHECKSUM TYPE. File consistency can not be verified.",
                    metsPath);
            continue;
        }
        if (path == null) {
            logger.warn("METS File Element in {} does not contain path.", metsPath);
            continue;
        }

        // check if file exists at path
        File file = new File(packagePath + "/" + path);
        if (!file.exists()) {
            result = false;
            String msg = "Could not find file referenced in METS metadata: " + path;
            logger.error(msg);
            messages.add(msg);
            continue;
        }

        logger.debug("Checking with algorithm: {}", checksumType);

        // calculate and verify checksum
        checksumType = checksumType.replaceAll("-", "");
        try {
            MessageDigest algorithm = MessageDigest.getInstance(checksumType);
            threads.add(new FileChecksumVerifierThread(checksum, file, algorithm));
        } catch (NoSuchAlgorithmException e) {
            logger.warn(
                    "METS File Element in {} contains unknown CHECKSUM TYPE: {}. File consistency can not be verified.",
                    metsPath, checksumType);
            continue;
        }

    }

    List<Future<ChecksumResult>> futures = executor.invokeAll(threads);

    for (Future<ChecksumResult> future : futures) {
        ChecksumResult cResult = future.get();
        if (!cResult.isSuccess()) {
            result = false;
            logger.error(cResult.getMessage());
            messages.add(cResult.getMessage());
        }
    }

    return result;
}

From source file:org.apache.sysml.runtime.compress.CompressedMatrixBlock.java

@Override
public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result,
        int blockingFactorRow, int blockingFactorCol, MatrixIndexes indexesIn, boolean inCP)
        throws DMLRuntimeException {
    //call uncompressed matrix mult if necessary
    if (!isCompressed()) {
        return super.aggregateUnaryOperations(op, result, blockingFactorRow, blockingFactorCol, indexesIn,
                inCP);//from   w ww.j av a  2 s . co  m
    }

    //check for supported operations
    if (!(op.aggOp.increOp.fn instanceof KahanPlus || op.aggOp.increOp.fn instanceof KahanPlusSq
            || (op.aggOp.increOp.fn instanceof Builtin
                    && (((Builtin) op.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN
                            || ((Builtin) op.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)))) {
        throw new DMLRuntimeException("Unary aggregates other than sum/sumsq/min/max not supported yet.");
    }

    Timing time = LOG.isDebugEnabled() ? new Timing(true) : null;

    //prepare output dimensions
    CellIndex tempCellIndex = new CellIndex(-1, -1);
    op.indexFn.computeDimension(rlen, clen, tempCellIndex);
    if (op.aggOp.correctionExists) {
        switch (op.aggOp.correctionLocation) {
        case LASTROW:
            tempCellIndex.row++;
            break;
        case LASTCOLUMN:
            tempCellIndex.column++;
            break;
        case LASTTWOROWS:
            tempCellIndex.row += 2;
            break;
        case LASTTWOCOLUMNS:
            tempCellIndex.column += 2;
            break;
        default:
            throw new DMLRuntimeException("unrecognized correctionLocation: " + op.aggOp.correctionLocation);
        }
    }

    // initialize and allocate the result
    if (result == null)
        result = new MatrixBlock(tempCellIndex.row, tempCellIndex.column, false);
    else
        result.reset(tempCellIndex.row, tempCellIndex.column, false);
    MatrixBlock ret = (MatrixBlock) result;
    ret.allocateDenseBlock();

    //special handling init value for rowmins/rowmax
    if (op.indexFn instanceof ReduceCol && op.aggOp.increOp.fn instanceof Builtin) {
        double val = Double.MAX_VALUE
                * ((((Builtin) op.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX) ? -1 : 1);
        Arrays.fill(ret.getDenseBlock(), val);
    }

    //core unary aggregate
    if (op.getNumThreads() > 1 && getExactSizeOnDisk() > MIN_PAR_AGG_THRESHOLD) {
        //multi-threaded execution of all groups 
        ArrayList<ColGroup>[] grpParts = createStaticTaskPartitioning(
                (op.indexFn instanceof ReduceCol) ? 1 : op.getNumThreads(), false);
        ColGroupUncompressed uc = getUncompressedColGroup();
        try {
            //compute uncompressed column group in parallel (otherwise bottleneck)
            if (uc != null)
                ret = (MatrixBlock) uc.getData().aggregateUnaryOperations(op, ret, blockingFactorRow,
                        blockingFactorCol, indexesIn, false);
            //compute all compressed column groups
            ExecutorService pool = Executors.newFixedThreadPool(op.getNumThreads());
            ArrayList<UnaryAggregateTask> tasks = new ArrayList<UnaryAggregateTask>();
            if (op.indexFn instanceof ReduceCol && grpParts.length > 0) {
                int seqsz = BitmapEncoder.BITMAP_BLOCK_SZ;
                int blklen = (int) (Math.ceil((double) rlen / op.getNumThreads()));
                blklen += (blklen % seqsz != 0) ? seqsz - blklen % seqsz : 0;
                for (int i = 0; i < op.getNumThreads() & i * blklen < rlen; i++)
                    tasks.add(new UnaryAggregateTask(grpParts[0], ret, i * blklen,
                            Math.min((i + 1) * blklen, rlen), op));
            } else
                for (ArrayList<ColGroup> grp : grpParts)
                    tasks.add(new UnaryAggregateTask(grp, ret, 0, rlen, op));
            List<Future<MatrixBlock>> rtasks = pool.invokeAll(tasks);
            pool.shutdown();

            //aggregate partial results
            if (op.indexFn instanceof ReduceAll) {
                double val = ret.quickGetValue(0, 0);
                for (Future<MatrixBlock> rtask : rtasks)
                    val = op.aggOp.increOp.fn.execute(val, rtask.get().quickGetValue(0, 0));
                ret.quickSetValue(0, 0, val);
            }
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
    } else {
        //process UC column group
        for (ColGroup grp : _colGroups)
            if (grp instanceof ColGroupUncompressed)
                grp.unaryAggregateOperations(op, ret);

        //process OLE/RLE column groups
        for (ColGroup grp : _colGroups)
            if (!(grp instanceof ColGroupUncompressed))
                grp.unaryAggregateOperations(op, ret);
    }

    //special handling zeros for rowmins/rowmax
    if (op.indexFn instanceof ReduceCol && op.aggOp.increOp.fn instanceof Builtin) {
        int[] rnnz = new int[rlen];
        for (ColGroup grp : _colGroups)
            grp.countNonZerosPerRow(rnnz, 0, rlen);
        Builtin builtin = (Builtin) op.aggOp.increOp.fn;
        for (int i = 0; i < rlen; i++)
            if (rnnz[i] < clen)
                ret.quickSetValue(i, 0, builtin.execute2(ret.quickGetValue(i, 0), 0));
    }

    //drop correction if necessary
    if (op.aggOp.correctionExists && inCP)
        ret.dropLastRowsOrColums(op.aggOp.correctionLocation);

    //post-processing
    ret.recomputeNonZeros();

    if (LOG.isDebugEnabled())
        LOG.debug("Compressed uagg k=" + op.getNumThreads() + " in " + time.stop());

    return ret;
}

From source file:com.ibm.bi.dml.runtime.matrix.data.LibMatrixMult.java

/**
 * /*from w  w w .j ava  2  s.c  o m*/
 * @param mX
 * @param mU
 * @param mV
 * @param ret
 * @param wt
 * @param k
 * @throws DMLRuntimeException
 */
public static void matrixMultWCeMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret,
        WCeMMType wt, int k) throws DMLRuntimeException {
    //check for empty result 
    if (mW.isEmptyBlock(false)) {
        ret.examSparsity(); //turn empty dense into sparse
        return;
    }

    //Timing time = new Timing(true);

    //pre-processing
    ret.sparse = false;
    ret.allocateDenseBlock();

    try {
        ExecutorService pool = Executors.newFixedThreadPool(k);
        ArrayList<ScalarResultTask> tasks = new ArrayList<ScalarResultTask>();
        int blklen = (int) (Math.ceil((double) mW.rlen / k));
        for (int i = 0; i < k & i * blklen < mW.rlen; i++)
            tasks.add(new MatrixMultWCeTask(mW, mU, mV, wt, i * blklen, Math.min((i + 1) * blklen, mW.rlen)));
        pool.invokeAll(tasks);
        pool.shutdown();
        //aggregate partial results
        sumScalarResults(tasks, ret);
    } catch (InterruptedException e) {
        throw new DMLRuntimeException(e);
    }

    //System.out.println("MMWCe "+wt.toString()+" k="+k+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" +
    //                 "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
}

From source file:com.ibm.bi.dml.runtime.matrix.data.LibMatrixMult.java

/**
 * Performs a multi-threaded matrix multiplication and stores the result in the output matrix.
 * The parameter k (k>=1) determines the max parallelism k' with k'=min(k, vcores, m1.rlen).
 * //from w w w.j ava  2 s .c o m
 * @param m1
 * @param m2
 * @param ret
 * @param k
 * @throws DMLRuntimeException 
 */
public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k)
        throws DMLRuntimeException {
    //check inputs / outputs
    if (m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) {
        ret.examSparsity(); //turn empty dense into sparse
        return;
    }

    //check too high additional vector-matrix memory requirements (fallback to sequential)
    //check too small workload in terms of flops (fallback to sequential too)
    if (m1.rlen == 1
            && (8L * m2.clen * k > MEM_OVERHEAD_THRESHOLD || !LOW_LEVEL_OPTIMIZATION || m2.clen == 1
                    || m1.isUltraSparse() || m2.isUltraSparse())
            || 2L * m1.rlen * m1.clen * m2.clen < PAR_MINFLOP_THRESHOLD) {
        matrixMult(m1, m2, ret);
        return;
    }

    //Timing time = new Timing(true);

    //pre-processing: output allocation (in contrast to single-threaded,
    //we need to allocate sparse as well in order to prevent synchronization)
    boolean tm2 = checkPrepMatrixMultRightInput(m1, m2);
    m2 = prepMatrixMultRightInput(m1, m2);
    ret.sparse = (m1.isUltraSparse() || m2.isUltraSparse());
    if (!ret.sparse)
        ret.allocateDenseBlock();
    else
        ret.allocateSparseRowsBlock();

    //prepare row-upper for special cases of vector-matrix / matrix-matrix
    boolean pm2 = checkParMatrixMultRightInput(m1, m2, k);
    int ru = pm2 ? m2.rlen : m1.rlen;

    //core multi-threaded matrix mult computation
    //(currently: always parallelization over number of rows)
    try {
        ExecutorService pool = Executors.newFixedThreadPool(k);
        ArrayList<MatrixMultTask> tasks = new ArrayList<MatrixMultTask>();
        int blklen = (int) (Math.ceil((double) ru / k));
        for (int i = 0; i < k & i * blklen < ru; i++)
            tasks.add(new MatrixMultTask(m1, m2, ret, tm2, pm2, i * blklen, Math.min((i + 1) * blklen, ru)));
        pool.invokeAll(tasks);
        pool.shutdown();
        //aggregate partial results (nnz, ret for vector/matrix)
        ret.nonZeros = 0; //reset after execute
        for (MatrixMultTask task : tasks) {
            if (pm2)
                vectAdd(task.getResult().denseBlock, ret.denseBlock, 0, 0, ret.rlen * ret.clen);
            else
                ret.nonZeros += task.getPartialNnz();
        }
        if (pm2)
            ret.recomputeNonZeros();
    } catch (Exception ex) {
        throw new DMLRuntimeException(ex);
    }

    //post-processing (nnz maintained in parallel)
    ret.examSparsity();

    //System.out.println("MM k="+k+" ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" +
    //                    "("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop());
}