More type work

This commit is contained in:
Jens True 2023-08-02 13:19:35 +00:00
parent 4f60c3d024
commit 73781e9000
13 changed files with 35 additions and 25 deletions

@ -44,11 +44,17 @@ abstract class Factor implements \Stringable
return count($this->messages); return count($this->messages);
} }
/**
* @return Variable[]
*/
protected function getVariables(): array protected function getVariables(): array
{ {
return $this->variables; return $this->variables;
} }
/**
* @return Message[]
*/
protected function getMessages(): array protected function getMessages(): array
{ {
return $this->messages; return $this->messages;
@ -61,7 +67,7 @@ abstract class Factor implements \Stringable
* *
* @throws Exception * @throws Exception
*/ */
public function updateMessageIndex(int $messageIndex) public function updateMessageIndex(int $messageIndex): float
{ {
Guard::argumentIsValidIndex($messageIndex, count($this->messages), 'messageIndex'); Guard::argumentIsValidIndex($messageIndex, count($this->messages), 'messageIndex');
$message = $this->messages[$messageIndex]; $message = $this->messages[$messageIndex];
@ -78,7 +84,7 @@ abstract class Factor implements \Stringable
/** /**
* Resets the marginal of the variables a factor is connected to * Resets the marginal of the variables a factor is connected to
*/ */
public function resetMarginals() public function resetMarginals(): void
{ {
$allValues = $this->messageToVariableBinding->getAllValues(); $allValues = $this->messageToVariableBinding->getAllValues();
foreach ($allValues as $currentVariable) { foreach ($allValues as $currentVariable) {
@ -102,7 +108,7 @@ abstract class Factor implements \Stringable
abstract protected function sendMessageVariable(Message $message, Variable $variable): float|int; abstract protected function sendMessageVariable(Message $message, Variable $variable): float|int;
abstract public function createVariableToMessageBinding(Variable $variable); abstract public function createVariableToMessageBinding(Variable $variable): Message;
protected function createVariableToMessageBindingWithMessage(Variable $variable, Message $message): Message protected function createVariableToMessageBindingWithMessage(Variable $variable, Message $message): Message
{ {

@ -4,7 +4,7 @@ namespace DNW\Skills\FactorGraphs;
class VariableFactory class VariableFactory
{ {
public function __construct(private $variablePriorInitializer) public function __construct(private \Closure $variablePriorInitializer)
{ {
} }

@ -26,32 +26,32 @@ class GameInfo
) { ) {
} }
public function getInitialMean() public function getInitialMean(): float
{ {
return $this->initialMean; return $this->initialMean;
} }
public function getInitialStandardDeviation() public function getInitialStandardDeviation(): float
{ {
return $this->initialStandardDeviation; return $this->initialStandardDeviation;
} }
public function getBeta() public function getBeta(): float
{ {
return $this->beta; return $this->beta;
} }
public function getDynamicsFactor() public function getDynamicsFactor(): float
{ {
return $this->dynamicsFactor; return $this->dynamicsFactor;
} }
public function getDrawProbability() public function getDrawProbability(): float
{ {
return $this->drawProbability; return $this->drawProbability;
} }
public function getDefaultRating() public function getDefaultRating(): Rating
{ {
return new Rating($this->initialMean, $this->initialStandardDeviation); return new Rating($this->initialMean, $this->initialStandardDeviation);
} }

@ -12,11 +12,11 @@ class GaussianDistribution implements \Stringable
{ {
// precision and precisionMean are used because they make multiplying and dividing simpler // precision and precisionMean are used because they make multiplying and dividing simpler
// (the the accompanying math paper for more details) // (the the accompanying math paper for more details)
private $precision; private float $precision;
private $precisionMean; private float $precisionMean;
private $variance; private float $variance;
public function __construct(private float $mean = 0.0, private float $standardDeviation = 1.0) public function __construct(private float $mean = 0.0, private float $standardDeviation = 1.0)
{ {

@ -26,7 +26,7 @@ class Range
return $this->max; return $this->max;
} }
protected static function create(int $min, int $max): self protected static function create(int $min, int $max): static
{ {
return new Range($min, $max); return new Range($min, $max);
} }

@ -4,7 +4,7 @@ namespace DNW\Skills\Numerics;
class SquareMatrix extends Matrix class SquareMatrix extends Matrix
{ {
public function __construct(...$allValues) public function __construct(float|int ...$allValues)
{ {
$rows = (int) sqrt(count($allValues)); $rows = (int) sqrt(count($allValues));
$cols = $rows; $cols = $rows;

@ -7,8 +7,7 @@ class PartialPlay
public static function getPartialPlayPercentage(Player $player): float public static function getPartialPlayPercentage(Player $player): float
{ {
// If the player doesn't support the interface, assume 1.0 == 100% // If the player doesn't support the interface, assume 1.0 == 100%
$supportsPartialPlay = $player instanceof ISupportPartialPlay; if (! $player instanceof ISupportPartialPlay) {
if (! $supportsPartialPlay) {
return 1.0; return 1.0;
} }

@ -6,7 +6,7 @@ use DNW\Skills\Numerics\Range;
class PlayersRange extends Range class PlayersRange extends Range
{ {
protected static function create(int $min, int $max): self protected static function create(int $min, int $max): static
{ {
return new PlayersRange($min, $max); return new PlayersRange($min, $max);
} }

@ -14,7 +14,7 @@ class RankSorter
* @param array $teamRanks The ranks for each item where 1 is first place. * @param array $teamRanks The ranks for each item where 1 is first place.
* @return array * @return array
*/ */
public static function sort(array &$teams, array &$teamRanks) public static function sort(array &$teams, array &$teamRanks): array
{ {
array_multisort($teamRanks, $teams); array_multisort($teamRanks, $teams);

@ -21,11 +21,17 @@ class RatingContainer
return $this->playerToRating->setValue($player, $rating); return $this->playerToRating->setValue($player, $rating);
} }
/**
* @return Player[]
*/
public function getAllPlayers(): array public function getAllPlayers(): array
{ {
return $this->playerToRating->getAllKeys(); return $this->playerToRating->getAllKeys();
} }
/**
* @return Rating[]
*/
public function getAllRatings(): array public function getAllRatings(): array
{ {
return $this->playerToRating->getAllValues(); return $this->playerToRating->getAllValues();

@ -39,9 +39,9 @@ abstract class SkillCalculator
*/ */
abstract public function calculateMatchQuality(GameInfo $gameInfo, array $teamsOfPlayerToRatings): float; abstract public function calculateMatchQuality(GameInfo $gameInfo, array $teamsOfPlayerToRatings): float;
public function isSupported(SkillCalculatorSupportedOptions $option): bool public function isSupported(int $option): bool
{ {
return (bool)($this->supportedOptions & $option->value) == $option; return (bool)($this->supportedOptions & $option) == $option;
} }
protected function validateTeamCountAndPlayersCountPerTeam(array $teamsOfPlayerToRatings): void protected function validateTeamCountAndPlayersCountPerTeam(array $teamsOfPlayerToRatings): void

@ -6,7 +6,7 @@ use DNW\Skills\Numerics\Range;
class TeamsRange extends Range class TeamsRange extends Range
{ {
protected static function create(int $min, int $max): self protected static function create(int $min, int $max): static
{ {
return new TeamsRange($min, $max); return new TeamsRange($min, $max);
} }

@ -121,8 +121,7 @@ class GaussianWeightedSumFactor extends GaussianFactor
$result = 0.0; $result = 0.0;
// We start at 1 since offset 0 has the sum // We start at 1 since offset 0 has the sum
$varCount = is_countable($vars) ? count($vars) : 0; for ($i = 1; $i < count($vars); $i++) {
for ($i = 1; $i < $varCount; $i++) {
$result += GaussianDistribution::logRatioNormalization($vars[$i]->getValue(), $messages[$i]->getValue()); $result += GaussianDistribution::logRatioNormalization($vars[$i]->getValue(), $messages[$i]->getValue());
} }
@ -189,7 +188,7 @@ class GaussianWeightedSumFactor extends GaussianFactor
$allMessages = $this->getMessages(); $allMessages = $this->getMessages();
$allVariables = $this->getVariables(); $allVariables = $this->getVariables();
Guard::argumentIsValidIndex($messageIndex, is_countable($allMessages) ? count($allMessages) : 0, 'messageIndex'); Guard::argumentIsValidIndex($messageIndex, count($allMessages), 'messageIndex');
$updatedMessages = []; $updatedMessages = [];
$updatedVariables = []; $updatedVariables = [];