mirror of
https://github.com/furyfire/trueskill.git
synced 2025-04-11 17:14:13 +00:00
130 lines
3.2 KiB
PHP
130 lines
3.2 KiB
PHP
<?php
|
|
|
|
declare(strict_types=1);
|
|
|
|
namespace DNW\Skills\FactorGraphs;
|
|
|
|
use DNW\Skills\Guard;
|
|
use DNW\Skills\HashMap;
|
|
use Exception;
|
|
|
|
abstract class Factor implements \Stringable
|
|
{
|
|
/**
|
|
* @var Message[] $messages
|
|
*/
|
|
private array $messages = [];
|
|
|
|
private readonly HashMap $messageToVariableBinding;
|
|
|
|
private readonly string $name;
|
|
|
|
/**
|
|
* @var Variable[] $variables
|
|
*/
|
|
private array $variables = [];
|
|
|
|
protected function __construct(string $name)
|
|
{
|
|
$this->name = 'Factor[' . $name . ']';
|
|
$this->messageToVariableBinding = new HashMap();
|
|
}
|
|
|
|
/**
|
|
* @return float The log-normalization constant of that factor
|
|
*/
|
|
public function getLogNormalization(): float
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
/**
|
|
* @return int The number of messages that the factor has
|
|
*/
|
|
public function getNumberOfMessages(): int
|
|
{
|
|
return count($this->messages);
|
|
}
|
|
|
|
/**
|
|
* @return Variable[]
|
|
*/
|
|
protected function getVariables(): array
|
|
{
|
|
return $this->variables;
|
|
}
|
|
|
|
/**
|
|
* @return Message[]
|
|
*/
|
|
protected function getMessages(): array
|
|
{
|
|
return $this->messages;
|
|
}
|
|
|
|
/**
|
|
* Update the message and marginal of the i-th variable that the factor is connected to
|
|
*
|
|
* @param $messageIndex
|
|
*
|
|
* @throws Exception
|
|
*/
|
|
public function updateMessageIndex(int $messageIndex): float
|
|
{
|
|
Guard::argumentIsValidIndex($messageIndex, count($this->messages), 'messageIndex');
|
|
$message = $this->messages[$messageIndex];
|
|
$variable = $this->messageToVariableBinding->getValue($message);
|
|
|
|
return $this->updateMessageVariable($message, $variable);
|
|
}
|
|
|
|
protected function updateMessageVariable(Message $message, Variable $variable): float
|
|
{
|
|
throw new Exception();
|
|
}
|
|
|
|
/**
|
|
* Resets the marginal of the variables a factor is connected to
|
|
*/
|
|
public function resetMarginals(): void
|
|
{
|
|
$allValues = $this->messageToVariableBinding->getAllValues();
|
|
foreach ($allValues as $currentVariable) {
|
|
$currentVariable->resetToPrior();
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Sends the ith message to the marginal and returns the log-normalization constant
|
|
*
|
|
* @throws Exception
|
|
*/
|
|
public function sendMessageIndex(int $messageIndex): float|int
|
|
{
|
|
Guard::argumentIsValidIndex($messageIndex, count($this->messages), 'messageIndex');
|
|
|
|
$message = $this->messages[$messageIndex];
|
|
$variable = $this->messageToVariableBinding->getValue($message);
|
|
|
|
return $this->sendMessageVariable($message, $variable);
|
|
}
|
|
|
|
abstract protected function sendMessageVariable(Message $message, Variable $variable): float|int;
|
|
|
|
abstract public function createVariableToMessageBinding(Variable $variable): Message;
|
|
|
|
protected function createVariableToMessageBindingWithMessage(Variable $variable, Message $message): Message
|
|
{
|
|
$this->messageToVariableBinding->setValue($message, $variable);
|
|
$this->messages[] = $message;
|
|
$this->variables[] = $variable;
|
|
|
|
return $message;
|
|
}
|
|
|
|
public function __toString(): string
|
|
{
|
|
return $this->name;
|
|
}
|
|
}
|