mirror of
https://github.com/furyfire/trueskill.git
synced 2025-04-19 20:34:28 +00:00
CodeBeautifier for PSR-12 standard.
This commit is contained in:
@ -26,9 +26,12 @@ abstract class GaussianFactor extends Factor
|
||||
{
|
||||
$newDistribution = GaussianDistribution::fromPrecisionMean(0, 0);
|
||||
|
||||
return parent::createVariableToMessageBindingWithMessage($variable,
|
||||
return parent::createVariableToMessageBindingWithMessage(
|
||||
$variable,
|
||||
new Message(
|
||||
$newDistribution,
|
||||
sprintf('message from %s to %s', $this, $variable)));
|
||||
sprintf('message from %s to %s', $this, $variable)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -14,22 +14,26 @@ use DNW\Skills\TrueSkill\TruncatedGaussianCorrectionFunctions;
|
||||
*/
|
||||
class GaussianGreaterThanFactor extends GaussianFactor
|
||||
{
|
||||
private $_epsilon;
|
||||
private $epsilon;
|
||||
|
||||
public function __construct($epsilon, Variable $variable)
|
||||
{
|
||||
parent::__construct(\sprintf('%s > %.2f', $variable, $epsilon));
|
||||
$this->_epsilon = $epsilon;
|
||||
$this->epsilon = $epsilon;
|
||||
$this->createVariableToMessageBinding($variable);
|
||||
}
|
||||
|
||||
public function getLogNormalization()
|
||||
{
|
||||
/** @var Variable[] $vars */
|
||||
/**
|
||||
* @var Variable[] $vars
|
||||
*/
|
||||
$vars = $this->getVariables();
|
||||
$marginal = $vars[0]->getValue();
|
||||
|
||||
/** @var Message[] $messages */
|
||||
/**
|
||||
* @var Message[] $messages
|
||||
*/
|
||||
$messages = $this->getMessages();
|
||||
$message = $messages[0]->getValue();
|
||||
$messageFromVariable = GaussianDistribution::divide($marginal, $message);
|
||||
@ -38,7 +42,7 @@ class GaussianGreaterThanFactor extends GaussianFactor
|
||||
+
|
||||
log(
|
||||
GaussianDistribution::cumulativeTo(
|
||||
($messageFromVariable->getMean() - $this->_epsilon) /
|
||||
($messageFromVariable->getMean() - $this->epsilon) /
|
||||
$messageFromVariable->getStandardDeviation()
|
||||
)
|
||||
);
|
||||
@ -57,7 +61,7 @@ class GaussianGreaterThanFactor extends GaussianFactor
|
||||
|
||||
$dOnSqrtC = $d / $sqrtC;
|
||||
|
||||
$epsilsonTimesSqrtC = $this->_epsilon * $sqrtC;
|
||||
$epsilsonTimesSqrtC = $this->epsilon * $sqrtC;
|
||||
$d = $messageFromVar->getPrecisionMean();
|
||||
|
||||
$denom = 1.0 - TruncatedGaussianCorrectionFunctions::wExceedsMargin($dOnSqrtC, $epsilsonTimesSqrtC);
|
||||
|
@ -15,21 +15,25 @@ use Exception;
|
||||
*/
|
||||
class GaussianLikelihoodFactor extends GaussianFactor
|
||||
{
|
||||
private $_precision;
|
||||
private $precision;
|
||||
|
||||
public function __construct($betaSquared, Variable $variable1, Variable $variable2)
|
||||
{
|
||||
parent::__construct(sprintf('Likelihood of %s going to %s', $variable2, $variable1));
|
||||
$this->_precision = 1.0 / $betaSquared;
|
||||
$this->precision = 1.0 / $betaSquared;
|
||||
$this->createVariableToMessageBinding($variable1);
|
||||
$this->createVariableToMessageBinding($variable2);
|
||||
}
|
||||
|
||||
public function getLogNormalization(): float
|
||||
{
|
||||
/** @var KeyedVariable[]|mixed $vars */
|
||||
/**
|
||||
* @var KeyedVariable[]|mixed $vars
|
||||
*/
|
||||
$vars = $this->getVariables();
|
||||
/** @var Message[] $messages */
|
||||
/**
|
||||
* @var Message[] $messages
|
||||
*/
|
||||
$messages = $this->getMessages();
|
||||
|
||||
return GaussianDistribution::logRatioNormalization(
|
||||
@ -46,7 +50,7 @@ class GaussianLikelihoodFactor extends GaussianFactor
|
||||
$marginal1 = clone $variable1->getValue();
|
||||
$marginal2 = clone $variable2->getValue();
|
||||
|
||||
$a = $this->_precision / ($this->_precision + $marginal2->getPrecision() - $message2Value->getPrecision());
|
||||
$a = $this->precision / ($this->precision + $marginal2->getPrecision() - $message2Value->getPrecision());
|
||||
|
||||
$newMessage = GaussianDistribution::fromPrecisionMean(
|
||||
$a * ($marginal2->getPrecisionMean() - $message2Value->getPrecisionMean()),
|
||||
@ -72,10 +76,16 @@ class GaussianLikelihoodFactor extends GaussianFactor
|
||||
$vars = $this->getVariables();
|
||||
|
||||
return match ($messageIndex) {
|
||||
0 => $this->updateHelper($messages[0], $messages[1],
|
||||
$vars[0], $vars[1]),
|
||||
1 => $this->updateHelper($messages[1], $messages[0],
|
||||
$vars[1], $vars[0]),
|
||||
0 => $this->updateHelper(
|
||||
$messages[0],
|
||||
$messages[1],
|
||||
$vars[0], $vars[1]
|
||||
),
|
||||
1 => $this->updateHelper(
|
||||
$messages[1],
|
||||
$messages[0],
|
||||
$vars[1], $vars[0]
|
||||
),
|
||||
default => throw new Exception(),
|
||||
};
|
||||
}
|
||||
|
@ -19,8 +19,10 @@ class GaussianPriorFactor extends GaussianFactor
|
||||
{
|
||||
parent::__construct(sprintf('Prior value going to %s', $variable));
|
||||
$this->_newMessage = new GaussianDistribution($mean, sqrt($variance));
|
||||
$newMessage = new Message(GaussianDistribution::fromPrecisionMean(0, 0),
|
||||
sprintf('message from %s to %s', $this, $variable));
|
||||
$newMessage = new Message(
|
||||
GaussianDistribution::fromPrecisionMean(0, 0),
|
||||
sprintf('message from %s to %s', $this, $variable)
|
||||
);
|
||||
|
||||
$this->createVariableToMessageBindingWithMessage($variable, $newMessage);
|
||||
}
|
||||
|
@ -15,13 +15,13 @@ use DNW\Skills\Numerics\GaussianDistribution;
|
||||
*/
|
||||
class GaussianWeightedSumFactor extends GaussianFactor
|
||||
{
|
||||
private array $_variableIndexOrdersForWeights = [];
|
||||
private array $variableIndexOrdersForWeights = [];
|
||||
|
||||
// This following is used for convenience, for example, the first entry is [0, 1, 2]
|
||||
// corresponding to v[0] = a1*v[1] + a2*v[2]
|
||||
private array $_weights = [];
|
||||
private array $weights = [];
|
||||
|
||||
private array $_weightsSquared = [];
|
||||
private array $weightsSquared = [];
|
||||
|
||||
public function __construct(Variable $sumVariable, array $variablesToSum, array $variableWeights = null)
|
||||
{
|
||||
@ -30,20 +30,20 @@ class GaussianWeightedSumFactor extends GaussianFactor
|
||||
// The first weights are a straightforward copy
|
||||
// v_0 = a_1*v_1 + a_2*v_2 + ... + a_n * v_n
|
||||
$variableWeightsLength = count((array) $variableWeights);
|
||||
$this->_weights[0] = array_fill(0, count((array) $variableWeights), 0);
|
||||
$this->weights[0] = array_fill(0, count((array) $variableWeights), 0);
|
||||
|
||||
for ($i = 0; $i < $variableWeightsLength; $i++) {
|
||||
$weight = &$variableWeights[$i];
|
||||
$this->_weights[0][$i] = $weight;
|
||||
$this->_weightsSquared[0][$i] = BasicMath::square($weight);
|
||||
$this->weights[0][$i] = $weight;
|
||||
$this->weightsSquared[0][$i] = BasicMath::square($weight);
|
||||
}
|
||||
|
||||
$variablesToSumLength = count($variablesToSum);
|
||||
|
||||
// 0..n-1
|
||||
$this->_variableIndexOrdersForWeights[0] = [];
|
||||
$this->variableIndexOrdersForWeights[0] = [];
|
||||
for ($i = 0; $i < ($variablesToSumLength + 1); $i++) {
|
||||
$this->_variableIndexOrdersForWeights[0][] = $i;
|
||||
$this->variableIndexOrdersForWeights[0][] = $i;
|
||||
}
|
||||
|
||||
$variableWeightsLength = count((array) $variableWeights);
|
||||
@ -66,9 +66,11 @@ class GaussianWeightedSumFactor extends GaussianFactor
|
||||
// This is helpful since we skip over one of the spots
|
||||
$currentDestinationWeightIndex = 0;
|
||||
|
||||
for ($currentWeightSourceIndex = 0;
|
||||
for (
|
||||
$currentWeightSourceIndex = 0;
|
||||
$currentWeightSourceIndex < $variableWeightsLength;
|
||||
$currentWeightSourceIndex++) {
|
||||
$currentWeightSourceIndex++
|
||||
) {
|
||||
if ($currentWeightSourceIndex === $weightsIndex - 1) {
|
||||
continue;
|
||||
}
|
||||
@ -97,10 +99,10 @@ class GaussianWeightedSumFactor extends GaussianFactor
|
||||
$currentWeights[$currentDestinationWeightIndex] = $finalWeight;
|
||||
$currentWeightsSquared[$currentDestinationWeightIndex] = BasicMath::square($finalWeight);
|
||||
$variableIndices[count((array) $variableWeights)] = 0;
|
||||
$this->_variableIndexOrdersForWeights[] = $variableIndices;
|
||||
$this->variableIndexOrdersForWeights[] = $variableIndices;
|
||||
|
||||
$this->_weights[$weightsIndex] = $currentWeights;
|
||||
$this->_weightsSquared[$weightsIndex] = $currentWeightsSquared;
|
||||
$this->weights[$weightsIndex] = $currentWeights;
|
||||
$this->weightsSquared[$weightsIndex] = $currentWeightsSquared;
|
||||
}
|
||||
|
||||
$this->createVariableToMessageBinding($sumVariable);
|
||||
@ -192,7 +194,7 @@ class GaussianWeightedSumFactor extends GaussianFactor
|
||||
$updatedMessages = [];
|
||||
$updatedVariables = [];
|
||||
|
||||
$indicesToUse = $this->_variableIndexOrdersForWeights[$messageIndex];
|
||||
$indicesToUse = $this->variableIndexOrdersForWeights[$messageIndex];
|
||||
|
||||
// The tricky part here is that we have to put the messages and variables in the same
|
||||
// order as the weights. Thankfully, the weights and messages share the same index numbers,
|
||||
@ -203,10 +205,12 @@ class GaussianWeightedSumFactor extends GaussianFactor
|
||||
$updatedVariables[] = $allVariables[$indicesToUse[$i]];
|
||||
}
|
||||
|
||||
return $this->updateHelper($this->_weights[$messageIndex],
|
||||
$this->_weightsSquared[$messageIndex],
|
||||
return $this->updateHelper(
|
||||
$this->weights[$messageIndex],
|
||||
$this->weightsSquared[$messageIndex],
|
||||
$updatedMessages,
|
||||
$updatedVariables);
|
||||
$updatedVariables
|
||||
);
|
||||
}
|
||||
|
||||
private static function createName($sumVariable, $variablesToSum, $weights)
|
||||
|
@ -14,30 +14,34 @@ use DNW\Skills\TrueSkill\TruncatedGaussianCorrectionFunctions;
|
||||
*/
|
||||
class GaussianWithinFactor extends GaussianFactor
|
||||
{
|
||||
private $_epsilon;
|
||||
private $epsilon;
|
||||
|
||||
public function __construct($epsilon, Variable $variable)
|
||||
{
|
||||
parent::__construct(sprintf('%s <= %.2f', $variable, $epsilon));
|
||||
$this->_epsilon = $epsilon;
|
||||
$this->epsilon = $epsilon;
|
||||
$this->createVariableToMessageBinding($variable);
|
||||
}
|
||||
|
||||
public function getLogNormalization()
|
||||
{
|
||||
/** @var Variable[] $variables */
|
||||
/**
|
||||
* @var Variable[] $variables
|
||||
*/
|
||||
$variables = $this->getVariables();
|
||||
$marginal = $variables[0]->getValue();
|
||||
|
||||
/** @var Message[] $messages */
|
||||
/**
|
||||
* @var Message[] $messages
|
||||
*/
|
||||
$messages = $this->getMessages();
|
||||
$message = $messages[0]->getValue();
|
||||
$messageFromVariable = GaussianDistribution::divide($marginal, $message);
|
||||
$mean = $messageFromVariable->getMean();
|
||||
$std = $messageFromVariable->getStandardDeviation();
|
||||
$z = GaussianDistribution::cumulativeTo(($this->_epsilon - $mean) / $std)
|
||||
$z = GaussianDistribution::cumulativeTo(($this->epsilon - $mean) / $std)
|
||||
-
|
||||
GaussianDistribution::cumulativeTo((-$this->_epsilon - $mean) / $std);
|
||||
GaussianDistribution::cumulativeTo((-$this->epsilon - $mean) / $std);
|
||||
|
||||
return -GaussianDistribution::logProductNormalization($messageFromVariable, $message) + log($z);
|
||||
}
|
||||
@ -54,7 +58,7 @@ class GaussianWithinFactor extends GaussianFactor
|
||||
$sqrtC = sqrt($c);
|
||||
$dOnSqrtC = $d / $sqrtC;
|
||||
|
||||
$epsilonTimesSqrtC = $this->_epsilon * $sqrtC;
|
||||
$epsilonTimesSqrtC = $this->epsilon * $sqrtC;
|
||||
$d = $messageFromVariable->getPrecisionMean();
|
||||
|
||||
$denominator = 1.0 - TruncatedGaussianCorrectionFunctions::wWithinMargin($dOnSqrtC, $epsilonTimesSqrtC);
|
||||
|
Reference in New Issue
Block a user