2022-07-05 15:55:47 +02:00
|
|
|
<?php
|
|
|
|
|
2024-02-02 15:16:11 +00:00
|
|
|
declare(strict_types=1);
|
|
|
|
|
2022-07-05 15:55:47 +02:00
|
|
|
namespace DNW\Skills\TrueSkill\Factors;
|
2010-09-18 11:11:44 -04:00
|
|
|
|
2022-07-05 15:33:34 +02:00
|
|
|
use DNW\Skills\FactorGraphs\Message;
|
|
|
|
use DNW\Skills\FactorGraphs\Variable;
|
2022-07-05 15:55:47 +02:00
|
|
|
use DNW\Skills\Numerics\GaussianDistribution;
|
|
|
|
use DNW\Skills\TrueSkill\TruncatedGaussianCorrectionFunctions;
|
2010-09-18 17:56:57 -04:00
|
|
|
|
2010-10-08 21:44:36 -04:00
|
|
|
/**
|
|
|
|
* Factor representing a team difference that has not exceeded the draw margin.
|
|
|
|
*
|
|
|
|
* See the accompanying math paper for more details.
|
|
|
|
*/
|
2010-09-18 11:11:44 -04:00
|
|
|
class GaussianWithinFactor extends GaussianFactor
|
|
|
|
{
|
2024-02-20 14:21:44 +00:00
|
|
|
private readonly float $epsilon;
|
2010-09-18 11:11:44 -04:00
|
|
|
|
2023-08-02 09:36:44 +00:00
|
|
|
public function __construct(float $epsilon, Variable $variable)
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
2023-08-08 07:00:51 +00:00
|
|
|
parent::__construct(sprintf('%s <= %.2f', (string)$variable, $epsilon));
|
2023-08-01 13:53:19 +00:00
|
|
|
$this->epsilon = $epsilon;
|
2010-09-18 11:11:44 -04:00
|
|
|
$this->createVariableToMessageBinding($variable);
|
|
|
|
}
|
|
|
|
|
2023-08-02 09:36:44 +00:00
|
|
|
public function getLogNormalization(): float
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
2023-08-01 13:53:19 +00:00
|
|
|
/**
|
|
|
|
* @var Variable[] $variables
|
|
|
|
*/
|
2016-05-24 16:31:21 +02:00
|
|
|
$variables = $this->getVariables();
|
|
|
|
$marginal = $variables[0]->getValue();
|
2010-09-18 11:11:44 -04:00
|
|
|
|
2023-08-01 13:53:19 +00:00
|
|
|
/**
|
|
|
|
* @var Message[] $messages
|
|
|
|
*/
|
2016-05-24 16:31:21 +02:00
|
|
|
$messages = $this->getMessages();
|
|
|
|
$message = $messages[0]->getValue();
|
2010-09-18 11:11:44 -04:00
|
|
|
$messageFromVariable = GaussianDistribution::divide($marginal, $message);
|
|
|
|
$mean = $messageFromVariable->getMean();
|
|
|
|
$std = $messageFromVariable->getStandardDeviation();
|
2023-08-01 13:53:19 +00:00
|
|
|
$z = GaussianDistribution::cumulativeTo(($this->epsilon - $mean) / $std)
|
2016-05-24 14:10:39 +02:00
|
|
|
-
|
2023-08-01 13:53:19 +00:00
|
|
|
GaussianDistribution::cumulativeTo((-$this->epsilon - $mean) / $std);
|
2010-09-18 11:11:44 -04:00
|
|
|
|
|
|
|
return -GaussianDistribution::logProductNormalization($messageFromVariable, $message) + log($z);
|
|
|
|
}
|
|
|
|
|
2023-08-02 09:36:44 +00:00
|
|
|
protected function updateMessageVariable(Message $message, Variable $variable): float
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
|
|
|
$oldMarginal = clone $variable->getValue();
|
|
|
|
$oldMessage = clone $message->getValue();
|
|
|
|
$messageFromVariable = GaussianDistribution::divide($oldMarginal, $oldMessage);
|
|
|
|
|
|
|
|
$c = $messageFromVariable->getPrecision();
|
|
|
|
$d = $messageFromVariable->getPrecisionMean();
|
|
|
|
|
|
|
|
$sqrtC = sqrt($c);
|
2016-05-24 14:10:39 +02:00
|
|
|
$dOnSqrtC = $d / $sqrtC;
|
2010-09-18 11:11:44 -04:00
|
|
|
|
2023-08-01 13:53:19 +00:00
|
|
|
$epsilonTimesSqrtC = $this->epsilon * $sqrtC;
|
2010-09-18 11:11:44 -04:00
|
|
|
$d = $messageFromVariable->getPrecisionMean();
|
|
|
|
|
|
|
|
$denominator = 1.0 - TruncatedGaussianCorrectionFunctions::wWithinMargin($dOnSqrtC, $epsilonTimesSqrtC);
|
2016-05-24 14:10:39 +02:00
|
|
|
$newPrecision = $c / $denominator;
|
2022-07-05 15:55:47 +02:00
|
|
|
$newPrecisionMean = ($d +
|
2016-05-24 16:31:21 +02:00
|
|
|
$sqrtC *
|
2024-02-02 15:16:11 +00:00
|
|
|
TruncatedGaussianCorrectionFunctions::vWithinMargin($dOnSqrtC, $epsilonTimesSqrtC)) / $denominator;
|
2010-09-18 11:11:44 -04:00
|
|
|
|
|
|
|
$newMarginal = GaussianDistribution::fromPrecisionMean($newPrecisionMean, $newPrecision);
|
|
|
|
$newMessage = GaussianDistribution::divide(
|
2016-05-24 14:10:39 +02:00
|
|
|
GaussianDistribution::multiply($oldMessage, $newMarginal),
|
2016-05-24 16:31:21 +02:00
|
|
|
$oldMarginal
|
|
|
|
);
|
2010-09-18 11:11:44 -04:00
|
|
|
|
2010-10-08 21:44:36 -04:00
|
|
|
// Update the message and marginal
|
2010-09-18 11:11:44 -04:00
|
|
|
$message->setValue($newMessage);
|
|
|
|
$variable->setValue($newMarginal);
|
|
|
|
|
2010-10-08 21:44:36 -04:00
|
|
|
// Return the difference in the new marginal
|
2010-09-18 11:11:44 -04:00
|
|
|
return GaussianDistribution::subtract($newMarginal, $oldMarginal);
|
|
|
|
}
|
2022-07-05 15:55:47 +02:00
|
|
|
}
|