edu.berkeley.sparrow.daemon.scheduler.Scheduler.java Source code

Java tutorial

Introduction

Here is the source code for edu.berkeley.sparrow.daemon.scheduler.Scheduler.java

Source

/*
 * Copyright 2013 The Regents of The University California
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package edu.berkeley.sparrow.daemon.scheduler;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.configuration.Configuration;
import org.apache.log4j.Logger;
import org.apache.thrift.TException;
import org.apache.thrift.async.AsyncMethodCallback;

import com.google.common.base.Optional;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

import edu.berkeley.sparrow.daemon.SparrowConf;
import edu.berkeley.sparrow.daemon.util.Logging;
import edu.berkeley.sparrow.daemon.util.Network;
import edu.berkeley.sparrow.daemon.util.Serialization;
import edu.berkeley.sparrow.daemon.util.ThriftClientPool;
import edu.berkeley.sparrow.thrift.FrontendService;
import edu.berkeley.sparrow.thrift.FrontendService.AsyncClient.frontendMessage_call;
import edu.berkeley.sparrow.thrift.InternalService;
import edu.berkeley.sparrow.thrift.InternalService.AsyncClient;
import edu.berkeley.sparrow.thrift.InternalService.AsyncClient.enqueueTaskReservations_call;
import edu.berkeley.sparrow.thrift.TEnqueueTaskReservationsRequest;
import edu.berkeley.sparrow.thrift.TFullTaskId;
import edu.berkeley.sparrow.thrift.THostPort;
import edu.berkeley.sparrow.thrift.TPlacementPreference;
import edu.berkeley.sparrow.thrift.TSchedulingRequest;
import edu.berkeley.sparrow.thrift.TTaskLaunchSpec;
import edu.berkeley.sparrow.thrift.TTaskSpec;

/**
 * This class implements the Sparrow scheduler functionality.
 */
public class Scheduler {
    private final static Logger LOG = Logger.getLogger(Scheduler.class);
    private final static Logger AUDIT_LOG = Logging.getAuditLogger(Scheduler.class);

    /** Used to uniquely identify requests arriving at this scheduler. */
    private AtomicInteger counter = new AtomicInteger(0);

    /** How many times the special case has been triggered. */
    private AtomicInteger specialCaseCounter = new AtomicInteger(0);

    private THostPort address;

    /** Socket addresses for each frontend. */
    HashMap<String, InetSocketAddress> frontendSockets = new HashMap<String, InetSocketAddress>();

    /**
     * Service that handles cancelling outstanding reservations for jobs that have already been
     * scheduled.  Only instantiated if {@code SparrowConf.CANCELLATION} is set to true.
     */
    private CancellationService cancellationService;
    private boolean useCancellation;

    /** Thrift client pool for communicating with node monitors */
    ThriftClientPool<InternalService.AsyncClient> nodeMonitorClientPool = new ThriftClientPool<InternalService.AsyncClient>(
            new ThriftClientPool.InternalServiceMakerFactory());

    /** Thrift client pool for communicating with front ends. */
    private ThriftClientPool<FrontendService.AsyncClient> frontendClientPool = new ThriftClientPool<FrontendService.AsyncClient>(
            new ThriftClientPool.FrontendServiceMakerFactory());

    /** Information about cluster workload due to other schedulers. */
    private SchedulerState state;

    /** Probe ratios to use if the probe ratio is not explicitly set in the request. */
    private double defaultProbeRatioUnconstrained;
    private double defaultProbeRatioConstrained;

    /**
     * For each request, the task placer that should be used to place the request's tasks. Indexed
     * by the request ID.
     */
    private ConcurrentMap<String, TaskPlacer> requestTaskPlacers;

    /**
     * When a job includes SPREAD_EVENLY in the description and has this number of tasks,
     * Sparrow spreads the tasks evenly over machines to evenly cache data. We need this (in
     * addition to the SPREAD_EVENLY descriptor) because only the reduce phase -- not the map
     * phase -- should be spread.
     */
    private int spreadEvenlyTaskSetSize;

    private Configuration conf;

    public void initialize(Configuration conf, InetSocketAddress socket) throws IOException {
        address = Network.socketAddressToThrift(socket);
        String mode = conf.getString(SparrowConf.DEPLYOMENT_MODE, "unspecified");
        this.conf = conf;
        if (mode.equals("standalone")) {
            state = new StandaloneSchedulerState();
        } else if (mode.equals("configbased")) {
            state = new ConfigSchedulerState();
        } else {
            throw new RuntimeException("Unsupported deployment mode: " + mode);
        }

        state.initialize(conf);

        defaultProbeRatioUnconstrained = conf.getDouble(SparrowConf.SAMPLE_RATIO, SparrowConf.DEFAULT_SAMPLE_RATIO);
        defaultProbeRatioConstrained = conf.getDouble(SparrowConf.SAMPLE_RATIO_CONSTRAINED,
                SparrowConf.DEFAULT_SAMPLE_RATIO_CONSTRAINED);

        requestTaskPlacers = Maps.newConcurrentMap();

        useCancellation = conf.getBoolean(SparrowConf.CANCELLATION, SparrowConf.DEFAULT_CANCELLATION);
        if (useCancellation) {
            LOG.debug("Initializing cancellation service");
            cancellationService = new CancellationService(nodeMonitorClientPool);
            new Thread(cancellationService).start();
        } else {
            LOG.debug("Not using cancellation");
        }

        spreadEvenlyTaskSetSize = conf.getInt(SparrowConf.SPREAD_EVENLY_TASK_SET_SIZE,
                SparrowConf.DEFAULT_SPREAD_EVENLY_TASK_SET_SIZE);
    }

    public boolean registerFrontend(String appId, String addr) {
        LOG.debug(Logging.functionCall(appId, addr));
        Optional<InetSocketAddress> socketAddress = Serialization.strToSocket(addr);
        if (!socketAddress.isPresent()) {
            LOG.error("Bad address from frontend: " + addr);
            return false;
        }
        frontendSockets.put(appId, socketAddress.get());
        return state.watchApplication(appId);
    }

    /**
     * Callback for enqueueTaskReservations() that returns the client to the client pool.
     */
    private class EnqueueTaskReservationsCallback implements AsyncMethodCallback<enqueueTaskReservations_call> {
        String requestId;
        InetSocketAddress nodeMonitorAddress;
        long startTimeMillis;

        public EnqueueTaskReservationsCallback(String requestId, InetSocketAddress nodeMonitorAddress) {
            this.requestId = requestId;
            this.nodeMonitorAddress = nodeMonitorAddress;
            this.startTimeMillis = System.currentTimeMillis();
        }

        public void onComplete(enqueueTaskReservations_call response) {
            AUDIT_LOG.debug(Logging.auditEventString("scheduler_complete_enqueue_task", requestId,
                    nodeMonitorAddress.getAddress().getHostAddress()));
            long totalTime = System.currentTimeMillis() - startTimeMillis;
            LOG.debug("Enqueue Task RPC to " + nodeMonitorAddress.getAddress().getHostAddress() + " for request "
                    + requestId + " completed in " + totalTime + "ms");
            try {
                nodeMonitorClientPool.returnClient(nodeMonitorAddress, (AsyncClient) response.getClient());
            } catch (Exception e) {
                LOG.error("Error returning client to node monitor client pool: " + e);
            }
            return;
        }

        public void onError(Exception exception) {
            // Do not return error client to pool
            LOG.error("Error executing enqueueTaskReservation RPC:" + exception);
        }
    }

    /** Adds constraints such that tasks in the job will be spread evenly across the cluster.
     *
     *  We expect three of these special jobs to be submitted; 3 sequential calls to this
     *  method will result in spreading the tasks for the 3 jobs across the cluster such that no
     *  more than 1 task is assigned to each machine.
     */
    private TSchedulingRequest addConstraintsToSpreadTasks(TSchedulingRequest req) throws TException {
        LOG.info("Handling spread tasks request: " + req);
        int specialCaseIndex = specialCaseCounter.incrementAndGet();
        if (specialCaseIndex < 1 || specialCaseIndex > 3) {
            LOG.error("Invalid special case index: " + specialCaseIndex);
        }

        // No tasks have preferences and we have the magic number of tasks
        TSchedulingRequest newReq = new TSchedulingRequest();
        newReq.user = req.user;
        newReq.app = req.app;
        newReq.probeRatio = req.probeRatio;

        List<InetSocketAddress> allBackends = Lists.newArrayList();
        List<InetSocketAddress> backends = Lists.newArrayList();
        // We assume the below always returns the same order (invalid assumption?)
        for (InetSocketAddress backend : state.getBackends(req.app)) {
            allBackends.add(backend);
        }

        // Each time this is called, we restrict to 1/3 of the nodes in the cluster
        for (int i = 0; i < allBackends.size(); i++) {
            if (i % 3 == specialCaseIndex - 1) {
                backends.add(allBackends.get(i));
            }
        }
        Collections.shuffle(backends);

        if (!(allBackends.size() >= (req.getTasks().size() * 3))) {
            LOG.error("Special case expects at least three times as many machines as tasks.");
            return null;
        }
        LOG.info(backends);
        for (int i = 0; i < req.getTasksSize(); i++) {
            TTaskSpec task = req.getTasks().get(i);
            TTaskSpec newTask = new TTaskSpec();
            newTask.message = task.message;
            newTask.taskId = task.taskId;
            newTask.preference = new TPlacementPreference();
            newTask.preference.addToNodes(backends.get(i).getHostName());
            newReq.addToTasks(newTask);
        }
        LOG.info("New request: " + newReq);
        return newReq;
    }

    /** Checks whether we should add constraints to this job to evenly spread tasks over machines.
     *
     * This is a hack used to force Spark to cache data in 3 locations: we run 3 select * queries
     * on the same table and spread the tasks for those queries evenly across the cluster such that
     * the input data for the query is triple replicated and spread evenly across the cluster.
     *
     * We signal that Sparrow should use this hack by adding SPREAD_TASKS to the job's description.
     */
    private boolean isSpreadTasksJob(TSchedulingRequest request) {
        if ((request.getDescription() != null) && (request.getDescription().indexOf("SPREAD_EVENLY") != -1)) {
            // Need to check to see if there are 3 constraints; if so, it's the map phase of the
            // first job that reads the data from HDFS, so we shouldn't override the constraints.
            for (TTaskSpec t : request.getTasks()) {
                if (t.getPreference() != null && (t.getPreference().getNodes() != null)
                        && (t.getPreference().getNodes().size() == 3)) {
                    LOG.debug("Not special case: one of request's tasks had 3 preferences");
                    return false;
                }
            }
            if (request.getTasks().size() != spreadEvenlyTaskSetSize) {
                LOG.debug("Not special case: job had " + request.getTasks().size()
                        + " tasks rather than the expected " + spreadEvenlyTaskSetSize);
                return false;
            }
            if (specialCaseCounter.get() >= 3) {
                LOG.error("Not using special case because special case code has already been "
                        + " called 3 more more times!");
                return false;
            }
            LOG.debug("Spreading tasks for job with " + request.getTasks().size() + " tasks");
            return true;
        }
        LOG.debug("Not special case: description did not contain SPREAD_EVENLY");
        return false;
    }

    public void submitJob(TSchedulingRequest request) throws TException {
        // Short-circuit case that is used for liveness checking
        if (request.tasks.size() == 0) {
            return;
        }
        if (isSpreadTasksJob(request)) {
            handleJobSubmission(addConstraintsToSpreadTasks(request));
        } else {
            handleJobSubmission(request);
        }
    }

    public void handleJobSubmission(TSchedulingRequest request) throws TException {
        LOG.debug(Logging.functionCall(request));

        long start = System.currentTimeMillis();

        String requestId = getRequestId();

        String user = "";
        if (request.getUser() != null && request.getUser().getUser() != null) {
            user = request.getUser().getUser();
        }
        String description = "";
        if (request.getDescription() != null) {
            description = request.getDescription();
        }

        String app = request.getApp();
        List<TTaskSpec> tasks = request.getTasks();
        Set<InetSocketAddress> backends = state.getBackends(app);
        LOG.debug("NumBackends: " + backends.size());
        boolean constrained = false;
        for (TTaskSpec task : tasks) {
            constrained = constrained || (task.preference != null && task.preference.nodes != null
                    && !task.preference.nodes.isEmpty());
        }
        // Logging the address here is somewhat redundant, since all of the
        // messages in this particular log file come from the same address.
        // However, it simplifies the process of aggregating the logs, and will
        // also be useful when we support multiple daemons running on a single
        // machine.
        AUDIT_LOG.info(Logging.auditEventString("arrived", requestId, request.getTasks().size(), address.getHost(),
                address.getPort(), user, description, constrained));

        TaskPlacer taskPlacer;
        if (constrained) {
            if (request.isSetProbeRatio()) {
                taskPlacer = new ConstrainedTaskPlacer(requestId, request.getProbeRatio());
            } else {
                taskPlacer = new ConstrainedTaskPlacer(requestId, defaultProbeRatioConstrained);
            }
        } else {
            if (request.isSetProbeRatio()) {
                taskPlacer = new UnconstrainedTaskPlacer(requestId, request.getProbeRatio());
            } else {
                taskPlacer = new UnconstrainedTaskPlacer(requestId, defaultProbeRatioUnconstrained);
            }
        }
        requestTaskPlacers.put(requestId, taskPlacer);

        Map<InetSocketAddress, TEnqueueTaskReservationsRequest> enqueueTaskReservationsRequests;
        enqueueTaskReservationsRequests = taskPlacer.getEnqueueTaskReservationsRequests(request, requestId,
                backends, address);

        // Request to enqueue a task at each of the selected nodes.
        for (Entry<InetSocketAddress, TEnqueueTaskReservationsRequest> entry : enqueueTaskReservationsRequests
                .entrySet()) {
            try {
                InternalService.AsyncClient client = nodeMonitorClientPool.borrowClient(entry.getKey());
                LOG.debug("Launching enqueueTask for request " + requestId + "on node: " + entry.getKey());
                AUDIT_LOG.debug(Logging.auditEventString("scheduler_launch_enqueue_task",
                        entry.getValue().requestId, entry.getKey().getAddress().getHostAddress()));
                client.enqueueTaskReservations(entry.getValue(),
                        new EnqueueTaskReservationsCallback(requestId, entry.getKey()));
            } catch (Exception e) {
                LOG.error("Error enqueuing task on node " + entry.getKey().toString() + ":" + e);
            }
        }

        long end = System.currentTimeMillis();
        LOG.debug("All tasks enqueued for request " + requestId + "; returning. Total time: " + (end - start)
                + " milliseconds");
    }

    public List<TTaskLaunchSpec> getTask(String requestId, THostPort nodeMonitorAddress) {
        /* TODO: Consider making this synchronized to avoid the need for synchronization in
         * the task placers (although then we'd lose the ability to parallelize over task placers). */
        LOG.debug(Logging.functionCall(requestId, nodeMonitorAddress));
        TaskPlacer taskPlacer = requestTaskPlacers.get(requestId);
        if (taskPlacer == null) {
            LOG.debug("Received getTask() request for request " + requestId + ", which had no more "
                    + "unplaced tasks");
            return Lists.newArrayList();
        }

        synchronized (taskPlacer) {
            List<TTaskLaunchSpec> taskLaunchSpecs = taskPlacer.assignTask(nodeMonitorAddress);
            if (taskLaunchSpecs == null || taskLaunchSpecs.size() > 1) {
                LOG.error("Received invalid task placement for request " + requestId + ": "
                        + taskLaunchSpecs.toString());
                return Lists.newArrayList();
            } else if (taskLaunchSpecs.size() == 1) {
                AUDIT_LOG.info(Logging.auditEventString("scheduler_assigned_task", requestId,
                        taskLaunchSpecs.get(0).taskId, nodeMonitorAddress.getHost()));
            } else {
                AUDIT_LOG.info(Logging.auditEventString("scheduler_get_task_no_task", requestId,
                        nodeMonitorAddress.getHost()));
            }

            if (taskPlacer.allTasksPlaced()) {
                LOG.debug("All tasks placed for request " + requestId);
                requestTaskPlacers.remove(requestId);
                if (useCancellation) {
                    Set<THostPort> outstandingNodeMonitors = taskPlacer.getOutstandingNodeMonitorsForCancellation();
                    for (THostPort nodeMonitorToCancel : outstandingNodeMonitors) {
                        cancellationService.addCancellation(requestId, nodeMonitorToCancel);
                    }
                }
            }
            return taskLaunchSpecs;
        }
    }

    /**
     * Returns an ID that identifies a request uniquely (across all Sparrow schedulers).
     *
     * This should only be called once for each request (it will return a different
     * identifier if called a second time).
     *
     * TODO: Include the port number, so this works when there are multiple schedulers
     * running on a single machine.
     */
    private String getRequestId() {
        /* The request id is a string that includes the IP address of this scheduler followed
         * by the counter.  We use a counter rather than a hash of the request because there
         * may be multiple requests to run an identical job. */
        return String.format("%s_%d", Network.getIPAddress(conf), counter.getAndIncrement());
    }

    private class sendFrontendMessageCallback implements AsyncMethodCallback<frontendMessage_call> {
        private InetSocketAddress frontendSocket;
        private FrontendService.AsyncClient client;

        public sendFrontendMessageCallback(InetSocketAddress socket, FrontendService.AsyncClient client) {
            frontendSocket = socket;
            this.client = client;
        }

        public void onComplete(frontendMessage_call response) {
            try {
                frontendClientPool.returnClient(frontendSocket, client);
            } catch (Exception e) {
                LOG.error(e);
            }
        }

        public void onError(Exception exception) {
            // Do not return error client to pool
            LOG.error("Error sending frontend message callback: " + exception);
        }
    }

    public void sendFrontendMessage(String app, TFullTaskId taskId, int status, ByteBuffer message) {
        LOG.debug(Logging.functionCall(app, taskId, message));
        InetSocketAddress frontend = frontendSockets.get(app);
        if (frontend == null) {
            LOG.error("Requested message sent to unregistered app: " + app);
        }
        try {
            FrontendService.AsyncClient client = frontendClientPool.borrowClient(frontend);
            client.frontendMessage(taskId, status, message, new sendFrontendMessageCallback(frontend, client));
        } catch (IOException e) {
            LOG.error("Error launching message on frontend: " + app, e);
        } catch (TException e) {
            LOG.error("Error launching message on frontend: " + app, e);
        } catch (Exception e) {
            LOG.error("Error launching message on frontend: " + app, e);
        }
    }
}