CodeBeautifier for PSR-12 standard.

This commit is contained in:
2023-08-01 13:53:19 +00:00
parent da8125be94
commit c8c126962d
27 changed files with 307 additions and 213 deletions

View File

@ -26,9 +26,12 @@ abstract class GaussianFactor extends Factor
{
$newDistribution = GaussianDistribution::fromPrecisionMean(0, 0);
return parent::createVariableToMessageBindingWithMessage($variable,
return parent::createVariableToMessageBindingWithMessage(
$variable,
new Message(
$newDistribution,
sprintf('message from %s to %s', $this, $variable)));
sprintf('message from %s to %s', $this, $variable)
)
);
}
}

View File

@ -14,22 +14,26 @@ use DNW\Skills\TrueSkill\TruncatedGaussianCorrectionFunctions;
*/
class GaussianGreaterThanFactor extends GaussianFactor
{
private $_epsilon;
private $epsilon;
public function __construct($epsilon, Variable $variable)
{
parent::__construct(\sprintf('%s > %.2f', $variable, $epsilon));
$this->_epsilon = $epsilon;
$this->epsilon = $epsilon;
$this->createVariableToMessageBinding($variable);
}
public function getLogNormalization()
{
/** @var Variable[] $vars */
/**
* @var Variable[] $vars
*/
$vars = $this->getVariables();
$marginal = $vars[0]->getValue();
/** @var Message[] $messages */
/**
* @var Message[] $messages
*/
$messages = $this->getMessages();
$message = $messages[0]->getValue();
$messageFromVariable = GaussianDistribution::divide($marginal, $message);
@ -38,7 +42,7 @@ class GaussianGreaterThanFactor extends GaussianFactor
+
log(
GaussianDistribution::cumulativeTo(
($messageFromVariable->getMean() - $this->_epsilon) /
($messageFromVariable->getMean() - $this->epsilon) /
$messageFromVariable->getStandardDeviation()
)
);
@ -57,7 +61,7 @@ class GaussianGreaterThanFactor extends GaussianFactor
$dOnSqrtC = $d / $sqrtC;
$epsilsonTimesSqrtC = $this->_epsilon * $sqrtC;
$epsilsonTimesSqrtC = $this->epsilon * $sqrtC;
$d = $messageFromVar->getPrecisionMean();
$denom = 1.0 - TruncatedGaussianCorrectionFunctions::wExceedsMargin($dOnSqrtC, $epsilsonTimesSqrtC);

View File

@ -15,21 +15,25 @@ use Exception;
*/
class GaussianLikelihoodFactor extends GaussianFactor
{
private $_precision;
private $precision;
public function __construct($betaSquared, Variable $variable1, Variable $variable2)
{
parent::__construct(sprintf('Likelihood of %s going to %s', $variable2, $variable1));
$this->_precision = 1.0 / $betaSquared;
$this->precision = 1.0 / $betaSquared;
$this->createVariableToMessageBinding($variable1);
$this->createVariableToMessageBinding($variable2);
}
public function getLogNormalization(): float
{
/** @var KeyedVariable[]|mixed $vars */
/**
* @var KeyedVariable[]|mixed $vars
*/
$vars = $this->getVariables();
/** @var Message[] $messages */
/**
* @var Message[] $messages
*/
$messages = $this->getMessages();
return GaussianDistribution::logRatioNormalization(
@ -46,7 +50,7 @@ class GaussianLikelihoodFactor extends GaussianFactor
$marginal1 = clone $variable1->getValue();
$marginal2 = clone $variable2->getValue();
$a = $this->_precision / ($this->_precision + $marginal2->getPrecision() - $message2Value->getPrecision());
$a = $this->precision / ($this->precision + $marginal2->getPrecision() - $message2Value->getPrecision());
$newMessage = GaussianDistribution::fromPrecisionMean(
$a * ($marginal2->getPrecisionMean() - $message2Value->getPrecisionMean()),
@ -72,10 +76,16 @@ class GaussianLikelihoodFactor extends GaussianFactor
$vars = $this->getVariables();
return match ($messageIndex) {
0 => $this->updateHelper($messages[0], $messages[1],
$vars[0], $vars[1]),
1 => $this->updateHelper($messages[1], $messages[0],
$vars[1], $vars[0]),
0 => $this->updateHelper(
$messages[0],
$messages[1],
$vars[0], $vars[1]
),
1 => $this->updateHelper(
$messages[1],
$messages[0],
$vars[1], $vars[0]
),
default => throw new Exception(),
};
}

View File

@ -19,8 +19,10 @@ class GaussianPriorFactor extends GaussianFactor
{
parent::__construct(sprintf('Prior value going to %s', $variable));
$this->_newMessage = new GaussianDistribution($mean, sqrt($variance));
$newMessage = new Message(GaussianDistribution::fromPrecisionMean(0, 0),
sprintf('message from %s to %s', $this, $variable));
$newMessage = new Message(
GaussianDistribution::fromPrecisionMean(0, 0),
sprintf('message from %s to %s', $this, $variable)
);
$this->createVariableToMessageBindingWithMessage($variable, $newMessage);
}

View File

@ -15,13 +15,13 @@ use DNW\Skills\Numerics\GaussianDistribution;
*/
class GaussianWeightedSumFactor extends GaussianFactor
{
private array $_variableIndexOrdersForWeights = [];
private array $variableIndexOrdersForWeights = [];
// This following is used for convenience, for example, the first entry is [0, 1, 2]
// corresponding to v[0] = a1*v[1] + a2*v[2]
private array $_weights = [];
private array $weights = [];
private array $_weightsSquared = [];
private array $weightsSquared = [];
public function __construct(Variable $sumVariable, array $variablesToSum, array $variableWeights = null)
{
@ -30,20 +30,20 @@ class GaussianWeightedSumFactor extends GaussianFactor
// The first weights are a straightforward copy
// v_0 = a_1*v_1 + a_2*v_2 + ... + a_n * v_n
$variableWeightsLength = count((array) $variableWeights);
$this->_weights[0] = array_fill(0, count((array) $variableWeights), 0);
$this->weights[0] = array_fill(0, count((array) $variableWeights), 0);
for ($i = 0; $i < $variableWeightsLength; $i++) {
$weight = &$variableWeights[$i];
$this->_weights[0][$i] = $weight;
$this->_weightsSquared[0][$i] = BasicMath::square($weight);
$this->weights[0][$i] = $weight;
$this->weightsSquared[0][$i] = BasicMath::square($weight);
}
$variablesToSumLength = count($variablesToSum);
// 0..n-1
$this->_variableIndexOrdersForWeights[0] = [];
$this->variableIndexOrdersForWeights[0] = [];
for ($i = 0; $i < ($variablesToSumLength + 1); $i++) {
$this->_variableIndexOrdersForWeights[0][] = $i;
$this->variableIndexOrdersForWeights[0][] = $i;
}
$variableWeightsLength = count((array) $variableWeights);
@ -66,9 +66,11 @@ class GaussianWeightedSumFactor extends GaussianFactor
// This is helpful since we skip over one of the spots
$currentDestinationWeightIndex = 0;
for ($currentWeightSourceIndex = 0;
for (
$currentWeightSourceIndex = 0;
$currentWeightSourceIndex < $variableWeightsLength;
$currentWeightSourceIndex++) {
$currentWeightSourceIndex++
) {
if ($currentWeightSourceIndex === $weightsIndex - 1) {
continue;
}
@ -97,10 +99,10 @@ class GaussianWeightedSumFactor extends GaussianFactor
$currentWeights[$currentDestinationWeightIndex] = $finalWeight;
$currentWeightsSquared[$currentDestinationWeightIndex] = BasicMath::square($finalWeight);
$variableIndices[count((array) $variableWeights)] = 0;
$this->_variableIndexOrdersForWeights[] = $variableIndices;
$this->variableIndexOrdersForWeights[] = $variableIndices;
$this->_weights[$weightsIndex] = $currentWeights;
$this->_weightsSquared[$weightsIndex] = $currentWeightsSquared;
$this->weights[$weightsIndex] = $currentWeights;
$this->weightsSquared[$weightsIndex] = $currentWeightsSquared;
}
$this->createVariableToMessageBinding($sumVariable);
@ -192,7 +194,7 @@ class GaussianWeightedSumFactor extends GaussianFactor
$updatedMessages = [];
$updatedVariables = [];
$indicesToUse = $this->_variableIndexOrdersForWeights[$messageIndex];
$indicesToUse = $this->variableIndexOrdersForWeights[$messageIndex];
// The tricky part here is that we have to put the messages and variables in the same
// order as the weights. Thankfully, the weights and messages share the same index numbers,
@ -203,10 +205,12 @@ class GaussianWeightedSumFactor extends GaussianFactor
$updatedVariables[] = $allVariables[$indicesToUse[$i]];
}
return $this->updateHelper($this->_weights[$messageIndex],
$this->_weightsSquared[$messageIndex],
return $this->updateHelper(
$this->weights[$messageIndex],
$this->weightsSquared[$messageIndex],
$updatedMessages,
$updatedVariables);
$updatedVariables
);
}
private static function createName($sumVariable, $variablesToSum, $weights)

View File

@ -14,30 +14,34 @@ use DNW\Skills\TrueSkill\TruncatedGaussianCorrectionFunctions;
*/
class GaussianWithinFactor extends GaussianFactor
{
private $_epsilon;
private $epsilon;
public function __construct($epsilon, Variable $variable)
{
parent::__construct(sprintf('%s <= %.2f', $variable, $epsilon));
$this->_epsilon = $epsilon;
$this->epsilon = $epsilon;
$this->createVariableToMessageBinding($variable);
}
public function getLogNormalization()
{
/** @var Variable[] $variables */
/**
* @var Variable[] $variables
*/
$variables = $this->getVariables();
$marginal = $variables[0]->getValue();
/** @var Message[] $messages */
/**
* @var Message[] $messages
*/
$messages = $this->getMessages();
$message = $messages[0]->getValue();
$messageFromVariable = GaussianDistribution::divide($marginal, $message);
$mean = $messageFromVariable->getMean();
$std = $messageFromVariable->getStandardDeviation();
$z = GaussianDistribution::cumulativeTo(($this->_epsilon - $mean) / $std)
$z = GaussianDistribution::cumulativeTo(($this->epsilon - $mean) / $std)
-
GaussianDistribution::cumulativeTo((-$this->_epsilon - $mean) / $std);
GaussianDistribution::cumulativeTo((-$this->epsilon - $mean) / $std);
return -GaussianDistribution::logProductNormalization($messageFromVariable, $message) + log($z);
}
@ -54,7 +58,7 @@ class GaussianWithinFactor extends GaussianFactor
$sqrtC = sqrt($c);
$dOnSqrtC = $d / $sqrtC;
$epsilonTimesSqrtC = $this->_epsilon * $sqrtC;
$epsilonTimesSqrtC = $this->epsilon * $sqrtC;
$d = $messageFromVariable->getPrecisionMean();
$denominator = 1.0 - TruncatedGaussianCorrectionFunctions::wWithinMargin($dOnSqrtC, $epsilonTimesSqrtC);