com.analog.lyric.dimple.model.transform.JunctionTreeTransformMap.java Source code

Java tutorial

Introduction

Here is the source code for com.analog.lyric.dimple.model.transform.JunctionTreeTransformMap.java

Source

/*******************************************************************************
*   Copyright 2014 Analog Devices, Inc.
*
*   Licensed under the Apache License, Version 2.0 (the "License");
*   you may not use this file except in compliance with the License.
*   You may obtain a copy of the License at
*
*       http://www.apache.org/licenses/LICENSE-2.0
*
*   Unless required by applicable law or agreed to in writing, software
*   distributed under the License is distributed on an "AS IS" BASIS,
*   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*   See the License for the specific language governing permissions and
*   limitations under the License.
********************************************************************************/

package com.analog.lyric.dimple.model.transform;

import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

import org.eclipse.jdt.annotation.Nullable;

import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.Domain;
import com.analog.lyric.dimple.model.domains.JointDiscreteDomain;
import com.analog.lyric.dimple.model.domains.JointDomainIndexer;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.util.misc.Internal;
import com.google.common.collect.Iterables;

/**
 * Junction tree mapping generated by {@link JunctionTreeTransform}.
 * <p>
 * This holds references to the source factor graph, the transformation of the graph into a tree,
 * and information that describes the mapping between the two.
 * <p>
 * @since 0.05
 * @author Christopher Barber
 */
public class JunctionTreeTransformMap {
    /*-------
     * State
     */

    private final FactorGraph _sourceModel;
    private final long _sourceVersion;
    private final FactorGraph _targetModel;
    private final @Nullable Map<Factor, Factor> _sourceToTargetFactors;
    private final @Nullable Map<Variable, Variable> _sourceToTargetVariables;
    /**
     * Newly created joint variables that are deterministically computed from component variables.
     */
    private final LinkedHashMap<Variable, AddedJointVariable<?>> _addedDeterministicVariables;
    private final Set<Variable> _conditionedVariables;

    /**
     * Represents a variable that joins two or more other variables along an edge between
     * two factors in the target model to ensure that it is singly connected. There may be
     * multiple such joined variables for the same set of underlying variables.
     * 
     * @param <Var> specifies the variable type. Currently only {@link Discrete} is supported.
     * @since 0.05
     * @author Christopher Barber
     */
    public static abstract class AddedJointVariable<Var extends Variable> {
        protected final Var _variable;
        protected final Var[] _inputs;

        protected AddedJointVariable(Var newVariable, Var[] inputVariables) {
            _variable = newVariable;
            _inputs = inputVariables;
        }

        @Internal
        public abstract void updateGuess();

        /**
         * The domain of the joined variable.
         */
        public Domain getDomain() {
            return _variable.getDomain();
        }

        /**
         * The joined variable itself.
         */
        public Var getVariable() {
            return _variable;
        }

        /**
         * The i'th variable that is joined into this one.
         */
        public Var getInput(int i) {
            return _inputs[i];
        }

        /**
         * The number of variables that were joined into this one.
         */
        public final int getInputCount() {
            return _inputs.length;
        }

        /**
         * @category internal
         */
        @Internal
        public abstract void updateValue(Value newVariableValue, Value[] inputs);
    }

    public static class AddedJointDiscreteVariable extends AddedJointVariable<Discrete> {

        /**
         * @param newVariable
         * @param inputVariables
         */
        public AddedJointDiscreteVariable(Discrete newVariable, Discrete[] inputVariables) {
            super(newVariable, inputVariables);
            assert (invariantsHold());
        }

        private boolean invariantsHold() {
            JointDomainIndexer domain = getDomain().getDomainIndexer();
            assert (domain.size() == _inputs.length);
            for (int i = 0; i < _inputs.length; ++i) {
                assert (domain.get(i) == _inputs[i].getDomain());
            }
            return true;
        }

        @Override
        public JointDiscreteDomain<?> getDomain() {
            return (JointDiscreteDomain<?>) getVariable().getDomain();
        }

        @Override
        public void updateGuess() {
            final JointDomainIndexer indexer = getDomain().getDomainIndexer();
            final int[] indices = indexer.allocateIndices(null);
            boolean allWereSet = true;
            for (int i = 0; i < _inputs.length; ++i) {
                Discrete input = getInput(i);
                allWereSet &= input.guessWasSet() || input.hasFixedValue();
                indices[i] = getInput(i).getGuessIndex();
            }
            Discrete var = getVariable();
            if (allWereSet) {
                var.setGuessIndex(indexer.jointIndexFromIndices(indices));
            } else {
                var.setGuess(null);
            }
        }

        @Override
        public void updateValue(Value newVariableValue, Value[] inputs) {
            JointDomainIndexer indexer = getDomain().getDomainIndexer();
            newVariableValue.setIndex(indexer.jointIndexFromValues(inputs));
        }

    }

    /*--------------
     * Construction
     */

    protected JunctionTreeTransformMap(FactorGraph source, FactorGraph target) {
        final boolean identity = (source == target);
        _sourceModel = source;
        _sourceVersion = source.structureVersion();
        _targetModel = target;
        _sourceToTargetVariables = identity ? null : new HashMap<Variable, Variable>(source.getVariableCount());
        _sourceToTargetFactors = identity ? null : new HashMap<Factor, Factor>(source.getFactorCount());
        _addedDeterministicVariables = new LinkedHashMap<Variable, AddedJointVariable<?>>();
        _conditionedVariables = new LinkedHashSet<Variable>();
    }

    protected JunctionTreeTransformMap(FactorGraph source) {
        this(source, source);
    }

    static JunctionTreeTransformMap create(FactorGraph source, FactorGraph target) {
        return new JunctionTreeTransformMap(source, target);
    }

    static JunctionTreeTransformMap identity(FactorGraph model) {
        return new JunctionTreeTransformMap(model);
    }

    /*---------
     * Methods
     */

    public Iterable<AddedJointVariable<?>> addedJointVariables() {
        return Iterables.unmodifiableIterable(_addedDeterministicVariables.values());
    }

    public @Nullable <Var extends Variable> AddedJointVariable<Var> getAddedDeterministicVariable(
            Var targetVariable) {
        @SuppressWarnings("unchecked")
        AddedJointVariable<Var> var = (AddedJointVariable<Var>) _addedDeterministicVariables.get(targetVariable);
        return var;
    }

    /**
     * Unmodifiable set of source variables that have been conditioned out of
     * the target graph.
     */
    public Set<Variable> conditionedVariables() {
        return Collections.unmodifiableSet(_conditionedVariables);
    }

    /**
     * True if mapping is the identity mapping, which is a simple copy of the graph.
     */
    public boolean isIdentity() {
        return _sourceToTargetVariables == null;
    }

    /**
     * True if the current mapping is up-to-date with respect to the current state of
     * the {@link #source()} model (and therefore can be reused for inference).
     */
    public boolean isValid() {
        if (_sourceVersion != _sourceModel.structureVersion()) {
            return false;
        }

        for (Variable sourceVar : _conditionedVariables) {
            if (!sourceVar.hasFixedValue())
                return false;
            Variable targetVar = sourceToTargetVariable(sourceVar);
            if (!targetVar.hasFixedValue())
                return false;
            if (!Objects.equals(sourceVar.getPrior(), targetVar.getPrior()))
                return false;
        }

        return true;
    }

    /**
     * The original model from which the transformation was generated.
     */
    public FactorGraph source() {
        return _sourceModel;
    }

    /**
     * Returns the target factor that subsumes the given {@code sourceFactor}.
     * <p>
     * As long as the transform {@link #isValid()} this is guaranteed to return a
     * non-null variable in {@link #target()} for every variable in {@link #source()}.
     * Note that unlike {@link #sourceToTargetVariable(Variable)} the target factor
     * may not exactly correspond to the source factor. Instead it may represent the
     * product of multiple factors.
     * <p>
     * @see #sourceToTargetFactors()
     */
    public Factor sourceToTargetFactor(Factor sourceFactor) {
        final Map<Factor, Factor> sourceToTargetFactors = _sourceToTargetFactors;
        if (sourceToTargetFactors == null) {
            return sourceFactor;
        }
        return sourceToTargetFactors.get(sourceFactor);
    }

    /**
     * Returns a read-only mapping from factors in {@link #source()} to factors
     * in {@link #target()}.
     * 
     * @see #sourceToTargetFactor(Factor)
     */
    public Map<Factor, Factor> sourceToTargetFactors() {
        if (_sourceToTargetFactors == null) {
            return Collections.emptyMap();
        }
        return Collections.unmodifiableMap(_sourceToTargetFactors);
    }

    /**
     * Returns the target variable corresponding to the given {@code sourceVariable}.
     * <p>
     * As long as the transform {@link #isValid()} this is guaranteed to return a
     * non-null variable in {@link #target()} for every variable in {@link #source()}.
     * <p>
     * @see #sourceToTargetVariables()
     */
    public Variable sourceToTargetVariable(Variable sourceVariable) {
        final Map<Variable, Variable> sourceToTargetVariables = _sourceToTargetVariables;
        if (sourceToTargetVariables == null) {
            return sourceVariable;
        }
        return sourceToTargetVariables.get(sourceVariable);
    }

    /**
     * Returns a read-only mapping from variables in {@link #source()} to variables
     * in {@link #target()}.
     * 
     * @see #sourceToTargetVariable(Variable)
     */
    public Map<Variable, Variable> sourceToTargetVariables() {
        if (_sourceToTargetVariables == null) {
            return Collections.emptyMap();
        }
        return Collections.unmodifiableMap(_sourceToTargetVariables);
    }

    /**
     * Value of {@link FactorGraph#getVersionId()} of {@link #source()} when
     * transform map was created.
     */
    public long sourceVersion() {
        return _sourceVersion;
    }

    /**
     * The generated target model generated from {@link #source()} by {@link JunctionTreeTransform}.
     * <p>
     * As long as {@link #isValid()} this will have variables corresponding to the ones in the source model.
     * <p>
     * @see #sourceToTargetVariable(Variable)
     * @see #sourceToTargetFactor(Factor)
     */
    public FactorGraph target() {
        return _targetModel;
    }

    /*------------------
     * Internal methods
     */

    /**
     * @category internal
     */
    @Internal
    public void updateGuesses() {

        for (Map.Entry<Variable, Variable> entry : sourceToTargetVariables().entrySet()) {
            Variable sourceVar = entry.getKey();
            Variable targetVar = entry.getValue();

            if (!sourceVar.guessWasSet()) {
                targetVar.setGuess(null);
            } else {
                if (sourceVar instanceof Discrete) {
                    ((Discrete) targetVar).setGuessIndex(((Discrete) sourceVar).getGuessIndex());
                } else {
                    targetVar.setGuess(sourceVar.getGuess());
                }
            }
        }

        for (AddedJointVariable<?> added : addedJointVariables()) {
            added.updateGuess();
        }
    }

    /*-----------------
     * Package methods
     */

    void addConditionedVariable(Variable variable) {
        assert (variable.hasFixedValue());
        _conditionedVariables.add(variable);
    }

    void addDeterministicVariable(AddedJointVariable<?> addedVar) {
        _addedDeterministicVariables.put(addedVar.getVariable(), addedVar);
    }

    void addFactorMapping(Factor sourceFactor, Factor targetFactor) {
        Objects.requireNonNull(_sourceToTargetFactors).put(sourceFactor, targetFactor);
    }

    void addVariableMapping(Variable sourceVariable, Variable targetVariable) {
        Objects.requireNonNull(_sourceToTargetVariables).put(sourceVariable, targetVariable);
    }

}