com.wsc.myexample.decisionForest.MyDataset.java Source code

Java tutorial

Introduction

Here is the source code for com.wsc.myexample.decisionForest.MyDataset.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.wsc.myexample.decisionForest;

import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;

import org.apache.commons.lang.ArrayUtils;
import org.apache.mahout.classifier.df.DFUtils;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;

import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;

/**
 * Contains informations about the attributes.
 */
public class MyDataset extends Dataset {

    private Dataset.Attribute[] attributes;

    /**
     * list of ignored attributes
     */
    private int[] ignored;

    /**
     * distinct values (CATEGORIAL attributes only)
     */
    private String[][] values;

    /**
     * index of the label attribute in the loaded data (without ignored attributed)
     */
    private int labelId;

    /**
     * number of instances in the dataset
     */
    private int nbInstances;

    public MyDataset() {

    }

    /**
     * Should only be called by a DataLoader
     *
     * @param attrs  attributes description
     * @param values distinct values for all CATEGORICAL attributes
     */
    public MyDataset(Dataset.Attribute[] attrs, List<String>[] values, int nbInstances, boolean regression) {
        validateValues(attrs, values);

        int nbattrs = countAttributes(attrs);

        // the label values are set apart
        attributes = new Dataset.Attribute[nbattrs];
        this.values = new String[nbattrs][];
        ignored = new int[attrs.length - nbattrs]; // nbignored = total - nbattrs

        labelId = -1;
        int ignoredId = 0;
        int ind = 0;
        for (int attr = 0; attr < attrs.length; attr++) {
            if (attrs[attr].isIgnored()) {
                ignored[ignoredId++] = attr;
                continue;
            }

            if (attrs[attr].isLabel()) {
                if (labelId != -1) {
                    throw new IllegalStateException("Label found more than once");
                }
                labelId = ind;
                if (regression) {
                    attrs[attr] = Dataset.Attribute.NUMERICAL;
                } else {
                    attrs[attr] = Dataset.Attribute.CATEGORICAL;
                }
            }

            if (attrs[attr].isCategorical() || (!regression && attrs[attr].isLabel())) {
                this.values[ind] = new String[values[attr].size()];
                values[attr].toArray(this.values[ind]);
            }

            attributes[ind++] = attrs[attr];
        }

        if (labelId == -1) {
            throw new IllegalStateException("Label not found");
        }

        this.nbInstances = nbInstances;
    }

    public int nbValues(int attr) {
        return values[attr].length;
    }

    public String[] labels() {
        return Arrays.copyOf(values[labelId], nblabels());
    }

    public int nblabels() {
        return values[labelId].length;
    }

    public int getLabelId() {
        return labelId;
    }

    public double getLabel(Instance instance) {
        return instance.get(getLabelId());
    }

    public int nbInstances() {
        return nbInstances;
    }

    /**
     * Returns the code used to represent the label value in the data
     *
     * @param label label's value to code
     * @return label's code
     */
    public int labelCode(String label) {
        return ArrayUtils.indexOf(values[labelId], label);
    }

    /**
     * Returns the label value in the data
     * This method can be used when the criterion variable is the categorical attribute.
     *
     * @param code label's code
     * @return label's value
     */
    public String getLabelString(double code) {
        // handle the case (prediction == -1)
        if (code == -1) {
            return "unknown";
        }
        return values[labelId][(int) code];
    }

    /**
     * Converts a token to its corresponding int code for a given attribute
     *
     * @param attr attribute's index
     */
    public int valueOf(int attr, String token) {
        Preconditions.checkArgument(!isNumerical(attr), "Only for CATEGORICAL attributes");
        Preconditions.checkArgument(values != null, "Values not found");
        return ArrayUtils.indexOf(values[attr], token);
    }

    public int[] getIgnored() {
        return ignored;
    }

    /**
     * @return number of attributes that are not IGNORED
     */
    private static int countAttributes(Dataset.Attribute[] attrs) {
        int nbattrs = 0;

        for (Dataset.Attribute attr : attrs) {
            if (!attr.isIgnored()) {
                nbattrs++;
            }
        }

        return nbattrs;
    }

    private static void validateValues(Dataset.Attribute[] attrs, List<String>[] values) {
        Preconditions.checkArgument(attrs.length == values.length, "attrs.length != values.length");
        for (int attr = 0; attr < attrs.length; attr++) {
            Preconditions.checkArgument(!attrs[attr].isCategorical() || values[attr] != null,
                    "values not found for attribute " + attr);
        }
    }

    /**
     * @return number of attributes
     */
    public int nbAttributes() {
        return attributes.length;
    }

    /**
     * Is this a numerical attribute ?
     *
     * @param attr index of the attribute to check
     * @return true if the attribute is numerical
     */
    public boolean isNumerical(int attr) {
        return attributes[attr].isNumerical();
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof MyDataset)) {
            return false;
        }

        MyDataset dataset = (MyDataset) obj;

        if (!Arrays.equals(attributes, dataset.attributes)) {
            return false;
        }

        for (int attr = 0; attr < nbAttributes(); attr++) {
            if (!Arrays.equals(values[attr], dataset.values[attr])) {
                return false;
            }
        }

        return labelId == dataset.labelId && nbInstances == dataset.nbInstances;
    }

    @Override
    public int hashCode() {
        int hashCode = labelId + 31 * nbInstances;
        for (Dataset.Attribute attr : attributes) {
            hashCode = 31 * hashCode + attr.hashCode();
        }
        for (String[] valueRow : values) {
            if (valueRow == null) {
                continue;
            }
            for (String value : valueRow) {
                hashCode = 31 * hashCode + value.hashCode();
            }
        }
        return hashCode;
    }

    /**
     * Loads the dataset from a file
     *
     * @throws java.io.IOException
     */
    //  public static MyDataset load(Configuration conf, Path path) throws IOException {
    //    FileSystem fs = path.getFileSystem(conf);
    //    FSDataInputStream input = fs.open(path);
    //    try {
    //      return read(input);
    //    } finally {
    //      Closeables.closeQuietly(input);
    //    }
    //  }

    public static MyDataset load(String path) throws IOException {
        //       FileSystem fs = path.getFileSystem(conf);
        //       FSDataInputStream input = fs.open(path);

        //       InputStream Fin=new FileInputStream(path);
        DataInputStream input = new DataInputStream(new FileInputStream(path));

        try {
            return read(input);
        } finally {
            Closeables.closeQuietly(input);
        }
    }

    public static MyDataset read(DataInputStream in) throws IOException {
        MyDataset dataset = new MyDataset();

        dataset.readFields(in);
        return dataset;
    }

    public void readFields(DataInput in) throws IOException {
        int nbAttributes = in.readInt();
        attributes = new Dataset.Attribute[nbAttributes];
        for (int attr = 0; attr < nbAttributes; attr++) {
            String name = readString(in);
            attributes[attr] = Dataset.Attribute.valueOf(name);
        }

        ignored = DFUtils.readIntArray(in);

        // only CATEGORICAL attributes have values
        values = new String[nbAttributes][];
        for (int attr = 0; attr < nbAttributes; attr++) {
            if (attributes[attr].isCategorical()) {
                values[attr] = readStringArray(in);
            }
        }

        labelId = in.readInt();
        nbInstances = in.readInt();
    }

    public void write(DataOutput out) throws IOException {
        out.writeInt(attributes.length); // nb attributes
        for (Dataset.Attribute attr : attributes) {
            writeString(out, attr.name());
        }

        MyDFUtils.writeArray(out, ignored);

        // only CATEGORICAL attributes have values
        for (String[] vals : values) {
            if (vals != null) {
                writeStringArray(out, vals);
            }
        }

        out.writeInt(labelId);
        out.writeInt(nbInstances);
    }

    public static void writeString(DataOutput out, String s) throws IOException {
        if (s != null) {
            byte[] buffer = s.getBytes("UTF-8");
            int len = buffer.length;
            out.writeInt(len);
            out.write(buffer, 0, len);
        } else {
            out.writeInt(-1);
        }
    }

    /*
     * Read a String as a Network Int n, followed by n Bytes
     * Alternative to 16 bit read/writeUTF.
     * Encoding standard is... ?
     *
     */
    public static String readString(DataInput in) throws IOException {
        int length = in.readInt();
        if (length == -1)
            return null;
        byte[] buffer = new byte[length];
        in.readFully(buffer); // could/should use readFully(buffer,0,length)?
        return new String(buffer, "UTF-8");
    }

    /*
     * Write a String array as a Nework Int N, followed by Int N Byte Array Strings.
     * Could be generalised using introspection.
     *
     */
    public static void writeStringArray(DataOutput out, String[] s) throws IOException {
        out.writeInt(s.length);
        for (int i = 0; i < s.length; i++) {
            writeString(out, s[i]);
        }
    }

    /*
     * Write a String array as a Nework Int N, followed by Int N Byte Array Strings.
     * Could be generalised using introspection. Actually this bit couldn't...
     *
     */
    public static String[] readStringArray(DataInput in) throws IOException {
        int len = in.readInt();
        if (len == -1)
            return null;
        String[] s = new String[len];
        for (int i = 0; i < len; i++) {
            s[i] = readString(in);
        }
        return s;
    }
}