com.rapidminer.operator.learner.tree.SelectionCreator.java Source code

Java tutorial

Introduction

Here is the source code for com.rapidminer.operator.learner.tree.SelectionCreator.java

Source

/**
 * Copyright (C) 2001-2015 by RapidMiner and the contributors
 *
 * Complete list of developers available at our web site:
 *
 *      http://rapidminer.com
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see http://www.gnu.org/licenses/.
 */
package com.rapidminer.operator.learner.tree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;

import org.apache.commons.lang.ArrayUtils;

import com.rapidminer.core.internal.Resources;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.tools.Tools;

/**
 * Handles selections of attributes and examples of a {@link ColumnExampleTable}. Creates start
 * selections and updates them.
 *
 * @author Gisa Schaefer
 *
 */
public class SelectionCreator {

    private ColumnExampleTable columnTable;

    public SelectionCreator(ColumnExampleTable columnTable) {
        this.columnTable = columnTable;
    }

    /**
     * Creates an example index start selection for each numerical attribute, or if there is none,
     * only one.
     *
     * @return a map containing for each numerical attribute an example index array such that the
     *         associated attribute values are in ascending order.
     */
    public Map<Integer, int[]> getStartSelection() {
        Map<Integer, int[]> selection = new HashMap<>();
        if (columnTable.getNumberOfRegularNumericalAttributes() == 0) {
            selection.put(0, createFullArray(columnTable.getNumberOfExamples()));
        } else {
            Integer[] bigSelectionArray = createFullBigArray(columnTable.getNumberOfExamples());
            for (int j = columnTable.getNumberOfRegularNominalAttributes(); j < columnTable
                    .getTotalNumberOfRegularAttributes(); j++) {
                final double[] attributeColumn = columnTable.getNumericalAttributeColumn(j);
                Integer[] startSelection = Arrays.copyOf(bigSelectionArray, bigSelectionArray.length);
                Arrays.sort(startSelection, new Comparator<Integer>() {

                    @Override
                    public int compare(Integer a, Integer b) {
                        return Double.compare(attributeColumn[a], attributeColumn[b]);
                    }
                });
                selection.put(j, ArrayUtils.toPrimitive(startSelection));
            }
        }
        return selection;
    }

    /**
     * Creates in parallel an example index start selection for each numerical attribute, or if
     * there is none, only one.
     *
     * @param operator
     *            the operator for which the calculation is done
     * @return a map containing for each numerical attribute an example index array such that the
     *         associated attribute values are in ascending order.
     * @throws OperatorException
     */
    public Map<Integer, int[]> getStartSelectionParallel(Operator operator) throws OperatorException {
        Map<Integer, int[]> selection = new HashMap<>();
        if (columnTable.getNumberOfRegularNumericalAttributes() == 0) {
            selection.put(0, createFullArray(columnTable.getNumberOfExamples()));
        } else {
            List<Callable<int[]>> tasks = new ArrayList<Callable<int[]>>();
            final Integer[] bigSelectionArray = createFullBigArray(columnTable.getNumberOfExamples());
            for (int j = columnTable.getNumberOfRegularNominalAttributes(); j < columnTable
                    .getTotalNumberOfRegularAttributes(); j++) {
                final double[] attributeColumn = columnTable.getNumericalAttributeColumn(j);
                tasks.add(new Callable<int[]>() {

                    @Override
                    public int[] call() {
                        Integer[] startSelection = Arrays.copyOf(bigSelectionArray, bigSelectionArray.length);
                        Arrays.sort(startSelection, new Comparator<Integer>() {

                            @Override
                            public int compare(Integer a, Integer b) {
                                return Double.compare(attributeColumn[a], attributeColumn[b]);
                            }
                        });
                        return ArrayUtils.toPrimitive(startSelection);
                    }

                });
            }

            List<int[]> results = null;
            try {
                results = Resources.getConcurrencyContext(operator).call(tasks);
            } catch (ExecutionException e) {
                Throwable cause = e.getCause();
                if (cause instanceof RuntimeException) {
                    throw (RuntimeException) cause;
                } else if (cause instanceof Error) {
                    throw (Error) cause;
                } else {
                    throw new OperatorException(cause.getMessage(), cause);
                }
            }

            for (int j = columnTable.getNumberOfRegularNominalAttributes(); j < columnTable
                    .getTotalNumberOfRegularAttributes(); j++) {
                selection.put(j, results.get(j - columnTable.getNumberOfRegularNominalAttributes()));
            }
        }
        return selection;
    }

    /**
     * Splits the selected examples according to the bestAttribute and, if the attribute is
     * numerical, the bestSplitValue.
     *
     * @param allSelectedExamples
     * @param bestAttribute
     * @param bestSplitValue
     * @return a collection of maps mapping the numerical attribute number to the sorted array
     *         containing the selected example numbers
     */
    public Collection<Map<Integer, int[]>> getSplits(Map<Integer, int[]> allSelectedExamples, int bestAttribute,
            double bestSplitValue) {
        Collection<Map<Integer, int[]>> splits;
        if (columnTable.representsNominalAttribute(bestAttribute)) {
            splits = calculateSplits(allSelectedExamples, bestAttribute);
        } else {
            splits = calculateSplits(allSelectedExamples, bestAttribute, bestSplitValue);
        }
        return splits;
    }

    /**
     * Splits for every numerical attribute the sorted index array according to the bestSplitValue
     * at the bestAttribute. Groups by smaller or equal to bestSplitValue, greater than
     * bestSplitValue and value is NaN.
     *
     * @param allSelectedExamples
     * @param bestAttribute
     * @param bestSplitValue
     * @return a list containing first the example number where the value is smaller than
     *         bestSplitValue, then the ones greater, then the NaNs
     */
    public Collection<Map<Integer, int[]>> calculateSplits(Map<Integer, int[]> allSelectedExamples,
            int bestAttribute, double bestSplitValue) {
        double[] attributeColumn = columnTable.getNumericalAttributeColumn(bestAttribute);
        List<Map<Integer, int[]>> results = new ArrayList<>(3);
        results.add(0, new HashMap<Integer, int[]>());
        results.add(1, new HashMap<Integer, int[]>());

        boolean existNaNs = false;
        // check if the selectedExamples contain NaN values of the attribute Column - because of
        // sorting they should be at the end
        if (Double.isNaN(attributeColumn[allSelectedExamples
                .get(bestAttribute)[allSelectedExamples.get(bestAttribute).length - 1]])) {
            existNaNs = true;
            results.add(2, new HashMap<Integer, int[]>());
        }
        int maximalLength = getArbitraryValue(allSelectedExamples).length;
        int[] smaller = new int[maximalLength];
        int[] bigger = new int[maximalLength];
        int[] naNs = new int[maximalLength];

        double value;
        for (int i : allSelectedExamples.keySet()) {
            int smallerPosition = 0;
            int biggerPosition = 0;
            int naNsPosition = 0;

            int[] selectedExamples = allSelectedExamples.get(i);
            for (int j : selectedExamples) {
                value = attributeColumn[j];
                if (Double.isNaN(value)) {
                    naNs[naNsPosition] = j;
                    naNsPosition++;
                } else if (Tools.isLessEqual(value, bestSplitValue)) {
                    smaller[smallerPosition] = j;
                    smallerPosition++;
                } else {
                    bigger[biggerPosition] = j;
                    biggerPosition++;
                }
            }
            results.get(0).put(i, Arrays.copyOf(smaller, smallerPosition));
            results.get(1).put(i, Arrays.copyOf(bigger, biggerPosition));
            if (existNaNs) {
                results.get(2).put(i, Arrays.copyOf(naNs, naNsPosition));
            }
        }

        return results;
    }

    /**
     * Splits for every numerical attribute the sorted index array according to the value at the
     * best attribute. Groups the splitted arrays by the value at the best attribute.
     *
     * @param allSelectedExamples
     * @param bestAttribute
     * @return
     */
    public Collection<Map<Integer, int[]>> calculateSplits(Map<Integer, int[]> allSelectedExamples,
            int bestAttribute) {
        byte[] attributeColumn = columnTable.getNominalAttributeColumn(bestAttribute);
        Map<Byte, Map<Integer, int[]>> results = new HashMap<>();
        Map<Byte, List<Integer>> valueLists;

        byte value;
        for (int i : allSelectedExamples.keySet()) {
            valueLists = new HashMap<>();
            int[] selectedExamples = allSelectedExamples.get(i);

            for (int j : selectedExamples) {
                // put j in the list associated to its value
                value = attributeColumn[j];
                if (valueLists.containsKey(value)) {
                    valueLists.get(value).add(j);
                } else {
                    List<Integer> temp = new ArrayList<>();
                    temp.add(j);
                    valueLists.put(value, temp);
                }
            }

            // store the pair (key, list) as (key, (i,array(list))
            for (Byte key : valueLists.keySet()) {
                List<Integer> list = valueLists.get(key);
                int[] temp = ArrayUtils.toPrimitive(list.toArray(new Integer[list.size()]));
                if (results.containsKey(key)) {
                    results.get(key).put(i, temp);
                } else {
                    Map<Integer, int[]> toadd = new HashMap<>();
                    toadd.put(i, temp);
                    results.put(key, toadd);
                }

            }
        }

        return results.values();
    }

    /**
     * If the bestAttribute is nominal, its number is removed from the selectedAttributes, otherwise
     * it stays the same.
     *
     * @param selectedAttributes
     * @param bestAttribute
     * @return
     */
    public int[] updateRemainingAttributes(int[] selectedAttributes, int bestAttribute) {
        int[] remainingAttributes;
        if (columnTable.representsNominalAttribute(bestAttribute)) {
            remainingAttributes = removeAttribute(bestAttribute, selectedAttributes);
        } else {
            remainingAttributes = selectedAttributes;
        }
        return remainingAttributes;
    }

    /**
     * Creates a new array containing all entries of selectedAttributes except for
     * attributeNumberToDelete.
     *
     * @param attributeNumberToDelete
     * @param selectedAttributes
     * @return
     */
    public int[] removeAttribute(int attributeNumberToDelete, int[] selectedAttributes) {
        int[] newSelection = new int[selectedAttributes.length - 1];
        int j = 0;
        for (int i : selectedAttributes) {
            if (i != attributeNumberToDelete) {
                newSelection[j] = i;
                j++;
            }
        }
        return newSelection;
    }

    /**
     * Create a selection array containing all rows, i.e. containing all consecutive numbers
     * [0..length-1]
     *
     * @param length
     * @return
     */
    public int[] createFullArray(int length) {
        int[] fullSelection = new int[length];
        for (int i = 0; i < length; i++) {
            fullSelection[i] = i;
        }
        return fullSelection;
    }

    /**
     * Create an Integer array containing all consecutive numbers [0..length-1]
     *
     * @param length
     * @return
     */
    public Integer[] createFullBigArray(int length) {
        Integer[] fullSelection = new Integer[length];
        for (int i = 0; i < length; i++) {
            fullSelection[i] = i;
        }
        return fullSelection;
    }

    /**
     * Returns a value of the map.
     *
     * @param map
     *            a non-empty map
     * @return
     */
    public static int[] getArbitraryValue(Map<Integer, int[]> map) {
        return map.values().iterator().next();
    }

}