String based "name" for Variable class removed for performance

This commit is contained in:
2024-03-19 15:09:13 +00:00
parent 0095829906
commit ae5d2a8b73
14 changed files with 28 additions and 83 deletions

View File

@ -6,9 +6,9 @@ namespace DNW\Skills\FactorGraphs;
class KeyedVariable extends Variable class KeyedVariable extends Variable
{ {
public function __construct(private readonly mixed $key, string $name, mixed $prior) public function __construct(private readonly mixed $key, mixed $prior)
{ {
parent::__construct($name, $prior); parent::__construct($prior);
} }
public function getKey(): mixed public function getKey(): mixed

View File

@ -6,15 +6,12 @@ namespace DNW\Skills\FactorGraphs;
use DNW\Skills\Numerics\GaussianDistribution; use DNW\Skills\Numerics\GaussianDistribution;
class Variable implements \Stringable class Variable
{ {
private readonly string $name;
private mixed $value; private mixed $value;
public function __construct(string $name, private GaussianDistribution $prior) public function __construct(private GaussianDistribution $prior)
{ {
$this->name = 'Variable[' . $name . ']';
$this->resetToPrior(); $this->resetToPrior();
} }
@ -32,9 +29,4 @@ class Variable implements \Stringable
{ {
$this->value = $this->prior; $this->value = $this->prior;
} }
public function __toString(): string
{
return $this->name;
}
} }

View File

@ -10,17 +10,17 @@ class VariableFactory
{ {
} }
public function createBasicVariable(string $name): Variable public function createBasicVariable(): Variable
{ {
$initializer = $this->variablePriorInitializer; $initializer = $this->variablePriorInitializer;
return new Variable($name, $initializer()); return new Variable($initializer());
} }
public function createKeyedVariable(mixed $key, string $name): KeyedVariable public function createKeyedVariable(mixed $key): KeyedVariable
{ {
$initializer = $this->variablePriorInitializer; $initializer = $this->variablePriorInitializer;
return new KeyedVariable($key, $name, $initializer()); return new KeyedVariable($key, $initializer());
} }
} }

View File

@ -31,8 +31,7 @@ abstract class GaussianFactor extends Factor
return parent::createVariableToMessageBindingWithMessage( return parent::createVariableToMessageBindingWithMessage(
$variable, $variable,
new Message( new Message(
$newDistribution, $newDistribution,'message from %s to %s'
sprintf('message from %s to %s', (string)$this, (string)$variable)
) )
); );
} }

View File

@ -16,12 +16,9 @@ use DNW\Skills\TrueSkill\TruncatedGaussianCorrectionFunctions;
*/ */
class GaussianGreaterThanFactor extends GaussianFactor class GaussianGreaterThanFactor extends GaussianFactor
{ {
private readonly float $epsilon; public function __construct(private readonly float $epsilon, Variable $variable)
public function __construct(float $epsilon, Variable $variable)
{ {
parent::__construct(\sprintf('%s > %.2f', (string)$variable, $epsilon)); parent::__construct('%s > %.2f');
$this->epsilon = $epsilon;
$this->createVariableToMessageBinding($variable); $this->createVariableToMessageBinding($variable);
} }

View File

@ -21,7 +21,7 @@ class GaussianLikelihoodFactor extends GaussianFactor
public function __construct(float $betaSquared, Variable $variable1, Variable $variable2) public function __construct(float $betaSquared, Variable $variable1, Variable $variable2)
{ {
parent::__construct(sprintf('Likelihood of %s going to %s', (string)$variable2, (string)$variable1)); parent::__construct('Likelihood of %s going to %s');
$this->precision = 1.0 / $betaSquared; $this->precision = 1.0 / $betaSquared;
$this->createVariableToMessageBinding($variable1); $this->createVariableToMessageBinding($variable1);
$this->createVariableToMessageBinding($variable2); $this->createVariableToMessageBinding($variable2);

View File

@ -19,11 +19,11 @@ class GaussianPriorFactor extends GaussianFactor
public function __construct(float $mean, float $variance, Variable $variable) public function __construct(float $mean, float $variance, Variable $variable)
{ {
parent::__construct(sprintf('Prior value going to %s', (string)$variable)); parent::__construct('Prior value going to %s');
$this->newMessage = new GaussianDistribution($mean, sqrt($variance)); $this->newMessage = new GaussianDistribution($mean, sqrt($variance));
$newMessage = new Message( $newMessage = new Message(
GaussianDistribution::fromPrecisionMean(0, 0), GaussianDistribution::fromPrecisionMean(0, 0),
sprintf('message from %s to %s', (string)$this, (string)$variable) 'message from %s to %s'
); );
$this->createVariableToMessageBindingWithMessage($variable, $newMessage); $this->createVariableToMessageBindingWithMessage($variable, $newMessage);

View File

@ -42,7 +42,7 @@ class GaussianWeightedSumFactor extends GaussianFactor
*/ */
public function __construct(Variable $sumVariable, array $variablesToSum, array $variableWeights) public function __construct(Variable $sumVariable, array $variablesToSum, array $variableWeights)
{ {
parent::__construct(self::createName((string)$sumVariable, $variablesToSum, $variableWeights)); parent::__construct('$sumVariable, $variablesToSum, $variableWeights');
// The first weights are a straightforward copy // The first weights are a straightforward copy
// v_0 = a_1*v_1 + a_2*v_2 + ... + a_n * v_n // v_0 = a_1*v_1 + a_2*v_2 + ... + a_n * v_n
@ -235,42 +235,4 @@ class GaussianWeightedSumFactor extends GaussianFactor
$updatedVariables $updatedVariables
); );
} }
/**
* @param Variable[] $variablesToSum
* @param float[] $weights
*/
private static function createName(string $sumVariable, array $variablesToSum, array $weights): string
{
// TODO: Perf? Use PHP equivalent of StringBuilder? implode on arrays?
$result = $sumVariable;
$result .= ' = ';
$totalVars = count($variablesToSum);
for ($i = 0; $i < $totalVars; ++$i) {
$isFirst = ($i == 0);
if ($isFirst && ($weights[$i] < 0)) {
$result .= '-';
}
$absValue = sprintf('%.2f', \abs($weights[$i])); // 0.00?
$result .= $absValue;
$result .= '*[';
$result .= (string)$variablesToSum[$i];
$result .= ']';
$isLast = ($i === $totalVars - 1);
if (! $isLast) {
if ($weights[$i + 1] >= 0) {
$result .= ' + ';
} else {
$result .= ' - ';
}
}
}
return $result;
}
} }

View File

@ -16,12 +16,9 @@ use DNW\Skills\TrueSkill\TruncatedGaussianCorrectionFunctions;
*/ */
class GaussianWithinFactor extends GaussianFactor class GaussianWithinFactor extends GaussianFactor
{ {
private readonly float $epsilon; public function __construct(private readonly float $epsilon, Variable $variable)
public function __construct(float $epsilon, Variable $variable)
{ {
parent::__construct(sprintf('%s <= %.2f', (string)$variable, $epsilon)); parent::__construct('%s <= %.2f');
$this->epsilon = $epsilon;
$this->createVariableToMessageBinding($variable); $this->createVariableToMessageBinding($variable);
} }

View File

@ -21,7 +21,7 @@ class PlayerPerformancesToTeamPerformancesLayer extends TrueSkillFactorGraphLaye
*/ */
foreach ($inputVariablesGroups as $currentTeam) { foreach ($inputVariablesGroups as $currentTeam) {
$localCurrentTeam = $currentTeam; $localCurrentTeam = $currentTeam;
$teamPerformance = $this->createOutputVariable($localCurrentTeam); $teamPerformance = $this->createOutputVariable();
$newSumFactor = $this->createPlayerToTeamSumFactor($localCurrentTeam, $teamPerformance); $newSumFactor = $this->createPlayerToTeamSumFactor($localCurrentTeam, $teamPerformance);
$this->addLayerFactor($newSumFactor); $this->addLayerFactor($newSumFactor);
@ -85,14 +85,10 @@ class PlayerPerformancesToTeamPerformancesLayer extends TrueSkillFactorGraphLaye
} }
/** /**
* @param KeyedVariable[] $team * Team's performance
*/ */
private function createOutputVariable(array $team): Variable private function createOutputVariable(): Variable
{ {
$memberNames = array_map(static fn($currentPlayer): string => (string)($currentPlayer->getKey()->getId()), $team); return $this->getParentFactorGraph()->getVariableFactory()->createBasicVariable();
$teamMemberNames = \implode(', ', $memberNames);
return $this->getParentFactorGraph()->getVariableFactory()->createBasicVariable('Team[' . $teamMemberNames . "]'s performance");
} }
} }

View File

@ -77,6 +77,6 @@ class PlayerPriorValuesToSkillsLayer extends TrueSkillFactorGraphLayer
$parentFactorGraph = $this->getParentFactorGraph(); $parentFactorGraph = $this->getParentFactorGraph();
$variableFactory = $parentFactorGraph->getVariableFactory(); $variableFactory = $parentFactorGraph->getVariableFactory();
return $variableFactory->createKeyedVariable($key, $key->getId() . "'s skill"); return $variableFactory->createKeyedVariable($key);
} }
} }

View File

@ -48,7 +48,7 @@ class PlayerSkillsToPerformancesLayer extends TrueSkillFactorGraphLayer
private function createOutputVariable(mixed $key): KeyedVariable private function createOutputVariable(mixed $key): KeyedVariable
{ {
return $this->getParentFactorGraph()->getVariableFactory()->createKeyedVariable($key, $key->getId() . "'s performance"); return $this->getParentFactorGraph()->getVariableFactory()->createKeyedVariable($key);
} }
public function createPriorSchedule(): ?ScheduleSequence public function createPriorSchedule(): ?ScheduleSequence

View File

@ -40,8 +40,11 @@ class TeamPerformancesToTeamPerformanceDifferencesLayer extends TrueSkillFactorG
return new GaussianWeightedSumFactor($output, $teams, $weights); return new GaussianWeightedSumFactor($output, $teams, $weights);
} }
/**
* Team performance difference
*/
private function createOutputVariable(): Variable private function createOutputVariable(): Variable
{ {
return $this->getParentFactorGraph()->getVariableFactory()->createBasicVariable('Team performance difference'); return $this->getParentFactorGraph()->getVariableFactory()->createBasicVariable();
} }
} }

View File

@ -13,13 +13,12 @@ class VariableTest extends TestCase
public function testGetterSetter(): void public function testGetterSetter(): void
{ {
$gd_prior = new GaussianDistribution(); $gd_prior = new GaussianDistribution();
$var = new Variable('dummy', $gd_prior); $var = new Variable($gd_prior);
$this->assertEquals($gd_prior, $var->getValue()); $this->assertEquals($gd_prior, $var->getValue());
$gd_new = new GaussianDistribution(); $gd_new = new GaussianDistribution();
$this->assertEquals($gd_new, $var->getValue()); $this->assertEquals($gd_new, $var->getValue());
$var->resetToPrior(); $var->resetToPrior();
$this->assertEquals($gd_prior, $var->getValue()); $this->assertEquals($gd_prior, $var->getValue());
$this->assertEquals('Variable[dummy]', (string)$var);
} }
} }