Files

130 lines
3.2 KiB
PHP
Raw Normal View History

2022-07-05 15:55:47 +02:00
<?php
2024-02-02 14:53:38 +00:00
declare(strict_types=1);
2022-07-05 15:55:47 +02:00
namespace DNW\Skills\FactorGraphs;
2022-07-05 15:33:34 +02:00
use DNW\Skills\Guard;
use DNW\Skills\HashMap;
2022-07-05 15:55:47 +02:00
use Exception;
2022-07-05 16:21:06 +02:00
abstract class Factor implements \Stringable
{
2023-08-02 12:39:42 +00:00
/**
* @var Message[] $messages
*/
2023-08-01 13:53:19 +00:00
private array $messages = [];
2022-07-05 15:55:47 +02:00
private readonly HashMap $messageToVariableBinding;
private readonly string $name;
2022-07-05 15:55:47 +02:00
2023-08-02 12:39:42 +00:00
/**
* @var Variable[] $variables
*/
2023-08-01 13:53:19 +00:00
private array $variables = [];
2023-08-01 12:26:38 +00:00
protected function __construct(string $name)
{
2023-08-01 13:53:19 +00:00
$this->name = 'Factor[' . $name . ']';
$this->messageToVariableBinding = new HashMap();
}
/**
* @return mixed The log-normalization constant of that factor
*/
public function getLogNormalization()
{
return 0;
}
/**
* @return int The number of messages that the factor has
*/
2023-08-01 12:26:38 +00:00
public function getNumberOfMessages(): int
{
2023-08-01 13:53:19 +00:00
return count($this->messages);
}
2023-08-02 13:19:35 +00:00
/**
* @return Variable[]
*/
2023-08-01 12:26:38 +00:00
protected function getVariables(): array
{
2023-08-01 13:53:19 +00:00
return $this->variables;
}
2023-08-02 13:19:35 +00:00
/**
* @return Message[]
*/
2023-08-01 12:26:38 +00:00
protected function getMessages(): array
{
2023-08-01 13:53:19 +00:00
return $this->messages;
}
/**
* Update the message and marginal of the i-th variable that the factor is connected to
2022-07-05 15:55:47 +02:00
*
* @param $messageIndex
2022-07-05 15:55:47 +02:00
*
* @throws Exception
*/
2023-08-02 13:19:35 +00:00
public function updateMessageIndex(int $messageIndex): float
{
2023-08-01 13:53:19 +00:00
Guard::argumentIsValidIndex($messageIndex, count($this->messages), 'messageIndex');
$message = $this->messages[$messageIndex];
$variable = $this->messageToVariableBinding->getValue($message);
2022-07-05 15:55:47 +02:00
return $this->updateMessageVariable($message, $variable);
}
2023-08-02 12:39:42 +00:00
protected function updateMessageVariable(Message $message, Variable $variable): float
{
throw new Exception();
}
/**
* Resets the marginal of the variables a factor is connected to
*/
2023-08-02 13:19:35 +00:00
public function resetMarginals(): void
{
2023-08-01 13:53:19 +00:00
$allValues = $this->messageToVariableBinding->getAllValues();
foreach ($allValues as $currentVariable) {
$currentVariable->resetToPrior();
}
}
/**
* Sends the ith message to the marginal and returns the log-normalization constant
2024-02-02 14:53:38 +00:00
*
* @throws Exception
*/
2023-08-02 12:39:42 +00:00
public function sendMessageIndex(int $messageIndex): float|int
{
2023-08-01 13:53:19 +00:00
Guard::argumentIsValidIndex($messageIndex, count($this->messages), 'messageIndex');
2023-08-01 13:53:19 +00:00
$message = $this->messages[$messageIndex];
$variable = $this->messageToVariableBinding->getValue($message);
2022-07-05 15:55:47 +02:00
return $this->sendMessageVariable($message, $variable);
}
2023-08-02 12:39:42 +00:00
abstract protected function sendMessageVariable(Message $message, Variable $variable): float|int;
2023-08-02 13:19:35 +00:00
abstract public function createVariableToMessageBinding(Variable $variable): Message;
protected function createVariableToMessageBindingWithMessage(Variable $variable, Message $message): Message
{
2023-08-01 13:53:19 +00:00
$this->messageToVariableBinding->setValue($message, $variable);
$this->messages[] = $message;
$this->variables[] = $variable;
2022-07-05 15:55:47 +02:00
return $message;
}
2022-07-05 16:21:06 +02:00
public function __toString(): string
{
2023-08-01 13:53:19 +00:00
return $this->name;
}
2022-07-05 15:55:47 +02:00
}