org.apache.mahout.ep.State.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.ep.State.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.ep;

import com.google.common.collect.Lists;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.sgd.PolymorphicWritable;
import org.apache.mahout.common.RandomUtils;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Locale;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * Records evolutionary state and provides a mutation operation for recorded-step meta-mutation.
 *
 * You provide the payload, this class provides the mutation operations.  During mutation,
 * the payload is copied and after the state variables are changed, they are passed to the
 * payload.
 *
 * Parameters are internally mutated in a state space that spans all of R^n, but parameters
 * passed to the payload are transformed as specified by a call to setMap().  The default
 * mapping is the identity map, but uniform-ish or exponential-ish coverage of a range are
 * also supported.
 *
 * More information on the underlying algorithm can be found in the following paper
 *
 * http://arxiv.org/abs/0803.3838
 *
 * @see Mapping
 */
public class State<T extends Payload<U>, U> implements Comparable<State<T, U>>, Writable {

    // object count is kept to break ties in comparison.
    private static final AtomicInteger OBJECT_COUNT = new AtomicInteger();

    private int id = OBJECT_COUNT.getAndIncrement();
    private Random gen = RandomUtils.getRandom();
    // current state
    private double[] params;
    // mappers to transform state
    private Mapping[] maps;
    // omni-directional mutation
    private double omni;
    // directional mutation
    private double[] step;
    // current fitness value
    private double value;
    private T payload;

    public State() {
    }

    /**
     * Invent a new state with no momentum (yet).
     */
    public State(double[] x0, double omni) {
        params = Arrays.copyOf(x0, x0.length);
        this.omni = omni;
        step = new double[params.length];
        maps = new Mapping[params.length];
    }

    /**
     * Deep copies a state, useful in mutation.
     */
    public State<T, U> copy() {
        State<T, U> r = new State<>();
        r.params = Arrays.copyOf(this.params, this.params.length);
        r.omni = this.omni;
        r.step = Arrays.copyOf(this.step, this.step.length);
        r.maps = Arrays.copyOf(this.maps, this.maps.length);
        if (this.payload != null) {
            r.payload = (T) this.payload.copy();
        }
        r.gen = this.gen;
        return r;
    }

    /**
     * Clones this state with a random change in position.  Copies the payload and
     * lets it know about the change.
     *
     * @return A new state.
     */
    public State<T, U> mutate() {
        double sum = 0;
        for (double v : step) {
            sum += v * v;
        }
        sum = Math.sqrt(sum);
        double lambda = 1 + gen.nextGaussian();

        State<T, U> r = this.copy();
        double magnitude = 0.9 * omni + sum / 10;
        r.omni = magnitude * -Math.log1p(-gen.nextDouble());
        for (int i = 0; i < step.length; i++) {
            r.step[i] = lambda * step[i] + r.omni * gen.nextGaussian();
            r.params[i] += r.step[i];
        }
        if (this.payload != null) {
            r.payload.update(r.getMappedParams());
        }
        return r;
    }

    /**
     * Defines the transformation for a parameter.
     * @param i Which parameter's mapping to define.
     * @param m The mapping to use.
     * @see org.apache.mahout.ep.Mapping
     */
    public void setMap(int i, Mapping m) {
        maps[i] = m;
    }

    /**
     * Returns a transformed parameter.
     * @param i  The parameter to return.
     * @return The value of the parameter.
     */
    public double get(int i) {
        Mapping m = maps[i];
        return m == null ? params[i] : m.apply(params[i]);
    }

    public int getId() {
        return id;
    }

    public double[] getParams() {
        return params;
    }

    public Mapping[] getMaps() {
        return maps;
    }

    /**
     * Returns all the parameters in mapped form.
     * @return An array of parameters.
     */
    public double[] getMappedParams() {
        double[] r = Arrays.copyOf(params, params.length);
        for (int i = 0; i < params.length; i++) {
            r[i] = get(i);
        }
        return r;
    }

    public double getOmni() {
        return omni;
    }

    public double[] getStep() {
        return step;
    }

    public T getPayload() {
        return payload;
    }

    public double getValue() {
        return value;
    }

    public void setOmni(double omni) {
        this.omni = omni;
    }

    public void setId(int id) {
        this.id = id;
    }

    public void setStep(double[] step) {
        this.step = step;
    }

    public void setMaps(Mapping[] maps) {
        this.maps = maps;
    }

    public void setMaps(Iterable<Mapping> maps) {
        Collection<Mapping> list = Lists.newArrayList(maps);
        this.maps = list.toArray(new Mapping[list.size()]);
    }

    public void setValue(double v) {
        value = v;
    }

    public void setPayload(T payload) {
        this.payload = payload;
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof State)) {
            return false;
        }
        State<?, ?> other = (State<?, ?>) o;
        return id == other.id && value == other.value;
    }

    @Override
    public int hashCode() {
        return RandomUtils.hashDouble(value) ^ id;
    }

    /**
     * Natural order is to sort in descending order of score.  Creation order is used as a
     * tie-breaker.
     *
     * @param other The state to compare with.
     * @return -1, 0, 1 if the other state is better, identical or worse than this one.
     */
    @Override
    public int compareTo(State<T, U> other) {
        int r = Double.compare(other.value, this.value);
        if (r != 0) {
            return r;
        }
        if (this.id < other.id) {
            return -1;
        }
        if (this.id > other.id) {
            return 1;
        }
        return 0;
    }

    @Override
    public String toString() {
        double sum = 0;
        for (double v : step) {
            sum += v * v;
        }
        return String.format(Locale.ENGLISH, "<S/%s %.3f %.3f>", payload, omni + Math.sqrt(sum), value);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        out.writeInt(id);
        out.writeInt(params.length);
        for (double v : params) {
            out.writeDouble(v);
        }
        for (Mapping map : maps) {
            PolymorphicWritable.write(out, map);
        }

        out.writeDouble(omni);
        for (double v : step) {
            out.writeDouble(v);
        }

        out.writeDouble(value);
        PolymorphicWritable.write(out, payload);
    }

    @Override
    public void readFields(DataInput input) throws IOException {
        id = input.readInt();
        int n = input.readInt();
        params = new double[n];
        for (int i = 0; i < n; i++) {
            params[i] = input.readDouble();
        }

        maps = new Mapping[n];
        for (int i = 0; i < n; i++) {
            maps[i] = PolymorphicWritable.read(input, Mapping.class);
        }
        omni = input.readDouble();
        step = new double[n];
        for (int i = 0; i < n; i++) {
            step[i] = input.readDouble();
        }
        value = input.readDouble();
        payload = (T) PolymorphicWritable.read(input, Payload.class);
    }
}