hivemall.fm.FFMPredictionModel.java Source code

Java tutorial

Introduction

Here is the source code for hivemall.fm.FFMPredictionModel.java

Source

/*
 * Hivemall: Hive scalable Machine Learning Library
 *
 * Copyright (C) 2015 Makoto YUI
 * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package hivemall.fm;

import hivemall.utils.buffer.HeapBuffer;
import hivemall.utils.codec.VariableByteCodec;
import hivemall.utils.codec.ZigZagLEB128Codec;
import hivemall.utils.collections.Int2LongOpenHashTable;
import hivemall.utils.collections.IntOpenHashTable;
import hivemall.utils.io.CompressionStreamFactory.CompressionAlgorithm;
import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.ArrayUtils;
import hivemall.utils.lang.HalfFloat;
import hivemall.utils.lang.ObjectUtils;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Arrays;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public final class FFMPredictionModel implements Externalizable {
    private static final Log LOG = LogFactory.getLog(FFMPredictionModel.class);

    private static final byte HALF_FLOAT_ENTRY = 1;
    private static final byte W_ONLY_HALF_FLOAT_ENTRY = 2;
    private static final byte FLOAT_ENTRY = 3;
    private static final byte W_ONLY_FLOAT_ENTRY = 4;

    /**
     * maps feature to feature weight pointer
     */
    private Int2LongOpenHashTable _map;
    private HeapBuffer _buf;

    private double _w0;
    private int _factors;
    private int _numFeatures;
    private int _numFields;

    public FFMPredictionModel() {
    }// for Externalizable

    public FFMPredictionModel(@Nonnull Int2LongOpenHashTable map, @Nonnull HeapBuffer buf, double w0, int factor,
            int numFeatures, int numFields) {
        this._map = map;
        this._buf = buf;
        this._w0 = w0;
        this._factors = factor;
        this._numFeatures = numFeatures;
        this._numFields = numFields;
    }

    public int getNumFactors() {
        return _factors;
    }

    public double getW0() {
        return _w0;
    }

    public int getNumFeatures() {
        return _numFeatures;
    }

    public int getNumFields() {
        return _numFields;
    }

    public int getActualNumFeatures() {
        return _map.size();
    }

    public long approxBytesConsumed() {
        int size = _map.size();

        // [map] size * (|state| + |key| + |entry|) 
        long bytes = size * (1L + 4L + 4L + (4L * _factors));
        int rest = _map.capacity() - size;
        if (rest > 0) {
            bytes += rest * 1L;
        }
        // w0, factors, numFeatures, numFields, used, size
        bytes += (8 + 4 + 4 + 4 + 4 + 4);
        return bytes;
    }

    @Nullable
    private Entry getEntry(final int key) {
        final long ptr = _map.get(key);
        if (ptr == -1L) {
            return null;
        }
        return new Entry(_buf, _factors, ptr);
    }

    public float getW(@Nonnull final Feature x) {
        int j = x.getFeatureIndex();

        Entry entry = getEntry(j);
        if (entry == null) {
            return 0.f;
        }
        return entry.getW();
    }

    /**
     * @return true if V exists
     */
    public boolean getV(@Nonnull final Feature x, @Nonnull final int yField, @Nonnull float[] dst) {
        int j = Feature.toIntFeature(x, yField, _numFields);

        Entry entry = getEntry(j);
        if (entry == null) {
            return false;
        }

        entry.getV(dst);
        if (ArrayUtils.equals(dst, 0.f)) {
            return false; // treat as null
        }
        return true;
    }

    @Override
    public void writeExternal(@Nonnull ObjectOutput out) throws IOException {
        out.writeDouble(_w0);
        final int factors = _factors;
        out.writeInt(factors);
        out.writeInt(_numFeatures);
        out.writeInt(_numFields);

        int used = _map.size();
        out.writeInt(used);

        final int[] keys = _map.getKeys();
        final int size = keys.length;
        out.writeInt(size);

        final byte[] states = _map.getStates();
        writeStates(states, out);

        final long[] values = _map.getValues();

        final HeapBuffer buf = _buf;
        final Entry e = new Entry(buf, factors);
        final float[] Vf = new float[factors];
        for (int i = 0; i < size; i++) {
            if (states[i] != IntOpenHashTable.FULL) {
                continue;
            }
            ZigZagLEB128Codec.writeSignedInt(keys[i], out);
            e.setOffset(values[i]);
            writeEntry(e, factors, Vf, out);
        }

        // help GC
        this._map = null;
        this._buf = null;
    }

    private static void writeEntry(@Nonnull final Entry e, final int factors, @Nonnull final float[] Vf,
            @Nonnull final DataOutput out) throws IOException {
        final float W = e.getW();
        e.getV(Vf);

        if (ArrayUtils.almostEquals(Vf, 0.f)) {
            if (HalfFloat.isRepresentable(W)) {
                out.writeByte(W_ONLY_HALF_FLOAT_ENTRY);
                out.writeShort(HalfFloat.floatToHalfFloat(W));
            } else {
                out.writeByte(W_ONLY_FLOAT_ENTRY);
                out.writeFloat(W);
            }
        } else if (isRepresentableAsHalfFloat(W, Vf)) {
            out.writeByte(HALF_FLOAT_ENTRY);
            out.writeShort(HalfFloat.floatToHalfFloat(W));
            for (int i = 0; i < factors; i++) {
                out.writeShort(HalfFloat.floatToHalfFloat(Vf[i]));
            }
        } else {
            out.writeByte(FLOAT_ENTRY);
            out.writeFloat(W);
            IOUtils.writeFloats(Vf, factors, out);
        }
    }

    private static boolean isRepresentableAsHalfFloat(final float W, @Nonnull final float[] Vf) {
        if (!HalfFloat.isRepresentable(W)) {
            return false;
        }
        for (float V : Vf) {
            if (!HalfFloat.isRepresentable(V)) {
                return false;
            }
        }
        return true;
    }

    @Nonnull
    static void writeStates(@Nonnull final byte[] status, @Nonnull final DataOutput out) throws IOException {
        // write empty states's indexes differentially
        final int size = status.length;
        int cardinarity = 0;
        for (int i = 0; i < size; i++) {
            if (status[i] != IntOpenHashTable.FULL) {
                cardinarity++;
            }
        }
        out.writeInt(cardinarity);
        if (cardinarity == 0) {
            return;
        }
        int prev = 0;
        for (int i = 0; i < size; i++) {
            if (status[i] != IntOpenHashTable.FULL) {
                int diff = i - prev;
                assert (diff >= 0);
                VariableByteCodec.encodeUnsignedInt(diff, out);
                prev = i;
            }
        }
    }

    @Override
    public void readExternal(@Nonnull final ObjectInput in) throws IOException, ClassNotFoundException {
        this._w0 = in.readDouble();
        final int factors = in.readInt();
        this._factors = factors;
        this._numFeatures = in.readInt();
        this._numFields = in.readInt();

        final int used = in.readInt();
        final int size = in.readInt();
        final int[] keys = new int[size];
        final long[] values = new long[size];
        final byte[] states = new byte[size];
        readStates(in, states);

        final int entrySize = Entry.sizeOf(factors);
        int numChunks = (entrySize * used) / HeapBuffer.DEFAULT_CHUNK_BYTES + 1;
        final HeapBuffer buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE, numChunks);
        final Entry e = new Entry(buf, factors);
        final float[] Vf = new float[factors];
        for (int i = 0; i < size; i++) {
            if (states[i] != IntOpenHashTable.FULL) {
                continue;
            }
            keys[i] = ZigZagLEB128Codec.readSignedInt(in);
            long ptr = buf.allocate(entrySize);
            e.setOffset(ptr);
            readEntry(in, factors, Vf, e);
            values[i] = ptr;
        }

        this._map = new Int2LongOpenHashTable(keys, values, states, used);
        this._buf = buf;
    }

    @Nonnull
    private static void readEntry(@Nonnull final DataInput in, final int factors, @Nonnull final float[] Vf,
            @Nonnull Entry dst) throws IOException {
        final byte type = in.readByte();
        switch (type) {
        case HALF_FLOAT_ENTRY: {
            float W = HalfFloat.halfFloatToFloat(in.readShort());
            dst.setW(W);
            for (int i = 0; i < factors; i++) {
                Vf[i] = HalfFloat.halfFloatToFloat(in.readShort());
            }
            dst.setV(Vf);
            break;
        }
        case W_ONLY_HALF_FLOAT_ENTRY: {
            float W = HalfFloat.halfFloatToFloat(in.readShort());
            dst.setW(W);
            break;
        }
        case FLOAT_ENTRY: {
            float W = in.readFloat();
            dst.setW(W);
            IOUtils.readFloats(in, Vf);
            dst.setV(Vf);
            break;
        }
        case W_ONLY_FLOAT_ENTRY: {
            float W = in.readFloat();
            dst.setW(W);
            break;
        }
        default:
            throw new IOException("Unexpected Entry type: " + type);
        }
    }

    @Nonnull
    static void readStates(@Nonnull final DataInput in, @Nonnull final byte[] status) throws IOException {
        // read non-empty states differentially
        final int cardinarity = in.readInt();
        Arrays.fill(status, IntOpenHashTable.FULL);
        int prev = 0;
        for (int j = 0; j < cardinarity; j++) {
            int i = VariableByteCodec.decodeUnsignedInt(in) + prev;
            status[i] = IntOpenHashTable.FREE;
            prev = i;
        }
    }

    public byte[] serialize() throws IOException {
        LOG.info("FFMPredictionModel#serialize(): " + _buf.toString());
        return ObjectUtils.toCompressedBytes(this, CompressionAlgorithm.lzma2, true);
    }

    public static FFMPredictionModel deserialize(@Nonnull final byte[] serializedObj, final int len)
            throws ClassNotFoundException, IOException {
        FFMPredictionModel model = new FFMPredictionModel();
        ObjectUtils.readCompressedObject(serializedObj, len, model, CompressionAlgorithm.lzma2, true);
        LOG.info("FFMPredictionModel#deserialize(): " + model._buf.toString());
        return model;
    }

}