org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.controlprogram.paramserv;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.commons.lang.StringUtils;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.MultiThreadedHop;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.DMLTranslator;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ListObject;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.MetaDataFormat;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.OutputInfo;

public class ParamservUtils {

    public static final String PS_FUNC_PREFIX = "_ps_";

    /**
     * Deep copy the list object
     *
     * @param lo list object
     * @return a new copied list object
     */
    public static ListObject copyList(ListObject lo) {
        if (lo.getLength() == 0) {
            return lo;
        }
        List<Data> newData = IntStream.range(0, lo.getLength()).mapToObj(i -> {
            Data oldData = lo.slice(i);
            if (oldData instanceof MatrixObject) {
                MatrixObject mo = (MatrixObject) oldData;
                return sliceMatrix(mo, 1, mo.getNumRows());
            } else if (oldData instanceof ListObject || oldData instanceof FrameObject) {
                throw new DMLRuntimeException("Copy list: does not support list or frame.");
            } else {
                return oldData;
            }
        }).collect(Collectors.toList());
        return new ListObject(newData, lo.getNames());
    }

    public static void cleanupListObject(ExecutionContext ec, String lName) {
        ListObject lo = (ListObject) ec.removeVariable(lName);
        lo.getData().forEach(ParamservUtils::cleanupData);
    }

    public static void cleanupData(Data data) {
        if (!(data instanceof CacheableData))
            return;
        CacheableData<?> cd = (CacheableData<?>) data;
        cd.enableCleanup(true);
        cd.clearData();
    }

    public static MatrixObject newMatrixObject() {
        return new MatrixObject(Expression.ValueType.DOUBLE, OptimizerUtils.getUniqueTempFileName(),
                new MetaDataFormat(new MatrixCharacteristics(-1, -1, -1, -1), OutputInfo.BinaryBlockOutputInfo,
                        InputInfo.BinaryBlockInputInfo));
    }

    /**
     * Slice the matrix
     *
     * @param mo input matrix
     * @param rl low boundary
     * @param rh high boundary
     * @return new sliced matrix
     */
    public static MatrixObject sliceMatrix(MatrixObject mo, long rl, long rh) {
        MatrixObject result = newMatrixObject();
        MatrixBlock tmp = mo.acquireRead();
        result.acquireModify(tmp.slice((int) rl - 1, (int) rh - 1));
        mo.release();
        result.release();
        result.enableCleanup(false);
        return result;
    }

    public static MatrixBlock generatePermutation(int numEntries) {
        // Create a sequence and sample w/o replacement
        MatrixBlock seq = MatrixBlock.seqOperations(1, numEntries, 1);
        MatrixBlock sample = MatrixBlock.sampleOperations(numEntries, numEntries, false, -1);

        // Combine the sequence and sample as a table
        MatrixBlock permutation = new MatrixBlock(numEntries, numEntries, true);
        seq.ctableOperations(null, sample, 1.0, permutation);
        return permutation;
    }

    public static String[] getCompleteFuncName(String funcName, String prefix) {
        String[] keys = DMLProgram.splitFunctionKey(funcName);
        String ns = (keys.length == 2) ? keys[0] : null;
        String name = (keys.length == 2) ? keys[1] : keys[0];
        return StringUtils.isEmpty(prefix) ? new String[] { ns, name } : new String[] { ns, name };
    }

    public static List<ExecutionContext> createExecutionContexts(ExecutionContext ec, LocalVariableMap varsMap,
            String updFunc, String aggFunc, int workerNum, int k) {

        FunctionProgramBlock updPB = getFunctionBlock(ec, updFunc);
        FunctionProgramBlock aggPB = getFunctionBlock(ec, aggFunc);

        Program prog = ec.getProgram();

        // 1. Recompile the internal program blocks
        recompileProgramBlocks(k, prog.getProgramBlocks());
        // 2. Recompile the imported function blocks
        prog.getFunctionProgramBlocks()
                .forEach((fname, fvalue) -> recompileProgramBlocks(k, fvalue.getChildBlocks()));

        // 3. Copy function for workers
        List<ExecutionContext> workerECs = IntStream.range(0, workerNum).mapToObj(i -> {
            FunctionProgramBlock newUpdFunc = copyFunction(updFunc, updPB);
            FunctionProgramBlock newAggFunc = copyFunction(aggFunc, aggPB);
            Program newProg = new Program();
            putFunction(newProg, newUpdFunc);
            putFunction(newProg, newAggFunc);
            return ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), newProg);
        }).collect(Collectors.toList());

        // 4. Copy function for agg service
        FunctionProgramBlock newAggFunc = copyFunction(aggFunc, aggPB);
        Program newProg = new Program();
        putFunction(newProg, newAggFunc);
        ExecutionContext aggEC = ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), newProg);

        List<ExecutionContext> result = new ArrayList<>(workerECs);
        result.add(aggEC);
        return result;
    }

    private static FunctionProgramBlock copyFunction(String funcName, FunctionProgramBlock fpb) {
        FunctionProgramBlock copiedFunc = ProgramConverter.createDeepCopyFunctionProgramBlock(fpb, new HashSet<>(),
                new HashSet<>());
        String[] cfn = getCompleteFuncName(funcName, ParamservUtils.PS_FUNC_PREFIX);
        copiedFunc._namespace = cfn[0];
        copiedFunc._functionName = cfn[1];
        return copiedFunc;
    }

    private static void putFunction(Program prog, FunctionProgramBlock fpb) {
        prog.addFunctionProgramBlock(fpb._namespace, fpb._functionName, fpb);
        prog.addProgramBlock(fpb);
    }

    private static void recompileProgramBlocks(int k, ArrayList<ProgramBlock> pbs) {
        // Reset the visit status from root
        for (ProgramBlock pb : pbs)
            DMLTranslator.resetHopsDAGVisitStatus(pb.getStatementBlock());

        // Should recursively assign the level of parallelism
        // and recompile the program block
        try {
            rAssignParallelism(pbs, k, false);
        } catch (IOException e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static boolean rAssignParallelism(ArrayList<ProgramBlock> pbs, int k, boolean recompiled)
            throws IOException {
        for (ProgramBlock pb : pbs) {
            if (pb instanceof ParForProgramBlock) {
                ParForProgramBlock pfpb = (ParForProgramBlock) pb;
                pfpb.setDegreeOfParallelism(k);
                recompiled |= rAssignParallelism(pfpb.getChildBlocks(), 1, recompiled);
            } else if (pb instanceof ForProgramBlock) {
                recompiled |= rAssignParallelism(((ForProgramBlock) pb).getChildBlocks(), k, recompiled);
            } else if (pb instanceof WhileProgramBlock) {
                recompiled |= rAssignParallelism(((WhileProgramBlock) pb).getChildBlocks(), k, recompiled);
            } else if (pb instanceof FunctionProgramBlock) {
                recompiled |= rAssignParallelism(((FunctionProgramBlock) pb).getChildBlocks(), k, recompiled);
            } else if (pb instanceof IfProgramBlock) {
                IfProgramBlock ipb = (IfProgramBlock) pb;
                recompiled |= rAssignParallelism(ipb.getChildBlocksIfBody(), k, recompiled);
                if (ipb.getChildBlocksElseBody() != null)
                    recompiled |= rAssignParallelism(ipb.getChildBlocksElseBody(), k, recompiled);
            } else {
                StatementBlock sb = pb.getStatementBlock();
                for (Hop hop : sb.getHops())
                    recompiled |= rAssignParallelism(hop, k, recompiled);
            }
            // Recompile the program block
            if (recompiled) {
                Recompiler.recompileProgramBlockInstructions(pb);
            }
        }
        return recompiled;
    }

    private static boolean rAssignParallelism(Hop hop, int k, boolean recompiled) {
        if (hop.isVisited()) {
            return recompiled;
        }
        if (hop instanceof MultiThreadedHop) {
            // Reassign the level of parallelism
            MultiThreadedHop mhop = (MultiThreadedHop) hop;
            mhop.setMaxNumThreads(k);
            recompiled = true;
        }
        ArrayList<Hop> inputs = hop.getInput();
        for (Hop h : inputs) {
            recompiled |= rAssignParallelism(h, k, recompiled);
        }
        hop.setVisited();
        return recompiled;
    }

    private static FunctionProgramBlock getFunctionBlock(ExecutionContext ec, String funcName) {
        String[] cfn = getCompleteFuncName(funcName, null);
        String ns = cfn[0];
        String fname = cfn[1];
        return ec.getProgram().getFunctionProgramBlock(ns, fname);
    }
}