com.antbrains.crf.hadoop.CalcFeatureWeights.java Source code

Java tutorial

Introduction

Here is the source code for com.antbrains.crf.hadoop.CalcFeatureWeights.java

Source

/**
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package com.antbrains.crf.hadoop;

import gnu.trove.map.hash.TObjectIntHashMap;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.apache.commons.codec.binary.Base64;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;

import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.Reducer.Context;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.hadoop.mapreduce.Counter;

import com.antbrains.crf.BESB1B2MTagConvertor;
import com.antbrains.crf.FeatureDict;
import com.antbrains.crf.Instance;
import com.antbrains.crf.SgdCrf;
import com.antbrains.crf.TagConvertor;
import com.antbrains.crf.Template;
import com.antbrains.crf.TrainingDataSet;
import com.antbrains.crf.TrainingParams;
import com.antbrains.crf.TrainingProgress;
import com.antbrains.crf.TrainingWeights;
import com.antbrains.crf.hadoop.ParallelTraining2.TrainingMapper;
import com.google.gson.Gson;

public class CalcFeatureWeights {

    public static class CalcFeatureMapper extends Mapper<Text, DoubleWritable, IntWritable, MyKey> {
        @Override
        public void map(Text key, DoubleWritable value, final Context context)
                throws IOException, InterruptedException {
            String s = key.toString();
            if (!s.startsWith(TrainingMapper.FEATURE_WEIGHT)) {
                return;
            }

            int idx = Integer.parseInt(s.substring(2));
            int fId = idx / 6;
            MyKey mk = new MyKey(idx % 6, value.get());
            context.write(new IntWritable(fId), mk);
        }
    }

    public static class IdentityMapper extends Mapper<MyKey, MyValue, MyKey, MyValue> {
        @Override
        public void map(MyKey key, MyValue value, final Context context) throws IOException, InterruptedException {
            context.write(key, value);
        }
    }

    public static class IdentityReducer extends Reducer<MyKey, MyValue, MyKey, MyValue> {
        @Override
        public void reduce(MyKey key, Iterable<MyValue> values, final Context context)
                throws IOException, InterruptedException {
            for (MyValue value : values) {
                context.write(key, value);
            }
        }
    }

    public static class CalcFeatureReducer extends Reducer<IntWritable, MyKey, MyKey, MyValue> {
        @Override
        protected void reduce(IntWritable key, Iterable<MyKey> values, Context context)
                throws IOException, InterruptedException {
            double w = 0;
            int total = 0;
            double[] array = new double[6];
            for (MyKey value : values) {
                total++;
                w += value.score * value.score;
                array[value.id] = value.score;
            }
            if (total != 6) {
                throw new IOException("not 6 for: " + key.get());
            }

            MyKey k = new MyKey(key.get(), w);
            MyValue v = new MyValue(array);
            context.write(k, v);
        }

    }

    public static void main(String[] args) throws Exception {
        Configuration conf = new Configuration();
        String[] otherArgs = new GenericOptionsParser(conf, args).getRemainingArgs();

        if (otherArgs.length != 3 && otherArgs.length != 4) {
            System.err.println("CalcFeatureWeights <inDir> <tmpDir> <outDir> [startStep]");
            System.exit(-1);
        }
        int startStep = 1;
        if (otherArgs.length == 4) {
            startStep = Integer.valueOf(otherArgs[otherArgs.length - 1]);
        }
        FileSystem fs = FileSystem.get(conf);
        if (startStep <= 1) {
            System.out.println("calc");
            fs.delete(new Path(otherArgs[1]), true);
            Job job = new Job(conf, CalcFeatureWeights.class.getSimpleName());
            job.setNumReduceTasks(1);
            job.setJarByClass(CalcFeatureWeights.class);
            job.setMapperClass(CalcFeatureMapper.class);
            job.setReducerClass(CalcFeatureReducer.class);

            job.setOutputFormatClass(SequenceFileOutputFormat.class);

            job.setInputFormatClass(SequenceFileInputFormat.class);

            job.setMapOutputKeyClass(IntWritable.class);
            job.setMapOutputValueClass(MyKey.class);

            job.setOutputKeyClass(MyKey.class);
            job.setOutputValueClass(MyValue.class);
            FileInputFormat.setInputPaths(job, new Path(otherArgs[0]));

            FileOutputFormat.setOutputPath(job, new Path(otherArgs[1]));

            boolean res = job.waitForCompletion(true);
            if (!res) {
                System.err.println("step1 failed");
                return;
            }
        }

        if (startStep <= 2)
        // sort
        {
            fs.delete(new Path(otherArgs[2]), true);
            System.out.println("sort");
            Job job = new Job(conf, CalcFeatureWeights.class.getSimpleName());

            job.setNumReduceTasks(1);
            job.setJarByClass(CalcFeatureWeights.class);
            job.setMapperClass(IdentityMapper.class);
            job.setReducerClass(IdentityReducer.class);

            job.setOutputFormatClass(SequenceFileOutputFormat.class);

            job.setInputFormatClass(SequenceFileInputFormat.class);

            job.setMapOutputKeyClass(MyKey.class);
            job.setMapOutputValueClass(MyValue.class);
            job.setOutputKeyClass(MyKey.class);
            job.setOutputValueClass(MyValue.class);

            FileInputFormat.setInputPaths(job, new Path(otherArgs[1]));

            FileOutputFormat.setOutputPath(job, new Path(otherArgs[2]));

            boolean res = job.waitForCompletion(true);
            if (!res) {
                System.err.println("step2 failed");
                return;
            }
        }

    }
}