2022-07-05 15:55:47 +02:00
|
|
|
<?php
|
|
|
|
|
|
|
|
namespace DNW\Skills\FactorGraphs;
|
2010-09-18 11:11:44 -04:00
|
|
|
|
2022-07-05 15:33:34 +02:00
|
|
|
use DNW\Skills\Guard;
|
|
|
|
use DNW\Skills\HashMap;
|
2022-07-05 15:55:47 +02:00
|
|
|
use Exception;
|
2010-09-18 11:11:44 -04:00
|
|
|
|
2022-07-05 16:21:06 +02:00
|
|
|
abstract class Factor implements \Stringable
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
2022-07-05 16:21:06 +02:00
|
|
|
private array $_messages = [];
|
2022-07-05 15:55:47 +02:00
|
|
|
|
2010-09-18 11:11:44 -04:00
|
|
|
private $_messageToVariableBinding;
|
|
|
|
|
2023-08-01 12:26:38 +00:00
|
|
|
private string $_name;
|
2022-07-05 15:55:47 +02:00
|
|
|
|
2022-07-05 16:21:06 +02:00
|
|
|
private array $_variables = [];
|
2010-09-18 11:11:44 -04:00
|
|
|
|
2023-08-01 12:26:38 +00:00
|
|
|
protected function __construct(string $name)
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
2022-07-05 15:55:47 +02:00
|
|
|
$this->_name = 'Factor['.$name.']';
|
2010-09-25 18:25:56 -04:00
|
|
|
$this->_messageToVariableBinding = new HashMap();
|
2010-09-18 11:11:44 -04:00
|
|
|
}
|
|
|
|
|
2010-10-08 21:44:36 -04:00
|
|
|
/**
|
2016-05-31 10:01:06 +02:00
|
|
|
* @return mixed The log-normalization constant of that factor
|
2010-10-08 21:44:36 -04:00
|
|
|
*/
|
2010-09-18 11:11:44 -04:00
|
|
|
public function getLogNormalization()
|
|
|
|
{
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
2010-10-08 21:44:36 -04:00
|
|
|
/**
|
2016-05-31 10:01:06 +02:00
|
|
|
* @return int The number of messages that the factor has
|
2010-10-08 21:44:36 -04:00
|
|
|
*/
|
2023-08-01 12:26:38 +00:00
|
|
|
public function getNumberOfMessages(): int
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
|
|
|
return count($this->_messages);
|
|
|
|
}
|
2016-05-24 14:10:39 +02:00
|
|
|
|
2023-08-01 12:26:38 +00:00
|
|
|
protected function getVariables(): array
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
|
|
|
return $this->_variables;
|
|
|
|
}
|
2016-05-24 14:10:39 +02:00
|
|
|
|
2023-08-01 12:26:38 +00:00
|
|
|
protected function getMessages(): array
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
|
|
|
return $this->_messages;
|
|
|
|
}
|
|
|
|
|
2010-10-08 21:44:36 -04:00
|
|
|
/**
|
|
|
|
* Update the message and marginal of the i-th variable that the factor is connected to
|
2022-07-05 15:55:47 +02:00
|
|
|
*
|
2016-05-24 16:31:21 +02:00
|
|
|
* @param $messageIndex
|
2022-07-05 15:55:47 +02:00
|
|
|
*
|
2016-05-24 16:31:21 +02:00
|
|
|
* @throws Exception
|
2010-10-08 21:44:36 -04:00
|
|
|
*/
|
2023-08-01 12:26:38 +00:00
|
|
|
public function updateMessageIndex(int $messageIndex)
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
2022-07-05 15:55:47 +02:00
|
|
|
Guard::argumentIsValidIndex($messageIndex, count($this->_messages), 'messageIndex');
|
2016-05-24 16:31:21 +02:00
|
|
|
$message = $this->_messages[$messageIndex];
|
|
|
|
$variable = $this->_messageToVariableBinding->getValue($message);
|
2022-07-05 15:55:47 +02:00
|
|
|
|
2010-09-28 08:12:06 -04:00
|
|
|
return $this->updateMessageVariable($message, $variable);
|
2010-09-18 11:11:44 -04:00
|
|
|
}
|
|
|
|
|
2010-09-25 10:15:51 -04:00
|
|
|
protected function updateMessageVariable(Message $message, Variable $variable)
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
|
|
|
throw new Exception();
|
|
|
|
}
|
|
|
|
|
2010-10-08 21:44:36 -04:00
|
|
|
/**
|
|
|
|
* Resets the marginal of the variables a factor is connected to
|
|
|
|
*/
|
2010-09-18 11:11:44 -04:00
|
|
|
public function resetMarginals()
|
|
|
|
{
|
2016-05-24 16:31:21 +02:00
|
|
|
$allValues = $this->_messageToVariableBinding->getAllValues();
|
|
|
|
foreach ($allValues as $currentVariable) {
|
2010-09-18 11:11:44 -04:00
|
|
|
$currentVariable->resetToPrior();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2010-10-08 21:44:36 -04:00
|
|
|
/**
|
|
|
|
* Sends the ith message to the marginal and returns the log-normalization constant
|
2022-07-05 15:55:47 +02:00
|
|
|
*
|
2016-05-24 16:31:21 +02:00
|
|
|
* @param $messageIndex
|
|
|
|
* @return
|
2022-07-05 15:55:47 +02:00
|
|
|
*
|
2016-05-24 16:31:21 +02:00
|
|
|
* @throws Exception
|
2010-10-08 21:44:36 -04:00
|
|
|
*/
|
2010-09-18 11:11:44 -04:00
|
|
|
public function sendMessageIndex($messageIndex)
|
|
|
|
{
|
2022-07-05 15:55:47 +02:00
|
|
|
Guard::argumentIsValidIndex($messageIndex, count($this->_messages), 'messageIndex');
|
2010-09-18 11:11:44 -04:00
|
|
|
|
2016-05-24 16:31:21 +02:00
|
|
|
$message = $this->_messages[$messageIndex];
|
|
|
|
$variable = $this->_messageToVariableBinding->getValue($message);
|
2022-07-05 15:55:47 +02:00
|
|
|
|
2010-09-18 11:11:44 -04:00
|
|
|
return $this->sendMessageVariable($message, $variable);
|
|
|
|
}
|
|
|
|
|
2022-07-05 15:55:47 +02:00
|
|
|
abstract protected function sendMessageVariable(Message $message, Variable $variable);
|
2010-09-18 11:11:44 -04:00
|
|
|
|
2022-07-05 15:55:47 +02:00
|
|
|
abstract public function createVariableToMessageBinding(Variable $variable);
|
2010-09-18 11:11:44 -04:00
|
|
|
|
2016-05-24 16:31:21 +02:00
|
|
|
protected function createVariableToMessageBindingWithMessage(Variable $variable, Message $message)
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
2010-09-23 22:14:56 -04:00
|
|
|
$this->_messageToVariableBinding->setValue($message, $variable);
|
2016-05-24 16:31:21 +02:00
|
|
|
$this->_messages[] = $message;
|
|
|
|
$this->_variables[] = $variable;
|
2022-07-05 15:55:47 +02:00
|
|
|
|
2010-09-18 11:11:44 -04:00
|
|
|
return $message;
|
|
|
|
}
|
|
|
|
|
2022-07-05 16:21:06 +02:00
|
|
|
public function __toString(): string
|
2010-09-18 11:11:44 -04:00
|
|
|
{
|
2016-05-31 10:01:06 +02:00
|
|
|
return $this->_name;
|
2010-09-18 11:11:44 -04:00
|
|
|
}
|
2022-07-05 15:55:47 +02:00
|
|
|
}
|