org.linqs.psl.utils.dataloading.graph.Graph.java Source code

Java tutorial

Introduction

Here is the source code for org.linqs.psl.utils.dataloading.graph.Graph.java

Source

/*
 * This file is part of the PSL software.
 * Copyright 2011-2015 University of Maryland
 * Copyright 2013-2018 The Regents of the University of California
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.linqs.psl.utils.dataloading.graph;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;

import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;

public class Graph<ET extends EntityType, RT extends RelationType> {

    private final Map<ET, Map<Integer, Entity<ET, RT>>> entities;

    public Graph() {
        entities = new HashMap<ET, Map<Integer, Entity<ET, RT>>>();
    }

    public int getNoEntities(ET type) {
        if (!entities.containsKey(type))
            return 0;
        else
            return entities.get(type).size();
    }

    public Iterator<Entity<ET, RT>> getEntities(ET type) {
        if (!entities.containsKey(type)) {
            return new ArrayList<Entity<ET, RT>>().iterator();
        }

        final Iterator<Entity<ET, RT>> iter = entities.get(type).values().iterator();
        return new Iterator<Entity<ET, RT>>() {

            Entity<ET, RT> current = null;

            @Override
            public boolean hasNext() {
                return iter.hasNext();
            }

            @Override
            public Entity<ET, RT> next() {
                current = iter.next();
                return current;
            }

            @Override
            public void remove() {
                if (current.getDegree() > 0)
                    throw new IllegalArgumentException("Cannot delete connected entity!");
                else
                    iter.remove();
            }

        };
    }

    public Entity<ET, RT> getEntity(int id, ET type) {
        Map<Integer, Entity<ET, RT>> map = entities.get(type);
        if (map == null)
            return null;
        Entity<ET, RT> e = map.get(id);
        if (e != null && !e.hasType(type))
            throw new AssertionError("Entity does not have the expected type!");
        return e;
    }

    public boolean deleteEntity(Entity<ET, RT> e) {
        if (e.getDegree() > 0)
            throw new IllegalArgumentException("Cannot delete connected entity!");
        Map<Integer, Entity<ET, RT>> map = entities.get(e.getType());
        if (map == null)
            return false;
        if (!map.containsKey(e.getId()))
            return false;
        map.remove(e.getId());
        return true;
    }

    public Entity<ET, RT> getorCreateEntity(int id, ET type) {
        Entity<ET, RT> e = getEntity(id, type);
        if (e == null) {
            e = createEntity(id, type);
        }
        return e;
    }

    public Entity<ET, RT> createEntity(int id, ET type) {
        Map<Integer, Entity<ET, RT>> map = entities.get(type);
        if (map == null) {
            map = new HashMap<Integer, Entity<ET, RT>>();
            entities.put(type, map);
        }
        if (map.containsKey(id))
            throw new AssertionError("Entity already exists!");
        Entity<ET, RT> e = new Entity<ET, RT>(id, type);
        map.put(id, e);
        return e;
    }

    public void loadEntityAttributes(String file, final ET type, final String[] attNames,
            final boolean createEntity) {
        loadEntityAttributes(file, type, attNames, DelimitedObjectConstructor.NoFilter, createEntity);
    }

    public void loadEntityAttributes(String file, final ET type, final String[] attNames,
            final DelimitedObjectConstructor.Filter filter, final boolean createEntity) {
        DelimitedObjectConstructor<Object> loader = new DelimitedObjectConstructor<Object>() {
            @Override
            public Object create(String[] data) {
                if (!filter.include(data))
                    return null;

                int id = Integer.parseInt(data[0]);
                Entity<ET, RT> e = getEntity(id, type);
                if (e == null) {
                    if (createEntity)
                        e = createEntity(id, type);
                    else
                        return null;
                }
                //Load attributes
                for (int a = 0; a < attNames.length; a++) {
                    if (attNames[a] != null) {
                        e.setAttribute(attNames[a], data[a + 1]);
                    }
                }
                return null;
            }

            @Override
            public int length() {
                return attNames.length + 1;
            }

        };
        LoadDelimitedData.loadTabData(file, loader);
    }

    public void loadRelationship(String file, final RT relType, final ET[] types, final boolean[] createEntity) {
        loadRelationship(file, new String[0], relType, types, createEntity);
    }

    public void loadRelationship(String file, final String[] attNames, final RT relType, final ET[] types,
            final boolean[] createEntity) {
        loadRelationship(file, attNames, relType, types, DelimitedObjectConstructor.NoFilter, createEntity);
    }

    public void loadRelationship(String file, final String[] attNames, final RT relType, final ET[] types,
            final DelimitedObjectConstructor.Filter filter, final boolean[] createEntity) {
        DelimitedObjectConstructor<Object> loader = new DelimitedObjectConstructor<Object>() {
            @Override
            public Object create(String[] data) {
                if (!filter.include(data))
                    return null;
                List<Entity<ET, RT>> entities = new ArrayList<Entity<ET, RT>>(types.length);
                for (int i = 0; i < types.length; i++) {
                    ET type = types[i];
                    if (type != null) {
                        int id = Integer.parseInt(data[i]);
                        Entity<ET, RT> e = getEntity(id, type);
                        if (e == null) {
                            if (createEntity[i])
                                e = createEntity(id, type);
                            else
                                return null;
                        }
                        entities.add(e);
                    }
                }

                if (entities.size() != 2)
                    throw new AssertionError("Currently, only binary relations are supported!");
                Relation<ET, RT> rel = new BinaryRelation<ET, RT>(relType, entities.get(0), entities.get(1));

                //Load attributes
                for (int a = 0; a < attNames.length; a++) {
                    if (attNames[a] != null) {
                        rel.setAttribute(attNames[a], data[a + types.length]);
                    }
                }

                //Add relation
                for (Entity<ET, RT> e : entities) {
                    e.addRelation(rel);
                }
                return null;
            }

            @Override
            public int length() {
                return types.length + attNames.length;
            }

        };
        LoadDelimitedData.loadTabData(file, loader);
    }

    public List<Subgraph<ET, RT>> splitGraphRandom(int numsplits, ET splitType) {
        List<Subgraph<ET, RT>> splits = new ArrayList<Subgraph<ET, RT>>(numsplits);
        List<Set<Entity<ET, RT>>> starts = new ArrayList<Set<Entity<ET, RT>>>(numsplits);
        for (int i = 0; i < numsplits; i++) {
            starts.add(new HashSet<Entity<ET, RT>>());
        }

        if (!entities.containsKey(splitType))
            throw new IllegalArgumentException("There are no entities of given type!");
        for (Entity<ET, RT> e : entities.get(splitType).values()) {
            int cont = (int) Math.floor(Math.random() * numsplits);
            starts.get(cont).add(e);
        }

        //Grow splits
        for (int i = 0; i < numsplits; i++) {
            Set<Entity<ET, RT>> excluded = new HashSet<Entity<ET, RT>>();
            for (int j = 0; j < numsplits; j++) {
                if (j != i) {
                    excluded.addAll(starts.get(j));
                }
            }
            splits.add(growSplit(starts.get(i), excluded, new KeepGrowing(), 1.0));
        }

        return splits;
    }

    public List<Subgraph<ET, RT>> splitGraphSnowball(int numsplits, ET splitType, int splitSize,
            double exploreProbability) {
        List<Subgraph<ET, RT>> splits = new ArrayList<Subgraph<ET, RT>>(numsplits);
        Set<Entity<ET, RT>> remaining = new HashSet<Entity<ET, RT>>();
        Set<Entity<ET, RT>> excluded = new HashSet<Entity<ET, RT>>();

        if (!entities.containsKey(splitType))
            throw new IllegalArgumentException("There are no entities of given type!");
        remaining.addAll(entities.get(splitType).values());

        //Grow splits
        for (int i = 0; i < numsplits; i++) {
            Set<Entity<ET, RT>> seed = new HashSet<Entity<ET, RT>>();
            GrowCondition gc = new SizeLimit(splitType, splitSize);
            Subgraph<ET, RT> sample;
            do {
                int pos = (int) Math.floor(Math.random() * remaining.size());
                seed.add(Iterables.get(remaining, pos));
                sample = growSplit(seed, excluded, new SizeLimit(splitType, splitSize), exploreProbability);
            } while (gc.continueGrowing(sample));
            //Update sets
            remaining.removeAll(sample.getEntities(splitType));
            excluded.addAll(sample.getEntities(splitType));
            splits.add(sample);
        }

        return splits;
    }

    private Subgraph<ET, RT> growSplit(Set<Entity<ET, RT>> start, Set<Entity<ET, RT>> excluded,
            GrowCondition growcondition, double exploreProbability) {
        Subgraph<ET, RT> subgraph = new Subgraph<ET, RT>();
        Queue<Entity<ET, RT>> queue = new LinkedList<Entity<ET, RT>>();
        //Initialize
        queue.addAll(start);
        while (!queue.isEmpty() && growcondition.continueGrowing(subgraph)) {
            Entity<ET, RT> entity = queue.poll();
            if (subgraph.containsEntity(entity))
                continue; //We have already visited this entity
            subgraph.addEntity(entity);
            for (Relation<ET, RT> relation : entity.getAllRelations()) {
                boolean hasNewEntity = false;
                boolean isExcluded = false;
                for (int i = 0; i < relation.getArity(); i++) {
                    Entity<ET, RT> ngh = relation.get(i);
                    if (ngh.equals(entity))
                        continue;
                    if (excluded.contains(ngh)) {
                        isExcluded = true;
                        break;
                    }
                    if (!subgraph.containsEntity(ngh))
                        hasNewEntity = true;
                }
                if (!isExcluded && !hasNewEntity)
                    subgraph.addRelation(relation);
                else if (!isExcluded && hasNewEntity) {
                    for (int i = 0; i < relation.getArity(); i++) {
                        Entity<ET, RT> ngh = relation.get(i);
                        if (!ngh.equals(entity) && !subgraph.containsEntity(ngh)) {
                            if (Math.random() < exploreProbability)
                                queue.add(ngh);
                        }
                    }
                }
            }
        }
        return subgraph;
    }

    private abstract class GrowCondition {
        abstract boolean continueGrowing(Subgraph<ET, RT> subgraph);
    }

    private class KeepGrowing extends GrowCondition {
        @Override
        boolean continueGrowing(Subgraph<ET, RT> subgraph) {
            return true;
        }
    }

    private class SizeLimit extends GrowCondition {
        private final ET type;
        private final int size;

        public SizeLimit(ET t, int s) {
            type = t;
            size = s;
        }

        @Override
        boolean continueGrowing(Subgraph<ET, RT> subgraph) {
            return subgraph.getEntities(type).size() < size;
        }
    }
}