com.rapidminer.gui.graphs.TransitionGraphCreator.java Source code

Java tutorial

Introduction

Here is the source code for com.rapidminer.gui.graphs.TransitionGraphCreator.java

Source

/**
 * Copyright (C) 2001-2015 by RapidMiner and the contributors
 *
 * Complete list of developers available at our web site:
 *
 *      http://rapidminer.com
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see http://www.gnu.org/licenses/.
 */
package com.rapidminer.gui.graphs;

import com.rapidminer.ObjectVisualizer;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.gui.tools.ExtendedJComboBox;
import com.rapidminer.gui.tools.SwingTools;
import com.rapidminer.operator.visualization.dependencies.TransitionGraph;
import com.rapidminer.tools.ObjectVisualizerService;
import com.rapidminer.tools.Tools;
import edu.uci.ics.jung.graph.DirectedSparseGraph;
import edu.uci.ics.jung.graph.Graph;
import edu.uci.ics.jung.graph.util.EdgeType;
import edu.uci.ics.jung.visualization.VisualizationViewer;

import java.awt.Dimension;
import java.awt.Paint;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

import javax.swing.JComboBox;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JSlider;
import javax.swing.JSpinner;
import javax.swing.SpinnerNumberModel;
import javax.swing.SwingConstants;
import javax.swing.event.ChangeEvent;
import javax.swing.event.ChangeListener;

import org.apache.commons.collections15.Factory;
import org.apache.commons.collections15.Transformer;

/**
 * The graph model creator for transition graphs.
 * 
 * @author Ingo Mierswa
 */
public class TransitionGraphCreator extends GraphCreatorAdaptor {

    private static class SourceId implements Comparable<SourceId> {

        private final String id;

        private final String label;

        public SourceId(String id, String label) {
            this.id = id;
            this.label = label;
        }

        public String getId() {
            return id;
        }

        @Override
        public String toString() {
            return label;
        }

        @Override
        public int compareTo(SourceId o) {
            return this.label.compareTo(o.label);
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (getClass() != obj.getClass()) {
                return false;
            }
            SourceId other = (SourceId) obj;
            if (label == null) {
                if (other.label != null) {
                    return false;
                }
            } else if (!label.equals(other.label)) {
                return false;
            }
            return true;
        }

    }

    private final Factory<String> edgeFactory = new Factory<String>() {

        int i = 0;

        @Override
        public String create() {
            return "E" + i++;
        }
    };

    private final JSlider edgeSlider = new JSlider(SwingConstants.HORIZONTAL, 0, 1000, 100) {

        private static final long serialVersionUID = -6931545310805789589L;

        @Override
        public Dimension getMinimumSize() {
            return new Dimension(40, (int) super.getMinimumSize().getHeight());
        }

        @Override
        public Dimension getPreferredSize() {
            return new Dimension(40, (int) super.getPreferredSize().getHeight());
        }

        @Override
        public Dimension getMaximumSize() {
            return new Dimension(40, (int) super.getMaximumSize().getHeight());
        }
    };

    private final JComboBox sourceFilter;

    private final JSpinner numberOfHops = new JSpinner(new SpinnerNumberModel(1, 1, Integer.MAX_VALUE, 1));

    private Graph<String, String> graph;

    private final Attribute sourceAttribute;

    private final Attribute targetAttribute;

    private Attribute strengthAttribute;

    private Attribute typeAttribute;

    private final String nodeDescription;

    private final ExampleSet exampleSet;

    private final Map<String, String> edgeLabelMap = new HashMap<String, String>();

    private final Map<String, Double> edgeStrengthMap = new HashMap<String, Double>();

    private final Map<String, String> vertexLabelMap = new HashMap<String, String>();

    private final DefaultObjectViewer objectViewer;

    public TransitionGraphCreator(TransitionGraph transitionGraph, ExampleSet exampleSet) {
        this.sourceAttribute = exampleSet.getAttributes().get(transitionGraph.getSourceAttribute());
        this.targetAttribute = exampleSet.getAttributes().get(transitionGraph.getTargetAttribute());
        if (transitionGraph.getStrengthAttribute() != null) {
            this.strengthAttribute = exampleSet.getAttributes().get(transitionGraph.getStrengthAttribute());
        }
        if (transitionGraph.getTypeAttribute() != null) {
            this.typeAttribute = exampleSet.getAttributes().get(transitionGraph.getTypeAttribute());
        }
        this.exampleSet = exampleSet;
        this.nodeDescription = transitionGraph.getNodeDescription();

        SortedSet<SourceId> sourceNames = new TreeSet<SourceId>();
        // Attribute idAttribute = exampleSet.getAttributes().getId();
        for (Example example : exampleSet) {
            Object id = example.getValue(sourceAttribute);
            if (sourceAttribute.isNominal()) {
                id = example.getValueAsString(sourceAttribute);
            }
            String description = getNodeDescription(id);
            if (description == null) {
                sourceNames.add(new SourceId(id.toString(), id.toString()));
            } else {
                sourceNames.add(new SourceId(id.toString(), description));
            }

        }

        sourceFilter = new ExtendedJComboBox(200);
        sourceFilter.addItem(new SourceId("None", "None"));
        for (SourceId sourceId : sourceNames) {
            sourceFilter.addItem(sourceId);
        }

        objectViewer = new DefaultObjectViewer(exampleSet);
    }

    @Override
    public Graph<String, String> createGraph() {
        graph = new DirectedSparseGraph<String, String>();
        updateGraph();
        return graph;
    }

    @Override
    public String getEdgeName(String id) {
        return edgeLabelMap.get(id);
    }

    @Override
    public String getVertexName(String id) {
        String storedName = vertexLabelMap.get(id);
        if (storedName == null) {
            return id;
        } else {
            return storedName;
        }
    }

    @Override
    public String getVertexToolTip(String id) {
        return id;
    }

    /**
     * Returns the label offset. In most case, using -1 is just fine (default offset). Some tree
     * like graphs might prefer to use 0 since they manage the offset themself.
     */
    @Override
    public int getLabelOffset() {
        return -1;
    }

    @Override
    public int getNumberOfOptionComponents() {
        return 6;
    }

    @Override
    public JComponent getOptionComponent(final GraphViewer viewer, int index) {
        if (index == 0) {
            return new JLabel("Source Filter:");
        } else if (index == 1) {
            sourceFilter.addActionListener(new ActionListener() {

                @Override
                public void actionPerformed(ActionEvent e) {
                    updateGraph();
                    viewer.updateLayout();
                }
            });
            return sourceFilter;
        } else if (index == 2) {
            return new JLabel("Number of Hops:");
        } else if (index == 3) {
            this.numberOfHops.addChangeListener(new ChangeListener() {

                @Override
                public void stateChanged(ChangeEvent e) {
                    updateGraph();
                    viewer.updateLayout();
                }
            });
            return numberOfHops;

        } else if (index == 4) {
            return new JLabel("Number of Edges:");
        } else if (index == 5) {
            this.edgeSlider.addChangeListener(new ChangeListener() {

                @Override
                public void stateChanged(ChangeEvent e) {
                    if (!edgeSlider.getValueIsAdjusting()) {
                        updateGraph();
                        viewer.updateLayout();
                    }
                }
            });
            return edgeSlider;
        } else {
            return null;
        }
    }

    private void updateGraph() {
        // remove old edges if available
        Iterator<String> e = edgeLabelMap.keySet().iterator();
        while (e.hasNext()) {
            graph.removeEdge(e.next());
        }
        edgeLabelMap.clear();
        edgeStrengthMap.clear();

        // remove old vertices if available
        Iterator<String> v = vertexLabelMap.keySet().iterator();
        while (v.hasNext()) {
            graph.removeVertex(v.next());
        }
        vertexLabelMap.clear();

        String sourceFilterName = null;
        if (sourceFilter.getSelectedIndex() > 0) {
            sourceFilterName = ((SourceId) sourceFilter.getSelectedItem()).getId();
        }

        List<SortableEdge> sortableEdges = new LinkedList<SortableEdge>();
        if (sourceFilterName == null) {
            for (Example example : exampleSet) {
                String source = example.getValueAsString(sourceAttribute);
                String target = example.getValueAsString(targetAttribute);

                double strength = 1.0d;
                if (strengthAttribute != null) {
                    strength = example.getValue(strengthAttribute);
                }

                String type = null;
                if (typeAttribute != null) {
                    type = example.getValueAsString(typeAttribute);
                }

                String edgeName = null;
                if (type != null) {
                    edgeName = type;
                } else {
                    edgeName = strength + "";
                }

                sortableEdges
                        .add(new SortableEdge(source, target, edgeName, strength, SortableEdge.DIRECTION_INCREASE));
            }
        } else {
            List<String> sources = new LinkedList<String>();
            sources.add(sourceFilterName);
            int hop = 1;
            int maxHops = (Integer) numberOfHops.getValue();

            do {
                List<String> newSources = new LinkedList<String>();
                for (String currentSourceFilterName : sources) {
                    for (Example example : exampleSet) {
                        String source = example.getValueAsString(sourceAttribute);
                        if (currentSourceFilterName != null) {
                            if (!currentSourceFilterName.equals(source)) {
                                continue;
                            }
                        }

                        String target = example.getValueAsString(targetAttribute);

                        double strength = 1.0d;
                        if (strengthAttribute != null) {
                            strength = example.getValue(strengthAttribute);
                        }

                        String type = null;
                        if (typeAttribute != null) {
                            type = example.getValueAsString(typeAttribute);
                        }

                        String edgeName = null;
                        if (type != null) {
                            edgeName = type;
                        } else {
                            edgeName = strength + "";
                        }

                        sortableEdges.add(new SortableEdge(source, target, edgeName, strength,
                                SortableEdge.DIRECTION_INCREASE));

                        newSources.add(target);
                    }
                }
                sources.clear();
                hop++;
                if (hop > maxHops) {
                    sources = null;
                } else {
                    sources = newSources;
                }
            } while (sources != null);
        }

        Collections.sort(sortableEdges);

        // determine used vertices
        Set<String> allVertices = new HashSet<String>();
        int numberOfEdges = edgeSlider.getValue();
        int counter = 0;
        for (SortableEdge sortableEdge : sortableEdges) {
            if (counter > numberOfEdges) {
                break;
            }

            allVertices.add(sortableEdge.getFirstVertex());
            allVertices.add(sortableEdge.getSecondVertex());

            counter++;
        }

        // add all used vertices to graph
        for (String vertex : allVertices) {
            graph.addVertex(vertex);

            String description = getNodeDescription(vertex);
            if (description == null) {
                vertexLabelMap.put(vertex, vertex);
            } else {
                vertexLabelMap.put(vertex, description);
            }
        }

        counter = 0;
        double minStrength = Double.POSITIVE_INFINITY;
        double maxStrength = Double.NEGATIVE_INFINITY;
        Map<String, Double> strengthMap = new HashMap<String, Double>();
        for (SortableEdge sortableEdge : sortableEdges) {
            if (counter > numberOfEdges) {
                break;
            }

            String idString = edgeFactory.create();
            graph.addEdge(idString, sortableEdge.getFirstVertex(), sortableEdge.getSecondVertex(),
                    EdgeType.DIRECTED);
            edgeLabelMap.put(idString, Tools.formatIntegerIfPossible(sortableEdge.getEdgeValue()));

            double strength = sortableEdge.getEdgeValue();

            minStrength = Math.min(minStrength, strength);
            maxStrength = Math.max(maxStrength, strength);

            strengthMap.put(idString, strength);

            counter++;
        }

        for (Entry<String, Double> entry : strengthMap.entrySet()) {
            edgeStrengthMap.put(entry.getKey(),
                    (strengthMap.get(entry.getKey()) - minStrength) / (maxStrength - minStrength));
        }
    }

    private String getNodeDescription(Object vertexId) {
        if (nodeDescription != null) {
            ObjectVisualizer visualizer = ObjectVisualizerService.getVisualizerForObject(exampleSet);

            if (visualizer != null) {
                if (visualizer.isCapableToVisualize(vertexId)) {
                    StringBuffer resultString = new StringBuffer();
                    int currentIndex = 0;
                    int startIndex = nodeDescription.indexOf("%{", currentIndex);
                    while (startIndex >= currentIndex) {
                        int endIndex = nodeDescription.indexOf("}", startIndex);
                        if (endIndex >= startIndex) {
                            String fieldName = nodeDescription.substring(startIndex + 2, endIndex);
                            String fieldValue = visualizer.getDetailData(vertexId, fieldName);
                            resultString.append(nodeDescription.substring(currentIndex, startIndex));
                            if (fieldValue != null) {
                                resultString.append(fieldValue);
                            } else {
                                resultString.append("?");
                            }
                            currentIndex = endIndex + 1;
                        } else {
                            resultString.append(nodeDescription.substring(startIndex));
                            currentIndex = nodeDescription.length();
                        }

                        startIndex = nodeDescription.indexOf("%{", currentIndex);
                    }

                    if (currentIndex < nodeDescription.length()) {
                        resultString.append(nodeDescription.substring(currentIndex));
                    }

                    return resultString.toString();
                }
            }
        }
        return null;
    }

    @Override
    public Transformer<String, Paint> getVertexPaintTransformer(VisualizationViewer<String, String> viewer) {
        return new Transformer<String, Paint>() {

            @Override
            public Paint transform(String name) {
                if ((sourceFilter.getSelectedIndex() > 0)
                        && (((SourceId) sourceFilter.getSelectedItem()).getId().equals(name))) {
                    return SwingTools.LIGHT_YELLOW;
                } else {
                    return SwingTools.LIGHT_BLUE;
                }
            }
        };
    }

    /** Returns false. */
    @Override
    public boolean showEdgeLabelsDefault() {
        return false;
    }

    /** Returns false. */
    @Override
    public boolean showVertexLabelsDefault() {
        return true;
    }

    @Override
    public double getEdgeStrength(String id) {
        Double value = edgeStrengthMap.get(id);
        if (value == null) {
            return 1.0d;
        } else {
            if (Double.isNaN(value)) {
                return 1.0d;
            } else {
                return value;
            }
        }
    }

    /** Returns the shape of the edges. */
    @Override
    public int getEdgeShape() {
        return EDGE_SHAPE_QUAD_CURVE;
    }

    @Override
    public Object getObject(String id) {
        return id;
    }

    @Override
    public GraphObjectViewer getObjectViewer() {
        return objectViewer;
    }
}