com.cloudera.branchreduce.impl.distributed.TaskMaster.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.branchreduce.impl.distributed.TaskMaster.java

Source

/**
 * Copyright (c) 2012, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */
package com.cloudera.branchreduce.impl.distributed;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.io.Writable;

import com.cloudera.branchreduce.GlobalState;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.AbstractScheduledService;

/**
 *
 */
public class TaskMaster<T extends Writable, G extends GlobalState<G>> extends AbstractScheduledService {

    public interface WorkerProxy<T, G> {
        void startTasks(List<T> tasks, G globalState);

        List<T> getTasks();

        void updateGlobalState(G globalState);
    }

    private static final Log LOG = LogFactory.getLog(TaskMaster.class);

    private final int vassalCount;
    private final Class<T> taskClass;
    private final ExecutorService executor;

    private final List<WorkerProxy<T, G>> workers;
    private final Map<Integer, Boolean> hasWork;

    private final BlockingQueue<T> tasks;

    private final G globalState;
    private boolean sendGlobalStateUpdate = false;
    private boolean hasStarted = false;

    public TaskMaster(int vassalCount, List<T> initialTasks, G globalState) {
        this(vassalCount, initialTasks, globalState, Executors.newCachedThreadPool());
    }

    public TaskMaster(int vassalCount, List<T> initialTasks, G globalState, ExecutorService executor) {
        this.vassalCount = vassalCount;
        this.workers = Lists.newArrayList();
        this.hasWork = Maps.newConcurrentMap();
        this.tasks = new LinkedBlockingQueue<T>(initialTasks);
        this.globalState = globalState;
        if (!initialTasks.isEmpty()) {
            this.taskClass = (Class<T>) initialTasks.get(0).getClass();
        } else {
            this.taskClass = null;
        }
        this.executor = executor;
    }

    public Class<T> getTaskClass() {
        return taskClass;
    }

    public Class<G> getGlobalStateClass() {
        return (Class<G>) globalState.getClass();
    }

    public G getGlobalState() {
        return globalState;
    }

    public boolean hasStarted() {
        return hasStarted;
    }

    public int registerWorker(final WorkerProxy<T, G> worker) {
        int id = -1;

        synchronized (workers) {
            id = workers.size();
            workers.add(worker);
            if (workers.size() == vassalCount) {
                start();
            }
        }

        LOG.info("Registered worker no. " + (id + 1));
        return id;
    }

    public List<T> getWork(int requestorId, G workerState) {
        updateGlobalState(workerState);
        if (!tasks.isEmpty()) {
            try {
                List<T> ret = ImmutableList.of(tasks.take());
                LOG.info("Sending work to " + requestorId);
                return ret;
            } catch (InterruptedException e) {
                return ImmutableList.of();
            }
        } else {
            // Need to get some more work, if anyone has any.
            Set<Integer> workingIds = Sets.newHashSet(hasWork.keySet());
            for (Integer workingId : workingIds) {
                if (workingId != requestorId) {
                    WorkerProxy<T, G> worker = workers.get(workingId);
                    List<T> stolen = worker.getTasks();
                    if (!stolen.isEmpty()) {
                        List<T> ret = ImmutableList.of(stolen.get(0));
                        tasks.addAll(stolen.subList(1, stolen.size()));
                        LOG.info("Sending stolen work to " + requestorId);
                        return ret;
                    }
                }
            }
            LOG.info("Could not send work to " + requestorId);
            hasWork.remove(requestorId);
            return ImmutableList.of();
        }
    }

    public boolean updateGlobalState(G other) {
        LOG.info("Received global state update: " + other);
        boolean ret = false;
        synchronized (globalState) {
            if (globalState.mergeWith(other)) {
                this.sendGlobalStateUpdate = true;
                ret = true;
                LOG.info("New global state: " + globalState);
            }
        }
        return ret;
    }

    @Override
    protected void runOneIteration() throws Exception {
        this.hasStarted = true;
        synchronized (globalState) {
            if (sendGlobalStateUpdate) {
                // Send notifications to all of the workers.
                for (final WorkerProxy<T, G> worker : workers) {
                    executor.submit(new Runnable() {
                        @Override
                        public void run() {
                            worker.updateGlobalState(globalState);
                        }
                    });
                }
                this.sendGlobalStateUpdate = false;
            }
        }

        if (tasks.isEmpty() && hasWork.isEmpty()) {
            LOG.info("Nothing to do, stopping");
            stop();
        }
    }

    @Override
    protected Scheduler scheduler() {
        return Scheduler.newFixedRateSchedule(1, 1, TimeUnit.SECONDS);
    }

    @Override
    protected void startUp() throws Exception {
        if (tasks.isEmpty()) {
            LOG.info("No tasks to perform, exiting");
            stop();
            return;
        } else {
            LOG.info("Initial task count = " + tasks.size());
        }

        // Send tasks to all of the workers.
        for (int i = 0; i < vassalCount; i++) {
            final WorkerProxy<T, G> worker = workers.get(i);
            final List<T> task = ImmutableList.of(tasks.take());
            LOG.info("Starting " + task.size() + " units at worker " + i);
            executor.submit(new Runnable() {
                @Override
                public void run() {
                    worker.startTasks(task, globalState);
                }
            });
            hasWork.put(i, true);
        }
    }

    @Override
    protected void shutDown() throws Exception {
        // No cleanup, surprisingly.
    }
}