mirror of
				https://github.com/furyfire/trueskill.git
				synced 2025-11-04 02:02:29 +01:00 
			
		
		
		
	Start of factor graph port. Things don't work yet, but a lot of syntax updates towards PHP
This commit is contained in:
		
							
								
								
									
										216
									
								
								PHPSkills/TrueSkill/Factors/GaussianWeightedSumFactor.php
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										216
									
								
								PHPSkills/TrueSkill/Factors/GaussianWeightedSumFactor.php
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,216 @@
 | 
			
		||||
<?php
 | 
			
		||||
 | 
			
		||||
namespace Moserware\Skills\TrueSkill\Factors;
 | 
			
		||||
 | 
			
		||||
/// <summary>
 | 
			
		||||
/// Factor that sums together multiple Gaussians.
 | 
			
		||||
/// </summary>
 | 
			
		||||
/// <remarks>See the accompanying math paper for more details.</remarks>
 | 
			
		||||
class GaussianWeightedSumFactor extends GaussianFactor
 | 
			
		||||
{
 | 
			
		||||
    private $_variableIndexOrdersForWeights = array();
 | 
			
		||||
 | 
			
		||||
    // This following is used for convenience, for example, the first entry is [0, 1, 2]
 | 
			
		||||
    // corresponding to v[0] = a1*v[1] + a2*v[2]
 | 
			
		||||
    private $_weights;
 | 
			
		||||
    private $_weightsSquared;
 | 
			
		||||
 | 
			
		||||
    public function __construct(Variable $sumVariable, array $variablesToSum, array $variableWeights = null)
 | 
			
		||||
    {
 | 
			
		||||
        parent::__construct($this->createName($sumVariable, $variablesToSum, $variableWeights));
 | 
			
		||||
        $this->_weights = array();
 | 
			
		||||
        $this->_weightsSquared = array();
 | 
			
		||||
 | 
			
		||||
        // The first weights are a straightforward copy
 | 
			
		||||
        // v_0 = a_1*v_1 + a_2*v_2 + ... + a_n * v_n
 | 
			
		||||
        $this->_weights[0] = array();
 | 
			
		||||
 | 
			
		||||
        $variableWeightsLength = count($variableWeights);
 | 
			
		||||
 | 
			
		||||
        for($i = 0; $i < $variableWeightsLength; $i++)
 | 
			
		||||
        {
 | 
			
		||||
            $weight = $variableWeights[$i];
 | 
			
		||||
            $this->_weights[0][$i] = $weight;
 | 
			
		||||
            $this->_weightsSquared[0][i] = $weight * $weight;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        $variablesToSumLength = count($variablesToSum);
 | 
			
		||||
 | 
			
		||||
        // 0..n-1
 | 
			
		||||
        for($i = 0; $i < ($variablesToSumLength + 1); $i++)
 | 
			
		||||
        {
 | 
			
		||||
            $this->_variableIndexOrdersForWeights[] = $i;
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        // The rest move the variables around and divide out the constant.
 | 
			
		||||
        // For example:
 | 
			
		||||
        // v_1 = (-a_2 / a_1) * v_2 + (-a3/a1) * v_3 + ... + (1.0 / a_1) * v_0
 | 
			
		||||
        // By convention, we'll put the v_0 term at the end
 | 
			
		||||
 | 
			
		||||
        $weightsLength = $variableWeightsLength + 1;
 | 
			
		||||
        for ($weightsIndex = 1; $weightsIndex < $weightsLength; $weightsIndex++)
 | 
			
		||||
        {            
 | 
			
		||||
            $this->_weights[$weightsIndex] = array();
 | 
			
		||||
 | 
			
		||||
            $variableIndices = array();
 | 
			
		||||
            $variableIndices[] = $weightsIndex;
 | 
			
		||||
 | 
			
		||||
            $currentWeightsSquared = array();
 | 
			
		||||
            $this->_WeightsSquared[$weightsIndex] = $currentWeightsSquared;
 | 
			
		||||
 | 
			
		||||
            // keep a single variable to keep track of where we are in the array.
 | 
			
		||||
            // This is helpful since we skip over one of the spots
 | 
			
		||||
            $currentDestinationWeightIndex = 0;
 | 
			
		||||
 | 
			
		||||
            for ($currentWeightSourceIndex = 0;
 | 
			
		||||
                 $currentWeightSourceIndex < $variableWeights.Length;
 | 
			
		||||
                 $currentWeightSourceIndex++)
 | 
			
		||||
            {
 | 
			
		||||
                if ($currentWeightSourceIndex == ($weightsIndex - 1))
 | 
			
		||||
                {
 | 
			
		||||
                    continue;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                $currentWeight = (-$variableWeights[$currentWeightSourceIndex]/$variableWeights[$weightsIndex - 1]);
 | 
			
		||||
 | 
			
		||||
                if ($variableWeights[$weightsIndex - 1] == 0)
 | 
			
		||||
                {
 | 
			
		||||
                    // HACK: Getting around division by zero
 | 
			
		||||
                    $currentWeight = 0;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                $currentWeights[$currentDestinationWeightIndex] = $currentWeight;
 | 
			
		||||
                $currentWeightsSquared[$currentDestinationWeightIndex] = $currentWeight*currentWeight;
 | 
			
		||||
 | 
			
		||||
                $variableIndices[$currentDestinationWeightIndex + 1] = $currentWeightSourceIndex + 1;
 | 
			
		||||
                $currentDestinationWeightIndex++;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // And the final one
 | 
			
		||||
            $finalWeight = 1.0/$variableWeights[$weightsIndex - 1];
 | 
			
		||||
 | 
			
		||||
            if ($variableWeights[$weightsIndex - 1] == 0)
 | 
			
		||||
            {
 | 
			
		||||
                // HACK: Getting around division by zero
 | 
			
		||||
                $finalWeight = 0;
 | 
			
		||||
            }
 | 
			
		||||
            $currentWeights[$currentDestinationWeightIndex] = finalWeight;
 | 
			
		||||
            $currentWeightsSquared[currentDestinationWeightIndex] = finalWeight*finalWeight;
 | 
			
		||||
            $variableIndices[$variableIndices.Length - 1] = 0;
 | 
			
		||||
            $this->_variableIndexOrdersForWeights[] = $variableIndices;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        $this->createVariableToMessageBinding($sumVariable);
 | 
			
		||||
 | 
			
		||||
        foreach ($variablesToSum as $currentVariable)
 | 
			
		||||
        {
 | 
			
		||||
            $this->createVariableToMessageBinding($currentVariable);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public function getLogNormalization()
 | 
			
		||||
    {
 | 
			
		||||
        $vars = $this->getVariables();
 | 
			
		||||
        $messages = $this->getMessages();
 | 
			
		||||
 | 
			
		||||
        $result = 0.0;
 | 
			
		||||
 | 
			
		||||
        // We start at 1 since offset 0 has the sum
 | 
			
		||||
        $varCount = count($vars);
 | 
			
		||||
        for ($i = 1; $i < $varCount; $i++)
 | 
			
		||||
        {
 | 
			
		||||
            $result += GaussianDistribution::logRatioNormalization($vars[i]->getValue(), $messages[i]->getValue());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return $result;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private function updateHelper(array $weights, array $weightsSquared,
 | 
			
		||||
                                  array $messages,
 | 
			
		||||
                                  array $variables)
 | 
			
		||||
    {
 | 
			
		||||
        // Potentially look at http://mathworld.wolfram.com/NormalSumDistribution.html for clues as
 | 
			
		||||
        // to what it's doing
 | 
			
		||||
 | 
			
		||||
        $messages = $this->getMessages();
 | 
			
		||||
        $message0 = clone $messages[0]->getValue();
 | 
			
		||||
        $marginal0 = clone $variables[0]->getValue();
 | 
			
		||||
 | 
			
		||||
        // The math works out so that 1/newPrecision = sum of a_i^2 /marginalsWithoutMessages[i]
 | 
			
		||||
        $inverseOfNewPrecisionSum = 0.0;
 | 
			
		||||
        $anotherInverseOfNewPrecisionSum = 0.0;
 | 
			
		||||
        $weightedMeanSum = 0.0;
 | 
			
		||||
        $anotherWeightedMeanSum = 0.0;
 | 
			
		||||
 | 
			
		||||
        $weightsSquaredLength = count($weightsSquared);
 | 
			
		||||
 | 
			
		||||
        for ($i = 0; $i < $weightsSquaredLength; $i++)
 | 
			
		||||
        {
 | 
			
		||||
            // These flow directly from the paper
 | 
			
		||||
 | 
			
		||||
            $inverseOfNewPrecisionSum += $weightsSquared[i]/
 | 
			
		||||
                                         ($variables[$i + 1]->getValue()->getPrecision() - $messages[$i + 1]->getValue()->getPrecision());
 | 
			
		||||
 | 
			
		||||
            $diff = GaussianDistribution::divide($variables[$i + 1]->getValue(), $messages[$i + 1]->getValue());
 | 
			
		||||
            $anotherInverseOfNewPrecisionSum += $weightsSquared[i]/$diff->getPrecision();
 | 
			
		||||
 | 
			
		||||
            $weightedMeanSum += $weights[i]
 | 
			
		||||
                                *
 | 
			
		||||
                                ($variables[$i + 1]->getValue()->getPrecisionMean() - $messages[$i + 1]->getValue()->getPrecisionMean())
 | 
			
		||||
                                /
 | 
			
		||||
                                ($variables[$i + 1]->getValue()->getPrecision() - $messages[$i + 1]->getValue()->getPrecision());
 | 
			
		||||
 | 
			
		||||
            $anotherWeightedMeanSum += $weights[$i]*$diff->getPrecisionMean()/$diff->getPrecision();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        $newPrecision = 1.0/$inverseOfNewPrecisionSum;
 | 
			
		||||
        $anotherNewPrecision = 1.0/$anotherInverseOfNewPrecisionSum;
 | 
			
		||||
 | 
			
		||||
        $newPrecisionMean = $newPrecision*$weightedMeanSum;
 | 
			
		||||
        $anotherNewPrecisionMean = $anotherNewPrecision*$anotherWeightedMeanSum;
 | 
			
		||||
 | 
			
		||||
        $newMessage = GaussianDistribution::fromPrecisionMean($newPrecisionMean, $newPrecision);
 | 
			
		||||
        $oldMarginalWithoutMessage = GaussianDistribution::divide($marginal0, $message0);
 | 
			
		||||
 | 
			
		||||
        $newMarginal = GaussianDistribution::multiply($oldMarginalWithoutMessage, $newMessage);
 | 
			
		||||
 | 
			
		||||
        /// Update the message and marginal
 | 
			
		||||
 | 
			
		||||
        $messages[0]->setValue($newMessage);
 | 
			
		||||
        $variables[0]->setValue($newMarginal);
 | 
			
		||||
 | 
			
		||||
        /// Return the difference in the new marginal
 | 
			
		||||
        $finalDiff = GaussianDistribution::subtract($newMarginal, $marginal0);
 | 
			
		||||
        return $finalDiff;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public function updateMessageIndex($messageIndex)
 | 
			
		||||
    {
 | 
			
		||||
        $allMessages = $this->getMessages();
 | 
			
		||||
        $allVariables = $this->getVariables();
 | 
			
		||||
 | 
			
		||||
        Guard::argumentIsValidIndex($messageIndex, count($allMessages),"messageIndex");
 | 
			
		||||
 | 
			
		||||
        $updatedMessages = array();
 | 
			
		||||
        $updatedVariables = array();
 | 
			
		||||
 | 
			
		||||
        $indicesToUse = $this->_variableIndexOrdersForWeights[$messageIndex];
 | 
			
		||||
 | 
			
		||||
        // The tricky part here is that we have to put the messages and variables in the same
 | 
			
		||||
        // order as the weights. Thankfully, the weights and messages share the same index numbers,
 | 
			
		||||
        // so we just need to make sure they're consistent
 | 
			
		||||
        $allMessagesCount = count($allMessages);
 | 
			
		||||
        for ($i = 0; i < $allMessagesCount; $i++)
 | 
			
		||||
        {
 | 
			
		||||
            $updatedMessages[] =$allMessages[$indicesToUse[$i]];
 | 
			
		||||
            $updatedVariables[] = $allVariables[$indicesToUse[$i]];
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return updateHelper($this->_weights[$messageIndex],
 | 
			
		||||
                            $this->_weightsSquared[$messageIndex],
 | 
			
		||||
                            $updatedMessages,
 | 
			
		||||
                            $updatedVariables);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
?>
 | 
			
		||||
		Reference in New Issue
	
	Block a user