using System; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Linq; using System.Text; using Moserware.Numerics; using Moserware.Skills.FactorGraphs; namespace Moserware.Skills.TrueSkill.Factors { /// /// Factor that sums together multiple Gaussians. /// /// See the accompanying math paper for more details. public class GaussianWeightedSumFactor : GaussianFactor { private readonly List _VariableIndexOrdersForWeights = new List(); // This following is used for convenience, for example, the first entry is [0, 1, 2] // corresponding to v[0] = a1*v[1] + a2*v[2] private readonly double[][] _Weights; private readonly double[][] _WeightsSquared; public GaussianWeightedSumFactor(Variable sumVariable, Variable[] variablesToSum) : this(sumVariable, variablesToSum, variablesToSum.Select(v => 1.0).ToArray()) // By default, set the weight to 1.0 { } public GaussianWeightedSumFactor(Variable sumVariable, Variable[] variablesToSum, double[] variableWeights) : base(CreateName(sumVariable, variablesToSum, variableWeights)) { _Weights = new double[variableWeights.Length + 1][]; _WeightsSquared = new double[_Weights.Length][]; // The first weights are a straightforward copy // v_0 = a_1*v_1 + a_2*v_2 + ... + a_n * v_n _Weights[0] = new double[variableWeights.Length]; Array.Copy(variableWeights, _Weights[0], variableWeights.Length); _WeightsSquared[0] = _Weights[0].Select(w => w*w).ToArray(); // 0..n-1 _VariableIndexOrdersForWeights.Add(Enumerable.Range(0, 1 + variablesToSum.Length).ToArray()); // The rest move the variables around and divide out the constant. // For example: // v_1 = (-a_2 / a_1) * v_2 + (-a3/a1) * v_3 + ... + (1.0 / a_1) * v_0 // By convention, we'll put the v_0 term at the end for (int weightsIndex = 1; weightsIndex < _Weights.Length; weightsIndex++) { var currentWeights = new double[variableWeights.Length]; _Weights[weightsIndex] = currentWeights; var variableIndices = new int[variableWeights.Length + 1]; variableIndices[0] = weightsIndex; var currentWeightsSquared = new double[variableWeights.Length]; _WeightsSquared[weightsIndex] = currentWeightsSquared; // keep a single variable to keep track of where we are in the array. // This is helpful since we skip over one of the spots int currentDestinationWeightIndex = 0; for (int currentWeightSourceIndex = 0; currentWeightSourceIndex < variableWeights.Length; currentWeightSourceIndex++) { if (currentWeightSourceIndex == (weightsIndex - 1)) { continue; } double currentWeight = (-variableWeights[currentWeightSourceIndex]/variableWeights[weightsIndex - 1]); if (variableWeights[weightsIndex - 1] == 0) { // HACK: Getting around division by zero currentWeight = 0; } currentWeights[currentDestinationWeightIndex] = currentWeight; currentWeightsSquared[currentDestinationWeightIndex] = currentWeight*currentWeight; variableIndices[currentDestinationWeightIndex + 1] = currentWeightSourceIndex + 1; currentDestinationWeightIndex++; } // And the final one double finalWeight = 1.0/variableWeights[weightsIndex - 1]; if (variableWeights[weightsIndex - 1] == 0) { // HACK: Getting around division by zero finalWeight = 0; } currentWeights[currentDestinationWeightIndex] = finalWeight; currentWeightsSquared[currentDestinationWeightIndex] = finalWeight*finalWeight; variableIndices[variableIndices.Length - 1] = 0; _VariableIndexOrdersForWeights.Add(variableIndices); } CreateVariableToMessageBinding(sumVariable); foreach (var currentVariable in variablesToSum) { CreateVariableToMessageBinding(currentVariable); } } public override double LogNormalization { get { ReadOnlyCollection> vars = Variables; ReadOnlyCollection> messages = Messages; double result = 0.0; // We start at 1 since offset 0 has the sum for (int i = 1; i < vars.Count; i++) { result += GaussianDistribution.LogRatioNormalization(vars[i].Value, messages[i].Value); } return result; } } private double UpdateHelper(double[] weights, double[] weightsSquared, IList> messages, IList> variables) { // Potentially look at http://mathworld.wolfram.com/NormalSumDistribution.html for clues as // to what it's doing GaussianDistribution message0 = messages[0].Value.Clone(); GaussianDistribution marginal0 = variables[0].Value.Clone(); // The math works out so that 1/newPrecision = sum of a_i^2 /marginalsWithoutMessages[i] double inverseOfNewPrecisionSum = 0.0; double anotherInverseOfNewPrecisionSum = 0.0; double weightedMeanSum = 0.0; double anotherWeightedMeanSum = 0.0; for (int i = 0; i < weightsSquared.Length; i++) { // These flow directly from the paper inverseOfNewPrecisionSum += weightsSquared[i]/ (variables[i + 1].Value.Precision - messages[i + 1].Value.Precision); GaussianDistribution diff = (variables[i + 1].Value/messages[i + 1].Value); anotherInverseOfNewPrecisionSum += weightsSquared[i]/diff.Precision; weightedMeanSum += weights[i] * (variables[i + 1].Value.PrecisionMean - messages[i + 1].Value.PrecisionMean) / (variables[i + 1].Value.Precision - messages[i + 1].Value.Precision); anotherWeightedMeanSum += weights[i]*diff.PrecisionMean/diff.Precision; } double newPrecision = 1.0/inverseOfNewPrecisionSum; double anotherNewPrecision = 1.0/anotherInverseOfNewPrecisionSum; double newPrecisionMean = newPrecision*weightedMeanSum; double anotherNewPrecisionMean = anotherNewPrecision*anotherWeightedMeanSum; GaussianDistribution newMessage = GaussianDistribution.FromPrecisionMean(newPrecisionMean, newPrecision); GaussianDistribution oldMarginalWithoutMessage = marginal0/message0; GaussianDistribution newMarginal = oldMarginalWithoutMessage*newMessage; /// Update the message and marginal messages[0].Value = newMessage; variables[0].Value = newMarginal; /// Return the difference in the new marginal return newMarginal - marginal0; } public override double UpdateMessage(int messageIndex) { ReadOnlyCollection> allMessages = Messages; ReadOnlyCollection> allVariables = Variables; Guard.ArgumentIsValidIndex(messageIndex, allMessages.Count, "messageIndex"); var updatedMessages = new List>(); var updatedVariables = new List>(); int[] indicesToUse = _VariableIndexOrdersForWeights[messageIndex]; // The tricky part here is that we have to put the messages and variables in the same // order as the weights. Thankfully, the weights and messages share the same index numbers, // so we just need to make sure they're consistent for (int i = 0; i < allMessages.Count; i++) { updatedMessages.Add(allMessages[indicesToUse[i]]); updatedVariables.Add(allVariables[indicesToUse[i]]); } return UpdateHelper(_Weights[messageIndex], _WeightsSquared[messageIndex], updatedMessages, updatedVariables); } private static string CreateName(Variable sumVariable, IList> variablesToSum, double[] weights) { var sb = new StringBuilder(); sb.Append(sumVariable.ToString()); sb.Append(" = "); for (int i = 0; i < variablesToSum.Count; i++) { bool isFirst = (i == 0); if (isFirst && (weights[i] < 0)) { sb.Append("-"); } sb.Append(Math.Abs(weights[i]).ToString("0.00")); sb.Append("*["); sb.Append(variablesToSum[i]); sb.Append("]"); bool isLast = (i == variablesToSum.Count - 1); if (!isLast) { if (weights[i + 1] >= 0) { sb.Append(" + "); } else { sb.Append(" - "); } } } return sb.ToString(); } } }