com.aliyun.odps.graph.local.worker.Worker.java Source code

Java tutorial

Introduction

Here is the source code for com.aliyun.odps.graph.local.worker.Worker.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.aliyun.odps.graph.local.worker;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import org.apache.commons.io.FileUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.aliyun.odps.counter.Counter;
import com.aliyun.odps.counter.Counters;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.Aggregator;
import com.aliyun.odps.graph.Combiner;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.JobConf;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.VertexResolver;
import com.aliyun.odps.graph.WorkerComputer;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.graph.local.BaseRecordReader;
import com.aliyun.odps.graph.local.COUNTER;
import com.aliyun.odps.graph.local.EmptyRecordReader;
import com.aliyun.odps.graph.local.GraphTaskAttemptID;
import com.aliyun.odps.graph.local.InputSplit;
import com.aliyun.odps.graph.local.LocalRecordReader;
import com.aliyun.odps.graph.local.LocalRecordWriter;
import com.aliyun.odps.graph.local.LocalVertexMutations;
import com.aliyun.odps.graph.local.RuntimeContext;
import com.aliyun.odps.graph.local.SQLRecord;
import com.aliyun.odps.graph.local.TaskContextImpl;
import com.aliyun.odps.graph.local.master.Master;
import com.aliyun.odps.graph.local.message.MsgManager;
import com.aliyun.odps.graph.local.utils.LocalGraphRunUtils;
import com.aliyun.odps.graph.utils.VerifyUtils;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.io.Writable;
import com.aliyun.odps.io.WritableComparable;
import com.aliyun.odps.io.WritableRecord;
import com.aliyun.odps.io.WritableUtils;
import com.aliyun.odps.utils.ReflectionUtils;

@SuppressWarnings("rawtypes")
public class Worker<VERTEX_ID extends WritableComparable<?>, VERTEX_VALUE extends Writable, EDGE_VALUE extends Writable, MESSAGE extends Writable, VALUE extends Writable> {

    private static Log LOG = LogFactory.getLog(Worker.class);

    private List<Aggregator> mAggregators;
    private List<Writable> mAggregatorValues;

    private Master master;
    private Counters mCounters;
    private RuntimeContext mCtx;
    private InputSplit mInput;
    private JobConf mJob;
    private List<Writable> mLastAggregatorValues;
    private Map<Vertex, Iterable<Writable>> mLastStepMessage = new HashMap<Vertex, Iterable<Writable>>();
    private Map<String, TableInfo> mOutputs;
    private Map<VERTEX_ID, LocalVertexMutations> mVertexMutations;
    private GraphTaskAttemptID mTaskAttemptID;
    private TaskContextImpl mTaskContext;
    private WorkerComputer mWorkerComputer;

    private int mWorkerID;
    private int mWorkerNum;
    private Map<String, LocalRecordWriter> mWriters;

    private Map<VERTEX_ID, Vertex> vertices = new HashMap<VERTEX_ID, Vertex>();

    private Writable mWorkerValue;

    private MsgManager mMsgManager;

    private Combiner mCombiner;

    @SuppressWarnings("unchecked")
    public Worker(JobConf job, RuntimeContext ctx, Master m, GraphTaskAttemptID taskAttemptID, int workerID,
            int workerNum, InputSplit input, Map<String, TableInfo> outputs)
            throws InstantiationException, IllegalAccessException, IOException, ClassNotFoundException {

        mJob = job;
        mCtx = ctx;
        master = m;
        mTaskAttemptID = taskAttemptID;
        mWorkerID = workerID;
        mWorkerNum = workerNum;

        mInput = input;
        mOutputs = outputs;
        mCounters = new Counters();

        mAggregators = LocalGraphRunUtils.getAggregator(mJob);
        mMsgManager = new MsgManager();

        mTaskContext = new TaskContextImpl(mCtx, mJob, this, mWorkerID, mWorkerNum, mOutputs, mCounters);
        mVertexMutations = new HashMap<VERTEX_ID, LocalVertexMutations>();

    }

    public boolean allVertexVoltHalt() {
        boolean ret = true;
        for (Vertex v : vertices.values()) {
            ret = ret && v.isHalted();
        }
        return ret && (!mMsgManager.hasNextStepMessages());
    }

    private void initCombiner() {
        Class<? extends Combiner> combinerCls = mJob.getCombinerClass();
        if (combinerCls != null) {
            try {
                mCombiner = combinerCls.newInstance();
                mCombiner.configure(mJob);
            } catch (Exception e) {
                throw new RuntimeException("exception occored when Instantiate combiner ", e);
            }
        } else {
            mCombiner = null;
        }

    }

    @SuppressWarnings("unchecked")
    private void initWorkerComputer() throws IOException {
        try {
            Class<? extends WorkerComputer> workerComputerClass = mJob.getWorkerComputerClass();

            mWorkerComputer = workerComputerClass.newInstance();

            List<Class<?>> workerValueClass = ReflectionUtils.<WorkerComputer>getTypeArguments(WorkerComputer.class,
                    workerComputerClass);

            if (workerValueClass.size() > 1) {
                throw new IOException("more than ONE workerValue Type Declared");
            } else if (workerValueClass.size() == 0) {
                mWorkerValue = NullWritable.get();
            } else if (workerValueClass.get(0) == null) {
                mWorkerValue = NullWritable.get();
            } else {
                mWorkerValue = (Writable) ReflectionUtils.newInstance(workerValueClass.get(0), mJob);
            }
            mWorkerComputer.setup(mTaskContext, mWorkerValue);
        } catch (Exception e) {
            throw new IOException(e);
        }
    }

    private void initAggregatorValues() throws IOException {
        mLastAggregatorValues = new ArrayList<Writable>();
        for (Aggregator agg : mAggregators) {
            mLastAggregatorValues.add(agg.createStartupValue(mTaskContext));
        }
    }

    @SuppressWarnings("unchecked")
    public void loadGraph() throws IOException {

        Class<? extends GraphLoader> loaderClass = mJob.getGraphLoaderClass();

        GraphLoader graphLoader = ReflectionUtils.newInstance(loaderClass, mJob);

        graphLoader.setup(mTaskContext.getConfiguration(), mTaskContext.getWorkerId(), mInput.getTable(),
                mTaskContext);

        BaseRecordReader<LongWritable, WritableRecord> reader;

        if (mInput != null && mInput.getTable() != null) {
            String fullName = mInput.getTable().getProjectName() + "." + mInput.getTable().getTableName();
            //File inputDir = mCtx.getInputDir(fullName, mInput.getPartSpec());
            File tableInfo = mInput.getFile().getParentFile();
            while (!tableInfo.getName().equals(mInput.getTable().getTableName())) {
                tableInfo = tableInfo.getParentFile();
            }
            reader = new LocalRecordReader(mInput.getFile().getParentFile(), tableInfo,
                    mCounters.findCounter(COUNTER.TASK_INPUT_RECORD),
                    mCounters.findCounter(COUNTER.TASK_INPUT_BYTE));
        } else {
            reader = new EmptyRecordReader();
        }

        mWriters = new HashMap<String, LocalRecordWriter>();
        for (String label : mOutputs.keySet()) {
            Counter outputRecordCounter = mCounters.findCounter(COUNTER.TASK_OUTPUT_RECORD);
            Counter outputByteCounter = mCounters.findCounter(COUNTER.TASK_OUTPUT_BYTE);
            LocalRecordWriter writer = new LocalRecordWriter(
                    new File(mCtx.getOutputDir(label), this.mTaskAttemptID.toString()), outputRecordCounter,
                    outputByteCounter);
            mWriters.put(label, writer);
        }

        while (reader.nextKeyValue()) {
            LongWritable recordNum = new LongWritable();
            recordNum.set((reader.getCurrentKey().get()));
            graphLoader.load(recordNum, ((SQLRecord) reader.getCurrentValue()).clone(), mTaskContext);
        }

        reader.close();

        mTaskContext.setOutputWriters(mWriters);
    }

    @SuppressWarnings("unchecked")
    public void init() throws IOException {
        initCombiner();
        initWorkerComputer();
        initAggregatorValues();

        for (Vertex v : vertices.values()) {
            v.setup(mTaskContext);
        }

    }

    public void processNextStep() throws IOException {
        mAggregatorValues = new ArrayList<Writable>();
        for (int i = 0; i < mAggregators.size(); ++i) {
            Writable initAggregatorValue = mAggregators.get(i).createInitialValue(mTaskContext);
            if (initAggregatorValue == null) {
                throw new RuntimeException("ODPS-0730001: " + mAggregators.get(i).getClass().getName()
                        + " createInitialValue return null");
            }
            mAggregatorValues.add(initAggregatorValue);
        }
        mMsgManager.nextSuperStep(mCtx);
    }

    @SuppressWarnings("unchecked")
    public void Compute() throws IOException {

        prepareMsg();
        for (Vertex v : vertices.values()) {
            Iterable<Writable> msg = mLastStepMessage.get(v);
            if (v.isHalted() && msg.iterator().hasNext()) {
                v.wakeUp();
            }

            if (!v.isHalted()) {
                v.compute(mTaskContext, msg);
            }
        }
    }

    @SuppressWarnings("unchecked")
    public void processMutation(VERTEX_ID id, LocalVertexMutations mutations, VertexResolver vertexResolver)
            throws IOException {
        Vertex v = vertices.get(id);
        boolean hasMessage = mMsgManager.hasMessageForVertex(mCtx, master.getSuperStep(), id);
        if (vertexResolver == null) {
            throw new IOException(
                    "ODPS-0730001: encounter mutations in compute but not set the mutation resolver.");
        }
        Vertex new_v = vertexResolver.resolve(id, v, mutations, hasMessage);
        if (new_v == null) {
            vertices.remove(id);
        } else {
            VerifyUtils.verifyVertex(new_v);
            vertices.put(id, new_v);
        }
    }

    @SuppressWarnings("unchecked")
    public void processWorkerMutations(VertexResolver vertexResolver) throws IOException {
        HashSet<VERTEX_ID> mutationIDs = new HashSet<VERTEX_ID>();
        mutationIDs.addAll(mVertexMutations.keySet());

        for (WritableComparable<?> id : mMsgManager.getVertexIDList()) {
            if (vertices.get(id) == null) {
                mutationIDs.add((VERTEX_ID) id);
            }
        }

        for (VERTEX_ID id : mutationIDs) {
            processMutation(id, mVertexMutations.get(id), vertexResolver);
        }
        mVertexMutations = new HashMap<VERTEX_ID, LocalVertexMutations>();
    }

    public List<Writable> getAggregatorValues() {
        return mAggregatorValues;
    }

    public Counters getCounters() {
        return mCounters;
    }

    public long getEgeNumber() {
        long egeNumber = 0;
        for (Vertex v : vertices.values()) {
            egeNumber += v.getNumEdges();
        }
        return egeNumber;
    }

    public List<Writable> getLastAggregatedValue() {
        return mLastAggregatorValues;
    }

    public Master getMaster() {
        return master;
    }

    public WorkerContext getTaskContext() {
        return mTaskContext;
    }

    public long getVertexNumber() {
        return vertices.size();
    }

    public LocalVertexMutations getVertexMutations(VERTEX_ID id) {
        LocalVertexMutations ret = mVertexMutations.get(id);
        if (ret == null) {
            ret = new LocalVertexMutations();
            mVertexMutations.put(id, ret);
        }
        return ret;
    }

    public List<Writable> partialAggregate() {
        return mAggregatorValues;
    }

    public void pushMsg(RuntimeContext context, long superStep, WritableComparable<?> vertexId, Writable msg) {
        mMsgManager.pushMsg(context, superStep, vertexId, msg);
    }

    public void close() throws IOException {
        for (LocalRecordWriter writer : mWriters.values()) {
            writer.close();
        }
        mWriters = null;

        FileUtils.writeStringToFile(new File(mCtx.getCounterDir(), String.valueOf(this.mTaskAttemptID.getTaskId())),
                mCounters.toString());
        LOG.debug(mCounters);
    }

    @SuppressWarnings("unchecked")
    public void cleanup() throws IOException {
        for (Vertex v : vertices.values()) {
            v.cleanup(mTaskContext);
        }
        mWorkerComputer.cleanup(mTaskContext);
    }

    @SuppressWarnings("unchecked")
    private Iterable<Writable> combineMsg(WritableComparable id, Iterable<Writable> msgs) throws IOException {
        if (mCombiner != null) {
            Writable combineMsg = null;
            for (Writable msg : msgs) {
                if (combineMsg == null) {
                    combineMsg = msg;
                } else {
                    mCombiner.combine(id, combineMsg, msg);
                }
            }
            ArrayList<Writable> combinedMsgs = new ArrayList<Writable>();
            if (combineMsg != null) {
                combinedMsgs.add(combineMsg);
            }
            return combinedMsgs;
        } else {
            return msgs;
        }
    }

    private void prepareMsg() throws IOException {
        long superStep = master.getSuperStep();
        LOG.debug("worker super step " + superStep + ", vertices count " + vertices.size());
        mLastStepMessage.clear();
        for (Vertex v : vertices.values()) {
            Iterable<Writable> msgs = mMsgManager.popMsges(mCtx, superStep, v.getId());
            if (mCombiner != null) {
                msgs = combineMsg(v.getId(), msgs);
            }
            mLastStepMessage.put(v, msgs);
        }

    }

    public void setLastAggregatedValue(List<Writable> lastAggrValues) {
        mLastAggregatorValues = new ArrayList<Writable>(lastAggrValues.size());
        for (int i = 0; i < lastAggrValues.size(); ++i) {
            Writable value = null;
            if (lastAggrValues.get(i) != null) {
                value = WritableUtils.clone(lastAggrValues.get(i), mJob);
            }
            mLastAggregatorValues.add(value);
        }
    }

    public void setTotalNumVerticesAndEdges(int totalVertices, int totalEdge) {
        mTaskContext.setTotalNumVertices(totalVertices);
        mTaskContext.setTotalNumEdges(totalEdge);
    }

    public Writable getWorkerValue() {
        return mWorkerValue;
    }

}