edu.emory.mathcs.nlp.learn.weight.MultinomialWeightVector.java Source code

Java tutorial

Introduction

Here is the source code for edu.emory.mathcs.nlp.learn.weight.MultinomialWeightVector.java

Source

/**
 * Copyright 2015, Emory University
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package edu.emory.mathcs.nlp.learn.weight;

import java.util.Arrays;

import org.apache.commons.math3.util.FastMath;

import edu.emory.mathcs.nlp.common.util.DSUtils;
import edu.emory.mathcs.nlp.learn.util.Prediction;
import edu.emory.mathcs.nlp.learn.vector.IndexValuePair;
import edu.emory.mathcs.nlp.learn.vector.Vector;

/**
 * @author Jinho D. Choi ({@code jinho.choi@emory.edu})
 */
public class MultinomialWeightVector extends WeightVector {
    private static final long serialVersionUID = 2190946158451118027L;

    public MultinomialWeightVector() {
        super(0, 0);
    }

    public MultinomialWeightVector(int labelSize, int featureSize) {
        super(labelSize, featureSize);
    }

    public void init(int labelSize, int featureSize) {
        weight_vector = new float[labelSize * featureSize];
        label_size = labelSize;
        feature_size = featureSize;
    }

    @Override
    public boolean expand(int labelSize, int featureSize) {
        if (labelSize < label_size || featureSize < feature_size
                || (labelSize == label_size && featureSize == feature_size))
            return false;
        int i, j, diff = labelSize - label_size;
        float[] vector;

        if (diff > 0) {
            vector = new float[labelSize * featureSize];
            int size = label_size * feature_size;

            for (i = 0, j = 0; i < size; i++, j++) {
                if (i > 0 && i % label_size == 0)
                    j += diff;
                vector[j] = weight_vector[i];
            }
        } else
            vector = Arrays.copyOf(weight_vector, labelSize * featureSize);

        weight_vector = vector;
        label_size = labelSize;
        feature_size = featureSize;
        return true;
    }

    @Override
    public int indexOf(int y, int xi) {
        return y + indexOf(xi);
    }

    private int indexOf(int xi) {
        return xi * label_size;
    }

    @Override
    public double[] scores(Vector x) {
        double[] scores = new double[label_size];
        int i, index;

        for (IndexValuePair p : x) {
            if (p.getIndex() < feature_size) {
                index = indexOf(p.getIndex());

                for (i = 0; i < label_size; i++)
                    scores[i] += weight_vector[index + i] * p.getValue();
            }
        }

        if (isRegression()) {
            double sum = 0;

            for (i = 0; i < label_size; i++) {
                scores[i] = FastMath.exp(scores[i]);
                sum += scores[i];
            }

            for (i = 0; i < label_size; i++)
                scores[i] /= sum;
        }

        return scores;
    }

    @Override
    public Prediction predictBest(Vector x) {
        double[] scores = scores(x);
        int label = DSUtils.maxIndex(scores);
        return new Prediction(label, scores[label]);
    }

    //   @Override
    //   public Pair<Prediction,Prediction> predictTop2(Vector x)
    //   {
    //      double[] scores = scores(x);
    //      Prediction fst, snd;
    //      
    //      if (scores[0] < scores[1])
    //      {
    //         fst = new Prediction(1, scores[1]);
    //         snd = new Prediction(0, scores[0]);
    //      }
    //      else
    //      {
    //         fst = new Prediction(0, scores[0]);         
    //         snd = new Prediction(1, scores[1]);
    //      }
    //      
    //      for (int i=2; i<label_size; i++)
    //      {
    //         if (fst.getScore() < scores[i])
    //         {
    //            snd.copy(fst);
    //            fst.set(i, scores[i]);
    //         }
    //         else if (snd.getScore() < scores[i])
    //            snd.set(i, scores[i]);
    //      }
    //      
    //      return new Pair<Prediction,Prediction>(fst, snd);
    //   }
    //
    //   @Override
    //   public Prediction[] predictAll(Vector x)
    //   {
    //      double[] scores = scores(x);
    //      Prediction[] ps = new Prediction[label_size];
    //      
    //      for (int i=0; i<label_size; i++)
    //         ps[i] = new Prediction(i, scores[i]);
    //      
    //      return ps;
    //   }
}