More type checks

This commit is contained in:
2023-08-01 12:13:24 +00:00
parent 068b6f18aa
commit d5bba04f4f
15 changed files with 72 additions and 73 deletions

View File

@ -2,7 +2,6 @@
namespace DNW\Skills\FactorGraphs;
// edit this
abstract class FactorGraphLayer
{
private array $_localFactors = [];
@ -47,7 +46,7 @@ abstract class FactorGraphLayer
$this->_inputVariablesGroups = $value;
}
protected function scheduleSequence(array $itemsToSequence, $name)
protected function scheduleSequence(array $itemsToSequence, $name): ScheduleSequence
{
return new ScheduleSequence($name, $itemsToSequence);
}

View File

@ -8,7 +8,7 @@ abstract class Schedule implements \Stringable
{
}
abstract public function visit($depth = -1, $maxDepth = 0);
abstract public function visit(int $depth = -1, int $maxDepth = 0);
public function __toString(): string
{

View File

@ -9,7 +9,7 @@ class ScheduleLoop extends Schedule
parent::__construct($name);
}
public function visit($depth = -1, $maxDepth = 0)
public function visit(int $depth = -1, int $maxDepth = 0)
{
$totalIterations = 1;
$delta = $this->_scheduleToLoop->visit($depth + 1, $maxDepth);

View File

@ -4,22 +4,22 @@ namespace DNW\Skills\FactorGraphs;
class Variable implements \Stringable
{
private $_name;
private string $_name;
private $_value;
private mixed $_value;
public function __construct($name, private $_prior)
public function __construct(string $name, private mixed $_prior)
{
$this->_name = 'Variable['.$name.']';
$this->resetToPrior();
}
public function getValue()
public function getValue(): mixed
{
return $this->_value;
}
public function setValue($value)
public function setValue(mixed $value): void
{
$this->_value = $value;
}

View File

@ -16,7 +16,7 @@ class BasicMath
* @param number $x Value to square (x)
* @return number The squared value (x^2)
*/
public static function square($x)
public static function square($x): float|int
{
return $x * $x;
}
@ -28,7 +28,7 @@ class BasicMath
* @param callable $callback The function to apply to each array element before summing.
* @return number The sum.
*/
public static function sum(array $itemsToSum, $callback)
public static function sum(array $itemsToSum, \Closure $callback): float|int
{
$mappedItems = array_map($callback, $itemsToSum);

View File

@ -18,7 +18,7 @@ class GaussianDistribution implements \Stringable
private $_variance;
public function __construct(private $_mean = 0.0, private $_standardDeviation = 1.0)
public function __construct(private float $_mean = 0.0, private float $_standardDeviation = 1.0)
{
$this->_variance = BasicMath::square($_standardDeviation);
@ -32,32 +32,32 @@ class GaussianDistribution implements \Stringable
}
}
public function getMean()
public function getMean(): float
{
return $this->_mean;
}
public function getVariance()
public function getVariance(): float
{
return $this->_variance;
}
public function getStandardDeviation()
public function getStandardDeviation(): float
{
return $this->_standardDeviation;
}
public function getPrecision()
public function getPrecision(): float
{
return $this->_precision;
}
public function getPrecisionMean()
public function getPrecisionMean(): float
{
return $this->_precisionMean;
}
public function getNormalizationConstant()
public function getNormalizationConstant(): float
{
// Great derivation of this is at http://www.astro.psu.edu/~mce/A451_2/A451/downloads/notes0.pdf
return 1.0 / (sqrt(2 * M_PI) * $this->_standardDeviation);
@ -75,7 +75,7 @@ class GaussianDistribution implements \Stringable
return $result;
}
public static function fromPrecisionMean($precisionMean, $precision)
public static function fromPrecisionMean(float $precisionMean, float $precision): self
{
$result = new GaussianDistribution();
$result->_precision = $precision;
@ -96,13 +96,13 @@ class GaussianDistribution implements \Stringable
// For details, see http://www.tina-vision.net/tina-knoppix/tina-memo/2003-003.pdf
// for multiplication, the precision mean ones are easier to write :)
public static function multiply(GaussianDistribution $left, GaussianDistribution $right)
public static function multiply(GaussianDistribution $left, GaussianDistribution $right): self
{
return GaussianDistribution::fromPrecisionMean($left->_precisionMean + $right->_precisionMean, $left->_precision + $right->_precision);
}
// Computes the absolute difference between two Gaussians
public static function absoluteDifference(GaussianDistribution $left, GaussianDistribution $right)
public static function absoluteDifference(GaussianDistribution $left, GaussianDistribution $right): float
{
return max(
abs($left->_precisionMean - $right->_precisionMean),
@ -111,12 +111,12 @@ class GaussianDistribution implements \Stringable
}
// Computes the absolute difference between two Gaussians
public static function subtract(GaussianDistribution $left, GaussianDistribution $right)
public static function subtract(GaussianDistribution $left, GaussianDistribution $right): float
{
return GaussianDistribution::absoluteDifference($left, $right);
}
public static function logProductNormalization(GaussianDistribution $left, GaussianDistribution $right)
public static function logProductNormalization(GaussianDistribution $left, GaussianDistribution $right): float
{
if (($left->_precision == 0) || ($right->_precision == 0)) {
return 0;
@ -130,7 +130,7 @@ class GaussianDistribution implements \Stringable
return -$logSqrt2Pi - (log($varianceSum) / 2.0) - (BasicMath::square($meanDifference) / (2.0 * $varianceSum));
}
public static function divide(GaussianDistribution $numerator, GaussianDistribution $denominator)
public static function divide(GaussianDistribution $numerator, GaussianDistribution $denominator): self
{
return GaussianDistribution::fromPrecisionMean(
$numerator->_precisionMean - $denominator->_precisionMean,
@ -138,7 +138,7 @@ class GaussianDistribution implements \Stringable
);
}
public static function logRatioNormalization(GaussianDistribution $numerator, GaussianDistribution $denominator)
public static function logRatioNormalization(GaussianDistribution $numerator, GaussianDistribution $denominator): float
{
if (($numerator->_precision == 0) || ($denominator->_precision == 0)) {
return 0;
@ -153,7 +153,7 @@ class GaussianDistribution implements \Stringable
BasicMath::square($meanDifference) / (2 * $varianceDifference);
}
public static function at($x, $mean = 0.0, $standardDeviation = 1.0)
public static function at(float $x, float $mean = 0.0, float $standardDeviation = 1.0): float
{
// See http://mathworld.wolfram.com/NormalDistribution.html
// 1 -(x-mean)^2 / (2*stdDev^2)
@ -166,7 +166,7 @@ class GaussianDistribution implements \Stringable
return $multiplier * $expPart;
}
public static function cumulativeTo($x, $mean = 0.0, $standardDeviation = 1.0)
public static function cumulativeTo(float $x, float $mean = 0.0, float $standardDeviation = 1.0): float
{
$invsqrt2 = -0.707106781186547524400844362104;
$result = GaussianDistribution::errorFunctionCumulativeTo($invsqrt2 * $x);
@ -174,7 +174,7 @@ class GaussianDistribution implements \Stringable
return 0.5 * $result;
}
private static function errorFunctionCumulativeTo($x)
private static function errorFunctionCumulativeTo($x): float
{
// Derived from page 265 of Numerical Recipes 3rd Edition
$z = abs($x);
@ -227,7 +227,7 @@ class GaussianDistribution implements \Stringable
return ($x >= 0.0) ? $ans : (2.0 - $ans);
}
private static function inverseErrorFunctionCumulativeTo($p)
private static function inverseErrorFunctionCumulativeTo(float $p): float
{
// From page 265 of numerical recipes
@ -250,7 +250,7 @@ class GaussianDistribution implements \Stringable
return ($p < 1.0) ? $x : -$x;
}
public static function inverseCumulativeTo($x, $mean = 0.0, $standardDeviation = 1.0)
public static function inverseCumulativeTo(float $x, float $mean = 0.0, float $standardDeviation = 1.0): float
{
// From numerical recipes, page 320
return $mean - sqrt(2) * $standardDeviation * GaussianDistribution::inverseErrorFunctionCumulativeTo(2 * $x);

View File

@ -8,11 +8,11 @@ class Matrix
{
public const ERROR_TOLERANCE = 0.0000000001;
public function __construct(private $_rowCount = 0, private $_columnCount = 0, private $_matrixRowData = null)
public function __construct(private int $_rowCount = 0, private int $_columnCount = 0, private $_matrixRowData = null)
{
}
public static function fromColumnValues($rows, $columns, $columnValues)
public static function fromColumnValues(int $rows, int $columns, array $columnValues): self
{
$data = [];
$result = new Matrix($rows, $columns, $data);
@ -28,7 +28,7 @@ class Matrix
return $result;
}
public static function fromRowsColumns(...$args)
public static function fromRowsColumns(...$args): Matrix
{
$rows = $args[0];
$cols = $args[1];
@ -44,27 +44,27 @@ class Matrix
return $result;
}
public function getRowCount()
public function getRowCount(): int
{
return $this->_rowCount;
}
public function getColumnCount()
public function getColumnCount(): int
{
return $this->_columnCount;
}
public function getValue($row, $col)
public function getValue(int $row, int $col): float|int
{
return $this->_matrixRowData[$row][$col];
}
public function setValue($row, $col, $value)
public function setValue(int $row, int $col, float|int $value)
{
$this->_matrixRowData[$row][$col] = $value;
}
public function getTranspose()
public function getTranspose(): self
{
// Just flip everything
$transposeMatrix = [];
@ -84,12 +84,12 @@ class Matrix
return new Matrix($this->_columnCount, $this->_rowCount, $transposeMatrix);
}
private function isSquare()
private function isSquare(): bool
{
return ($this->_rowCount == $this->_columnCount) && ($this->_rowCount > 0);
}
public function getDeterminant()
public function getDeterminant(): float
{
// Basic argument checking
if (! $this->isSquare()) {
@ -134,7 +134,7 @@ class Matrix
return $result;
}
public function getAdjugate()
public function getAdjugate(): SquareMatrix|self
{
if (! $this->isSquare()) {
throw new Exception('Matrix must be square!');
@ -171,7 +171,7 @@ class Matrix
return new Matrix($this->_columnCount, $this->_rowCount, $result);
}
public function getInverse()
public function getInverse(): Matrix|SquareMatrix
{
if (($this->_rowCount == 1) && ($this->_columnCount == 1)) {
return new SquareMatrix(1.0 / $this->_matrixRowData[0][0]);
@ -185,7 +185,7 @@ class Matrix
return self::scalarMultiply($determinantInverse, $adjugate);
}
public static function scalarMultiply($scalarValue, $matrix)
public static function scalarMultiply(float|int $scalarValue, Matrix $matrix): Matrix
{
$rows = $matrix->getRowCount();
$columns = $matrix->getColumnCount();
@ -200,7 +200,7 @@ class Matrix
return new Matrix($rows, $columns, $newValues);
}
public static function add($left, $right)
public static function add(Matrix $left, Matrix $right): Matrix
{
if (
($left->getRowCount() != $right->getRowCount())
@ -226,7 +226,7 @@ class Matrix
return new Matrix($left->getRowCount(), $right->getColumnCount(), $resultMatrix);
}
public static function multiply($left, $right)
public static function multiply(Matrix $left, Matrix $right): Matrix
{
// Just your standard matrix multiplication.
// See http://en.wikipedia.org/wiki/Matrix_multiplication for details
@ -258,7 +258,7 @@ class Matrix
return new Matrix($resultRows, $resultColumns, $resultMatrix);
}
private function getMinorMatrix($rowToRemove, $columnToRemove)
private function getMinorMatrix(int $rowToRemove, int $columnToRemove): Matrix
{
// See http://en.wikipedia.org/wiki/Minor_(linear_algebra)

View File

@ -9,11 +9,11 @@ use Exception;
class Range
{
private $_min;
private int $_min;
private $_max;
private int $_max;
public function __construct($min, $max)
public function __construct(int $min, int $max)
{
if ($min > $max) {
throw new Exception('min > max');
@ -23,39 +23,39 @@ class Range
$this->_max = $max;
}
public function getMin()
public function getMin(): int
{
return $this->_min;
}
public function getMax()
public function getMax(): int
{
return $this->_max;
}
protected static function create($min, $max)
protected static function create(int $min, int $max): self
{
return new Range($min, $max);
}
// REVIEW: It's probably bad form to have access statics via a derived class, but the syntax looks better :-)
public static function inclusive($min, $max)
public static function inclusive(int $min, int $max): self
{
return static::create($min, $max);
}
public static function exactly($value)
public static function exactly(int $value): self
{
return static::create($value, $value);
}
public static function atLeast($minimumValue)
public static function atLeast(int $minimumValue): self
{
return static::create($minimumValue, PHP_INT_MAX);
}
public function isInRange($value)
public function isInRange(int $value): bool
{
return ($this->_min <= $value) && ($value <= $this->_max);
}

View File

@ -6,7 +6,7 @@ use DNW\Skills\Numerics\Range;
class PlayersRange extends Range
{
protected static function create($min, $max)
protected static function create(int $min, int $max): self
{
return new PlayersRange($min, $max);
}

View File

@ -6,7 +6,7 @@ use DNW\Skills\Numerics\Range;
class TeamsRange extends Range
{
protected static function create($min, $max)
protected static function create(int $min, int $max): self
{
return new TeamsRange($min, $max);
}

View File

@ -14,6 +14,7 @@ use DNW\Skills\RankSorter;
use DNW\Skills\SkillCalculator;
use DNW\Skills\SkillCalculatorSupportedOptions;
use DNW\Skills\TeamsRange;
use DNW\Skills\RatingContainer;
/**
* Calculates TrueSkill using a full factor graph.
@ -27,7 +28,7 @@ class FactorGraphTrueSkillCalculator extends SkillCalculator
public function calculateNewRatings(GameInfo $gameInfo,
array $teams,
array $teamRanks)
array $teamRanks): RatingContainer
{
Guard::argumentNotNull($gameInfo, 'gameInfo');
$this->validateTeamCountAndPlayersCountPerTeam($teams);

View File

@ -6,6 +6,7 @@ use DNW\Skills\FactorGraphs\KeyedVariable;
use DNW\Skills\FactorGraphs\ScheduleStep;
use DNW\Skills\Numerics\BasicMath;
use DNW\Skills\TrueSkill\Factors\GaussianLikelihoodFactor;
use DNW\Skills\FactorGraphs\ScheduleSequence;
class PlayerSkillsToPerformancesLayer extends TrueSkillFactorGraphLayer
{
@ -30,7 +31,7 @@ class PlayerSkillsToPerformancesLayer extends TrueSkillFactorGraphLayer
}
}
private function createLikelihood(KeyedVariable $playerSkill, KeyedVariable $playerPerformance)
private function createLikelihood(KeyedVariable $playerSkill, KeyedVariable $playerPerformance): GaussianLikelihoodFactor
{
return new GaussianLikelihoodFactor(
BasicMath::square($this->getParentFactorGraph()->getGameInfo()->getBeta()),
@ -44,7 +45,7 @@ class PlayerSkillsToPerformancesLayer extends TrueSkillFactorGraphLayer
return $this->getParentFactorGraph()->getVariableFactory()->createKeyedVariable($key, $key."'s performance");
}
public function createPriorSchedule()
public function createPriorSchedule(): ScheduleSequence
{
$localFactors = $this->getLocalFactors();
@ -55,7 +56,7 @@ class PlayerSkillsToPerformancesLayer extends TrueSkillFactorGraphLayer
'All skill to performance sending');
}
public function createPosteriorSchedule()
public function createPosteriorSchedule(): ScheduleSequence
{
$localFactors = $this->getLocalFactors();

View File

@ -110,7 +110,7 @@ class TrueSkillFactorGraph extends FactorGraph
return new ScheduleSequence('Full schedule', $fullSchedule);
}
public function getUpdatedRatings()
public function getUpdatedRatings(): RatingContainer
{
$result = new RatingContainer();

View File

@ -17,9 +17,8 @@ class TruncatedGaussianCorrectionFunctions
* @param $teamPerformanceDifference
* @param number $drawMargin In the paper, it's referred to as just "ε".
* @param $c
* @return float
*/
public static function vExceedsMarginScaled($teamPerformanceDifference, $drawMargin, $c)
public static function vExceedsMarginScaled($teamPerformanceDifference, float|int $drawMargin, $c): float
{
return self::vExceedsMargin($teamPerformanceDifference / $c, $drawMargin / $c);
}
@ -44,14 +43,13 @@ class TruncatedGaussianCorrectionFunctions
* @param $teamPerformanceDifference
* @param $drawMargin
* @param $c
* @return float
*/
public static function wExceedsMarginScaled($teamPerformanceDifference, $drawMargin, $c)
public static function wExceedsMarginScaled($teamPerformanceDifference, float|int $drawMargin, $c): float
{
return self::wExceedsMargin($teamPerformanceDifference / $c, $drawMargin / $c);
}
public static function wExceedsMargin($teamPerformanceDifference, $drawMargin)
public static function wExceedsMargin($teamPerformanceDifference, $drawMargin): float
{
$denominator = GaussianDistribution::cumulativeTo($teamPerformanceDifference - $drawMargin);
@ -69,13 +67,13 @@ class TruncatedGaussianCorrectionFunctions
}
// the additive correction of a double-sided truncated Gaussian with unit variance
public static function vWithinMarginScaled($teamPerformanceDifference, $drawMargin, $c)
public static function vWithinMarginScaled($teamPerformanceDifference, float|int $drawMargin, $c): float
{
return self::vWithinMargin($teamPerformanceDifference / $c, $drawMargin / $c);
}
// from F#:
public static function vWithinMargin($teamPerformanceDifference, $drawMargin)
public static function vWithinMargin($teamPerformanceDifference, $drawMargin): float
{
$teamPerformanceDifferenceAbsoluteValue = abs($teamPerformanceDifference);
$denominator =
@ -101,13 +99,13 @@ class TruncatedGaussianCorrectionFunctions
}
// the multiplicative correction of a double-sided truncated Gaussian with unit variance
public static function wWithinMarginScaled($teamPerformanceDifference, $drawMargin, $c)
public static function wWithinMarginScaled($teamPerformanceDifference, float|int $drawMargin, $c): float
{
return self::wWithinMargin($teamPerformanceDifference / $c, $drawMargin / $c);
}
// From F#:
public static function wWithinMargin($teamPerformanceDifference, $drawMargin)
public static function wWithinMargin($teamPerformanceDifference, float|int $drawMargin)
{
$teamPerformanceDifferenceAbsoluteValue = abs($teamPerformanceDifference);
$denominator = GaussianDistribution::cumulativeTo($drawMargin - $teamPerformanceDifferenceAbsoluteValue)

View File

@ -29,7 +29,7 @@ class TwoPlayerTrueSkillCalculator extends SkillCalculator
public function calculateNewRatings(GameInfo $gameInfo,
array $teams,
array $teamRanks)
array $teamRanks): RatingContainer
{
// Basic argument checking
Guard::argumentNotNull($gameInfo, 'gameInfo');