Cleaned up some of the run schedule parts

This commit is contained in:
Jeff Moser 2010-09-28 08:12:06 -04:00
parent 196d09429a
commit a45a1c47da
8 changed files with 21 additions and 13 deletions

@ -49,7 +49,9 @@ abstract class Factor
public function updateMessageIndex($messageIndex) public function updateMessageIndex($messageIndex)
{ {
Guard::argumentIsValidIndex($messageIndex, count($this->_messages), "messageIndex"); Guard::argumentIsValidIndex($messageIndex, count($this->_messages), "messageIndex");
return $this->updateMessageVariable($this->_messages[$messageIndex], $this->_messageToVariableBinding->getValue($messageIndex)); $message = $this->_messages[$messageIndex];
$variable = $this->_messageToVariableBinding->getValue($this->_messages[$messageIndex]);
return $this->updateMessageVariable($message, $variable);
} }
protected function updateMessageVariable(Message $message, Variable $variable) protected function updateMessageVariable(Message $message, Variable $variable)
@ -60,7 +62,7 @@ abstract class Factor
/// Resets the marginal of the variables a factor is connected to /// Resets the marginal of the variables a factor is connected to
public function resetMarginals() public function resetMarginals()
{ {
foreach ($this->_messageToVariableBindings->getAllValues() as $currentVariable) foreach ($this->_messageToVariableBinding->getAllValues() as $currentVariable)
{ {
$currentVariable->resetToPrior(); $currentVariable->resetToPrior();
} }
@ -69,7 +71,7 @@ abstract class Factor
/// Sends the ith message to the marginal and returns the log-normalization constant /// Sends the ith message to the marginal and returns the log-normalization constant
public function sendMessageIndex($messageIndex) public function sendMessageIndex($messageIndex)
{ {
Guard::argumentIsValidIndex($messageIndex, count($_messages), "messageIndex"); Guard::argumentIsValidIndex($messageIndex, count($this->_messages), "messageIndex");
$message = $this->_messages[$messageIndex]; $message = $this->_messages[$messageIndex];
$variable = $this->_messageToVariableBinding->getValue($message); $variable = $this->_messageToVariableBinding->getValue($message);

@ -43,7 +43,7 @@ abstract class FactorGraphLayer
$this->_inputVariablesGroups = $value; $this->_inputVariablesGroups = $value;
} }
protected function scheduleSequence(&$itemsToSequence, $name) protected function scheduleSequence($itemsToSequence, $name)
{ {
return new ScheduleSequence($name, $itemsToSequence); return new ScheduleSequence($name, $itemsToSequence);
} }

@ -12,11 +12,16 @@ class Message
$this->_value = $value; $this->_value = $value;
} }
public function getValue() public function& getValue()
{ {
return $this->_value; return $this->_value;
} }
public function setValue(&$value)
{
$this->_value = &$value;
}
public function __toString() public function __toString()
{ {
return $this->_name; return $this->_name;

@ -9,7 +9,8 @@ class HashMap
public function getValue($key) public function getValue($key)
{ {
return $this->_hashToValue[self::getHash($key)]; $hash = self::getHash($key);
return $this->_hashToValue[$hash];
} }
public function setValue($key, $value) public function setValue($key, $value)

@ -115,7 +115,7 @@ class GaussianDistribution
// Computes the absolute difference between two Gaussians // Computes the absolute difference between two Gaussians
public static function subtract(GaussianDistribution $left, GaussianDistribution $right) public static function subtract(GaussianDistribution $left, GaussianDistribution $right)
{ {
return absoluteDifference($left, $right); return GaussianDistribution::absoluteDifference($left, $right);
} }
public static function logProductNormalization(GaussianDistribution $left, GaussianDistribution $right) public static function logProductNormalization(GaussianDistribution $left, GaussianDistribution $right)

@ -25,7 +25,7 @@ abstract class GaussianFactor extends Factor
$marginal = &$variable->getValue(); $marginal = &$variable->getValue();
$messageValue = &$message->getValue(); $messageValue = &$message->getValue();
$logZ = GaussianDistribution::logProductNormalization($marginal, $messageValue); $logZ = GaussianDistribution::logProductNormalization($marginal, $messageValue);
$variable->setValue($marginal*$messageValue); $variable->setValue(GaussianDistribution::multiply($marginal, $messageValue));
return $logZ; return $logZ;
} }

@ -66,8 +66,8 @@ class GaussianLikelihoodFactor extends GaussianFactor
public function updateMessageIndex($messageIndex) public function updateMessageIndex($messageIndex)
{ {
$messages = $this->getMessages(); $messages = &$this->getMessages();
$vars = $this->getVariables(); $vars = &$this->getVariables();
switch ($messageIndex) switch ($messageIndex)
{ {

@ -91,7 +91,7 @@ class TrueSkillFactorGraph extends FactorGraph
foreach ($this->_layers as $currentLayer) foreach ($this->_layers as $currentLayer)
{ {
foreach ($currentLayer->getFactors() as $currentFactor) foreach ($currentLayer->getLocalFactors() as $currentFactor)
{ {
$factorList->addFactor($currentFactor); $factorList->addFactor($currentFactor);
} }
@ -107,10 +107,10 @@ class TrueSkillFactorGraph extends FactorGraph
foreach ($this->_layers as $currentLayer) foreach ($this->_layers as $currentLayer)
{ {
$currentPriorSchedule = &$currentLayer->createPriorSchedule(); $currentPriorSchedule = $currentLayer->createPriorSchedule();
if ($currentPriorSchedule != null) if ($currentPriorSchedule != null)
{ {
$fullSchedule[] = &$currentPriorSchedule; $fullSchedule[] = $currentPriorSchedule;
} }
} }