From e46d1b20fa920dc8b701a535dc7847e1225d490a Mon Sep 17 00:00:00 2001 From: Asriel Camora Date: Tue, 4 Jul 2023 08:36:41 +0200 Subject: [PATCH] Use rqrt in child score calculation Changes sqrt -> rsqrt and div->mul. Reduces latency by 23 -> 8 cycles. --- Solver/Crafty/Intrinsics.cs | 15 +++++++++++++++ Solver/Crafty/Solver.cs | 9 ++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/Solver/Crafty/Intrinsics.cs b/Solver/Crafty/Intrinsics.cs index 7894ae3..a1baef0 100644 --- a/Solver/Crafty/Intrinsics.cs +++ b/Solver/Crafty/Intrinsics.cs @@ -110,6 +110,21 @@ internal static class Intrinsics NthBitSetScalar(value, n); } + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector ReciprocalSqrt(Vector data) + { + if (Avx.IsSupported && Vector.Count >= Vector256.Count) + return Avx.ReciprocalSqrt(data.AsVector256()).AsVector(); + + if (Sse.IsSupported && Vector.Count >= Vector128.Count) + return Sse.ReciprocalSqrt(data.AsVector128()).AsVector(); + + Span result = stackalloc float[Vector.Count]; + for (var i = 0; i < Vector.Count; ++i) + result[i] = MathF.ReciprocalSqrtEstimate(data[i]); + return new(result); + } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static void CASMax(ref float location, float newValue) diff --git a/Solver/Crafty/Solver.cs b/Solver/Crafty/Solver.cs index df2864c..e10ac48 100644 --- a/Solver/Crafty/Solver.cs +++ b/Solver/Crafty/Solver.cs @@ -101,7 +101,7 @@ public class Solver var length = children.Length; var vecLength = Vector.Count; - var C = Config.ExplorationConstant * MathF.Log(parentVisits); + var C = MathF.Sqrt(Config.ExplorationConstant * MathF.Log(parentVisits)); var w = Config.MaxScoreWeightingConstant; var W = 1f - w; var CVector = new Vector(C); @@ -126,8 +126,11 @@ public class Solver maxScores[j] = node.MaxScore; } - var exploitation = (W * (new Vector(scoreSums) / new Vector(visits))) + (w * new Vector(maxScores)); - var exploration = Vector.SquareRoot(CVector / new Vector(visits)); + var s = new Vector(scoreSums); + var m = new Vector(maxScores); + var v = new Vector(visits); + var exploitation = (W * (s / v)) + (w * m); + var exploration = CVector * Intrinsics.ReciprocalSqrt(v); var evalScores = exploitation + exploration; var idx = Intrinsics.HMaxIndex(evalScores, iterCount);