Some reference updates and adding the needed array(...) scope to existing data for single team values

This commit is contained in:
Jeff Moser 2010-09-25 22:40:56 -04:00
parent 8e9e2d0d86
commit 04c911742d
8 changed files with 51 additions and 47 deletions

@ -22,8 +22,8 @@ abstract class GaussianFactor extends Factor
/// Sends the factor-graph message with and returns the log-normalization constant /// Sends the factor-graph message with and returns the log-normalization constant
protected function sendMessageVariable(Message &$message, Variable &$variable) protected function sendMessageVariable(Message &$message, Variable &$variable)
{ {
$marginal = $variable->getValue(); $marginal = &$variable->getValue();
$messageValue = $message->getValue(); $messageValue = &$message->getValue();
$logZ = GaussianDistribution::logProductNormalization($marginal, $messageValue); $logZ = GaussianDistribution::logProductNormalization($marginal, $messageValue);
$variable->setValue($marginal*$messageValue); $variable->setValue($marginal*$messageValue);
return $logZ; return $logZ;

@ -22,17 +22,17 @@ class GaussianGreaterThanFactor extends GaussianFactor
public function __construct($epsilon, Variable &$variable) public function __construct($epsilon, Variable &$variable)
{ {
parent::_construct("{0} > {1:0.000}"); parent::__construct("{0} > {1:0.000}");
$this->_epsilon = $epsilon; $this->_epsilon = $epsilon;
$this->createVariableToMessageBinding($variable); $this->createVariableToMessageBinding($variable);
} }
public function getLogNormalization() public function getLogNormalization()
{ {
$vars = $this->getVariables(); $vars = &$this->getVariables();
$marginal = $vars[0]->getValue(); $marginal = &$vars[0]->getValue();
$messages = $this->getMessages(); $messages = &$this->getMessages();
$message = $messages[0]->getValue(); $message = &$messages[0]->getValue();
$messageFromVariable = GaussianDistribution::divide($marginal, $message); $messageFromVariable = GaussianDistribution::divide($marginal, $message);
return -GaussianDistribution::logProductNormalization($messageFromVariable, $message) return -GaussianDistribution::logProductNormalization($messageFromVariable, $message)
+ +

@ -2,14 +2,16 @@
namespace Moserware\Skills\TrueSkill\Factors; namespace Moserware\Skills\TrueSkill\Factors;
require_once(dirname(__FILE__) . "/GaussianFactor.php"); require_once(dirname(__FILE__) . "/GaussianFactor.php");
require_once(dirname(__FILE__) . "/../../Guard.php");
require_once(dirname(__FILE__) . "/../../FactorGraphs/Message.php"); require_once(dirname(__FILE__) . "/../../FactorGraphs/Message.php");
require_once(dirname(__FILE__) . "/../../FactorGraphs/Variable.php"); require_once(dirname(__FILE__) . "/../../FactorGraphs/Variable.php");
require_once(dirname(__FILE__) . "/../../Numerics/GaussianDistribution.php"); require_once(dirname(__FILE__) . "/../../Numerics/GaussianDistribution.php");
require_once(dirname(__FILE__) . "/../../Numerics/BasicMath.php"); require_once(dirname(__FILE__) . "/../../Numerics/BasicMath.php");
use Moserware\Numerics\GaussianDistribution; use Moserware\Numerics\GaussianDistribution;
use Moserware\Skills\Guard;
use Moserware\Skills\FactorGraphs\Message; use Moserware\Skills\FactorGraphs\Message;
use Moserware\Skills\FactorGraphs\Variable; use Moserware\Skills\FactorGraphs\Variable;
@ -40,9 +42,9 @@ class GaussianWeightedSumFactor extends GaussianFactor
for($i = 0; $i < $variableWeightsLength; $i++) for($i = 0; $i < $variableWeightsLength; $i++)
{ {
$weight = $variableWeights[$i]; $weight = &$variableWeights[$i];
$this->_weights[0][$i] = $weight; $this->_weights[0][$i] = $weight;
$this->_weightsSquared[0][$i] = $weight * $weight; $this->_weightsSquared[0][$i] = square($weight);
} }
$variablesToSumLength = count($variablesToSum); $variablesToSumLength = count($variablesToSum);
@ -123,8 +125,8 @@ class GaussianWeightedSumFactor extends GaussianFactor
public function getLogNormalization() public function getLogNormalization()
{ {
$vars = $this->getVariables(); $vars = &$this->getVariables();
$messages = $this->getMessages(); $messages = &$this->getMessages();
$result = 0.0; $result = 0.0;
@ -132,7 +134,7 @@ class GaussianWeightedSumFactor extends GaussianFactor
$varCount = count($vars); $varCount = count($vars);
for ($i = 1; $i < $varCount; $i++) for ($i = 1; $i < $varCount; $i++)
{ {
$result += GaussianDistribution::logRatioNormalization($vars[i]->getValue(), $messages[i]->getValue()); $result += GaussianDistribution::logRatioNormalization($vars[$i]->getValue(), $messages[$i]->getValue());
} }
return $result; return $result;
@ -145,7 +147,7 @@ class GaussianWeightedSumFactor extends GaussianFactor
// Potentially look at http://mathworld.wolfram.com/NormalSumDistribution.html for clues as // Potentially look at http://mathworld.wolfram.com/NormalSumDistribution.html for clues as
// to what it's doing // to what it's doing
$messages = $this->getMessages(); $messages = &$this->getMessages();
$message0 = clone $messages[0]->getValue(); $message0 = clone $messages[0]->getValue();
$marginal0 = clone $variables[0]->getValue(); $marginal0 = clone $variables[0]->getValue();
@ -161,13 +163,13 @@ class GaussianWeightedSumFactor extends GaussianFactor
{ {
// These flow directly from the paper // These flow directly from the paper
$inverseOfNewPrecisionSum += $weightsSquared[i]/ $inverseOfNewPrecisionSum += $weightsSquared[$i]/
($variables[$i + 1]->getValue()->getPrecision() - $messages[$i + 1]->getValue()->getPrecision()); ($variables[$i + 1]->getValue()->getPrecision() - $messages[$i + 1]->getValue()->getPrecision());
$diff = GaussianDistribution::divide($variables[$i + 1]->getValue(), $messages[$i + 1]->getValue()); $diff = GaussianDistribution::divide($variables[$i + 1]->getValue(), $messages[$i + 1]->getValue());
$anotherInverseOfNewPrecisionSum += $weightsSquared[i]/$diff->getPrecision(); $anotherInverseOfNewPrecisionSum += $weightsSquared[$i]/$diff->getPrecision();
$weightedMeanSum += $weights[i] $weightedMeanSum += $weights[$i]
* *
($variables[$i + 1]->getValue()->getPrecisionMean() - $messages[$i + 1]->getValue()->getPrecisionMean()) ($variables[$i + 1]->getValue()->getPrecisionMean() - $messages[$i + 1]->getValue()->getPrecisionMean())
/ /
@ -199,21 +201,21 @@ class GaussianWeightedSumFactor extends GaussianFactor
public function updateMessageIndex($messageIndex) public function updateMessageIndex($messageIndex)
{ {
$allMessages = $this->getMessages(); $allMessages = &$this->getMessages();
$allVariables = $this->getVariables(); $allVariables = &$this->getVariables();
Guard::argumentIsValidIndex($messageIndex, count($allMessages), "messageIndex"); Guard::argumentIsValidIndex($messageIndex, count($allMessages), "messageIndex");
$updatedMessages = array(); $updatedMessages = array();
$updatedVariables = array(); $updatedVariables = array();
$indicesToUse = $this->_variableIndexOrdersForWeights[$messageIndex]; $indicesToUse = &$this->_variableIndexOrdersForWeights[$messageIndex];
// The tricky part here is that we have to put the messages and variables in the same // The tricky part here is that we have to put the messages and variables in the same
// order as the weights. Thankfully, the weights and messages share the same index numbers, // order as the weights. Thankfully, the weights and messages share the same index numbers,
// so we just need to make sure they're consistent // so we just need to make sure they're consistent
$allMessagesCount = count($allMessages); $allMessagesCount = count($allMessages);
for ($i = 0; i < $allMessagesCount; $i++) for ($i = 0; $i < $allMessagesCount; $i++)
{ {
$updatedMessages[] =$allMessages[$indicesToUse[$i]]; $updatedMessages[] =$allMessages[$indicesToUse[$i]];
$updatedVariables[] = $allVariables[$indicesToUse[$i]]; $updatedVariables[] = $allVariables[$indicesToUse[$i]];

@ -28,11 +28,11 @@ class GaussianWithinFactor extends GaussianFactor
public function getLogNormalization() public function getLogNormalization()
{ {
$variables = $this->getVariables(); $variables = &$this->getVariables();
$marginal = $variables[0]->getValue(); $marginal = &$variables[0]->getValue();
$messages = $this->getMessages(); $messages = &$this->getMessages();
$message = $messages[0]->getValue(); $message = &$messages[0]->getValue();
$messageFromVariable = GaussianDistribution::divide($marginal, $message); $messageFromVariable = GaussianDistribution::divide($marginal, $message);
$mean = $messageFromVariable->getMean(); $mean = $messageFromVariable->getMean();
$std = $messageFromVariable->getStandardDeviation(); $std = $messageFromVariable->getStandardDeviation();

@ -29,7 +29,7 @@ class IteratedTeamDifferencesInnerLayer extends TrueSkillFactorGraphLayer
public function buildLayer() public function buildLayer()
{ {
$inputVariablesGroups = $this->getInputVariablesGroups(); $inputVariablesGroups = &$this->getInputVariablesGroups();
$this->_TeamPerformancesToTeamPerformanceDifferencesLayer->setInputVariablesGroups($inputVariablesGroups); $this->_TeamPerformancesToTeamPerformanceDifferencesLayer->setInputVariablesGroups($inputVariablesGroups);
$this->_TeamPerformancesToTeamPerformanceDifferencesLayer->buildLayer(); $this->_TeamPerformancesToTeamPerformanceDifferencesLayer->buildLayer();
@ -58,7 +58,7 @@ class IteratedTeamDifferencesInnerLayer extends TrueSkillFactorGraphLayer
$totalTeamDifferences = count($this->_TeamPerformancesToTeamPerformanceDifferencesLayer->getLocalFactors()); $totalTeamDifferences = count($this->_TeamPerformancesToTeamPerformanceDifferencesLayer->getLocalFactors());
$totalTeams = $totalTeamDifferences + 1; $totalTeams = $totalTeamDifferences + 1;
$localFactors = $this->_TeamPerformancesToTeamPerformanceDifferencesLayer->getLocalFactors(); $localFactors = &$this->_TeamPerformancesToTeamPerformanceDifferencesLayer->getLocalFactors();
$innerSchedule = new ScheduleSequence( $innerSchedule = new ScheduleSequence(
"inner schedule", "inner schedule",
array( array(
@ -72,16 +72,15 @@ class IteratedTeamDifferencesInnerLayer extends TrueSkillFactorGraphLayer
) )
); );
return innerSchedule; return $innerSchedule;
} }
private function createTwoTeamInnerPriorLoopSchedule() private function createTwoTeamInnerPriorLoopSchedule()
{ {
$teamPerformancesToTeamPerformanceDifferencesLayerLocalFactors = $this->_TeamPerformancesToTeamPerformanceDifferencesLayer->getLocalFactors(); $teamPerformancesToTeamPerformanceDifferencesLayerLocalFactors = &$this->_TeamPerformancesToTeamPerformanceDifferencesLayer->getLocalFactors();
$teamDifferencesComparisonLayerLocalFactors = $this->_TeamDifferencesComparisonLayer->getLocalFactors(); $teamDifferencesComparisonLayerLocalFactors = &$this->_TeamDifferencesComparisonLayer->getLocalFactors();
return $this->scheduleSequence( $itemsToSequence = array(
array(
new ScheduleStep( new ScheduleStep(
"send team perf to perf differences", "send team perf to perf differences",
$teamPerformancesToTeamPerformanceDifferencesLayerLocalFactors[0], $teamPerformancesToTeamPerformanceDifferencesLayerLocalFactors[0],
@ -90,7 +89,10 @@ class IteratedTeamDifferencesInnerLayer extends TrueSkillFactorGraphLayer
"send to greater than or within factor", "send to greater than or within factor",
$teamDifferencesComparisonLayerLocalFactors[0], $teamDifferencesComparisonLayerLocalFactors[0],
0) 0)
), );
return $this->scheduleSequence(
$itemsToSequence,
"loop of just two teams inner sequence"); "loop of just two teams inner sequence");
} }
@ -102,8 +104,8 @@ class IteratedTeamDifferencesInnerLayer extends TrueSkillFactorGraphLayer
for ($i = 0; $i < $totalTeamDifferences - 1; $i++) for ($i = 0; $i < $totalTeamDifferences - 1; $i++)
{ {
$teamPerformancesToTeamPerformanceDifferencesLayerLocalFactors = $this->_TeamPerformancesToTeamPerformanceDifferencesLayer->getLocalFactors(); $teamPerformancesToTeamPerformanceDifferencesLayerLocalFactors = &$this->_TeamPerformancesToTeamPerformanceDifferencesLayer->getLocalFactors();
$teamDifferencesComparisonLayerLocalFactors = $this->_TeamDifferencesComparisonLayer->getLocalFactors(); $teamDifferencesComparisonLayerLocalFactors = &$this->_TeamDifferencesComparisonLayer->getLocalFactors();
$currentForwardSchedulePiece = $currentForwardSchedulePiece =
$this->scheduleSequence( $this->scheduleSequence(
@ -119,7 +121,7 @@ class IteratedTeamDifferencesInnerLayer extends TrueSkillFactorGraphLayer
$teamPerformancesToTeamPerformanceDifferencesLayerLocalFactors[$i], 2) $teamPerformancesToTeamPerformanceDifferencesLayerLocalFactors[$i], 2)
), sprintf("current forward schedule piece %d", $i)); ), sprintf("current forward schedule piece %d", $i));
$forwardScheduleList[] = $currentForwardSchedulePiece; $forwardScheduleList[] = &$currentForwardSchedulePiece;
} }
$forwardSchedule = new ScheduleSequence("forward schedule", $forwardScheduleList); $forwardSchedule = new ScheduleSequence("forward schedule", $forwardScheduleList);
@ -128,8 +130,8 @@ class IteratedTeamDifferencesInnerLayer extends TrueSkillFactorGraphLayer
for ($i = 0; $i < $totalTeamDifferences - 1; $i++) for ($i = 0; $i < $totalTeamDifferences - 1; $i++)
{ {
$teamPerformancesToTeamPerformanceDifferencesLayerLocalFactors = $this->_TeamPerformancesToTeamPerformanceDifferencesLayer->getLocalFactors(); $teamPerformancesToTeamPerformanceDifferencesLayerLocalFactors = &$this->_TeamPerformancesToTeamPerformanceDifferencesLayer->getLocalFactors();
$teamDifferencesComparisonLayerLocalFactors = $this->_TeamDifferencesComparisonLayer->getLocalFactors(); $teamDifferencesComparisonLayerLocalFactors = &$this->_TeamDifferencesComparisonLayer->getLocalFactors();
$currentBackwardSchedulePiece = new ScheduleSequence( $currentBackwardSchedulePiece = new ScheduleSequence(
"current backward schedule piece", "current backward schedule piece",
@ -144,7 +146,7 @@ class IteratedTeamDifferencesInnerLayer extends TrueSkillFactorGraphLayer
sprintf("teamPerformanceToPerformanceDifferenceFactors[totalTeamDifferences - 1 - %d] @ 1", $i), sprintf("teamPerformanceToPerformanceDifferenceFactors[totalTeamDifferences - 1 - %d] @ 1", $i),
$teamPerformancesToTeamPerformanceDifferencesLayerLocalFactors[$totalTeamDifferences - 1 - $i], 1) $teamPerformancesToTeamPerformanceDifferencesLayerLocalFactors[$totalTeamDifferences - 1 - $i], 1)
)); ));
$backwardScheduleList[] = $currentBackwardSchedulePiece; $backwardScheduleList[] = &$currentBackwardSchedulePiece;
} }
$backwardSchedule = new ScheduleSequence("backward schedule", $backwardScheduleList); $backwardSchedule = new ScheduleSequence("backward schedule", $backwardScheduleList);

@ -27,7 +27,7 @@ class TeamDifferencesComparisonLayer extends TrueSkillFactorGraphLayer
public function buildLayer() public function buildLayer()
{ {
$inputVarGroups = $this->getInputVariablesGroups(); $inputVarGroups = &$this->getInputVariablesGroups();
$inputVarGroupsCount = count($inputVarGroups); $inputVarGroupsCount = count($inputVarGroups);
for ($i = 0; $i < $inputVarGroupsCount; $i++) for ($i = 0; $i < $inputVarGroupsCount; $i++)

@ -28,13 +28,13 @@ class TeamPerformancesToTeamPerformanceDifferencesLayer extends TrueSkillFactorG
$strongerTeam = $inputVariablesGroups[$i][0]; $strongerTeam = $inputVariablesGroups[$i][0];
$weakerTeam = $inputVariablesGroups[$i + 1][0]; $weakerTeam = $inputVariablesGroups[$i + 1][0];
$currentDifference = $this->createOutputVariable(); $currentDifference = &$this->createOutputVariable();
$newDifferencesFactor = $this->createTeamPerformanceToDifferenceFactor($strongerTeam, $weakerTeam, $currentDifference); $newDifferencesFactor = &$this->createTeamPerformanceToDifferenceFactor($strongerTeam, $weakerTeam, $currentDifference);
$this->addLayerFactor($newDifferencesFactor); $this->addLayerFactor($newDifferencesFactor);
// REVIEW: Does it make sense to have groups of one? // REVIEW: Does it make sense to have groups of one?
$outputVariablesGroup = $this->getOutputVariablesGroups(); $outputVariablesGroup = &$this->getOutputVariablesGroups();
$outputVariablesGroup[] = $currentDifference; $outputVariablesGroup[] = array($currentDifference);
} }
} }

@ -107,10 +107,10 @@ class TrueSkillFactorGraph extends FactorGraph
foreach ($this->_layers as $currentLayer) foreach ($this->_layers as $currentLayer)
{ {
$currentPriorSchedule = $currentLayer->createPriorSchedule(); $currentPriorSchedule = &$currentLayer->createPriorSchedule();
if ($currentPriorSchedule != null) if ($currentPriorSchedule != null)
{ {
$fullSchedule[] = $currentPriorSchedule; $fullSchedule[] = &$currentPriorSchedule;
} }
} }