edu.umd.cs.psl.model.kernel.externalinducer.ExternalInducerKernel.java Source code

Java tutorial

Introduction

Here is the source code for edu.umd.cs.psl.model.kernel.externalinducer.ExternalInducerKernel.java

Source

/*
 * This file is part of the PSL software.
 * Copyright 2011 University of Maryland
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package edu.umd.cs.psl.model.kernel.externalinducer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.commons.lang.builder.HashCodeBuilder;

import com.google.common.base.Preconditions;

import edu.umd.cs.psl.application.GroundingMode;
import edu.umd.cs.psl.application.ModelApplication;
import edu.umd.cs.psl.database.DatabaseAtomStoreQuery;
import edu.umd.cs.psl.database.ResultList;
import edu.umd.cs.psl.database.ResultListValues;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.argument.Entity;
import edu.umd.cs.psl.model.argument.GroundTerm;
import edu.umd.cs.psl.model.argument.Term;
import edu.umd.cs.psl.model.argument.Variable;
import edu.umd.cs.psl.model.argument.type.VariableTypeMap;
import edu.umd.cs.psl.model.atom.Atom;
import edu.umd.cs.psl.model.atom.AtomEvent;
import edu.umd.cs.psl.model.atom.AtomEventFramework;
import edu.umd.cs.psl.model.atom.AtomEventSets;
import edu.umd.cs.psl.model.atom.AtomManager;
import edu.umd.cs.psl.model.formula.Conjunction;
import edu.umd.cs.psl.model.function.BulkExternalFunction;
import edu.umd.cs.psl.model.function.ExternalFunction;
import edu.umd.cs.psl.model.function.LearnableExternalFunction;
import edu.umd.cs.psl.model.kernel.Kernel;
import edu.umd.cs.psl.model.parameters.Parameters;
import edu.umd.cs.psl.model.parameters.PositiveWeight;
import edu.umd.cs.psl.model.parameters.Weight;
import edu.umd.cs.psl.model.predicate.ExternalFunctionPredicate;
import edu.umd.cs.psl.model.predicate.Predicate;
import edu.umd.cs.psl.model.predicate.StandardPredicate;
import edu.umd.cs.psl.model.predicate.type.PredicateTypes;

public class ExternalInducerKernel implements Kernel {

    private final Model model;
    private PositiveWeight weight;

    private final VariableTypeMap vartypes;
    private final Conjunction body;
    private final Atom head;
    private Atom externalAtom;
    private final Set<Variable> externalAtomVariables;
    private ExternalFunction extFun;
    private final Map<Variable, List<Atom>> retrievals;

    private final int hashcode;

    public ExternalInducerKernel(Model m, Conjunction c, Atom head, double w) {
        model = m;
        weight = new PositiveWeight(w);

        Preconditions.checkArgument(head.getPredicate() instanceof StandardPredicate);
        Preconditions.checkArgument(head.getPredicate().getNumberOfValues() == 1);
        this.head = head;
        retrievals = new HashMap<Variable, List<Atom>>();
        vartypes = new VariableTypeMap();
        for (int i = 0; i < head.getArity(); i++) {
            Term t = head.getArguments()[i];
            if (t instanceof Variable) {
                retrievals.put((Variable) t, new ArrayList<Atom>());
                vartypes.put((Variable) t, head.getPredicate().getArgumentType(i));
            }

        }

        externalAtom = null;
        Set<Variable> associatedVariables = new HashSet<Variable>();

        body = c;
        //Preconditions.checkArgument(FormulaUtil.isConjunction(c));
        for (Atom atom : c.getAtoms(new HashSet<Atom>())) {
            if (atom.getPredicate() instanceof ExternalFunctionPredicate) {
                if (externalAtom != null)
                    throw new IllegalArgumentException("Only one external function is allowed!");
                externalAtom = atom;
                extFun = ((ExternalFunctionPredicate) atom.getPredicate()).getExternalFunction();
            } else {
                assert atom.getPredicate() instanceof StandardPredicate;
                boolean associated = false;
                for (Term t : atom.getArguments()) {
                    if ((t instanceof Variable) && retrievals.containsKey(t)) {
                        if (associated)
                            throw new IllegalArgumentException(
                                    "One atom cannot be associated with multiple variables: " + atom);
                        associated = true;
                        retrievals.get(t).add(atom);
                    } else if (t instanceof Variable)
                        associatedVariables.add((Variable) t);
                }
                if (!associated)
                    throw new IllegalArgumentException("Atom could not be associated: " + atom);
            }
        }
        externalAtomVariables = new HashSet<Variable>();
        for (Term t : externalAtom.getArguments()) {
            if ((t instanceof Variable)) {
                Variable v = (Variable) t;
                externalAtomVariables.add(v);
                if (!associatedVariables.contains(v) && !retrievals.containsKey(v))
                    throw new IllegalArgumentException(
                            "Variable in external function could not be associated: " + v);
            }
        }

        hashcode = new HashCodeBuilder().append(body).append(head).toHashCode();
    }

    @Override
    public Kernel clone() {
        return new ExternalInducerKernel(model, body, head, weight.getWeight());
    }

    public Weight getWeight() {
        return weight;
    }

    @Override
    public Parameters getParameters() {
        return weight.duplicate();
    }

    @Override
    public void setParameters(Parameters para) {
        if (!(para instanceof Weight))
            throw new IllegalArgumentException("Expected weight parameter!");
        PositiveWeight newweight = (PositiveWeight) para;
        if (!newweight.equals(weight)) {
            weight = newweight;
            model.notifyKernelParametersModified(this);
        }
    }

    @Override
    public boolean isCompatibilityKernel() {
        return true;
    }

    private Set<Predicate> getRetrievalPredicates() {
        Set<Predicate> predicates = new HashSet<Predicate>();
        for (List<Atom> atoms : retrievals.values()) {
            for (Atom atom : atoms) {
                predicates.add(atom.getPredicate());
            }
        }
        return predicates;
    }

    private void validatePredicates(DatabaseAtomStoreQuery db) {
        for (Predicate p : getRetrievalPredicates()) {
            Preconditions.checkArgument(db.isClosed(p), p);
            Preconditions.checkArgument(p.getType() == PredicateTypes.BooleanTruth, p);
        }
    }

    public void learn(DatabaseAtomStoreQuery db) {
        if (!(extFun instanceof LearnableExternalFunction))
            throw new IllegalStateException("The provided external function is not learnable!");
        LearnableExternalFunction learnFun = (LearnableExternalFunction) extFun;
        getDataAndExecute(db, Execution.Learn, null);
    }

    private static enum Execution {
        Bulk, Learn
    }

    @SuppressWarnings("unchecked")
    private void getDataAndExecute(DatabaseAtomStoreQuery db, Execution exe, ModelApplication app) {
        Map<GroundTerm, GroundTerm[]>[] datas = (Map<GroundTerm, GroundTerm[]>[]) new Object[retrievals.size()];
        Map<Variable, int[]> argMap = new HashMap<Variable, int[]>();

        Map<Variable, Integer> varkeymap = new HashMap<Variable, Integer>();
        List<Variable> keyvarmap = new ArrayList<Variable>();
        for (Variable v : retrievals.keySet()) {
            keyvarmap.add(v);
            varkeymap.put(v, varkeymap.size());
        }
        for (Map.Entry<Variable, Integer> varpos : varkeymap.entrySet()) {
            int position = varpos.getValue();
            Variable var = varpos.getKey();
            List<Atom> atoms = retrievals.get(var);
            datas[position] = new HashMap<GroundTerm, GroundTerm[]>();
            assert atoms != null;
            if (atoms.isEmpty()) {
                assert externalAtomVariables.contains(var);
                Set<Entity> entities = db.getEntities(vartypes.get(var));
                for (Entity e : entities) {
                    datas[position].put(e, new GroundTerm[] { e });
                }
                argMap.put(var, new int[] { position, 0 });
            } else {
                Conjunction c = new Conjunction(atoms.toArray(new Atom[atoms.size()]));
                List<Variable> projectTo = new ArrayList<Variable>();
                projectTo.add(var);
                for (Atom a : atoms)
                    for (Term t : a.getArguments())
                        if (externalAtomVariables.contains(t) && !projectTo.contains(t))
                            projectTo.add((Variable) t);
                assert !projectTo.isEmpty();
                ResultList results = db.query(c, projectTo);
                for (int k = 0; k < results.size(); k++) {
                    GroundTerm[] result = results.get(k);
                    if (externalAtomVariables.contains(var))
                        datas[position].put(result[0], result);
                    else
                        datas[position].put(result[0], Arrays.copyOfRange(result, 1, result.length));
                }
                int start = externalAtomVariables.contains(var) ? 0 : 1;
                for (int i = start; i < projectTo.size(); i++) {
                    argMap.put(projectTo.get(i), new int[] { position, i - start });
                }
            }
        }

        GroundTerm[] args = new GroundTerm[head.getArity()];
        for (int i = 0; i < head.getArity(); i++) {
            if (!(head.getArguments()[i] instanceof Variable))
                args[i] = (GroundTerm) head.getArguments()[i];
        }
        Predicate p = head.getPredicate();

        if (exe == Execution.Bulk) {
            AtomManager amanager = app.getAtomManager();
            Map<GroundTerm[], double[]> results = ((BulkExternalFunction) extFun).bulkCompute(argMap, datas);
            for (Map.Entry<GroundTerm[], double[]> entry : results.entrySet()) {
                GroundTerm[] result = entry.getKey();
                GroundTerm[] arguments = args.clone();
                //Fill in remaining args
                for (int i = 0; i < head.getArity(); i++) {
                    if ((head.getArguments()[i] instanceof Variable))
                        arguments[i] = result[varkeymap.get((Variable) head.getArguments()[i])];
                }
                GroundExternalInducer ei = new GroundExternalInducer(this, amanager.getAtom(p, arguments),
                        entry.getValue());
                app.addGroundKernel(ei);
            }
        } else if (exe == Execution.Learn) {
            GroundTerm[] query = new GroundTerm[head.getArity()];
            for (int i = 0; i < head.getArity(); i++) {
                if (head.getArguments()[i] instanceof GroundTerm)
                    query[i] = (GroundTerm) head.getArguments()[i];
            }
            ResultListValues results = db.getFacts(head.getPredicate(), query);

            Map<GroundTerm[], double[]> truth = new HashMap<GroundTerm[], double[]>();

            for (int k = 0; k < results.size(); k++) {
                truth.put(results.get(k), results.getValues(k));
            }
            ((LearnableExternalFunction) extFun).learn(truth, argMap, datas);
        } else
            throw new UnsupportedOperationException(exe.toString());
    }

    @Override
    public void groundAll(ModelApplication app) {
        if (!(extFun instanceof BulkExternalFunction))
            throw new IllegalStateException("The provided external function is not learnable!");
        getDataAndExecute(app.getDatabase(), Execution.Learn, app);
    }

    @Override
    public void notifyAtomEvent(AtomEvent event, Atom atom, GroundingMode mode, ModelApplication app) {
        if (AtomEventSets.NonDefaultFactEvent.subsumes(event)) {
            throw new UnsupportedOperationException("Changes to facts not yet supported!");
        } else if (AtomEventSets.ActivationEvent.subsumes(event)) {
            assert atom.getPredicate().equals(head.getPredicate());
            if (mode == GroundingMode.Backward) {
                //TODO: implement backward grounding of such rules
                throw new UnsupportedOperationException("Not yet implemented!");
            } //else do nothing
        } else
            throw new UnsupportedOperationException("Unsupported event encountered: " + event);
    }

    @Override
    public void registerForAtomEvents(AtomEventFramework framework, DatabaseAtomStoreQuery db) {
        validatePredicates(db);
        framework.registerAtomEventObserver(head.getPredicate(), AtomEventSets.DeOrActivationEvent, this);
        for (Predicate p : getRetrievalPredicates()) {
            framework.registerAtomEventObserver(p, AtomEventSets.NonDefaultFactEvent, this);
        }
    }

    @Override
    public void unregisterForAtomEvents(AtomEventFramework framework, DatabaseAtomStoreQuery db) {
        framework.unregisterAtomEventObserver(head.getPredicate(), AtomEventSets.DeOrActivationEvent, this);
        for (Predicate p : getRetrievalPredicates()) {
            framework.unregisterAtomEventObserver(p, AtomEventSets.NonDefaultFactEvent, this);
        }
    }

    @Override
    public String toString() {
        return "{" + weight.toString() + "} External Function: " + body.toString() + " => " + head.toString();
    }

    @Override
    public int hashCode() {
        return hashcode;
    }
}