From d5bba04f4f2a6d0bcf2acb719f6cc1cdf096a07f Mon Sep 17 00:00:00 2001 From: Jens True Date: Tue, 1 Aug 2023 12:13:24 +0000 Subject: [PATCH] More type checks --- src/FactorGraphs/FactorGraphLayer.php | 3 +- src/FactorGraphs/Schedule.php | 2 +- src/FactorGraphs/ScheduleLoop.php | 2 +- src/FactorGraphs/Variable.php | 10 ++--- src/Numerics/BasicMath.php | 4 +- src/Numerics/GaussianDistribution.php | 38 +++++++++---------- src/Numerics/Matrix.php | 32 ++++++++-------- src/Numerics/Range.php | 20 +++++----- src/PlayersRange.php | 2 +- src/TeamsRange.php | 2 +- .../FactorGraphTrueSkillCalculator.php | 3 +- .../PlayerSkillsToPerformancesLayer.php | 7 ++-- src/TrueSkill/TrueSkillFactorGraph.php | 2 +- .../TruncatedGaussianCorrectionFunctions.php | 16 ++++---- .../TwoPlayerTrueSkillCalculator.php | 2 +- 15 files changed, 72 insertions(+), 73 deletions(-) diff --git a/src/FactorGraphs/FactorGraphLayer.php b/src/FactorGraphs/FactorGraphLayer.php index 2aa8c62..8b79644 100644 --- a/src/FactorGraphs/FactorGraphLayer.php +++ b/src/FactorGraphs/FactorGraphLayer.php @@ -2,7 +2,6 @@ namespace DNW\Skills\FactorGraphs; -// edit this abstract class FactorGraphLayer { private array $_localFactors = []; @@ -47,7 +46,7 @@ abstract class FactorGraphLayer $this->_inputVariablesGroups = $value; } - protected function scheduleSequence(array $itemsToSequence, $name) + protected function scheduleSequence(array $itemsToSequence, $name): ScheduleSequence { return new ScheduleSequence($name, $itemsToSequence); } diff --git a/src/FactorGraphs/Schedule.php b/src/FactorGraphs/Schedule.php index 79ae6af..5877081 100644 --- a/src/FactorGraphs/Schedule.php +++ b/src/FactorGraphs/Schedule.php @@ -8,7 +8,7 @@ abstract class Schedule implements \Stringable { } - abstract public function visit($depth = -1, $maxDepth = 0); + abstract public function visit(int $depth = -1, int $maxDepth = 0); public function __toString(): string { diff --git a/src/FactorGraphs/ScheduleLoop.php b/src/FactorGraphs/ScheduleLoop.php index 7ae5475..9aaacea 100644 --- a/src/FactorGraphs/ScheduleLoop.php +++ b/src/FactorGraphs/ScheduleLoop.php @@ -9,7 +9,7 @@ class ScheduleLoop extends Schedule parent::__construct($name); } - public function visit($depth = -1, $maxDepth = 0) + public function visit(int $depth = -1, int $maxDepth = 0) { $totalIterations = 1; $delta = $this->_scheduleToLoop->visit($depth + 1, $maxDepth); diff --git a/src/FactorGraphs/Variable.php b/src/FactorGraphs/Variable.php index 5c6512d..f3e8f42 100644 --- a/src/FactorGraphs/Variable.php +++ b/src/FactorGraphs/Variable.php @@ -4,22 +4,22 @@ namespace DNW\Skills\FactorGraphs; class Variable implements \Stringable { - private $_name; + private string $_name; - private $_value; + private mixed $_value; - public function __construct($name, private $_prior) + public function __construct(string $name, private mixed $_prior) { $this->_name = 'Variable['.$name.']'; $this->resetToPrior(); } - public function getValue() + public function getValue(): mixed { return $this->_value; } - public function setValue($value) + public function setValue(mixed $value): void { $this->_value = $value; } diff --git a/src/Numerics/BasicMath.php b/src/Numerics/BasicMath.php index 1310c32..3d1314d 100644 --- a/src/Numerics/BasicMath.php +++ b/src/Numerics/BasicMath.php @@ -16,7 +16,7 @@ class BasicMath * @param number $x Value to square (x) * @return number The squared value (x^2) */ - public static function square($x) + public static function square($x): float|int { return $x * $x; } @@ -28,7 +28,7 @@ class BasicMath * @param callable $callback The function to apply to each array element before summing. * @return number The sum. */ - public static function sum(array $itemsToSum, $callback) + public static function sum(array $itemsToSum, \Closure $callback): float|int { $mappedItems = array_map($callback, $itemsToSum); diff --git a/src/Numerics/GaussianDistribution.php b/src/Numerics/GaussianDistribution.php index 8983088..42e4e2e 100644 --- a/src/Numerics/GaussianDistribution.php +++ b/src/Numerics/GaussianDistribution.php @@ -18,7 +18,7 @@ class GaussianDistribution implements \Stringable private $_variance; - public function __construct(private $_mean = 0.0, private $_standardDeviation = 1.0) + public function __construct(private float $_mean = 0.0, private float $_standardDeviation = 1.0) { $this->_variance = BasicMath::square($_standardDeviation); @@ -32,32 +32,32 @@ class GaussianDistribution implements \Stringable } } - public function getMean() + public function getMean(): float { return $this->_mean; } - public function getVariance() + public function getVariance(): float { return $this->_variance; } - public function getStandardDeviation() + public function getStandardDeviation(): float { return $this->_standardDeviation; } - public function getPrecision() + public function getPrecision(): float { return $this->_precision; } - public function getPrecisionMean() + public function getPrecisionMean(): float { return $this->_precisionMean; } - public function getNormalizationConstant() + public function getNormalizationConstant(): float { // Great derivation of this is at http://www.astro.psu.edu/~mce/A451_2/A451/downloads/notes0.pdf return 1.0 / (sqrt(2 * M_PI) * $this->_standardDeviation); @@ -75,7 +75,7 @@ class GaussianDistribution implements \Stringable return $result; } - public static function fromPrecisionMean($precisionMean, $precision) + public static function fromPrecisionMean(float $precisionMean, float $precision): self { $result = new GaussianDistribution(); $result->_precision = $precision; @@ -96,13 +96,13 @@ class GaussianDistribution implements \Stringable // For details, see http://www.tina-vision.net/tina-knoppix/tina-memo/2003-003.pdf // for multiplication, the precision mean ones are easier to write :) - public static function multiply(GaussianDistribution $left, GaussianDistribution $right) + public static function multiply(GaussianDistribution $left, GaussianDistribution $right): self { return GaussianDistribution::fromPrecisionMean($left->_precisionMean + $right->_precisionMean, $left->_precision + $right->_precision); } // Computes the absolute difference between two Gaussians - public static function absoluteDifference(GaussianDistribution $left, GaussianDistribution $right) + public static function absoluteDifference(GaussianDistribution $left, GaussianDistribution $right): float { return max( abs($left->_precisionMean - $right->_precisionMean), @@ -111,12 +111,12 @@ class GaussianDistribution implements \Stringable } // Computes the absolute difference between two Gaussians - public static function subtract(GaussianDistribution $left, GaussianDistribution $right) + public static function subtract(GaussianDistribution $left, GaussianDistribution $right): float { return GaussianDistribution::absoluteDifference($left, $right); } - public static function logProductNormalization(GaussianDistribution $left, GaussianDistribution $right) + public static function logProductNormalization(GaussianDistribution $left, GaussianDistribution $right): float { if (($left->_precision == 0) || ($right->_precision == 0)) { return 0; @@ -130,7 +130,7 @@ class GaussianDistribution implements \Stringable return -$logSqrt2Pi - (log($varianceSum) / 2.0) - (BasicMath::square($meanDifference) / (2.0 * $varianceSum)); } - public static function divide(GaussianDistribution $numerator, GaussianDistribution $denominator) + public static function divide(GaussianDistribution $numerator, GaussianDistribution $denominator): self { return GaussianDistribution::fromPrecisionMean( $numerator->_precisionMean - $denominator->_precisionMean, @@ -138,7 +138,7 @@ class GaussianDistribution implements \Stringable ); } - public static function logRatioNormalization(GaussianDistribution $numerator, GaussianDistribution $denominator) + public static function logRatioNormalization(GaussianDistribution $numerator, GaussianDistribution $denominator): float { if (($numerator->_precision == 0) || ($denominator->_precision == 0)) { return 0; @@ -153,7 +153,7 @@ class GaussianDistribution implements \Stringable BasicMath::square($meanDifference) / (2 * $varianceDifference); } - public static function at($x, $mean = 0.0, $standardDeviation = 1.0) + public static function at(float $x, float $mean = 0.0, float $standardDeviation = 1.0): float { // See http://mathworld.wolfram.com/NormalDistribution.html // 1 -(x-mean)^2 / (2*stdDev^2) @@ -166,7 +166,7 @@ class GaussianDistribution implements \Stringable return $multiplier * $expPart; } - public static function cumulativeTo($x, $mean = 0.0, $standardDeviation = 1.0) + public static function cumulativeTo(float $x, float $mean = 0.0, float $standardDeviation = 1.0): float { $invsqrt2 = -0.707106781186547524400844362104; $result = GaussianDistribution::errorFunctionCumulativeTo($invsqrt2 * $x); @@ -174,7 +174,7 @@ class GaussianDistribution implements \Stringable return 0.5 * $result; } - private static function errorFunctionCumulativeTo($x) + private static function errorFunctionCumulativeTo($x): float { // Derived from page 265 of Numerical Recipes 3rd Edition $z = abs($x); @@ -227,7 +227,7 @@ class GaussianDistribution implements \Stringable return ($x >= 0.0) ? $ans : (2.0 - $ans); } - private static function inverseErrorFunctionCumulativeTo($p) + private static function inverseErrorFunctionCumulativeTo(float $p): float { // From page 265 of numerical recipes @@ -250,7 +250,7 @@ class GaussianDistribution implements \Stringable return ($p < 1.0) ? $x : -$x; } - public static function inverseCumulativeTo($x, $mean = 0.0, $standardDeviation = 1.0) + public static function inverseCumulativeTo(float $x, float $mean = 0.0, float $standardDeviation = 1.0): float { // From numerical recipes, page 320 return $mean - sqrt(2) * $standardDeviation * GaussianDistribution::inverseErrorFunctionCumulativeTo(2 * $x); diff --git a/src/Numerics/Matrix.php b/src/Numerics/Matrix.php index 15ccda7..0397e52 100644 --- a/src/Numerics/Matrix.php +++ b/src/Numerics/Matrix.php @@ -8,11 +8,11 @@ class Matrix { public const ERROR_TOLERANCE = 0.0000000001; - public function __construct(private $_rowCount = 0, private $_columnCount = 0, private $_matrixRowData = null) + public function __construct(private int $_rowCount = 0, private int $_columnCount = 0, private $_matrixRowData = null) { } - public static function fromColumnValues($rows, $columns, $columnValues) + public static function fromColumnValues(int $rows, int $columns, array $columnValues): self { $data = []; $result = new Matrix($rows, $columns, $data); @@ -28,7 +28,7 @@ class Matrix return $result; } - public static function fromRowsColumns(...$args) + public static function fromRowsColumns(...$args): Matrix { $rows = $args[0]; $cols = $args[1]; @@ -44,27 +44,27 @@ class Matrix return $result; } - public function getRowCount() + public function getRowCount(): int { return $this->_rowCount; } - public function getColumnCount() + public function getColumnCount(): int { return $this->_columnCount; } - public function getValue($row, $col) + public function getValue(int $row, int $col): float|int { return $this->_matrixRowData[$row][$col]; } - public function setValue($row, $col, $value) + public function setValue(int $row, int $col, float|int $value) { $this->_matrixRowData[$row][$col] = $value; } - public function getTranspose() + public function getTranspose(): self { // Just flip everything $transposeMatrix = []; @@ -84,12 +84,12 @@ class Matrix return new Matrix($this->_columnCount, $this->_rowCount, $transposeMatrix); } - private function isSquare() + private function isSquare(): bool { return ($this->_rowCount == $this->_columnCount) && ($this->_rowCount > 0); } - public function getDeterminant() + public function getDeterminant(): float { // Basic argument checking if (! $this->isSquare()) { @@ -134,7 +134,7 @@ class Matrix return $result; } - public function getAdjugate() + public function getAdjugate(): SquareMatrix|self { if (! $this->isSquare()) { throw new Exception('Matrix must be square!'); @@ -171,7 +171,7 @@ class Matrix return new Matrix($this->_columnCount, $this->_rowCount, $result); } - public function getInverse() + public function getInverse(): Matrix|SquareMatrix { if (($this->_rowCount == 1) && ($this->_columnCount == 1)) { return new SquareMatrix(1.0 / $this->_matrixRowData[0][0]); @@ -185,7 +185,7 @@ class Matrix return self::scalarMultiply($determinantInverse, $adjugate); } - public static function scalarMultiply($scalarValue, $matrix) + public static function scalarMultiply(float|int $scalarValue, Matrix $matrix): Matrix { $rows = $matrix->getRowCount(); $columns = $matrix->getColumnCount(); @@ -200,7 +200,7 @@ class Matrix return new Matrix($rows, $columns, $newValues); } - public static function add($left, $right) + public static function add(Matrix $left, Matrix $right): Matrix { if ( ($left->getRowCount() != $right->getRowCount()) @@ -226,7 +226,7 @@ class Matrix return new Matrix($left->getRowCount(), $right->getColumnCount(), $resultMatrix); } - public static function multiply($left, $right) + public static function multiply(Matrix $left, Matrix $right): Matrix { // Just your standard matrix multiplication. // See http://en.wikipedia.org/wiki/Matrix_multiplication for details @@ -258,7 +258,7 @@ class Matrix return new Matrix($resultRows, $resultColumns, $resultMatrix); } - private function getMinorMatrix($rowToRemove, $columnToRemove) + private function getMinorMatrix(int $rowToRemove, int $columnToRemove): Matrix { // See http://en.wikipedia.org/wiki/Minor_(linear_algebra) diff --git a/src/Numerics/Range.php b/src/Numerics/Range.php index 40937fc..47e1256 100644 --- a/src/Numerics/Range.php +++ b/src/Numerics/Range.php @@ -9,11 +9,11 @@ use Exception; class Range { - private $_min; + private int $_min; - private $_max; + private int $_max; - public function __construct($min, $max) + public function __construct(int $min, int $max) { if ($min > $max) { throw new Exception('min > max'); @@ -23,39 +23,39 @@ class Range $this->_max = $max; } - public function getMin() + public function getMin(): int { return $this->_min; } - public function getMax() + public function getMax(): int { return $this->_max; } - protected static function create($min, $max) + protected static function create(int $min, int $max): self { return new Range($min, $max); } // REVIEW: It's probably bad form to have access statics via a derived class, but the syntax looks better :-) - public static function inclusive($min, $max) + public static function inclusive(int $min, int $max): self { return static::create($min, $max); } - public static function exactly($value) + public static function exactly(int $value): self { return static::create($value, $value); } - public static function atLeast($minimumValue) + public static function atLeast(int $minimumValue): self { return static::create($minimumValue, PHP_INT_MAX); } - public function isInRange($value) + public function isInRange(int $value): bool { return ($this->_min <= $value) && ($value <= $this->_max); } diff --git a/src/PlayersRange.php b/src/PlayersRange.php index c1d9949..ea465f6 100644 --- a/src/PlayersRange.php +++ b/src/PlayersRange.php @@ -6,7 +6,7 @@ use DNW\Skills\Numerics\Range; class PlayersRange extends Range { - protected static function create($min, $max) + protected static function create(int $min, int $max): self { return new PlayersRange($min, $max); } diff --git a/src/TeamsRange.php b/src/TeamsRange.php index d35b8a9..19261a4 100644 --- a/src/TeamsRange.php +++ b/src/TeamsRange.php @@ -6,7 +6,7 @@ use DNW\Skills\Numerics\Range; class TeamsRange extends Range { - protected static function create($min, $max) + protected static function create(int $min, int $max): self { return new TeamsRange($min, $max); } diff --git a/src/TrueSkill/FactorGraphTrueSkillCalculator.php b/src/TrueSkill/FactorGraphTrueSkillCalculator.php index 076641d..e945f5f 100644 --- a/src/TrueSkill/FactorGraphTrueSkillCalculator.php +++ b/src/TrueSkill/FactorGraphTrueSkillCalculator.php @@ -14,6 +14,7 @@ use DNW\Skills\RankSorter; use DNW\Skills\SkillCalculator; use DNW\Skills\SkillCalculatorSupportedOptions; use DNW\Skills\TeamsRange; +use DNW\Skills\RatingContainer; /** * Calculates TrueSkill using a full factor graph. @@ -27,7 +28,7 @@ class FactorGraphTrueSkillCalculator extends SkillCalculator public function calculateNewRatings(GameInfo $gameInfo, array $teams, - array $teamRanks) + array $teamRanks): RatingContainer { Guard::argumentNotNull($gameInfo, 'gameInfo'); $this->validateTeamCountAndPlayersCountPerTeam($teams); diff --git a/src/TrueSkill/Layers/PlayerSkillsToPerformancesLayer.php b/src/TrueSkill/Layers/PlayerSkillsToPerformancesLayer.php index b7190c2..ec8e9f2 100644 --- a/src/TrueSkill/Layers/PlayerSkillsToPerformancesLayer.php +++ b/src/TrueSkill/Layers/PlayerSkillsToPerformancesLayer.php @@ -6,6 +6,7 @@ use DNW\Skills\FactorGraphs\KeyedVariable; use DNW\Skills\FactorGraphs\ScheduleStep; use DNW\Skills\Numerics\BasicMath; use DNW\Skills\TrueSkill\Factors\GaussianLikelihoodFactor; +use DNW\Skills\FactorGraphs\ScheduleSequence; class PlayerSkillsToPerformancesLayer extends TrueSkillFactorGraphLayer { @@ -30,7 +31,7 @@ class PlayerSkillsToPerformancesLayer extends TrueSkillFactorGraphLayer } } - private function createLikelihood(KeyedVariable $playerSkill, KeyedVariable $playerPerformance) + private function createLikelihood(KeyedVariable $playerSkill, KeyedVariable $playerPerformance): GaussianLikelihoodFactor { return new GaussianLikelihoodFactor( BasicMath::square($this->getParentFactorGraph()->getGameInfo()->getBeta()), @@ -44,7 +45,7 @@ class PlayerSkillsToPerformancesLayer extends TrueSkillFactorGraphLayer return $this->getParentFactorGraph()->getVariableFactory()->createKeyedVariable($key, $key."'s performance"); } - public function createPriorSchedule() + public function createPriorSchedule(): ScheduleSequence { $localFactors = $this->getLocalFactors(); @@ -55,7 +56,7 @@ class PlayerSkillsToPerformancesLayer extends TrueSkillFactorGraphLayer 'All skill to performance sending'); } - public function createPosteriorSchedule() + public function createPosteriorSchedule(): ScheduleSequence { $localFactors = $this->getLocalFactors(); diff --git a/src/TrueSkill/TrueSkillFactorGraph.php b/src/TrueSkill/TrueSkillFactorGraph.php index 8ba0ade..dc6cf79 100644 --- a/src/TrueSkill/TrueSkillFactorGraph.php +++ b/src/TrueSkill/TrueSkillFactorGraph.php @@ -110,7 +110,7 @@ class TrueSkillFactorGraph extends FactorGraph return new ScheduleSequence('Full schedule', $fullSchedule); } - public function getUpdatedRatings() + public function getUpdatedRatings(): RatingContainer { $result = new RatingContainer(); diff --git a/src/TrueSkill/TruncatedGaussianCorrectionFunctions.php b/src/TrueSkill/TruncatedGaussianCorrectionFunctions.php index 13f4f47..b3f0125 100644 --- a/src/TrueSkill/TruncatedGaussianCorrectionFunctions.php +++ b/src/TrueSkill/TruncatedGaussianCorrectionFunctions.php @@ -17,9 +17,8 @@ class TruncatedGaussianCorrectionFunctions * @param $teamPerformanceDifference * @param number $drawMargin In the paper, it's referred to as just "ε". * @param $c - * @return float */ - public static function vExceedsMarginScaled($teamPerformanceDifference, $drawMargin, $c) + public static function vExceedsMarginScaled($teamPerformanceDifference, float|int $drawMargin, $c): float { return self::vExceedsMargin($teamPerformanceDifference / $c, $drawMargin / $c); } @@ -44,14 +43,13 @@ class TruncatedGaussianCorrectionFunctions * @param $teamPerformanceDifference * @param $drawMargin * @param $c - * @return float */ - public static function wExceedsMarginScaled($teamPerformanceDifference, $drawMargin, $c) + public static function wExceedsMarginScaled($teamPerformanceDifference, float|int $drawMargin, $c): float { return self::wExceedsMargin($teamPerformanceDifference / $c, $drawMargin / $c); } - public static function wExceedsMargin($teamPerformanceDifference, $drawMargin) + public static function wExceedsMargin($teamPerformanceDifference, $drawMargin): float { $denominator = GaussianDistribution::cumulativeTo($teamPerformanceDifference - $drawMargin); @@ -69,13 +67,13 @@ class TruncatedGaussianCorrectionFunctions } // the additive correction of a double-sided truncated Gaussian with unit variance - public static function vWithinMarginScaled($teamPerformanceDifference, $drawMargin, $c) + public static function vWithinMarginScaled($teamPerformanceDifference, float|int $drawMargin, $c): float { return self::vWithinMargin($teamPerformanceDifference / $c, $drawMargin / $c); } // from F#: - public static function vWithinMargin($teamPerformanceDifference, $drawMargin) + public static function vWithinMargin($teamPerformanceDifference, $drawMargin): float { $teamPerformanceDifferenceAbsoluteValue = abs($teamPerformanceDifference); $denominator = @@ -101,13 +99,13 @@ class TruncatedGaussianCorrectionFunctions } // the multiplicative correction of a double-sided truncated Gaussian with unit variance - public static function wWithinMarginScaled($teamPerformanceDifference, $drawMargin, $c) + public static function wWithinMarginScaled($teamPerformanceDifference, float|int $drawMargin, $c): float { return self::wWithinMargin($teamPerformanceDifference / $c, $drawMargin / $c); } // From F#: - public static function wWithinMargin($teamPerformanceDifference, $drawMargin) + public static function wWithinMargin($teamPerformanceDifference, float|int $drawMargin) { $teamPerformanceDifferenceAbsoluteValue = abs($teamPerformanceDifference); $denominator = GaussianDistribution::cumulativeTo($drawMargin - $teamPerformanceDifferenceAbsoluteValue) diff --git a/src/TrueSkill/TwoPlayerTrueSkillCalculator.php b/src/TrueSkill/TwoPlayerTrueSkillCalculator.php index 2331420..8429a53 100644 --- a/src/TrueSkill/TwoPlayerTrueSkillCalculator.php +++ b/src/TrueSkill/TwoPlayerTrueSkillCalculator.php @@ -29,7 +29,7 @@ class TwoPlayerTrueSkillCalculator extends SkillCalculator public function calculateNewRatings(GameInfo $gameInfo, array $teams, - array $teamRanks) + array $teamRanks): RatingContainer { // Basic argument checking Guard::argumentNotNull($gameInfo, 'gameInfo');