using System; using System.Collections.Generic; using System.Collections.ObjectModel; namespace Moserware.Skills.FactorGraphs { public abstract class Factor { private readonly List> _Messages = new List>(); private readonly Dictionary, Variable> _MessageToVariableBinding = new Dictionary, Variable>(); private readonly string _Name; private readonly List> _Variables = new List>(); protected Factor(string name) { _Name = "Factor[" + name + "]"; } /// Returns the log-normalization constant of that factor public virtual double LogNormalization { get { return 0; } } /// Returns the number of messages that the factor has public int NumberOfMessages { get { return _Messages.Count; } } protected ReadOnlyCollection> Variables { get { return _Variables.AsReadOnly(); } } protected ReadOnlyCollection> Messages { get { return _Messages.AsReadOnly(); } } /// Update the message and marginal of the i-th variable that the factor is connected to public virtual double UpdateMessage(int messageIndex) { Guard.ArgumentIsValidIndex(messageIndex, _Messages.Count, "messageIndex"); return UpdateMessage(_Messages[messageIndex], _MessageToVariableBinding[_Messages[messageIndex]]); } protected virtual double UpdateMessage(Message message, Variable variable) { throw new NotImplementedException(); } /// Resets the marginal of the variables a factor is connected to public virtual void ResetMarginals() { foreach (var currentVariable in _MessageToVariableBinding.Values) { currentVariable.ResetToPrior(); } } /// Sends the ith message to the marginal and returns the log-normalization constant public virtual double SendMessage(int messageIndex) { Guard.ArgumentIsValidIndex(messageIndex, _Messages.Count, "messageIndex"); Message message = _Messages[messageIndex]; Variable variable = _MessageToVariableBinding[message]; return SendMessage(message, variable); } protected abstract double SendMessage(Message message, Variable variable); public abstract Message CreateVariableToMessageBinding(Variable variable); protected Message CreateVariableToMessageBinding(Variable variable, Message message) { int index = _Messages.Count; _Messages.Add(message); _MessageToVariableBinding[message] = variable; _Variables.Add(variable); return message; } public override string ToString() { return _Name ?? base.ToString(); } } }