2010-09-18 11:11:44 -04:00
|
|
|
<?php
|
|
|
|
namespace Moserware\Skills\TrueSkill\Factors;
|
|
|
|
|
2010-09-25 10:15:51 -04:00
|
|
|
require_once(dirname(__FILE__) . "/GaussianFactor.php");
|
2010-09-25 22:40:56 -04:00
|
|
|
require_once(dirname(__FILE__) . "/../../Guard.php");
|
2010-09-25 10:15:51 -04:00
|
|
|
require_once(dirname(__FILE__) . "/../../FactorGraphs/Message.php");
|
|
|
|
require_once(dirname(__FILE__) . "/../../FactorGraphs/Variable.php");
|
|
|
|
require_once(dirname(__FILE__) . "/../../Numerics/GaussianDistribution.php");
|
2010-09-25 20:12:38 -04:00
|
|
|
require_once(dirname(__FILE__) . "/../../Numerics/BasicMath.php");
|
|
|
|
|
2010-09-18 17:56:57 -04:00
|
|
|
use Moserware\Numerics\GaussianDistribution;
|
2010-09-25 22:40:56 -04:00
|
|
|
use Moserware\Skills\Guard;
|
2010-09-18 17:56:57 -04:00
|
|
|
use Moserware\Skills\FactorGraphs\Message;
|
|
|
|
use Moserware\Skills\FactorGraphs\Variable;
|
|
|
|
|
2010-09-18 11:11:44 -04:00
|
|
|
/// <summary>
|
|
|
|
/// Factor that sums together multiple Gaussians.
|
|
|
|
/// </summary>
|
|
|
|
/// <remarks>See the accompanying math paper for more details.</remarks>
|
|
|
|
class GaussianWeightedSumFactor extends GaussianFactor
|
|
|
|
{
|
|
|
|
private $_variableIndexOrdersForWeights = array();
|
|
|
|
|
|
|
|
// This following is used for convenience, for example, the first entry is [0, 1, 2]
|
|
|
|
// corresponding to v[0] = a1*v[1] + a2*v[2]
|
|
|
|
private $_weights;
|
|
|
|
private $_weightsSquared;
|
|
|
|
|
2010-09-25 15:46:23 -04:00
|
|
|
public function __construct(Variable &$sumVariable, array &$variablesToSum, array &$variableWeights = null)
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
2010-09-25 20:12:38 -04:00
|
|
|
parent::__construct(self::createName($sumVariable, $variablesToSum, $variableWeights));
|
2010-09-18 11:11:44 -04:00
|
|
|
$this->_weights = array();
|
|
|
|
$this->_weightsSquared = array();
|
|
|
|
|
|
|
|
// The first weights are a straightforward copy
|
|
|
|
// v_0 = a_1*v_1 + a_2*v_2 + ... + a_n * v_n
|
|
|
|
$this->_weights[0] = array();
|
|
|
|
|
|
|
|
$variableWeightsLength = count($variableWeights);
|
|
|
|
|
|
|
|
for($i = 0; $i < $variableWeightsLength; $i++)
|
|
|
|
{
|
2010-09-25 22:40:56 -04:00
|
|
|
$weight = &$variableWeights[$i];
|
2010-09-18 11:11:44 -04:00
|
|
|
$this->_weights[0][$i] = $weight;
|
2010-09-25 22:40:56 -04:00
|
|
|
$this->_weightsSquared[0][$i] = square($weight);
|
2010-09-18 11:11:44 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
$variablesToSumLength = count($variablesToSum);
|
|
|
|
|
|
|
|
// 0..n-1
|
2010-09-26 18:32:18 -04:00
|
|
|
$this->_variableIndexOrdersForWeights[0] = array();
|
2010-09-18 11:11:44 -04:00
|
|
|
for($i = 0; $i < ($variablesToSumLength + 1); $i++)
|
|
|
|
{
|
2010-09-26 18:32:18 -04:00
|
|
|
$this->_variableIndexOrdersForWeights[0][] = $i;
|
2010-09-18 11:11:44 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
// The rest move the variables around and divide out the constant.
|
|
|
|
// For example:
|
|
|
|
// v_1 = (-a_2 / a_1) * v_2 + (-a3/a1) * v_3 + ... + (1.0 / a_1) * v_0
|
|
|
|
// By convention, we'll put the v_0 term at the end
|
|
|
|
|
|
|
|
$weightsLength = $variableWeightsLength + 1;
|
|
|
|
for ($weightsIndex = 1; $weightsIndex < $weightsLength; $weightsIndex++)
|
2010-09-26 18:32:18 -04:00
|
|
|
{
|
|
|
|
$currentWeights = array();
|
|
|
|
$this->_weights[$weightsIndex] = &$currentWeights;
|
2010-09-18 11:11:44 -04:00
|
|
|
|
|
|
|
$variableIndices = array();
|
2010-09-26 18:32:18 -04:00
|
|
|
$variableIndices[0] = $weightsIndex;
|
2010-09-18 11:11:44 -04:00
|
|
|
|
|
|
|
$currentWeightsSquared = array();
|
2010-09-26 18:32:18 -04:00
|
|
|
$this->_weightsSquared[$weightsIndex] = &$currentWeightsSquared;
|
2010-09-18 11:11:44 -04:00
|
|
|
|
|
|
|
// keep a single variable to keep track of where we are in the array.
|
|
|
|
// This is helpful since we skip over one of the spots
|
|
|
|
$currentDestinationWeightIndex = 0;
|
|
|
|
|
2010-09-25 20:12:38 -04:00
|
|
|
$variableWeightsLength = count($variableWeights);
|
|
|
|
|
2010-09-18 11:11:44 -04:00
|
|
|
for ($currentWeightSourceIndex = 0;
|
2010-09-25 20:12:38 -04:00
|
|
|
$currentWeightSourceIndex < $variableWeightsLength;
|
2010-09-18 11:11:44 -04:00
|
|
|
$currentWeightSourceIndex++)
|
|
|
|
{
|
|
|
|
if ($currentWeightSourceIndex == ($weightsIndex - 1))
|
|
|
|
{
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
$currentWeight = (-$variableWeights[$currentWeightSourceIndex]/$variableWeights[$weightsIndex - 1]);
|
|
|
|
|
|
|
|
if ($variableWeights[$weightsIndex - 1] == 0)
|
|
|
|
{
|
|
|
|
// HACK: Getting around division by zero
|
|
|
|
$currentWeight = 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
$currentWeights[$currentDestinationWeightIndex] = $currentWeight;
|
2010-09-25 22:16:47 -04:00
|
|
|
$currentWeightsSquared[$currentDestinationWeightIndex] = $currentWeight*$currentWeight;
|
2010-09-18 11:11:44 -04:00
|
|
|
|
|
|
|
$variableIndices[$currentDestinationWeightIndex + 1] = $currentWeightSourceIndex + 1;
|
|
|
|
$currentDestinationWeightIndex++;
|
|
|
|
}
|
|
|
|
|
|
|
|
// And the final one
|
|
|
|
$finalWeight = 1.0/$variableWeights[$weightsIndex - 1];
|
|
|
|
|
|
|
|
if ($variableWeights[$weightsIndex - 1] == 0)
|
|
|
|
{
|
|
|
|
// HACK: Getting around division by zero
|
|
|
|
$finalWeight = 0;
|
|
|
|
}
|
2010-09-25 20:12:38 -04:00
|
|
|
$currentWeights[$currentDestinationWeightIndex] = $finalWeight;
|
|
|
|
$currentWeightsSquared[$currentDestinationWeightIndex] = square($finalWeight);
|
2010-09-26 18:32:18 -04:00
|
|
|
$variableIndices[count($variableWeights)] = 0;
|
2010-09-25 20:12:38 -04:00
|
|
|
$this->_variableIndexOrdersForWeights[] = &$variableIndices;
|
2010-09-18 11:11:44 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
$this->createVariableToMessageBinding($sumVariable);
|
|
|
|
|
|
|
|
foreach ($variablesToSum as $currentVariable)
|
|
|
|
{
|
|
|
|
$this->createVariableToMessageBinding($currentVariable);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
public function getLogNormalization()
|
|
|
|
{
|
2010-09-25 22:40:56 -04:00
|
|
|
$vars = &$this->getVariables();
|
|
|
|
$messages = &$this->getMessages();
|
2010-09-18 11:11:44 -04:00
|
|
|
|
|
|
|
$result = 0.0;
|
|
|
|
|
|
|
|
// We start at 1 since offset 0 has the sum
|
|
|
|
$varCount = count($vars);
|
|
|
|
for ($i = 1; $i < $varCount; $i++)
|
|
|
|
{
|
2010-09-25 22:40:56 -04:00
|
|
|
$result += GaussianDistribution::logRatioNormalization($vars[$i]->getValue(), $messages[$i]->getValue());
|
2010-09-18 11:11:44 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
return $result;
|
|
|
|
}
|
|
|
|
|
2010-09-25 15:46:23 -04:00
|
|
|
private function updateHelper(array &$weights, array &$weightsSquared,
|
|
|
|
array &$messages,
|
|
|
|
array &$variables)
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
|
|
|
// Potentially look at http://mathworld.wolfram.com/NormalSumDistribution.html for clues as
|
|
|
|
// to what it's doing
|
|
|
|
|
2010-09-25 22:40:56 -04:00
|
|
|
$messages = &$this->getMessages();
|
2010-09-18 11:11:44 -04:00
|
|
|
$message0 = clone $messages[0]->getValue();
|
|
|
|
$marginal0 = clone $variables[0]->getValue();
|
|
|
|
|
|
|
|
// The math works out so that 1/newPrecision = sum of a_i^2 /marginalsWithoutMessages[i]
|
|
|
|
$inverseOfNewPrecisionSum = 0.0;
|
|
|
|
$anotherInverseOfNewPrecisionSum = 0.0;
|
|
|
|
$weightedMeanSum = 0.0;
|
|
|
|
$anotherWeightedMeanSum = 0.0;
|
|
|
|
|
|
|
|
$weightsSquaredLength = count($weightsSquared);
|
|
|
|
|
|
|
|
for ($i = 0; $i < $weightsSquaredLength; $i++)
|
|
|
|
{
|
|
|
|
// These flow directly from the paper
|
|
|
|
|
2010-09-25 22:40:56 -04:00
|
|
|
$inverseOfNewPrecisionSum += $weightsSquared[$i]/
|
2010-09-18 11:11:44 -04:00
|
|
|
($variables[$i + 1]->getValue()->getPrecision() - $messages[$i + 1]->getValue()->getPrecision());
|
|
|
|
|
|
|
|
$diff = GaussianDistribution::divide($variables[$i + 1]->getValue(), $messages[$i + 1]->getValue());
|
2010-09-25 22:40:56 -04:00
|
|
|
$anotherInverseOfNewPrecisionSum += $weightsSquared[$i]/$diff->getPrecision();
|
2010-09-18 11:11:44 -04:00
|
|
|
|
2010-09-25 22:40:56 -04:00
|
|
|
$weightedMeanSum += $weights[$i]
|
2010-09-18 11:11:44 -04:00
|
|
|
*
|
|
|
|
($variables[$i + 1]->getValue()->getPrecisionMean() - $messages[$i + 1]->getValue()->getPrecisionMean())
|
|
|
|
/
|
|
|
|
($variables[$i + 1]->getValue()->getPrecision() - $messages[$i + 1]->getValue()->getPrecision());
|
|
|
|
|
|
|
|
$anotherWeightedMeanSum += $weights[$i]*$diff->getPrecisionMean()/$diff->getPrecision();
|
|
|
|
}
|
|
|
|
|
|
|
|
$newPrecision = 1.0/$inverseOfNewPrecisionSum;
|
|
|
|
$anotherNewPrecision = 1.0/$anotherInverseOfNewPrecisionSum;
|
|
|
|
|
|
|
|
$newPrecisionMean = $newPrecision*$weightedMeanSum;
|
|
|
|
$anotherNewPrecisionMean = $anotherNewPrecision*$anotherWeightedMeanSum;
|
|
|
|
|
|
|
|
$newMessage = GaussianDistribution::fromPrecisionMean($newPrecisionMean, $newPrecision);
|
|
|
|
$oldMarginalWithoutMessage = GaussianDistribution::divide($marginal0, $message0);
|
|
|
|
|
|
|
|
$newMarginal = GaussianDistribution::multiply($oldMarginalWithoutMessage, $newMessage);
|
|
|
|
|
|
|
|
/// Update the message and marginal
|
|
|
|
|
|
|
|
$messages[0]->setValue($newMessage);
|
|
|
|
$variables[0]->setValue($newMarginal);
|
|
|
|
|
|
|
|
/// Return the difference in the new marginal
|
|
|
|
$finalDiff = GaussianDistribution::subtract($newMarginal, $marginal0);
|
|
|
|
return $finalDiff;
|
|
|
|
}
|
|
|
|
|
|
|
|
public function updateMessageIndex($messageIndex)
|
|
|
|
{
|
2010-09-25 22:40:56 -04:00
|
|
|
$allMessages = &$this->getMessages();
|
|
|
|
$allVariables = &$this->getVariables();
|
2010-09-18 11:11:44 -04:00
|
|
|
|
2010-09-25 22:40:56 -04:00
|
|
|
Guard::argumentIsValidIndex($messageIndex, count($allMessages), "messageIndex");
|
2010-09-18 11:11:44 -04:00
|
|
|
|
|
|
|
$updatedMessages = array();
|
|
|
|
$updatedVariables = array();
|
|
|
|
|
2010-09-25 22:40:56 -04:00
|
|
|
$indicesToUse = &$this->_variableIndexOrdersForWeights[$messageIndex];
|
2010-09-18 11:11:44 -04:00
|
|
|
|
|
|
|
// The tricky part here is that we have to put the messages and variables in the same
|
|
|
|
// order as the weights. Thankfully, the weights and messages share the same index numbers,
|
|
|
|
// so we just need to make sure they're consistent
|
|
|
|
$allMessagesCount = count($allMessages);
|
2010-09-25 22:40:56 -04:00
|
|
|
for ($i = 0; $i < $allMessagesCount; $i++)
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
|
|
|
$updatedMessages[] =$allMessages[$indicesToUse[$i]];
|
|
|
|
$updatedVariables[] = $allVariables[$indicesToUse[$i]];
|
|
|
|
}
|
|
|
|
|
2010-09-26 20:56:58 -04:00
|
|
|
return $this->updateHelper($this->_weights[$messageIndex],
|
|
|
|
$this->_weightsSquared[$messageIndex],
|
|
|
|
$updatedMessages,
|
|
|
|
$updatedVariables);
|
2010-09-18 11:11:44 -04:00
|
|
|
}
|
2010-09-25 20:12:38 -04:00
|
|
|
|
|
|
|
private static function createName($sumVariable, $variablesToSum, $variableWeights)
|
|
|
|
{
|
|
|
|
return "TODO";
|
|
|
|
}
|
2010-09-18 11:11:44 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
?>
|