Use rqrt in child score calculation

Changes sqrt -> rsqrt and div->mul. Reduces latency by 23 -> 8 cycles.
This commit is contained in:
Asriel Camora
2023-07-04 08:36:41 +02:00
parent 76853e2f0d
commit e46d1b20fa
2 changed files with 21 additions and 3 deletions
+15
View File
@@ -110,6 +110,21 @@ internal static class Intrinsics
NthBitSetScalar(value, n); NthBitSetScalar(value, n);
} }
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector<float> ReciprocalSqrt(Vector<float> data)
{
if (Avx.IsSupported && Vector<float>.Count >= Vector256<float>.Count)
return Avx.ReciprocalSqrt(data.AsVector256()).AsVector();
if (Sse.IsSupported && Vector<float>.Count >= Vector128<float>.Count)
return Sse.ReciprocalSqrt(data.AsVector128()).AsVector();
Span<float> result = stackalloc float[Vector<float>.Count];
for (var i = 0; i < Vector<float>.Count; ++i)
result[i] = MathF.ReciprocalSqrtEstimate(data[i]);
return new(result);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void CASMax(ref float location, float newValue) public static void CASMax(ref float location, float newValue)
+6 -3
View File
@@ -101,7 +101,7 @@ public class Solver
var length = children.Length; var length = children.Length;
var vecLength = Vector<float>.Count; var vecLength = Vector<float>.Count;
var C = Config.ExplorationConstant * MathF.Log(parentVisits); var C = MathF.Sqrt(Config.ExplorationConstant * MathF.Log(parentVisits));
var w = Config.MaxScoreWeightingConstant; var w = Config.MaxScoreWeightingConstant;
var W = 1f - w; var W = 1f - w;
var CVector = new Vector<float>(C); var CVector = new Vector<float>(C);
@@ -126,8 +126,11 @@ public class Solver
maxScores[j] = node.MaxScore; maxScores[j] = node.MaxScore;
} }
var exploitation = (W * (new Vector<float>(scoreSums) / new Vector<float>(visits))) + (w * new Vector<float>(maxScores)); var s = new Vector<float>(scoreSums);
var exploration = Vector.SquareRoot(CVector / new Vector<float>(visits)); var m = new Vector<float>(maxScores);
var v = new Vector<float>(visits);
var exploitation = (W * (s / v)) + (w * m);
var exploration = CVector * Intrinsics.ReciprocalSqrt(v);
var evalScores = exploitation + exploration; var evalScores = exploitation + exploration;
var idx = Intrinsics.HMaxIndex(evalScores, iterCount); var idx = Intrinsics.HMaxIndex(evalScores, iterCount);