org.apache.mahout.clustering.AbstractCluster.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.clustering.AbstractCluster.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.clustering;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import org.apache.hadoop.conf.Configuration;
import org.apache.mahout.common.parameters.Parameter;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.SquareRootFunction;
import org.codehaus.jackson.map.ObjectMapper;

public abstract class AbstractCluster implements Cluster {

    // cluster persistent state
    private int id;

    private long numObservations;

    private long totalObservations;

    private Vector center;

    private Vector radius;

    // the observation statistics
    private double s0;

    private Vector s1;

    private Vector s2;

    private static final ObjectMapper jxn = new ObjectMapper();

    protected AbstractCluster() {
    }

    protected AbstractCluster(Vector point, int id2) {
        this.numObservations = (long) 0;
        this.totalObservations = (long) 0;
        this.center = point.clone();
        this.radius = center.like();
        this.s0 = (double) 0;
        this.s1 = center.like();
        this.s2 = center.like();
        this.id = id2;
    }

    protected AbstractCluster(Vector center2, Vector radius2, int id2) {
        this.numObservations = (long) 0;
        this.totalObservations = (long) 0;
        this.center = new RandomAccessSparseVector(center2);
        this.radius = new RandomAccessSparseVector(radius2);
        this.s0 = (double) 0;
        this.s1 = center.like();
        this.s2 = center.like();
        this.id = id2;
    }

    @Override
    public void write(DataOutput out) throws IOException {
        out.writeInt(id);
        out.writeLong(getNumObservations());
        out.writeLong(getTotalObservations());
        VectorWritable.writeVector(out, getCenter());
        VectorWritable.writeVector(out, getRadius());
        out.writeDouble(s0);
        VectorWritable.writeVector(out, s1);
        VectorWritable.writeVector(out, s2);
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        this.id = in.readInt();
        this.setNumObservations(in.readLong());
        this.setTotalObservations(in.readLong());
        this.setCenter(VectorWritable.readVector(in));
        this.setRadius(VectorWritable.readVector(in));
        this.setS0(in.readDouble());
        this.setS1(VectorWritable.readVector(in));
        this.setS2(VectorWritable.readVector(in));
    }

    @Override
    public void configure(Configuration job) {
        // nothing to do
    }

    @Override
    public Collection<Parameter<?>> getParameters() {
        return Collections.emptyList();
    }

    @Override
    public void createParameters(String prefix, Configuration jobConf) {
        // nothing to do
    }

    @Override
    public int getId() {
        return id;
    }

    /**
     * @param id
     *          the id to set
     */
    protected void setId(int id) {
        this.id = id;
    }

    @Override
    public long getNumObservations() {
        return numObservations;
    }

    /**
     * @param l
     *          the numPoints to set
     */
    protected void setNumObservations(long l) {
        this.numObservations = l;
    }

    @Override
    public long getTotalObservations() {
        return totalObservations;
    }

    protected void setTotalObservations(long totalPoints) {
        this.totalObservations = totalPoints;
    }

    @Override
    public Vector getCenter() {
        return center;
    }

    /**
     * @param center
     *          the center to set
     */
    protected void setCenter(Vector center) {
        this.center = center;
    }

    @Override
    public Vector getRadius() {
        return radius;
    }

    /**
     * @param radius
     *          the radius to set
     */
    protected void setRadius(Vector radius) {
        this.radius = radius;
    }

    /**
     * @return the s0
     */
    protected double getS0() {
        return s0;
    }

    protected void setS0(double s0) {
        this.s0 = s0;
    }

    /**
     * @return the s1
     */
    protected Vector getS1() {
        return s1;
    }

    protected void setS1(Vector s1) {
        this.s1 = s1;
    }

    /**
     * @return the s2
     */
    protected Vector getS2() {
        return s2;
    }

    protected void setS2(Vector s2) {
        this.s2 = s2;
    }

    @Override
    public void observe(Model<VectorWritable> x) {
        AbstractCluster cl = (AbstractCluster) x;
        setS0(getS0() + cl.getS0());
        setS1(getS1().plus(cl.getS1()));
        setS2(getS2().plus(cl.getS2()));
    }

    @Override
    public void observe(VectorWritable x) {
        observe(x.get());
    }

    @Override
    public void observe(VectorWritable x, double weight) {
        observe(x.get(), weight);
    }

    public void observe(Vector x, double weight) {
        if (weight == 1.0) {
            observe(x);
        } else {
            setS0(getS0() + weight);
            Vector weightedX = x.times(weight);
            if (getS1() == null) {
                setS1(weightedX);
            } else {
                getS1().assign(weightedX, Functions.PLUS);
            }
            Vector x2 = x.times(x).times(weight);
            if (getS2() == null) {
                setS2(x2);
            } else {
                getS2().assign(x2, Functions.PLUS);
            }
        }
    }

    public void observe(Vector x) {
        setS0(getS0() + 1);
        if (getS1() == null) {
            setS1(x.clone());
        } else {
            getS1().assign(x, Functions.PLUS);
        }
        Vector x2 = x.times(x);
        if (getS2() == null) {
            setS2(x2);
        } else {
            getS2().assign(x2, Functions.PLUS);
        }
    }

    @Override
    public void computeParameters() {
        if (getS0() == 0) {
            return;
        }
        setNumObservations((long) getS0());
        setTotalObservations(getTotalObservations() + getNumObservations());
        setCenter(getS1().divide(getS0()));
        // compute the component stds
        if (getS0() > 1) {
            setRadius(getS2().times(getS0()).minus(getS1().times(getS1())).assign(new SquareRootFunction())
                    .divide(getS0()));
        }
        setS0(0);
        setS1(center.like());
        setS2(center.like());
    }

    @Override
    public String asFormatString(String[] bindings) {
        String fmtString = "";
        try {
            fmtString = jxn.writeValueAsString(asJson(bindings));
        } catch (IOException e) {
            log.error("Error writing JSON as String.", e);
        }
        return fmtString;
    }

    public Map<String, Object> asJson(String[] bindings) {
        Map<String, Object> dict = new HashMap<>();
        dict.put("identifier", getIdentifier());
        dict.put("n", getNumObservations());
        if (getCenter() != null) {
            try {
                dict.put("c", formatVectorAsJson(getCenter(), bindings));
            } catch (IOException e) {
                log.error("IOException:  ", e);
            }
        }
        if (getRadius() != null) {
            try {
                dict.put("r", formatVectorAsJson(getRadius(), bindings));
            } catch (IOException e) {
                log.error("IOException:  ", e);
            }
        }
        return dict;
    }

    public abstract String getIdentifier();

    /**
     * Compute the centroid by averaging the pointTotals
     * 
     * @return the new centroid
     */
    public Vector computeCentroid() {
        return getS0() == 0 ? getCenter() : getS1().divide(getS0());
    }

    /**
     * Return a human-readable formatted string representation of the vector, not
     * intended to be complete nor usable as an input/output representation
     */
    public static String formatVector(Vector v, String[] bindings) {
        String fmtString = "";
        try {
            fmtString = jxn.writeValueAsString(formatVectorAsJson(v, bindings));
        } catch (IOException e) {
            log.error("Error writing JSON as String.", e);
        }
        return fmtString;
    }

    /**
     * Create a List of HashMaps containing vector terms and weights
     *
     * @return List<Object>
     */
    public static List<Object> formatVectorAsJson(Vector v, String[] bindings) throws IOException {

        boolean hasBindings = bindings != null;
        boolean isSparse = v.getNumNonZeroElements() != v.size();

        // we assume sequential access in the output
        Vector provider = v.isSequentialAccess() ? v : new SequentialAccessSparseVector(v);

        List<Object> terms = new LinkedList<>();
        String term = "";

        for (Element elem : provider.nonZeroes()) {

            if (hasBindings && bindings.length >= elem.index() + 1 && bindings[elem.index()] != null) {
                term = bindings[elem.index()];
            } else if (hasBindings || isSparse) {
                term = String.valueOf(elem.index());
            }

            Map<String, Object> term_entry = new HashMap<>();
            double roundedWeight = (double) Math.round(elem.get() * 1000) / 1000;
            if (hasBindings || isSparse) {
                term_entry.put(term, roundedWeight);
                terms.add(term_entry);
            } else {
                terms.add(roundedWeight);
            }
        }

        return terms;
    }

    @Override
    public boolean isConverged() {
        // Convergence has no meaning yet, perhaps in subclasses
        return false;
    }
}