Example usage for java.util.concurrent CompletionService take

List of usage examples for java.util.concurrent CompletionService take

Introduction

In this page you can find the example usage for java.util.concurrent CompletionService take.

Prototype

Future<V> take() throws InterruptedException;

Source Link

Document

Retrieves and removes the Future representing the next completed task, waiting if none are yet present.

Usage

From source file:org.apache.hadoop.hbase.regionserver.HRegion.java

private Map<byte[], List<StoreFile>> doClose(final boolean abort, MonitoredTask status) throws IOException {
    if (isClosed()) {
        LOG.warn("Region " + this + " already closed");
        return null;
    }/*w ww  .  j av  a 2 s  .  c o m*/

    if (coprocessorHost != null) {
        status.setStatus("Running coprocessor pre-close hooks");
        this.coprocessorHost.preClose(abort);
    }

    status.setStatus("Disabling compacts and flushes for region");
    synchronized (writestate) {
        // Disable compacting and flushing by background threads for this
        // region.
        writestate.writesEnabled = false;
        LOG.debug("Closing " + this + ": disabling compactions & flushes");
        waitForFlushesAndCompactions();
    }
    // If we were not just flushing, is it worth doing a preflush...one
    // that will clear out of the bulk of the memstore before we put up
    // the close flag?
    if (!abort && worthPreFlushing()) {
        status.setStatus("Pre-flushing region before close");
        LOG.info("Running close preflush of " + this.getRegionNameAsString());
        try {
            internalFlushcache(status);
        } catch (IOException ioe) {
            // Failed to flush the region. Keep going.
            status.setStatus("Failed pre-flush " + this + "; " + ioe.getMessage());
        }
    }

    this.closing.set(true);
    status.setStatus("Disabling writes for close");
    // block waiting for the lock for closing
    lock.writeLock().lock();
    try {
        if (this.isClosed()) {
            status.abort("Already got closed by another process");
            // SplitTransaction handles the null
            return null;
        }
        LOG.debug("Updates disabled for region " + this);
        // Don't flush the cache if we are aborting
        if (!abort) {
            int flushCount = 0;
            while (this.getMemstoreSize().get() > 0) {
                try {
                    if (flushCount++ > 0) {
                        int actualFlushes = flushCount - 1;
                        if (actualFlushes > 5) {
                            // If we tried 5 times and are unable to clear memory, abort
                            // so we do not lose data
                            throw new DroppedSnapshotException("Failed clearing memory after " + actualFlushes
                                    + " attempts on region: " + Bytes.toStringBinary(getRegionName()));
                        }
                        LOG.info("Running extra flush, " + actualFlushes + " (carrying snapshot?) " + this);
                    }
                    internalFlushcache(status);
                } catch (IOException ioe) {
                    status.setStatus("Failed flush " + this + ", putting online again");
                    synchronized (writestate) {
                        writestate.writesEnabled = true;
                    }
                    // Have to throw to upper layers.  I can't abort server from here.
                    throw ioe;
                }
            }
        }

        Map<byte[], List<StoreFile>> result = new TreeMap<byte[], List<StoreFile>>(Bytes.BYTES_COMPARATOR);
        if (!stores.isEmpty()) {
            // initialize the thread pool for closing stores in parallel.
            ThreadPoolExecutor storeCloserThreadPool = getStoreOpenAndCloseThreadPool(
                    "StoreCloserThread-" + this.getRegionNameAsString());
            CompletionService<Pair<byte[], Collection<StoreFile>>> completionService = new ExecutorCompletionService<Pair<byte[], Collection<StoreFile>>>(
                    storeCloserThreadPool);

            // close each store in parallel
            for (final Store store : stores.values()) {
                assert abort || store.getFlushableSize() == 0;
                completionService.submit(new Callable<Pair<byte[], Collection<StoreFile>>>() {
                    @Override
                    public Pair<byte[], Collection<StoreFile>> call() throws IOException {
                        return new Pair<byte[], Collection<StoreFile>>(store.getFamily().getName(),
                                store.close());
                    }
                });
            }
            try {
                for (int i = 0; i < stores.size(); i++) {
                    Future<Pair<byte[], Collection<StoreFile>>> future = completionService.take();
                    Pair<byte[], Collection<StoreFile>> storeFiles = future.get();
                    List<StoreFile> familyFiles = result.get(storeFiles.getFirst());
                    if (familyFiles == null) {
                        familyFiles = new ArrayList<StoreFile>();
                        result.put(storeFiles.getFirst(), familyFiles);
                    }
                    familyFiles.addAll(storeFiles.getSecond());
                }
            } catch (InterruptedException e) {
                throw (InterruptedIOException) new InterruptedIOException().initCause(e);
            } catch (ExecutionException e) {
                throw new IOException(e.getCause());
            } finally {
                storeCloserThreadPool.shutdownNow();
            }
        }
        this.closed.set(true);
        if (memstoreSize.get() != 0)
            LOG.error("Memstore size is " + memstoreSize.get());
        if (coprocessorHost != null) {
            status.setStatus("Running coprocessor post-close hooks");
            this.coprocessorHost.postClose(abort);
        }
        if (this.metricsRegion != null) {
            this.metricsRegion.close();
        }
        if (this.metricsRegionWrapper != null) {
            Closeables.closeQuietly(this.metricsRegionWrapper);
        }
        status.markComplete("Closed");
        LOG.info("Closed " + this);
        return result;
    } finally {
        lock.writeLock().unlock();
    }
}

From source file:ml.shifu.shifu.core.dtrain.dt.DTWorker.java

@Override
public DTWorkerParams doCompute(WorkerContext<DTMasterParams, DTWorkerParams> context) {
    if (context.isFirstIteration()) {
        return new DTWorkerParams();
    }// w  ww  .j a  v  a 2  s .co m

    DTMasterParams lastMasterResult = context.getLastMasterResult();
    final List<TreeNode> trees = lastMasterResult.getTrees();
    final Map<Integer, TreeNode> todoNodes = lastMasterResult.getTodoNodes();
    if (todoNodes == null) {
        return new DTWorkerParams();
    }

    LOG.info("Start to work: todoNodes size is {}", todoNodes.size());

    Map<Integer, NodeStats> statistics = initTodoNodeStats(todoNodes);

    double trainError = 0d, validationError = 0d;
    double weightedTrainCount = 0d, weightedValidationCount = 0d;
    // renew random seed
    if (this.isGBDT && !this.gbdtSampleWithReplacement && lastMasterResult.isSwitchToNextTree()) {
        this.baggingRandomMap = new HashMap<Integer, Random>();
    }

    long start = System.nanoTime();
    for (Data data : this.trainingData) {
        if (this.isRF) {
            for (TreeNode treeNode : trees) {
                if (treeNode.getNode().getId() == Node.INVALID_INDEX) {
                    continue;
                }

                Node predictNode = predictNodeIndex(treeNode.getNode(), data, true);
                if (predictNode.getPredict() != null) {
                    // only update when not in first node, for treeNode, no predict statistics at that time
                    float weight = data.subsampleWeights[treeNode.getTreeId()];
                    if (Float.compare(weight, 0f) == 0) {
                        // oob data, no need to do weighting
                        validationError += data.significance * loss
                                .computeError((float) (predictNode.getPredict().getPredict()), data.label);
                        weightedValidationCount += data.significance;
                    } else {
                        trainError += weight * data.significance * loss
                                .computeError((float) (predictNode.getPredict().getPredict()), data.label);
                        weightedTrainCount += weight * data.significance;
                    }
                }
            }
        }

        if (this.isGBDT) {
            if (this.isContinuousEnabled && lastMasterResult.isContinuousRunningStart()) {
                recoverGBTData(context, data.output, data.predict, data, false);
                trainError += data.significance * loss.computeError(data.predict, data.label);
                weightedTrainCount += data.significance;
            } else {
                if (isNeedRecoverGBDTPredict) {
                    if (this.recoverTrees == null) {
                        this.recoverTrees = recoverCurrentTrees();
                    }
                    // recover gbdt data for fail over
                    recoverGBTData(context, data.output, data.predict, data, true);
                }
                int currTreeIndex = trees.size() - 1;

                if (lastMasterResult.isSwitchToNextTree()) {
                    if (currTreeIndex >= 1) {
                        Node node = trees.get(currTreeIndex - 1).getNode();
                        Node predictNode = predictNodeIndex(node, data, false);
                        if (predictNode.getPredict() != null) {
                            double predict = predictNode.getPredict().getPredict();
                            // first tree logic, master must set it to first tree even second tree with ROOT is
                            // sending
                            if (context.getLastMasterResult().isFirstTree()) {
                                data.predict = (float) predict;
                            } else {
                                // random drop
                                boolean drop = (this.dropOutRate > 0.0
                                        && dropOutRandom.nextDouble() < this.dropOutRate);
                                if (!drop) {
                                    data.predict += (float) (this.learningRate * predict);
                                }
                            }
                            data.output = -1f * loss.computeGradient(data.predict, data.label);
                        }
                        // if not sampling with replacement in gbdt, renew bagging sample rate in next tree
                        if (!this.gbdtSampleWithReplacement) {
                            Random random = null;
                            int classValue = (int) (data.label + 0.01f);
                            if (this.isStratifiedSampling) {
                                random = baggingRandomMap.get(classValue);
                                if (random == null) {
                                    random = DTrainUtils.generateRandomBySampleSeed(
                                            modelConfig.getTrain().getBaggingSampleSeed(),
                                            CommonConstants.NOT_CONFIGURED_BAGGING_SEED);
                                    baggingRandomMap.put(classValue, random);
                                }
                            } else {
                                random = baggingRandomMap.get(0);
                                if (random == null) {
                                    random = DTrainUtils.generateRandomBySampleSeed(
                                            modelConfig.getTrain().getBaggingSampleSeed(),
                                            CommonConstants.NOT_CONFIGURED_BAGGING_SEED);
                                    baggingRandomMap.put(0, random);
                                }
                            }
                            if (random.nextDouble() <= modelConfig.getTrain().getBaggingSampleRate()) {
                                data.subsampleWeights[currTreeIndex % data.subsampleWeights.length] = 1f;
                            } else {
                                data.subsampleWeights[currTreeIndex % data.subsampleWeights.length] = 0f;
                            }
                        }
                    }
                }

                if (context.getLastMasterResult().isFirstTree() && !lastMasterResult.isSwitchToNextTree()) {
                    Node currTree = trees.get(currTreeIndex).getNode();
                    Node predictNode = predictNodeIndex(currTree, data, true);
                    if (predictNode.getPredict() != null) {
                        trainError += data.significance * loss
                                .computeError((float) (predictNode.getPredict().getPredict()), data.label);
                        weightedTrainCount += data.significance;
                    }
                } else {
                    trainError += data.significance * loss.computeError(data.predict, data.label);
                    weightedTrainCount += data.significance;
                }
            }
        }
    }
    LOG.debug("Compute train error time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start));

    if (validationData != null) {
        start = System.nanoTime();
        for (Data data : this.validationData) {
            if (this.isRF) {
                for (TreeNode treeNode : trees) {
                    if (treeNode.getNode().getId() == Node.INVALID_INDEX) {
                        continue;
                    }
                    Node predictNode = predictNodeIndex(treeNode.getNode(), data, true);
                    if (predictNode.getPredict() != null) {
                        // only update when not in first node, for treeNode, no predict statistics at that time
                        validationError += data.significance * loss
                                .computeError((float) (predictNode.getPredict().getPredict()), data.label);
                        weightedValidationCount += data.significance;
                    }
                }
            }

            if (this.isGBDT) {
                if (this.isContinuousEnabled && lastMasterResult.isContinuousRunningStart()) {
                    recoverGBTData(context, data.output, data.predict, data, false);
                    validationError += data.significance * loss.computeError(data.predict, data.label);
                    weightedValidationCount += data.significance;
                } else {
                    if (isNeedRecoverGBDTPredict) {
                        if (this.recoverTrees == null) {
                            this.recoverTrees = recoverCurrentTrees();
                        }
                        // recover gbdt data for fail over
                        recoverGBTData(context, data.output, data.predict, data, true);
                    }
                    int currTreeIndex = trees.size() - 1;
                    if (lastMasterResult.isSwitchToNextTree()) {
                        if (currTreeIndex >= 1) {
                            Node node = trees.get(currTreeIndex - 1).getNode();
                            Node predictNode = predictNodeIndex(node, data, false);
                            if (predictNode.getPredict() != null) {
                                double predict = predictNode.getPredict().getPredict();
                                if (context.getLastMasterResult().isFirstTree()) {
                                    data.predict = (float) predict;
                                } else {
                                    data.predict += (float) (this.learningRate * predict);
                                }
                                data.output = -1f * loss.computeGradient(data.predict, data.label);
                            }
                        }
                    }
                    if (context.getLastMasterResult().isFirstTree() && !lastMasterResult.isSwitchToNextTree()) {
                        Node predictNode = predictNodeIndex(trees.get(currTreeIndex).getNode(), data, true);
                        if (predictNode.getPredict() != null) {
                            validationError += data.significance * loss
                                    .computeError((float) (predictNode.getPredict().getPredict()), data.label);
                            weightedValidationCount += data.significance;
                        }
                    } else {
                        validationError += data.significance * loss.computeError(data.predict, data.label);
                        weightedValidationCount += data.significance;
                    }
                }
            }
        }
        LOG.debug("Compute val error time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start));
    }

    if (this.isGBDT) {
        // reset trees to null to save memory
        this.recoverTrees = null;
        if (this.isNeedRecoverGBDTPredict) {
            // no need recover again
            this.isNeedRecoverGBDTPredict = false;
        }
    }

    start = System.nanoTime();
    CompletionService<Map<Integer, NodeStats>> completionService = new ExecutorCompletionService<Map<Integer, NodeStats>>(
            this.threadPool);

    int realThreadCount = 0;
    LOG.debug("while todo size {}", todoNodes.size());

    int realRecords = this.trainingData.size();
    int realThreads = this.workerThreadCount > realRecords ? realRecords : this.workerThreadCount;

    int[] trainLows = new int[realThreads];
    int[] trainHighs = new int[realThreads];

    int stepCount = realRecords / realThreads;
    if (realRecords % realThreads != 0) {
        // move step count to append last gap to avoid last thread worse 2*stepCount-1
        stepCount += (realRecords % realThreads) / stepCount;
    }
    for (int i = 0; i < realThreads; i++) {
        trainLows[i] = i * stepCount;
        if (i != realThreads - 1) {
            trainHighs[i] = trainLows[i] + stepCount - 1;
        } else {
            trainHighs[i] = realRecords - 1;
        }
    }

    for (int i = 0; i < realThreads; i++) {
        final Map<Integer, TreeNode> localTodoNodes = new HashMap<Integer, TreeNode>(todoNodes);
        final Map<Integer, NodeStats> localStatistics = initTodoNodeStats(todoNodes);

        final int startIndex = trainLows[i];
        final int endIndex = trainHighs[i];
        LOG.info("Thread {} todo size {} stats size {} start index {} end index {}", i, localTodoNodes.size(),
                localStatistics.size(), startIndex, endIndex);

        if (localTodoNodes.size() == 0) {
            continue;
        }
        realThreadCount += 1;
        completionService.submit(new Callable<Map<Integer, NodeStats>>() {
            @Override
            public Map<Integer, NodeStats> call() throws Exception {
                long start = System.nanoTime();
                List<Integer> nodeIndexes = new ArrayList<Integer>(trees.size());
                for (int j = startIndex; j <= endIndex; j++) {
                    Data data = DTWorker.this.trainingData.get(j);
                    nodeIndexes.clear();
                    if (DTWorker.this.isRF) {
                        for (TreeNode treeNode : trees) {
                            if (treeNode.getNode().getId() == Node.INVALID_INDEX) {
                                nodeIndexes.add(Node.INVALID_INDEX);
                            } else {
                                Node predictNode = predictNodeIndex(treeNode.getNode(), data, false);
                                nodeIndexes.add(predictNode.getId());
                            }
                        }
                    }

                    if (DTWorker.this.isGBDT) {
                        int currTreeIndex = trees.size() - 1;
                        Node predictNode = predictNodeIndex(trees.get(currTreeIndex).getNode(), data, false);
                        // update node index
                        nodeIndexes.add(predictNode.getId());
                    }
                    for (Map.Entry<Integer, TreeNode> entry : localTodoNodes.entrySet()) {
                        // only do statistics on effective data
                        Node todoNode = entry.getValue().getNode();
                        int treeId = entry.getValue().getTreeId();
                        int currPredictIndex = 0;
                        if (DTWorker.this.isRF) {
                            currPredictIndex = nodeIndexes.get(entry.getValue().getTreeId());
                        }
                        if (DTWorker.this.isGBDT) {
                            currPredictIndex = nodeIndexes.get(0);
                        }

                        if (todoNode.getId() == currPredictIndex) {
                            List<Integer> features = entry.getValue().getFeatures();
                            if (features.isEmpty()) {
                                features = getAllValidFeatures();
                            }
                            for (Integer columnNum : features) {
                                double[] featuerStatistic = localStatistics.get(entry.getKey())
                                        .getFeatureStatistics().get(columnNum);
                                float weight = data.subsampleWeights[treeId % data.subsampleWeights.length];
                                if (Float.compare(weight, 0f) != 0) {
                                    // only compute weight is not 0
                                    short binIndex = data.inputs[DTWorker.this.inputIndexMap.get(columnNum)];
                                    DTWorker.this.impurity.featureUpdate(featuerStatistic, binIndex,
                                            data.output, data.significance, weight);
                                }
                            }
                        }
                    }
                }
                LOG.debug("Thread computing stats time is {}ms in thread {}",
                        TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start),
                        Thread.currentThread().getName());
                return localStatistics;
            }
        });
    }

    int rCnt = 0;
    while (rCnt < realThreadCount) {
        try {
            Map<Integer, NodeStats> currNodeStatsmap = completionService.take().get();
            if (rCnt == 0) {
                statistics = currNodeStatsmap;
            } else {
                for (Entry<Integer, NodeStats> entry : statistics.entrySet()) {
                    NodeStats resultNodeStats = entry.getValue();
                    mergeNodeStats(resultNodeStats, currNodeStatsmap.get(entry.getKey()));
                }
            }
        } catch (ExecutionException e) {
            throw new RuntimeException(e);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        rCnt += 1;
    }
    LOG.debug("Compute stats time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start));

    LOG.info(
            "worker count is {}, error is {}, and stats size is {}. weightedTrainCount {}, weightedValidationCount {}, trainError {}, validationError {}",
            count, trainError, statistics.size(), weightedTrainCount, weightedValidationCount, trainError,
            validationError);
    return new DTWorkerParams(weightedTrainCount, weightedValidationCount, trainError, validationError,
            statistics);
}