org.apache.giraph.worker.WorkerAggregatorHandler.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.giraph.worker.WorkerAggregatorHandler.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.giraph.worker;

import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
import org.apache.giraph.bsp.CentralizedServiceWorker;
import org.apache.giraph.comm.aggregators.WorkerAggregatorRequestProcessor;
import org.apache.giraph.comm.aggregators.AggregatedValueOutputStream;
import org.apache.giraph.comm.aggregators.AggregatorUtils;
import org.apache.giraph.comm.aggregators.AllAggregatorServerData;
import org.apache.giraph.comm.aggregators.OwnerAggregatorServerData;
import org.apache.giraph.aggregators.Aggregator;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.util.Progressable;
import org.apache.log4j.Logger;

import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.io.IOException;
import java.util.Map;
import java.util.Set;

/**
 * Handler for aggregators on worker. Provides the aggregated values and
 * performs aggregations from user vertex code (thread-safe). Also has
 * methods for all superstep coordination related to aggregators.
 *
 * At the beginning of any superstep any worker calls prepareSuperstep(),
 * which blocks until the final aggregates from the previous superstep have
 * been delivered to the worker.
 * Next, during the superstep worker can call aggregate() and
 * getAggregatedValue() (both methods are thread safe) the former
 * computes partial aggregates for this superstep from the worker,
 * the latter returns (read-only) final aggregates from the previous superstep.
 * Finally, at the end of the superstep, the worker calls finishSuperstep(),
 * which propagates non-owned partial aggregates to the owner workers,
 * and sends the final aggregate from the owner worker to the master.
 */
public class WorkerAggregatorHandler implements WorkerThreadAggregatorUsage {
    /** Class logger */
    private static final Logger LOG = Logger.getLogger(WorkerAggregatorHandler.class);
    /** Map of values from previous superstep */
    private Map<String, Writable> previousAggregatedValueMap = Maps.newHashMap();
    /** Map of aggregators for current superstep */
    private Map<String, Aggregator<Writable>> currentAggregatorMap = Maps.newHashMap();
    /** Service worker */
    private final CentralizedServiceWorker<?, ?, ?, ?> serviceWorker;
    /** Progressable for reporting progress */
    private final Progressable progressable;
    /** How big a single aggregator request can be */
    private final int maxBytesPerAggregatorRequest;
    /** Giraph configuration */
    private final ImmutableClassesGiraphConfiguration conf;

    /**
     * Constructor
     *
     * @param serviceWorker Service worker
     * @param conf          Giraph configuration
     * @param progressable  Progressable for reporting progress
     */
    public WorkerAggregatorHandler(CentralizedServiceWorker<?, ?, ?, ?> serviceWorker,
            ImmutableClassesGiraphConfiguration conf, Progressable progressable) {
        this.serviceWorker = serviceWorker;
        this.progressable = progressable;
        this.conf = conf;
        maxBytesPerAggregatorRequest = conf.getInt(AggregatorUtils.MAX_BYTES_PER_AGGREGATOR_REQUEST,
                AggregatorUtils.MAX_BYTES_PER_AGGREGATOR_REQUEST_DEFAULT);
    }

    @Override
    public <A extends Writable> void aggregate(String name, A value) {
        Aggregator<Writable> aggregator = currentAggregatorMap.get(name);
        if (aggregator != null) {
            progressable.progress();
            synchronized (aggregator) {
                aggregator.aggregate(value);
            }
        } else {
            throw new IllegalStateException("aggregate: " + AggregatorUtils.getUnregisteredAggregatorMessage(name,
                    currentAggregatorMap.size() != 0, conf));
        }
    }

    @Override
    public <A extends Writable> A getAggregatedValue(String name) {
        A value = (A) previousAggregatedValueMap.get(name);
        if (value == null) {
            LOG.warn("getAggregatedValue: " + AggregatorUtils.getUnregisteredAggregatorMessage(name,
                    previousAggregatedValueMap.size() != 0, conf));
        }
        return value;
    }

    /**
     * Prepare aggregators for current superstep
     *
     * @param requestProcessor Request processor for aggregators
     */
    public void prepareSuperstep(WorkerAggregatorRequestProcessor requestProcessor) {
        if (LOG.isDebugEnabled()) {
            LOG.debug("prepareSuperstep: Start preparing aggregators");
        }
        AllAggregatorServerData allAggregatorData = serviceWorker.getServerData().getAllAggregatorData();
        // Wait for my aggregators
        Iterable<byte[]> dataToDistribute = allAggregatorData
                .getDataFromMasterWhenReady(serviceWorker.getMasterInfo());
        try {
            // Distribute my aggregators
            requestProcessor.distributeAggregators(dataToDistribute);
        } catch (IOException e) {
            throw new IllegalStateException(
                    "prepareSuperstep: " + "IOException occurred while trying to distribute aggregators", e);
        }
        // Wait for all other aggregators and store them
        allAggregatorData.fillNextSuperstepMapsWhenReady(getOtherWorkerIdsSet(), previousAggregatedValueMap,
                currentAggregatorMap);
        allAggregatorData.reset();
        if (LOG.isDebugEnabled()) {
            LOG.debug("prepareSuperstep: Aggregators prepared");
        }
    }

    /**
     * Send aggregators to their owners and in the end to the master
     *
     * @param requestProcessor Request processor for aggregators
     */
    public void finishSuperstep(WorkerAggregatorRequestProcessor requestProcessor) {
        if (LOG.isInfoEnabled()) {
            LOG.info("finishSuperstep: Start gathering aggregators, " + "workers will send their aggregated values "
                    + "once they are done with superstep computation");
        }
        OwnerAggregatorServerData ownerAggregatorData = serviceWorker.getServerData().getOwnerAggregatorData();
        // First send partial aggregated values to their owners and determine
        // which aggregators belong to this worker
        for (Map.Entry<String, Aggregator<Writable>> entry : currentAggregatorMap.entrySet()) {
            try {
                boolean sent = requestProcessor.sendAggregatedValue(entry.getKey(),
                        entry.getValue().getAggregatedValue());
                if (!sent) {
                    // If it's my aggregator, add it directly
                    ownerAggregatorData.aggregate(entry.getKey(), entry.getValue().getAggregatedValue());
                }
            } catch (IOException e) {
                throw new IllegalStateException("finishSuperstep: "
                        + "IOException occurred while sending aggregator " + entry.getKey() + " to its owner", e);
            }
            progressable.progress();
        }
        try {
            // Flush
            requestProcessor.flush();
        } catch (IOException e) {
            throw new IllegalStateException(
                    "finishSuperstep: " + "IOException occurred while sending aggregators to owners", e);
        }

        // Wait to receive partial aggregated values from all other workers
        Iterable<Map.Entry<String, Writable>> myAggregators = ownerAggregatorData
                .getMyAggregatorValuesWhenReady(getOtherWorkerIdsSet());

        // Send final aggregated values to master
        AggregatedValueOutputStream aggregatorOutput = new AggregatedValueOutputStream();
        for (Map.Entry<String, Writable> entry : myAggregators) {
            try {
                int currentSize = aggregatorOutput.addAggregator(entry.getKey(), entry.getValue());
                if (currentSize > maxBytesPerAggregatorRequest) {
                    requestProcessor.sendAggregatedValuesToMaster(aggregatorOutput.flush());
                }
                progressable.progress();
            } catch (IOException e) {
                throw new IllegalStateException(
                        "finishSuperstep: " + "IOException occurred while writing aggregator " + entry.getKey(), e);
            }
        }
        try {
            requestProcessor.sendAggregatedValuesToMaster(aggregatorOutput.flush());
        } catch (IOException e) {
            throw new IllegalStateException(
                    "finishSuperstep: " + "IOException occured while sending aggregators to master", e);
        }
        // Wait for master to receive aggregated values before proceeding
        serviceWorker.getWorkerClient().waitAllRequests();

        ownerAggregatorData.reset();
        if (LOG.isDebugEnabled()) {
            LOG.debug("finishSuperstep: Aggregators finished");
        }
    }

    /**
     * Create new aggregator usage which will be used by one of the compute
     * threads.
     *
     * @return New aggregator usage
     */
    public WorkerThreadAggregatorUsage newThreadAggregatorUsage() {
        if (AggregatorUtils.useThreadLocalAggregators(conf)) {
            return new ThreadLocalWorkerAggregatorUsage();
        } else {
            return this;
        }
    }

    @Override
    public void finishThreadComputation() {
        // If we don't use thread-local aggregators, all the aggregated values
        // are already in this object
    }

    /**
     * Get set of all worker task ids except the current one
     *
     * @return Set of all other worker task ids
     */
    public Set<Integer> getOtherWorkerIdsSet() {
        Set<Integer> otherWorkers = Sets.newHashSetWithExpectedSize(serviceWorker.getWorkerInfoList().size());
        for (WorkerInfo workerInfo : serviceWorker.getWorkerInfoList()) {
            if (workerInfo.getTaskId() != serviceWorker.getWorkerInfo().getTaskId()) {
                otherWorkers.add(workerInfo.getTaskId());
            }
        }
        return otherWorkers;
    }

    /**
     * Not thread-safe implementation of {@link WorkerThreadAggregatorUsage}.
     * We can use one instance of this object per thread to prevent
     * synchronizing on each aggregate() call. In the end of superstep,
     * values from each of these will be aggregated back to {@link
     * WorkerAggregatorHandler}
     */
    public class ThreadLocalWorkerAggregatorUsage implements WorkerThreadAggregatorUsage {
        /** Thread-local aggregator map */
        private final Map<String, Aggregator<Writable>> threadAggregatorMap;

        /**
         * Constructor
         *
         * Creates new instances of all aggregators from
         * {@link WorkerAggregatorHandler}
         */
        public ThreadLocalWorkerAggregatorUsage() {
            threadAggregatorMap = Maps
                    .newHashMapWithExpectedSize(WorkerAggregatorHandler.this.currentAggregatorMap.size());
            for (Map.Entry<String, Aggregator<Writable>> entry : WorkerAggregatorHandler.this.currentAggregatorMap
                    .entrySet()) {
                threadAggregatorMap.put(entry.getKey(), AggregatorUtils
                        .newAggregatorInstance((Class<Aggregator<Writable>>) entry.getValue().getClass(), conf));
            }
        }

        @Override
        public <A extends Writable> void aggregate(String name, A value) {
            Aggregator<Writable> aggregator = threadAggregatorMap.get(name);
            if (aggregator != null) {
                progressable.progress();
                aggregator.aggregate(value);
            } else {
                throw new IllegalStateException("aggregate: " + AggregatorUtils
                        .getUnregisteredAggregatorMessage(name, threadAggregatorMap.size() != 0, conf));
            }
        }

        @Override
        public <A extends Writable> A getAggregatedValue(String name) {
            return WorkerAggregatorHandler.this.<A>getAggregatedValue(name);
        }

        @Override
        public void finishThreadComputation() {
            // Aggregate the values this thread's vertices provided back to
            // WorkerAggregatorHandler
            for (Map.Entry<String, Aggregator<Writable>> entry : threadAggregatorMap.entrySet()) {
                WorkerAggregatorHandler.this.aggregate(entry.getKey(), entry.getValue().getAggregatedValue());
            }
        }
    }
}