org.apache.sysml.runtime.instructions.cp.ParamservBuiltinCPInstruction.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.sysml.runtime.instructions.cp.ParamservBuiltinCPInstruction.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.instructions.cp;

import static org.apache.sysml.parser.Statement.PSFrequency;
import static org.apache.sysml.parser.Statement.PSModeType;
import static org.apache.sysml.parser.Statement.PSScheme;
import static org.apache.sysml.parser.Statement.PSUpdateType;
import static org.apache.sysml.parser.Statement.PS_AGGREGATION_FUN;
import static org.apache.sysml.parser.Statement.PS_BATCH_SIZE;
import static org.apache.sysml.parser.Statement.PS_EPOCHS;
import static org.apache.sysml.parser.Statement.PS_FEATURES;
import static org.apache.sysml.parser.Statement.PS_FREQUENCY;
import static org.apache.sysml.parser.Statement.PS_HYPER_PARAMS;
import static org.apache.sysml.parser.Statement.PS_LABELS;
import static org.apache.sysml.parser.Statement.PS_MODE;
import static org.apache.sysml.parser.Statement.PS_MODEL;
import static org.apache.sysml.parser.Statement.PS_PARALLELISM;
import static org.apache.sysml.parser.Statement.PS_SCHEME;
import static org.apache.sysml.parser.Statement.PS_UPDATE_FUN;
import static org.apache.sysml.parser.Statement.PS_UPDATE_TYPE;
import static org.apache.sysml.parser.Statement.PS_VAL_FEATURES;
import static org.apache.sysml.parser.Statement.PS_VAL_LABELS;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitioner;
import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDC;
import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDR;
import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDRR;
import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerOR;
import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.utils.Statistics;

public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruction {

    private static final int DEFAULT_BATCH_SIZE = 64;
    private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = PSFrequency.BATCH;
    private static final PSScheme DEFAULT_SCHEME = PSScheme.DISJOINT_CONTIGUOUS;

    //internal local debug level
    private static final boolean LDEBUG = false;
    protected static final Log LOG = LogFactory.getLog(ParamservBuiltinCPInstruction.class.getName());

    static {
        // for internal debugging only
        if (LDEBUG) {
            Logger.getLogger("org.apache.sysml.runtime.controlprogram.paramserv").setLevel(Level.DEBUG);
            Logger.getLogger(ParamservBuiltinCPInstruction.class.getName()).setLevel(Level.DEBUG);
        }
    }

    public ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out,
            String opcode, String istr) {
        super(op, paramsMap, out, opcode, istr);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;

        PSModeType mode = getPSMode();
        int workerNum = getWorkerNum(mode);
        BasicThreadFactory factory = new BasicThreadFactory.Builder().namingPattern("workers-pool-thread-%d")
                .build();
        ExecutorService es = Executors.newFixedThreadPool(workerNum, factory);
        String updFunc = getParam(PS_UPDATE_FUN);
        String aggFunc = getParam(PS_AGGREGATION_FUN);

        int k = getParLevel(workerNum);

        // Get the compiled execution context
        // Create workers' execution context
        LocalVariableMap newVarsMap = createVarsMap(ec);
        List<ExecutionContext> newECs = ParamservUtils.createExecutionContexts(ec, newVarsMap, updFunc, aggFunc,
                workerNum, k);

        // Create workers' execution context
        List<ExecutionContext> workerECs = newECs.subList(0, newECs.size() - 1);

        // Create the agg service's execution context
        ExecutionContext aggServiceEC = newECs.get(newECs.size() - 1);

        PSFrequency freq = getFrequency();
        PSUpdateType updateType = getUpdateType();
        int epochs = getEpochs();

        // Create the parameter server
        ListObject model = ec.getListObject(getParam(PS_MODEL));
        ParamServer ps = createPS(mode, aggFunc, updateType, workerNum, model, aggServiceEC);

        // Create the local workers
        MatrixObject valFeatures = ec.getMatrixObject(getParam(PS_VAL_FEATURES));
        MatrixObject valLabels = ec.getMatrixObject(getParam(PS_VAL_LABELS));
        List<LocalPSWorker> workers = IntStream.range(0, workerNum).mapToObj(i -> new LocalPSWorker(i, updFunc,
                freq, epochs, getBatchSize(), valFeatures, valLabels, workerECs.get(i), ps))
                .collect(Collectors.toList());

        // Do data partition
        PSScheme scheme = getScheme();
        doDataPartitioning(scheme, ec, workers);

        if (DMLScript.STATISTICS)
            Statistics.accPSSetupTime((long) tSetup.stop());

        if (LOG.isDebugEnabled()) {
            LOG.debug(String.format(
                    "\nConfiguration of paramserv func: " + "\nmode: %s \nworkerNum: %d \nupdate frequency: %s "
                            + "\nstrategy: %s \ndata partitioner: %s",
                    mode, workerNum, freq, updateType, scheme));
        }

        try {
            // Launch the worker threads and wait for completion
            for (Future<Void> ret : es.invokeAll(workers))
                ret.get(); //error handling
            // Fetch the final model from ps
            ListObject result = ps.getResult();
            ec.setVariable(output.getName(), result);
        } catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e);
        } finally {
            es.shutdownNow();
            // Should shutdown the thread pool in param server
            ps.shutdown();
        }
    }

    private LocalVariableMap createVarsMap(ExecutionContext ec) {
        // Put the hyperparam into the variables table
        LocalVariableMap varsMap = new LocalVariableMap();
        ListObject hyperParams = getHyperParams(ec);
        if (hyperParams != null) {
            varsMap.put(PS_HYPER_PARAMS, hyperParams);
        }
        return varsMap;
    }

    private PSModeType getPSMode() {
        PSModeType mode;
        try {
            mode = PSModeType.valueOf(getParam(PS_MODE));
        } catch (IllegalArgumentException e) {
            throw new DMLRuntimeException(
                    String.format("Paramserv function: not support ps execution mode '%s'", getParam(PS_MODE)));
        }
        if (mode == PSModeType.REMOTE_SPARK)
            throw new DMLRuntimeException("Do not support remote spark.");
        return mode;
    }

    private int getEpochs() {
        int epochs = Integer.valueOf(getParam(PS_EPOCHS));
        if (epochs <= 0) {
            throw new DMLRuntimeException(String.format(
                    "Paramserv function: " + "The argument '%s' could not be less than or equal to 0.", PS_EPOCHS));
        }
        return epochs;
    }

    private int getParLevel(int workerNum) {
        return Math.max((int) Math.ceil((double) getRemainingCores() / workerNum), 1);
    }

    private PSUpdateType getUpdateType() {
        PSUpdateType updType;
        try {
            updType = PSUpdateType.valueOf(getParam(PS_UPDATE_TYPE));
        } catch (IllegalArgumentException e) {
            throw new DMLRuntimeException(
                    String.format("Paramserv function: not support update type '%s'.", getParam(PS_UPDATE_TYPE)));
        }
        if (updType == PSUpdateType.SSP)
            throw new DMLRuntimeException("Not support update type SSP.");
        return updType;
    }

    private PSFrequency getFrequency() {
        if (!getParameterMap().containsKey(PS_FREQUENCY)) {
            return DEFAULT_UPDATE_FREQUENCY;
        }
        try {
            return PSFrequency.valueOf(getParam(PS_FREQUENCY));
        } catch (IllegalArgumentException e) {
            throw new DMLRuntimeException(String
                    .format("Paramserv function: " + "not support '%s' update frequency.", getParam(PS_FREQUENCY)));
        }
    }

    private int getRemainingCores() {
        return InfrastructureAnalyzer.getLocalParallelism() - 1;
    }

    /**
     * Get the worker numbers according to the vcores
     *
     * @param mode execution mode
     * @return worker numbers
     */
    private int getWorkerNum(PSModeType mode) {
        switch (mode) {
        case LOCAL:
            // default worker number: available cores - 1 (assign one process for agg service)
            int workerNum = getRemainingCores();
            if (getParameterMap().containsKey(PS_PARALLELISM))
                workerNum = Integer.valueOf(getParam(PS_PARALLELISM));
            return workerNum;
        default:
            throw new DMLRuntimeException("Unsupported parameter server: " + mode.name());
        }
    }

    /**
     * Create a server which serves the local or remote workers
     *
     * @return parameter server
     */
    private ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType, int workerNum,
            ListObject model, ExecutionContext ec) {
        switch (mode) {
        case LOCAL:
            return new LocalParamServer(model, aggFunc, updateType, ec, workerNum);
        default:
            throw new DMLRuntimeException("Unsupported parameter server: " + mode.name());
        }
    }

    private long getBatchSize() {
        if (!getParameterMap().containsKey(PS_BATCH_SIZE)) {
            return DEFAULT_BATCH_SIZE;
        }
        long batchSize = Integer.valueOf(getParam(PS_BATCH_SIZE));
        if (batchSize <= 0) {
            throw new DMLRuntimeException(String.format(
                    "Paramserv function: the number " + "of argument '%s' could not be less than or equal to 0.",
                    PS_BATCH_SIZE));
        }
        return batchSize;
    }

    private ListObject getHyperParams(ExecutionContext ec) {
        ListObject hyperparams = null;
        if (getParameterMap().containsKey(PS_HYPER_PARAMS)) {
            hyperparams = ec.getListObject(getParam(PS_HYPER_PARAMS));
        }
        return hyperparams;
    }

    private void doDataPartitioning(PSScheme scheme, ExecutionContext ec, List<LocalPSWorker> workers) {
        MatrixObject features = ec.getMatrixObject(getParam(PS_FEATURES));
        MatrixObject labels = ec.getMatrixObject(getParam(PS_LABELS));
        switch (scheme) {
        case DISJOINT_CONTIGUOUS:
            doDataPartitioning(new DataPartitionerDC(), features, labels, workers);
            break;
        case DISJOINT_ROUND_ROBIN:
            doDataPartitioning(new DataPartitionerDRR(), features, labels, workers);
            break;
        case DISJOINT_RANDOM:
            doDataPartitioning(new DataPartitionerDR(), features, labels, workers);
            break;
        case OVERLAP_RESHUFFLE:
            doDataPartitioning(new DataPartitionerOR(), features, labels, workers);
            break;
        }
    }

    private PSScheme getScheme() {
        PSScheme scheme = DEFAULT_SCHEME;
        if (getParameterMap().containsKey(PS_SCHEME)) {
            try {
                scheme = PSScheme.valueOf(getParam(PS_SCHEME));
            } catch (IllegalArgumentException e) {
                throw new DMLRuntimeException(String
                        .format("Paramserv function: not support data partition scheme '%s'", getParam(PS_SCHEME)));
            }
        }
        return scheme;
    }

    private void doDataPartitioning(DataPartitioner dp, MatrixObject features, MatrixObject labels,
            List<LocalPSWorker> workers) {
        DataPartitioner.Result result = dp.doPartitioning(workers.size(), features, labels);
        List<MatrixObject> pfs = result.pFeatures;
        List<MatrixObject> pls = result.pLabels;
        if (pfs.size() < workers.size()) {
            if (LOG.isWarnEnabled()) {
                LOG.warn(String.format(
                        "There is only %d batches of data but has %d workers. "
                                + "Hence, reset the number of workers with %d.",
                        pfs.size(), workers.size(), pfs.size()));
            }
            workers = workers.subList(0, pfs.size());
        }
        for (int i = 0; i < workers.size(); i++) {
            workers.get(i).setFeatures(pfs.get(i));
            workers.get(i).setLabels(pls.get(i));
        }
    }
}