Moved UnitTests to tests/ and Skills to src/

This commit is contained in:
Alexander Liljengård
2016-05-24 13:53:56 +02:00
parent 11b5033c8a
commit 4ab0c5d719
64 changed files with 0 additions and 0 deletions

113
src/FactorGraphs/Factor.php Normal file
View File

@ -0,0 +1,113 @@
<?php
namespace Moserware\Skills\FactorGraphs;
require_once(dirname(__FILE__) . "/../Guard.php");
require_once(dirname(__FILE__) . "/../HashMap.php");
require_once(dirname(__FILE__) . "/Message.php");
require_once(dirname(__FILE__) . "/Variable.php");
use Moserware\Skills\Guard;
use Moserware\Skills\HashMap;
abstract class Factor
{
private $_messages = array();
private $_messageToVariableBinding;
private $_name;
private $_variables = array();
protected function __construct($name)
{
$this->_name = "Factor[" . $name . "]";
$this->_messageToVariableBinding = new HashMap();
}
/**
* @return The log-normalization constant of that factor
*/
public function getLogNormalization()
{
return 0;
}
/**
* @return The number of messages that the factor has
*/
public function getNumberOfMessages()
{
return count($this->_messages);
}
protected function &getVariables()
{
return $this->_variables;
}
protected function &getMessages()
{
return $this->_messages;
}
/**
* Update the message and marginal of the i-th variable that the factor is connected to
*/
public function updateMessageIndex($messageIndex)
{
Guard::argumentIsValidIndex($messageIndex, count($this->_messages), "messageIndex");
$message = &$this->_messages[$messageIndex];
$variable = &$this->_messageToVariableBinding->getValue($message);
return $this->updateMessageVariable($message, $variable);
}
protected function updateMessageVariable(Message $message, Variable $variable)
{
throw new Exception();
}
/**
* Resets the marginal of the variables a factor is connected to
*/
public function resetMarginals()
{
$allValues = &$this->_messageToVariableBinding->getAllValues();
foreach ($allValues as &$currentVariable)
{
$currentVariable->resetToPrior();
}
}
/**
* Sends the ith message to the marginal and returns the log-normalization constant
*/
public function sendMessageIndex($messageIndex)
{
Guard::argumentIsValidIndex($messageIndex, count($this->_messages), "messageIndex");
$message = &$this->_messages[$messageIndex];
$variable = &$this->_messageToVariableBinding->getValue($message);
return $this->sendMessageVariable($message, $variable);
}
protected abstract function sendMessageVariable(Message &$message, Variable &$variable);
public abstract function &createVariableToMessageBinding(Variable &$variable);
protected function &createVariableToMessageBindingWithMessage(Variable &$variable, Message &$message)
{
$index = count($this->_messages);
$localMessages = &$this->_messages;
$localMessages[] = &$message;
$this->_messageToVariableBinding->setValue($message, $variable);
$localVariables = &$this->_variables;
$localVariables[] = &$variable;
return $message;
}
public function __toString()
{
return ($this->_name != null) ? $this->_name : base::__toString();
}
}
?>

View File

@ -0,0 +1,21 @@
<?php
namespace Moserware\Skills\FactorGraphs;
require_once(dirname(__FILE__) . "/VariableFactory.php");
class FactorGraph
{
private $_variableFactory;
public function &getVariableFactory()
{
$factory = &$this->_variableFactory;
return $factory;
}
public function setVariableFactory(VariableFactory &$factory)
{
$this->_variableFactory = &$factory;
}
}
?>

View File

@ -0,0 +1,74 @@
<?php
namespace Moserware\Skills\FactorGraphs;
require_once(dirname(__FILE__) . "/Factor.php");
require_once(dirname(__FILE__) . "/FactorGraph.php");
require_once(dirname(__FILE__) . "/Schedule.php");
abstract class FactorGraphLayer
{
private $_localFactors = array();
private $_outputVariablesGroups = array();
private $_inputVariablesGroups = array();
private $_parentFactorGraph;
protected function __construct(FactorGraph &$parentGraph)
{
$this->_parentFactorGraph = &$parentGraph;
}
protected function &getInputVariablesGroups()
{
$inputVariablesGroups = &$this->_inputVariablesGroups;
return $inputVariablesGroups;
}
// HACK
public function &getParentFactorGraph()
{
$parentFactorGraph = &$this->_parentFactorGraph;
return $parentFactorGraph;
}
public function &getOutputVariablesGroups()
{
$outputVariablesGroups = &$this->_outputVariablesGroups;
return $outputVariablesGroups;
}
public function &getLocalFactors()
{
$localFactors = &$this->_localFactors;
return $localFactors;
}
public function setInputVariablesGroups(&$value)
{
$this->_inputVariablesGroups = $value;
}
protected function scheduleSequence(array $itemsToSequence, $name)
{
return new ScheduleSequence($name, $itemsToSequence);
}
protected function addLayerFactor(Factor &$factor)
{
$this->_localFactors[] = $factor;
}
public abstract function buildLayer();
public function createPriorSchedule()
{
return null;
}
public function createPosteriorSchedule()
{
return null;
}
}
?>

View File

@ -0,0 +1,60 @@
<?php
namespace Moserware\Skills\FactorGraphs;
require_once(dirname(__FILE__) . "/Factor.php");
/**
* Helper class for computing the factor graph's normalization constant.
*/
class FactorList
{
private $_list = array();
public function getLogNormalization()
{
$list = &$this->_list;
foreach($list as &$currentFactor)
{
$currentFactor->resetMarginals();
}
$sumLogZ = 0.0;
$listCount = count($this->_list);
for ($i = 0; $i < $listCount; $i++)
{
$f = $this->_list[$i];
$numberOfMessages = $f->getNumberOfMessages();
for ($j = 0; $j < $numberOfMessages; $j++)
{
$sumLogZ += $f->sendMessageIndex($j);
}
}
$sumLogS = 0;
foreach($list as &$currentFactor)
{
$sumLogS = $sumLogS + $currentFactor->getLogNormalization();
}
return $sumLogZ + $sumLogS;
}
public function count()
{
return count($this->_list);
}
public function &addFactor(Factor &$factor)
{
$this->_list[] = $factor;
return $factor;
}
}
?>

View File

@ -0,0 +1,32 @@
<?php
namespace Moserware\Skills\FactorGraphs;
class Message
{
private $_name;
private $_value;
public function __construct(&$value = null, $name = null)
{
$this->_name = $name;
$this->_value = $value;
}
public function& getValue()
{
$value = &$this->_value;
return $value;
}
public function setValue(&$value)
{
$this->_value = &$value;
}
public function __toString()
{
return $this->_name;
}
}
?>

View File

@ -0,0 +1,94 @@
<?php
namespace Moserware\Skills\FactorGraphs;
require_once(dirname(__FILE__) . "/Factor.php");
abstract class Schedule
{
private $_name;
protected function __construct($name)
{
$this->_name = $name;
}
public abstract function visit($depth = -1, $maxDepth = 0);
public function __toString()
{
return $this->_name;
}
}
class ScheduleStep extends Schedule
{
private $_factor;
private $_index;
public function __construct($name, Factor &$factor, $index)
{
parent::__construct($name);
$this->_factor = $factor;
$this->_index = $index;
}
public function visit($depth = -1, $maxDepth = 0)
{
$currentFactor = &$this->_factor;
$delta = $currentFactor->updateMessageIndex($this->_index);
return $delta;
}
}
class ScheduleSequence extends Schedule
{
private $_schedules;
public function __construct($name, array $schedules)
{
parent::__construct($name);
$this->_schedules = $schedules;
}
public function visit($depth = -1, $maxDepth = 0)
{
$maxDelta = 0;
$schedules = &$this->_schedules;
foreach ($schedules as &$currentSchedule)
{
$currentVisit = $currentSchedule->visit($depth + 1, $maxDepth);
$maxDelta = max($currentVisit, $maxDelta);
}
return $maxDelta;
}
}
class ScheduleLoop extends Schedule
{
private $_maxDelta;
private $_scheduleToLoop;
public function __construct($name, Schedule &$scheduleToLoop, $maxDelta)
{
parent::__construct($name);
$this->_scheduleToLoop = $scheduleToLoop;
$this->_maxDelta = $maxDelta;
}
public function visit($depth = -1, $maxDepth = 0)
{
$totalIterations = 1;
$delta = $this->_scheduleToLoop->visit($depth + 1, $maxDepth);
while ($delta > $this->_maxDelta)
{
$delta = $this->_scheduleToLoop->visit($depth + 1, $maxDepth);
$totalIterations++;
}
return $delta;
}
}
?>

View File

@ -0,0 +1,73 @@
<?php
namespace Moserware\Skills\FactorGraphs;
class Variable
{
private $_name;
private $_prior;
private $_value;
public function __construct($name, &$prior)
{
$this->_name = "Variable[" . $name . "]";
$this->_prior = $prior;
$this->resetToPrior();
}
public function &getValue()
{
$value = &$this->_value;
return $value;
}
public function setValue(&$value)
{
$this->_value = &$value;
}
public function resetToPrior()
{
$this->_value = $this->_prior;
}
public function __toString()
{
return $this->_name;
}
}
class DefaultVariable extends Variable
{
public function __construct()
{
parent::__construct("Default", null);
}
public function &getValue()
{
return null;
}
public function setValue(&$value)
{
throw new Exception();
}
}
class KeyedVariable extends Variable
{
private $_key;
public function __construct(&$key, $name, &$prior)
{
parent::__construct($name, $prior);
$this->_key = &$key;
}
public function &getKey()
{
$key = &$this->_key;
return $key;
}
}
?>

View File

@ -0,0 +1,32 @@
<?php
namespace Moserware\Skills\FactorGraphs;
require_once(dirname(__FILE__) . "/Variable.php");
class VariableFactory
{
// using a Func<TValue> to encourage fresh copies in case it's overwritten
private $_variablePriorInitializer;
public function __construct($variablePriorInitializer)
{
$this->_variablePriorInitializer = &$variablePriorInitializer;
}
public function &createBasicVariable($name)
{
$initializer = $this->_variablePriorInitializer;
$newVar = new Variable($name, $initializer());
return $newVar;
}
public function &createKeyedVariable(&$key, $name)
{
$initializer = $this->_variablePriorInitializer;
$newVar = new KeyedVariable($key, $name, $initializer());
return $newVar;
}
}
?>