org.apache.mahout.classifier.df.node.CategoricalNode.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.classifier.df.node.CategoricalNode.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.df.node;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.mahout.classifier.df.DFUtils;
import org.apache.mahout.classifier.df.data.Instance;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;

public class CategoricalNode extends Node {

    private int attr;
    private double[] values;
    private Node[] childs;

    public CategoricalNode() {
    }

    public CategoricalNode(int attr, double[] values, Node[] childs) {
        this.attr = attr;
        this.values = values;
        this.childs = childs;
    }

    @Override
    public double classify(Instance instance) {
        int index = ArrayUtils.indexOf(values, instance.get(attr));
        if (index == -1) {
            // value not available, we cannot predict
            return Double.NaN;
        }
        return childs[index].classify(instance);
    }

    @Override
    public long maxDepth() {
        long max = 0;

        for (Node child : childs) {
            long depth = child.maxDepth();
            if (depth > max) {
                max = depth;
            }
        }

        return 1 + max;
    }

    @Override
    public long nbNodes() {
        long nbNodes = 1;

        for (Node child : childs) {
            nbNodes += child.nbNodes();
        }

        return nbNodes;
    }

    @Override
    protected Type getType() {
        return Type.CATEGORICAL;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof CategoricalNode)) {
            return false;
        }

        CategoricalNode node = (CategoricalNode) obj;

        return attr == node.attr && Arrays.equals(values, node.values) && Arrays.equals(childs, node.childs);
    }

    @Override
    public int hashCode() {
        int hashCode = attr;
        for (double value : values) {
            hashCode = 31 * hashCode + (int) Double.doubleToLongBits(value);
        }
        for (Node node : childs) {
            hashCode = 31 * hashCode + node.hashCode();
        }
        return hashCode;
    }

    @Override
    protected String getString() {
        StringBuilder buffer = new StringBuilder();

        for (Node child : childs) {
            buffer.append(child).append(',');
        }

        return buffer.toString();
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        attr = in.readInt();
        values = DFUtils.readDoubleArray(in);
        childs = DFUtils.readNodeArray(in);
    }

    @Override
    protected void writeNode(DataOutput out) throws IOException {
        out.writeInt(attr);
        DFUtils.writeArray(out, values);
        DFUtils.writeArray(out, childs);
    }
}