Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.controlprogram.paramserv; import java.util.concurrent.Callable; import java.util.stream.IntStream; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.api.DMLScript; import org.apache.sysml.parser.Statement; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysml.runtime.functionobjects.Plus; import org.apache.sysml.runtime.instructions.cp.ListObject; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.operators.BinaryOperator; import org.apache.sysml.utils.Statistics; public class LocalPSWorker extends PSWorker implements Callable<Void> { protected static final Log LOG = LogFactory.getLog(LocalPSWorker.class.getName()); public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) { super(workerID, updFunc, freq, epochs, batchSize, valFeatures, valLabels, ec, ps); } @Override public Void call() throws Exception { if (DMLScript.STATISTICS) Statistics.incWorkerNumber(); try { long dataSize = _features.getNumRows(); int totalIter = (int) Math.ceil((double) dataSize / _batchSize); switch (_freq) { case BATCH: computeBatch(dataSize, totalIter); break; case EPOCH: computeEpoch(dataSize, totalIter); break; } if (LOG.isDebugEnabled()) { LOG.debug(String.format("Local worker_%d: Job finished.", _workerID)); } } catch (Exception e) { throw new DMLRuntimeException(String.format("Local worker_%d failed", _workerID), e); } return null; } private void computeEpoch(long dataSize, int totalIter) { for (int i = 0; i < _epochs; i++) { // Pull the global parameters from ps ListObject params = pullModel(); ListObject accGradients = null; for (int j = 0; j < totalIter; j++) { _ec.setVariable(Statement.PS_MODEL, params); ListObject gradients = computeGradients(dataSize, totalIter, i, j); // Accumulate the intermediate gradients accGradients = (accGradients == null) ? ParamservUtils.copyList(gradients) : accrueGradients(accGradients, gradients); // Update the local model with gradients if (j < totalIter - 1) params = updateModel(params, gradients, i, j, totalIter); } // Push the gradients to ps pushGradients(accGradients); ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL); if (LOG.isDebugEnabled()) { LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1)); } } } private ListObject updateModel(ListObject globalParams, ListObject gradients, int i, int j, int totalIter) { Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null; globalParams = _ps.updateModel(_ec, gradients, globalParams); if (DMLScript.STATISTICS) Statistics.accPSLocalModelUpdateTime((long) tUpd.stop()); if (LOG.isDebugEnabled()) { LOG.debug(String.format( "Local worker_%d: Local global parameter [size:%d kb] updated. " + "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", _workerID, globalParams.getDataSize(), i + 1, _epochs, j + 1, totalIter)); } return globalParams; } private void computeBatch(long dataSize, int totalIter) { for (int i = 0; i < _epochs; i++) { for (int j = 0; j < totalIter; j++) { ListObject globalParams = pullModel(); _ec.setVariable(Statement.PS_MODEL, globalParams); ListObject gradients = computeGradients(dataSize, totalIter, i, j); // Push the gradients to ps pushGradients(gradients); ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL); } if (LOG.isDebugEnabled()) { LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1)); } } } private ListObject pullModel() { // Pull the global parameters from ps ListObject globalParams = (ListObject) _ps.pull(_workerID); if (LOG.isDebugEnabled()) { LOG.debug(String.format( "Local worker_%d: Successfully pull the global parameters " + "[size:%d kb] from ps.", _workerID, globalParams.getDataSize() / 1024)); } return globalParams; } private void pushGradients(ListObject gradients) { // Push the gradients to ps _ps.push(_workerID, gradients); if (LOG.isDebugEnabled()) { LOG.debug(String.format("Local worker_%d: Successfully push the gradients " + "[size:%d kb] to ps.", _workerID, gradients.getDataSize() / 1024)); } } private ListObject computeGradients(long dataSize, int totalIter, int i, int j) { long begin = j * _batchSize + 1; long end = Math.min((j + 1) * _batchSize, dataSize); // Get batch features and labels Timing tSlic = DMLScript.STATISTICS ? new Timing(true) : null; MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, begin, end); MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, begin, end); if (DMLScript.STATISTICS) Statistics.accPSBatchIndexingTime((long) tSlic.stop()); _ec.setVariable(Statement.PS_FEATURES, bFeatures); _ec.setVariable(Statement.PS_LABELS, bLabels); if (LOG.isDebugEnabled()) { LOG.debug(String.format( "Local worker_%d: Got batch data [size:%d kb] of index from %d to %d [last index: %d]. " + "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", _workerID, bFeatures.getDataSize() / 1024 + bLabels.getDataSize() / 1024, begin, end, dataSize, i + 1, _epochs, j + 1, totalIter)); } // Invoke the update function Timing tGrad = DMLScript.STATISTICS ? new Timing(true) : null; _inst.processInstruction(_ec); if (DMLScript.STATISTICS) Statistics.accPSGradientComputeTime((long) tGrad.stop()); // Get the gradients ListObject gradients = (ListObject) _ec.getVariable(_output.getName()); ParamservUtils.cleanupData(bFeatures); ParamservUtils.cleanupData(bLabels); return gradients; } private ListObject accrueGradients(ListObject accGradients, ListObject gradients) { IntStream.range(0, accGradients.getLength()).forEach(i -> { MatrixBlock mb1 = ((MatrixObject) accGradients.getData().get(i)).acquireRead(); MatrixBlock mb2 = ((MatrixObject) gradients.getData().get(i)).acquireRead(); mb1.binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), mb2); ((MatrixObject) accGradients.getData().get(i)).release(); ((MatrixObject) gradients.getData().get(i)).release(); }); return accGradients; } }