org.trustedanalytics.atk.giraph.algorithms.lp.LabelPropagationComputation.java Source code

Java tutorial

Introduction

Here is the source code for org.trustedanalytics.atk.giraph.algorithms.lp.LabelPropagationComputation.java

Source

/**
 *  Copyright(c)2015 IntelCorporation
 *
 *  LicensedundertheApacheLicense,Version2.0(the"License");
 *  youmaynotusethisfileexceptincompliancewiththeLicense.
 *  YoumayobtainacopyoftheLicenseat
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unlessrequiredbyapplicablelaworagreedtoinwriting,software
 *  distributedundertheLicenseisdistributedonan"ASIS"BASIS,
 *  WITHOUTWARRANTIESORCONDITIONSOFANYKIND,eitherexpressorimplied.
 *  SeetheLicenseforthespecificlanguagegoverningpermissionsand
 *  limitationsundertheLicense.
 */

package org.trustedanalytics.atk.giraph.algorithms.lp;

import org.apache.giraph.Algorithm;
import org.apache.giraph.aggregators.AggregatorWriter;
import org.apache.giraph.aggregators.DoubleSumAggregator;
import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
import org.apache.giraph.counters.GiraphStats;
import org.apache.giraph.edge.Edge;
import org.apache.giraph.graph.BasicComputation;
import org.apache.giraph.graph.Vertex;
import org.apache.giraph.master.DefaultMasterCompute;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Mapper.Context;
import org.apache.mahout.math.Vector;
import org.trustedanalytics.atk.giraph.config.lp.LabelPropagationConfig;
import org.trustedanalytics.atk.giraph.config.lp.LabelPropagationConfiguration;
import org.trustedanalytics.atk.giraph.io.IdWithVectorMessage;
import org.trustedanalytics.atk.giraph.io.VertexData4LPWritable;

import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map.Entry;

/**
 * Label Propagation on Gaussian Random Fields
 * The algorithm presented in:
 * X. Zhu and Z. Ghahramani. Learning from labeled and unlabeled data with
 * label propagation. Technical Report CMU-CALD-02-107, CMU, 2002.
 */
@Algorithm(name = "Label Propagation on Gaussian Random Fields")
public class LabelPropagationComputation
        extends BasicComputation<LongWritable, VertexData4LPWritable, DoubleWritable, IdWithVectorMessage> {

    /** Aggregator name for sum of cost at each super step */
    private static String SUM_COST = "sum_cost";

    /** Cost of previous super step for convergence monitoring */
    private static String PREV_COST = "prev_cost";

    /** Number of super steps */
    private int maxSupersteps;

    /** The trade-off parameter between prior and posterior */
    private float lambda;

    /** Default to an invalid value */
    private Vector initialVectorValues;

    @Override
    public void preSuperstep() {
        LabelPropagationConfig config = new LabelPropagationConfiguration(getConf()).getConfig();

        maxSupersteps = config.maxIterations();
        lambda = config.lambda();
        initialVectorValues = null;
    }

    /**
     * initialize vertex and edges
     *
     * @param vertex a graph vertex
     */
    private void initializeVertexEdges(Vertex<LongWritable, VertexData4LPWritable, DoubleWritable> vertex) {

        // normalize prior and initialize posterior
        VertexData4LPWritable vertexValue = vertex.getValue();
        Vector priorValues = vertexValue.getPriorVector();
        if (null != priorValues) {
            priorValues = priorValues.normalize(1d);
            initialVectorValues = priorValues;
        } else if (initialVectorValues != null) {
            priorValues = initialVectorValues;
            vertexValue.setLabeledStatus(false);
        } else {
            throw new RuntimeException("Vector labels missing from input data for vertex " + vertex.getId()
                    + ". Add edge with vertex as first column.");
        }
        vertexValue.setPriorVector(priorValues);
        vertexValue.setPosteriorVector(priorValues.clone());
        vertexValue.setDegree(initializeEdge(vertex));

        // send out messages
        IdWithVectorMessage newMessage = new IdWithVectorMessage(vertex.getId().get(), priorValues);
        sendMessageToAllEdges(vertex, newMessage);
    }

    @Override
    public void compute(Vertex<LongWritable, VertexData4LPWritable, DoubleWritable> vertex,
            Iterable<IdWithVectorMessage> messages) throws IOException {
        long superStep = getSuperstep();

        if (superStep == 0) {
            initializeVertexEdges(vertex);
            vertex.voteToHalt();
        } else if (superStep <= maxSupersteps) {
            VertexData4LPWritable vertexValue = vertex.getValue();
            Vector prior = vertexValue.getPriorVector();
            Vector posterior = vertexValue.getPosteriorVector();
            double degree = vertexValue.getDegree();

            // collect messages sent to this vertex
            HashMap<Long, Vector> map = new HashMap();
            for (IdWithVectorMessage message : messages) {
                map.put(message.getData(), message.getVector());
            }

            // Update belief and calculate cost
            double hi = prior.getQuick(0);
            double fi = posterior.getQuick(0);
            double crossSum = 0d;
            Vector newBelief = posterior.clone().assign(0d);

            for (Edge<LongWritable, DoubleWritable> edge : vertex.getEdges()) {
                double weight = edge.getValue().get();
                if (weight <= 0d) {
                    throw new IllegalArgumentException(
                            "Vertex ID: " + vertex.getId() + "has an edge with negative or zero value");
                }
                long targetVertex = edge.getTargetVertexId().get();
                if (map.containsKey(targetVertex)) {
                    Vector tempVector = map.get(targetVertex);
                    newBelief = newBelief.plus(tempVector.times(weight));
                    double fj = tempVector.getQuick(0);
                    crossSum += weight * fi * fj;
                }
            }

            double cost = degree
                    * ((1 - lambda) * (Math.pow(fi, 2) - crossSum) + 0.5 * lambda * Math.pow((fi - hi), 2));
            aggregate(SUM_COST, new DoubleWritable(cost));

            // Update posterior if the vertex was not processed
            if (vertexValue.wasLabeled() == false) {
                newBelief = (newBelief.times(1 - lambda).plus(prior.times(lambda))).normalize(1d);
                vertexValue.setPosteriorVector(newBelief);
            }

            // Send out messages if not the last step
            if (superStep != maxSupersteps) {
                IdWithVectorMessage newMessage = new IdWithVectorMessage(vertex.getId().get(),
                        vertexValue.getPosteriorVector());
                sendMessageToAllEdges(vertex, newMessage);
            }
        }

        vertex.voteToHalt();
    }

    private double calculateVertexDegree(Vertex<LongWritable, VertexData4LPWritable, DoubleWritable> vertex) {
        double degree = 0d;
        for (Edge<LongWritable, DoubleWritable> edge : vertex.getEdges()) {
            double weight = edge.getValue().get();
            if (weight <= 0d) {
                throw new IllegalArgumentException(
                        "Vertex ID: " + vertex.getId() + " has an edge with negative or zero weight value.");
            }
            degree += weight;
        }
        return degree;
    }

    private double initializeEdge(Vertex<LongWritable, VertexData4LPWritable, DoubleWritable> vertex) {
        double degree = calculateVertexDegree(vertex);

        for (Edge<LongWritable, DoubleWritable> edge : vertex.getMutableEdges()) {
            edge.getValue().set(edge.getValue().get() / degree);
        }

        return degree;
    }

    /**
     * Master compute associated with {@link LabelPropagationComputation}. It registers required aggregators.
     */
    public static class LabelPropagationMasterCompute extends DefaultMasterCompute {

        @Override
        public void initialize() throws InstantiationException, IllegalAccessException {
            registerAggregator(SUM_COST, DoubleSumAggregator.class);
        }

        @Override
        public void compute() {
            if (getSuperstep() > 1) {

                DoubleWritable sumCost = getAggregatedValue(SUM_COST);
                double cost = sumCost.get() / getTotalNumVertices();
                sumCost.set(cost);

                getConf().setFloat(PREV_COST, (float) cost);
            }
        }
    }

    /**
     * This is an aggregator writer for label propagation, which after each super step will persist the
     * aggregator values to disk, by use of the Writable interface.
     */
    public static class LabelPropagationAggregatorWriter implements AggregatorWriter {
        /** Name of the file we wrote to */
        private static String filename;

        /** Saved output stream to write to */
        private FSDataOutputStream output;

        /** Last superstep number */
        private long currentStep = -1L;

        /** giraph configuration */
        private ImmutableClassesGiraphConfiguration conf;

        public static String getFilename() {
            return filename;
        }

        @SuppressWarnings("rawtypes")
        @Override
        public void initialize(Context context, long applicationAttempt) throws IOException {

            setFilename(applicationAttempt);
            String outputDir = context.getConfiguration().get("mapred.output.dir");
            Path outputPath = new Path(outputDir + File.separator + filename);
            FileSystem fileSystem = FileSystem.get(context.getConfiguration());

            if (fileSystem.exists(outputPath)) {
                fileSystem.delete(outputPath, true);
            }
            output = fileSystem.create(outputPath, true);
        }

        /**
         * Set filename written to
         *
         * @param applicationAttempt of type long
         */
        private static void setFilename(long applicationAttempt) {
            filename = "lp-learning-report_" + applicationAttempt;
        }

        @Override
        public void writeAggregator(Iterable<Entry<String, Writable>> aggregatorMap, long superstep)
                throws IOException {
            LabelPropagationConfig config = new LabelPropagationConfiguration(getConf()).getConfig();

            if (currentStep == 0) {

                output.writeBytes("======Graph Statistics======\n");
                output.writeBytes(String.format("Number of vertices: %d%n",
                        GiraphStats.getInstance().getVertices().getValue()));
                output.writeBytes(
                        String.format("Number of edges: %d%n", GiraphStats.getInstance().getEdges().getValue()));
                output.writeBytes("\n");

                output.writeBytes("======LP Configuration======\n");
                output.writeBytes(String.format("lambda: %f%n", config.lambda()));
                output.writeBytes(String.format("convergence threshold: %f%n", config.convergenceThreshold()));
                output.writeBytes(String.format("max iterations: %d%n", config.maxIterations()));
                output.writeBytes("\n");
                output.writeBytes("======Learning Progress======\n");
            } else if (currentStep > 0) {

                // collect aggregator data
                HashMap<String, String> map = new HashMap();
                for (Entry<String, Writable> entry : aggregatorMap) {
                    map.put(entry.getKey(), entry.getValue().toString());
                }

                // output learning progress
                output.writeBytes(String.format("superstep = %d%c", currentStep, '\t'));
                double cost = Double.parseDouble(map.get(SUM_COST));
                output.writeBytes(String.format("cost = %f%n", cost));
            }
            output.flush();
            currentStep = superstep;
        }

        @Override
        public void close() throws IOException {
            output.close();
        }

        /**
         * Set the configuration to be used by this object.
         *
         * @param configuration Set configuration
         */
        @Override
        public void setConf(ImmutableClassesGiraphConfiguration configuration) {
            this.conf = configuration;
        }

        /**
         * Return the configuration used by this object.
         *
         * @return Set configuration
         */
        @Override
        public ImmutableClassesGiraphConfiguration getConf() {
            return conf;
        }
    }
}