chapter4.src.logistic.CsvRecordFactoryPredict.java Source code

Java tutorial

Introduction

Here is the source code for chapter4.src.logistic.CsvRecordFactoryPredict.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package chapter4.src.logistic;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Collections2;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

import org.apache.commons.csv.CSVUtils;
import org.apache.mahout.classifier.sgd.RecordFactory;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.ContinuousValueEncoder;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
import org.apache.mahout.vectorizer.encoders.TextValueEncoder;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Converts CSV data lines to vectors.
 *
 * Use of this class proceeds in a few steps.
 * <ul>
 * <li> At construction time, you tell the class about the target variable and provide
 * a dictionary of the types of the predictor values.  At this point,
 * the class yet cannot decode inputs because it doesn't know the fields that are in the
 * data records, nor their order.
 * <li> Optionally, you tell the parser object about the possible values of the target
 * variable.  If you don't do this then you probably should set the number of distinct
 * values so that the target variable values will be taken from a restricted range.
 * <li> Later, when you get a list of the fields, typically from the first line of a CSV
 * file, you tell the factory about these fields and it builds internal data structures
 * that allow it to decode inputs.  The most important internal state is the field numbers
 * for various fields.  After this point, you can use the factory for decoding data.
 * <li> To encode data as a vector, you present a line of input to the factory and it
 * mutates a vector that you provide.  The factory also retains trace information so
 * that it can approximately reverse engineer vectors later.
 * <li> After converting data, you can ask for an explanation of the data in terms of
 * terms and weights.  In order to explain a vector accurately, the factory needs to
 * have seen the particular values of categorical fields (typically during encoding vectors)
 * and needs to have a reasonably small number of collisions in the vector encoding.
 * </ul>
 */
public class CsvRecordFactoryPredict implements RecordFactory {
    private static final String INTERCEPT_TERM = "Intercept Term";

    private static final Map<String, Class<? extends FeatureVectorEncoder>> TYPE_DICTIONARY = ImmutableMap
            .<String, Class<? extends FeatureVectorEncoder>>builder()
            .put("continuous", ContinuousValueEncoder.class).put("numeric", ContinuousValueEncoder.class)
            .put("n", ContinuousValueEncoder.class).put("word", StaticWordValueEncoder.class)
            .put("w", StaticWordValueEncoder.class).put("text", TextValueEncoder.class)
            .put("t", TextValueEncoder.class).build();

    private final Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();

    private int target;
    private final Dictionary targetDictionary;

    //Which column is  used for identify a CSV file line 
    private String idName;
    private int id = -1;

    private List<Integer> predictors;
    private Map<Integer, FeatureVectorEncoder> predictorEncoders;
    private int maxTargetValue = Integer.MAX_VALUE;
    private final String targetName;
    private final Map<String, String> typeMap;
    private List<String> variableNames;
    private boolean includeBiasTerm;
    private static final String CANNOT_CONSTRUCT_CONVERTER = "Unable to construct type converter... shouldn't be possible";

    /**
     * Parse a single line of CSV-formatted text.
     *
     * Separated to make changing this functionality for the entire class easier
     * in the future.
     * @param line - CSV formatted text
     * @return List<String>
     */
    private List<String> parseCsvLine(String line) {
        try {
            return Arrays.asList(CSVUtils.parseLine(line));
        } catch (IOException e) {
            List<String> list = Lists.newArrayList();
            list.add(line);
            return list;
        }
    }

    private List<String> parseCsvLine(CharSequence line) {
        return parseCsvLine(line.toString());
    }

    /**
     * Construct a parser for CSV lines that encodes the parsed data in vector form.
     * @param targetName            The name of the target variable.
     * @param typeMap               A map describing the types of the predictor variables.
     */
    public CsvRecordFactoryPredict(String targetName, Map<String, String> typeMap) {
        this.targetName = targetName;
        this.typeMap = typeMap;
        targetDictionary = new Dictionary();
    }

    public CsvRecordFactoryPredict(String targetName, String idName, Map<String, String> typeMap) {
        this(targetName, typeMap);
        this.idName = idName;
    }

    /**
     * Defines the values and thus the encoding of values of the target variables.  Note
     * that any values of the target variable not present in this list will be given the
     * value of the last member of the list.
     * @param values  The values the target variable can have.
     */

    public void defineTargetCategories(List<String> values) {
        Preconditions.checkArgument(values.size() <= maxTargetValue, "Must have less than or equal to "
                + maxTargetValue + " categories for target variable, but found " + values.size());
        if (maxTargetValue == Integer.MAX_VALUE) {
            maxTargetValue = values.size();
        }

        for (String value : values) {
            targetDictionary.intern(value);
        }
    }

    /**
     * Defines the number of target variable categories, but allows this parser to
     * pick encodings for them as they appear.
     * @param max  The number of categories that will be expected.  Once this many have been
     * seen, all others will get the encoding max-1.
     */

    public CsvRecordFactoryPredict maxTargetValue(int max) {
        maxTargetValue = max;
        return this;
    }

    public boolean usesFirstLineAsSchema() {
        return true;
    }

    /**
     * Processes the first line of a file (which should contain the variable names). The target and
     * predictor column numbers are set from the names on this line.
     *
     * @param line       Header line for the file.
     */

    public void firstLine(String line) {
        // read variable names, build map of name -> column
        final Map<String, Integer> vars = Maps.newHashMap();
        variableNames = parseCsvLine(line);
        int column = 0;
        for (String var : variableNames) {
            vars.put(var, column++);
        }

        // record target column and establish dictionary for decoding target
        target = vars.get(targetName);

        // record id column
        if (idName != null) {
            id = vars.get(idName);
        }

        // create list of predictor column numbers
        predictors = Lists.newArrayList(Collections2.transform(typeMap.keySet(), new Function<String, Integer>() {

            public Integer apply(String from) {
                Integer r = vars.get(from);
                Preconditions.checkArgument(r != null, "Can't find variable %s, only know about %s", from, vars);
                return r;
            }
        }));

        if (includeBiasTerm) {
            predictors.add(-1);
        }
        Collections.sort(predictors);

        // and map from column number to type encoder for each column that is a predictor
        predictorEncoders = Maps.newHashMap();
        for (Integer predictor : predictors) {
            String name;
            Class<? extends FeatureVectorEncoder> c;
            if (predictor == -1) {
                name = INTERCEPT_TERM;
                c = ConstantValueEncoder.class;
            } else {
                name = variableNames.get(predictor);
                c = TYPE_DICTIONARY.get(typeMap.get(name));
            }
            try {
                Preconditions.checkArgument(c != null, "Invalid type of variable %s,  wanted one of %s",
                        typeMap.get(name), TYPE_DICTIONARY.keySet());
                Constructor<? extends FeatureVectorEncoder> constructor = c.getConstructor(String.class);
                Preconditions.checkArgument(constructor != null, "Can't find correct constructor for %s",
                        typeMap.get(name));
                FeatureVectorEncoder encoder = constructor.newInstance(name);
                predictorEncoders.put(predictor, encoder);
                encoder.setTraceDictionary(traceDictionary);
            } catch (InstantiationException e) {
                throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
            } catch (IllegalAccessException e) {
                throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
            } catch (InvocationTargetException e) {
                throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
            } catch (NoSuchMethodException e) {
                throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
            }
        }
    }

    public void firstLine(String line, String targetName) {
        // read variable names, build map of name -> column
        final Map<String, Integer> vars = Maps.newHashMap();
        variableNames = parseCsvLine(line);
        int column = 0;
        for (String var : variableNames) {
            vars.put(var, column++);
        }

        // record target column and establish dictionary for decoding target
        target = vars.size() + 1;

        // record id column
        if (idName != null) {
            id = vars.get(idName);
        }

        // create list of predictor column numbers
        predictors = Lists.newArrayList(Collections2.transform(typeMap.keySet(), new Function<String, Integer>() {

            public Integer apply(String from) {
                Integer r = vars.get(from);
                Preconditions.checkArgument(r != null, "Can't find variable %s, only know about %s", from, vars);
                return r;
            }
        }));

        if (includeBiasTerm) {
            predictors.add(-1);
        }
        Collections.sort(predictors);

        // and map from column number to type encoder for each column that is a predictor
        predictorEncoders = Maps.newHashMap();
        for (Integer predictor : predictors) {
            String name;
            Class<? extends FeatureVectorEncoder> c;
            if (predictor == -1) {
                name = INTERCEPT_TERM;
                c = ConstantValueEncoder.class;
            } else {
                name = variableNames.get(predictor);
                c = TYPE_DICTIONARY.get(typeMap.get(name));
            }
            try {
                Preconditions.checkArgument(c != null, "Invalid type of variable %s,  wanted one of %s",
                        typeMap.get(name), TYPE_DICTIONARY.keySet());
                Constructor<? extends FeatureVectorEncoder> constructor = c.getConstructor(String.class);
                Preconditions.checkArgument(constructor != null, "Can't find correct constructor for %s",
                        typeMap.get(name));
                FeatureVectorEncoder encoder = constructor.newInstance(name);
                predictorEncoders.put(predictor, encoder);
                encoder.setTraceDictionary(traceDictionary);
            } catch (InstantiationException e) {
                throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
            } catch (IllegalAccessException e) {
                throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
            } catch (InvocationTargetException e) {
                throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
            } catch (NoSuchMethodException e) {
                throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
            }
        }
    }

    /**
     * Decodes a single line of CSV data and records the target and predictor variables in a record.
     * As a side effect, features are added into the featureVector.  Returns the value of the target
     * variable.
     *
     * @param line          The raw data.
     * @param featureVector Where to fill in the features.  Should be zeroed before calling
     *                      processLine.
     * @return The value of the target variable.
     */

    public int processLine(String line, Vector featureVector) {
        List<String> values = parseCsvLine(line);

        int targetValue = targetDictionary.intern(values.get(target));
        if (targetValue >= maxTargetValue) {
            targetValue = maxTargetValue - 1;
        }

        for (Integer predictor : predictors) {
            String value;
            if (predictor >= 0) {
                value = values.get(predictor);
            } else {
                value = null;
            }
            predictorEncoders.get(predictor).addToVector(value, featureVector);
        }
        return targetValue;
    }

    /***
     * Decodes a single line of CSV data and records the target(if retrunTarget is true)
     * and predictor variables in a record. As a side effect, features are added into the featureVector.
     * Returns the value of the target variable. When used during classify against production data without
     * target value, the method will be called with returnTarget = false. 
     * @param line The raw data.
     * @param featureVector Where to fill in the features.  Should be zeroed before calling
     *                      processLine.
     * @param returnTarget whether process and return target value, -1 will be returned if false.
     * @return The value of the target variable.
     */
    public int processLine(CharSequence line, Vector featureVector, boolean returnTarget) {
        List<String> values = parseCsvLine(line);
        int targetValue = -1;
        if (returnTarget) {
            targetValue = targetDictionary.intern(values.get(target));
            if (targetValue >= maxTargetValue) {
                targetValue = maxTargetValue - 1;
            }
        }

        for (Integer predictor : predictors) {
            String value = predictor >= 0 ? values.get(predictor) : null;
            predictorEncoders.get(predictor).addToVector(value, featureVector);
        }
        return targetValue;
    }

    /***
     * Extract the raw target string from a line read from a CSV file.
     * @param line the line of content read from CSV file
     * @return the raw target value in the corresponding column of CSV line 
     */
    public String getTargetString(CharSequence line) {
        List<String> values = parseCsvLine(line);
        return values.get(target);

    }

    /***
     * Extract the corresponding raw target label according to a code 
     * @param code the integer code encoded during training process
     * @return the raw target label
     */
    public String getTargetLabel(int code) {
        for (String key : targetDictionary.values()) {
            if (targetDictionary.intern(key) == code) {
                return key;
            }
        }
        return null;
    }

    /***
     * Extract the id column value from the CSV record
     * @param line the line of content read from CSV file
     * @return the id value of the CSV record
     */
    public String getIdString(CharSequence line) {
        List<String> values = parseCsvLine(line);
        return values.get(id);
    }

    /**
     * Returns a list of the names of the predictor variables.
     *
     * @return A list of variable names.
     */

    public Iterable<String> getPredictors() {
        return Lists.transform(predictors, new Function<Integer, String>() {

            public String apply(Integer v) {
                if (v >= 0) {
                    return variableNames.get(v);
                } else {
                    return INTERCEPT_TERM;
                }
            }
        });
    }

    public Map<String, Set<Integer>> getTraceDictionary() {
        return traceDictionary;
    }

    public CsvRecordFactoryPredict includeBiasTerm(boolean useBias) {
        includeBiasTerm = useBias;
        return this;
    }

    public List<String> getTargetCategories() {
        List<String> r = targetDictionary.values();
        if (r.size() > maxTargetValue) {
            r.subList(maxTargetValue, r.size()).clear();
        }
        return r;
    }

    public String getIdName() {
        return idName;
    }

    public void setIdName(String idName) {
        this.idName = idName;
    }

}