mirror of
https://github.com/furyfire/trueskill.git
synced 2025-04-19 20:34:28 +00:00
General cleanup and removal of all unnecessary references
This commit is contained in:
@ -18,19 +18,19 @@ abstract class GaussianFactor extends Factor
|
||||
* @param Variable $variable
|
||||
* @return float|int
|
||||
*/
|
||||
protected function sendMessageVariable(Message &$message, Variable &$variable)
|
||||
protected function sendMessageVariable(Message $message, Variable $variable)
|
||||
{
|
||||
$marginal = &$variable->getValue();
|
||||
$messageValue = &$message->getValue();
|
||||
$marginal = $variable->getValue();
|
||||
$messageValue = $message->getValue();
|
||||
$logZ = GaussianDistribution::logProductNormalization($marginal, $messageValue);
|
||||
$variable->setValue(GaussianDistribution::multiply($marginal, $messageValue));
|
||||
return $logZ;
|
||||
}
|
||||
|
||||
public function &createVariableToMessageBinding(Variable &$variable)
|
||||
public function createVariableToMessageBinding(Variable $variable)
|
||||
{
|
||||
$newDistribution = GaussianDistribution::fromPrecisionMean(0, 0);
|
||||
$binding = &parent::createVariableToMessageBindingWithMessage($variable,
|
||||
$binding = parent::createVariableToMessageBindingWithMessage($variable,
|
||||
new Message(
|
||||
$newDistribution,
|
||||
sprintf("message from %s to %s", $this, $variable)));
|
||||
|
@ -14,7 +14,7 @@ class GaussianGreaterThanFactor extends GaussianFactor
|
||||
{
|
||||
private $_epsilon;
|
||||
|
||||
public function __construct($epsilon, Variable &$variable)
|
||||
public function __construct($epsilon, Variable $variable)
|
||||
{
|
||||
parent::__construct(\sprintf("%s > %.2f", $variable, $epsilon));
|
||||
$this->_epsilon = $epsilon;
|
||||
@ -23,20 +23,26 @@ class GaussianGreaterThanFactor extends GaussianFactor
|
||||
|
||||
public function getLogNormalization()
|
||||
{
|
||||
$vars = &$this->getVariables();
|
||||
$marginal = &$vars[0]->getValue();
|
||||
$messages = &$this->getMessages();
|
||||
$message = &$messages[0]->getValue();
|
||||
/** @var Variable[] $vars */
|
||||
$vars = $this->getVariables();
|
||||
$marginal = $vars[0]->getValue();
|
||||
|
||||
/** @var Message[] $messages */
|
||||
$messages = $this->getMessages();
|
||||
$message = $messages[0]->getValue();
|
||||
$messageFromVariable = GaussianDistribution::divide($marginal, $message);
|
||||
return -GaussianDistribution::logProductNormalization($messageFromVariable, $message)
|
||||
+
|
||||
log(
|
||||
GaussianDistribution::cumulativeTo(($messageFromVariable->getMean() - $this->_epsilon) /
|
||||
$messageFromVariable->getStandardDeviation()));
|
||||
GaussianDistribution::cumulativeTo(
|
||||
($messageFromVariable->getMean() - $this->_epsilon) /
|
||||
$messageFromVariable->getStandardDeviation()
|
||||
)
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
protected function updateMessageVariable(Message &$message, Variable &$variable)
|
||||
protected function updateMessageVariable(Message $message, Variable $variable)
|
||||
{
|
||||
$oldMarginal = clone $variable->getValue();
|
||||
$oldMessage = clone $message->getValue();
|
||||
@ -55,16 +61,18 @@ class GaussianGreaterThanFactor extends GaussianFactor
|
||||
$denom = 1.0 - TruncatedGaussianCorrectionFunctions::wExceedsMargin($dOnSqrtC, $epsilsonTimesSqrtC);
|
||||
|
||||
$newPrecision = $c / $denom;
|
||||
$newPrecisionMean = ($d +
|
||||
$newPrecisionMean = (
|
||||
$d +
|
||||
$sqrtC *
|
||||
TruncatedGaussianCorrectionFunctions::vExceedsMargin($dOnSqrtC, $epsilsonTimesSqrtC)) /
|
||||
$denom;
|
||||
TruncatedGaussianCorrectionFunctions::vExceedsMargin($dOnSqrtC, $epsilsonTimesSqrtC)
|
||||
) / $denom;
|
||||
|
||||
$newMarginal = GaussianDistribution::fromPrecisionMean($newPrecisionMean, $newPrecision);
|
||||
|
||||
$newMessage = GaussianDistribution::divide(
|
||||
GaussianDistribution::multiply($oldMessage, $newMarginal),
|
||||
$oldMarginal);
|
||||
$oldMarginal
|
||||
);
|
||||
|
||||
// Update the message and marginal
|
||||
$message->setValue($newMessage);
|
||||
|
@ -1,6 +1,7 @@
|
||||
<?php namespace Moserware\Skills\TrueSkill\Factors;
|
||||
|
||||
use Exception;
|
||||
use Moserware\Skills\FactorGraphs\KeyedVariable;
|
||||
use Moserware\Skills\FactorGraphs\Message;
|
||||
use Moserware\Skills\FactorGraphs\Variable;
|
||||
use Moserware\Skills\Numerics\GaussianDistribution;
|
||||
@ -14,7 +15,7 @@ class GaussianLikelihoodFactor extends GaussianFactor
|
||||
{
|
||||
private $_precision;
|
||||
|
||||
public function __construct($betaSquared, Variable &$variable1, Variable &$variable2)
|
||||
public function __construct($betaSquared, Variable $variable1, Variable $variable2)
|
||||
{
|
||||
parent::__construct(sprintf("Likelihood of %s going to %s", $variable2, $variable1));
|
||||
$this->_precision = 1.0 / $betaSquared;
|
||||
@ -24,16 +25,18 @@ class GaussianLikelihoodFactor extends GaussianFactor
|
||||
|
||||
public function getLogNormalization()
|
||||
{
|
||||
$vars = &$this->getVariables();
|
||||
$messages = &$this->getMessages();
|
||||
/** @var KeyedVariable[]|mixed $vars */
|
||||
$vars = $this->getVariables();
|
||||
/** @var Message[] $messages */
|
||||
$messages = $this->getMessages();
|
||||
|
||||
return GaussianDistribution::logRatioNormalization(
|
||||
$vars[0]->getValue(),
|
||||
$messages[0]->getValue());
|
||||
$messages[0]->getValue()
|
||||
);
|
||||
}
|
||||
|
||||
private function updateHelper(Message &$message1, Message &$message2,
|
||||
Variable &$variable1, Variable &$variable2)
|
||||
private function updateHelper(Message $message1, Message $message2, Variable $variable1, Variable $variable2)
|
||||
{
|
||||
$message1Value = clone $message1->getValue();
|
||||
$message2Value = clone $message2->getValue();
|
||||
@ -45,7 +48,8 @@ class GaussianLikelihoodFactor extends GaussianFactor
|
||||
|
||||
$newMessage = GaussianDistribution::fromPrecisionMean(
|
||||
$a * ($marginal2->getPrecisionMean() - $message2Value->getPrecisionMean()),
|
||||
$a * ($marginal2->getPrecision() - $message2Value->getPrecision()));
|
||||
$a * ($marginal2->getPrecision() - $message2Value->getPrecision())
|
||||
);
|
||||
|
||||
$oldMarginalWithoutMessage = GaussianDistribution::divide($marginal1, $message1Value);
|
||||
|
||||
@ -62,8 +66,8 @@ class GaussianLikelihoodFactor extends GaussianFactor
|
||||
|
||||
public function updateMessageIndex($messageIndex)
|
||||
{
|
||||
$messages = &$this->getMessages();
|
||||
$vars = &$this->getVariables();
|
||||
$messages = $this->getMessages();
|
||||
$vars = $this->getVariables();
|
||||
|
||||
switch ($messageIndex) {
|
||||
case 0:
|
||||
|
@ -13,7 +13,7 @@ class GaussianPriorFactor extends GaussianFactor
|
||||
{
|
||||
private $_newMessage;
|
||||
|
||||
public function __construct($mean, $variance, Variable &$variable)
|
||||
public function __construct($mean, $variance, Variable $variable)
|
||||
{
|
||||
parent::__construct(sprintf("Prior value going to %s", $variable));
|
||||
$this->_newMessage = new GaussianDistribution($mean, sqrt($variance));
|
||||
@ -23,18 +23,17 @@ class GaussianPriorFactor extends GaussianFactor
|
||||
$this->createVariableToMessageBindingWithMessage($variable, $newMessage);
|
||||
}
|
||||
|
||||
protected function updateMessageVariable(Message &$message, Variable &$variable)
|
||||
protected function updateMessageVariable(Message $message, Variable $variable)
|
||||
{
|
||||
$oldMarginal = clone $variable->getValue();
|
||||
$oldMessage = $message;
|
||||
$newMarginal =
|
||||
GaussianDistribution::fromPrecisionMean(
|
||||
$oldMarginal->getPrecisionMean() + $this->_newMessage->getPrecisionMean() - $oldMessage->getValue()->getPrecisionMean(),
|
||||
$oldMarginal->getPrecision() + $this->_newMessage->getPrecision() - $oldMessage->getValue()->getPrecision());
|
||||
$newMarginal = GaussianDistribution::fromPrecisionMean(
|
||||
$oldMarginal->getPrecisionMean() + $this->_newMessage->getPrecisionMean() - $oldMessage->getValue()->getPrecisionMean(),
|
||||
$oldMarginal->getPrecision() + $this->_newMessage->getPrecision() - $oldMessage->getValue()->getPrecision()
|
||||
);
|
||||
|
||||
$variable->setValue($newMarginal);
|
||||
$newMessage = &$this->_newMessage;
|
||||
$message->setValue($newMessage);
|
||||
$message->setValue($this->_newMessage);
|
||||
return GaussianDistribution::subtract($oldMarginal, $newMarginal);
|
||||
}
|
||||
}
|
@ -20,7 +20,7 @@ class GaussianWeightedSumFactor extends GaussianFactor
|
||||
private $_weights;
|
||||
private $_weightsSquared;
|
||||
|
||||
public function __construct(Variable &$sumVariable, array &$variablesToSum, array &$variableWeights = null)
|
||||
public function __construct(Variable $sumVariable, array $variablesToSum, array $variableWeights = null)
|
||||
{
|
||||
parent::__construct(self::createName($sumVariable, $variablesToSum, $variableWeights));
|
||||
$this->_weights = array();
|
||||
@ -29,7 +29,7 @@ class GaussianWeightedSumFactor extends GaussianFactor
|
||||
// The first weights are a straightforward copy
|
||||
// v_0 = a_1*v_1 + a_2*v_2 + ... + a_n * v_n
|
||||
$variableWeightsLength = count($variableWeights);
|
||||
$this->_weights[0] = \array_fill(0, count($variableWeights), 0);
|
||||
$this->_weights[0] = array_fill(0, count($variableWeights), 0);
|
||||
|
||||
for ($i = 0; $i < $variableWeightsLength; $i++) {
|
||||
$weight = &$variableWeights[$i];
|
||||
@ -104,16 +104,16 @@ class GaussianWeightedSumFactor extends GaussianFactor
|
||||
|
||||
$this->createVariableToMessageBinding($sumVariable);
|
||||
|
||||
foreach ($variablesToSum as &$currentVariable) {
|
||||
$localCurrentVariable = &$currentVariable;
|
||||
foreach ($variablesToSum as $currentVariable) {
|
||||
$localCurrentVariable = $currentVariable;
|
||||
$this->createVariableToMessageBinding($localCurrentVariable);
|
||||
}
|
||||
}
|
||||
|
||||
public function getLogNormalization()
|
||||
{
|
||||
$vars = &$this->getVariables();
|
||||
$messages = &$this->getMessages();
|
||||
$vars = $this->getVariables();
|
||||
$messages = $this->getMessages();
|
||||
|
||||
$result = 0.0;
|
||||
|
||||
@ -126,9 +126,7 @@ class GaussianWeightedSumFactor extends GaussianFactor
|
||||
return $result;
|
||||
}
|
||||
|
||||
private function updateHelper(array &$weights, array &$weightsSquared,
|
||||
array &$messages,
|
||||
array &$variables)
|
||||
private function updateHelper(array $weights, array $weightsSquared, array $messages, array $variables)
|
||||
{
|
||||
// Potentially look at http://mathworld.wolfram.com/NormalSumDistribution.html for clues as
|
||||
// to what it's doing
|
||||
@ -185,23 +183,23 @@ class GaussianWeightedSumFactor extends GaussianFactor
|
||||
|
||||
public function updateMessageIndex($messageIndex)
|
||||
{
|
||||
$allMessages = &$this->getMessages();
|
||||
$allVariables = &$this->getVariables();
|
||||
$allMessages = $this->getMessages();
|
||||
$allVariables = $this->getVariables();
|
||||
|
||||
Guard::argumentIsValidIndex($messageIndex, count($allMessages), "messageIndex");
|
||||
|
||||
$updatedMessages = array();
|
||||
$updatedVariables = array();
|
||||
|
||||
$indicesToUse = &$this->_variableIndexOrdersForWeights[$messageIndex];
|
||||
$indicesToUse = $this->_variableIndexOrdersForWeights[$messageIndex];
|
||||
|
||||
// The tricky part here is that we have to put the messages and variables in the same
|
||||
// order as the weights. Thankfully, the weights and messages share the same index numbers,
|
||||
// so we just need to make sure they're consistent
|
||||
$allMessagesCount = count($allMessages);
|
||||
for ($i = 0; $i < $allMessagesCount; $i++) {
|
||||
$updatedMessages[] = &$allMessages[$indicesToUse[$i]];
|
||||
$updatedVariables[] = &$allVariables[$indicesToUse[$i]];
|
||||
$updatedMessages[] = $allMessages[$indicesToUse[$i]];
|
||||
$updatedVariables[] = $allVariables[$indicesToUse[$i]];
|
||||
}
|
||||
|
||||
return $this->updateHelper($this->_weights[$messageIndex],
|
||||
|
@ -14,7 +14,7 @@ class GaussianWithinFactor extends GaussianFactor
|
||||
{
|
||||
private $_epsilon;
|
||||
|
||||
public function __construct($epsilon, Variable &$variable)
|
||||
public function __construct($epsilon, Variable $variable)
|
||||
{
|
||||
parent::__construct(sprintf("%s <= %.2f", $variable, $epsilon));
|
||||
$this->_epsilon = $epsilon;
|
||||
@ -23,11 +23,13 @@ class GaussianWithinFactor extends GaussianFactor
|
||||
|
||||
public function getLogNormalization()
|
||||
{
|
||||
$variables = &$this->getVariables();
|
||||
$marginal = &$variables[0]->getValue();
|
||||
/** @var Variable[] $variables */
|
||||
$variables = $this->getVariables();
|
||||
$marginal = $variables[0]->getValue();
|
||||
|
||||
$messages = &$this->getMessages();
|
||||
$message = &$messages[0]->getValue();
|
||||
/** @var Message[] $messages */
|
||||
$messages = $this->getMessages();
|
||||
$message = $messages[0]->getValue();
|
||||
$messageFromVariable = GaussianDistribution::divide($marginal, $message);
|
||||
$mean = $messageFromVariable->getMean();
|
||||
$std = $messageFromVariable->getStandardDeviation();
|
||||
@ -38,7 +40,7 @@ class GaussianWithinFactor extends GaussianFactor
|
||||
return -GaussianDistribution::logProductNormalization($messageFromVariable, $message) + log($z);
|
||||
}
|
||||
|
||||
protected function updateMessageVariable(Message &$message, Variable &$variable)
|
||||
protected function updateMessageVariable(Message $message, Variable $variable)
|
||||
{
|
||||
$oldMarginal = clone $variable->getValue();
|
||||
$oldMessage = clone $message->getValue();
|
||||
@ -55,15 +57,16 @@ class GaussianWithinFactor extends GaussianFactor
|
||||
|
||||
$denominator = 1.0 - TruncatedGaussianCorrectionFunctions::wWithinMargin($dOnSqrtC, $epsilonTimesSqrtC);
|
||||
$newPrecision = $c / $denominator;
|
||||
$newPrecisionMean = ($d +
|
||||
$sqrtC *
|
||||
TruncatedGaussianCorrectionFunctions::vWithinMargin($dOnSqrtC, $epsilonTimesSqrtC)) /
|
||||
$denominator;
|
||||
$newPrecisionMean = ( $d +
|
||||
$sqrtC *
|
||||
TruncatedGaussianCorrectionFunctions::vWithinMargin($dOnSqrtC, $epsilonTimesSqrtC)
|
||||
) / $denominator;
|
||||
|
||||
$newMarginal = GaussianDistribution::fromPrecisionMean($newPrecisionMean, $newPrecision);
|
||||
$newMessage = GaussianDistribution::divide(
|
||||
GaussianDistribution::multiply($oldMessage, $newMarginal),
|
||||
$oldMarginal);
|
||||
$oldMarginal
|
||||
);
|
||||
|
||||
// Update the message and marginal
|
||||
$message->setValue($newMessage);
|
||||
|
Reference in New Issue
Block a user