From 82ac00df7c5c38a64091393d5e5cdac2e8808059 Mon Sep 17 00:00:00 2001 From: Asriel Camora Date: Fri, 15 Mar 2024 03:32:08 -0700 Subject: [PATCH] Batch allocations, replace all Vector with Vector256 --- Solver/ArenaBuffer.cs | 9 +++-- Solver/Intrinsics.cs | 73 +++++++------------------------------- Solver/MCTS.cs | 22 ++++++------ Solver/NodeScoresBuffer.cs | 53 ++++++++++++--------------- 4 files changed, 49 insertions(+), 108 deletions(-) diff --git a/Solver/ArenaBuffer.cs b/Solver/ArenaBuffer.cs index dda988a..bf32635 100644 --- a/Solver/ArenaBuffer.cs +++ b/Solver/ArenaBuffer.cs @@ -1,5 +1,4 @@ using System.Diagnostics.Contracts; -using System.Numerics; using System.Runtime.CompilerServices; namespace Craftimizer.Solver; @@ -10,11 +9,11 @@ public struct ArenaBuffer // The benchmark reaches 20 at most, but here we have a little leeway just in case. internal const int MaxSize = 32; - internal static readonly int BatchSize = Vector.Count; - internal static readonly int BatchSizeBits = int.Log2(BatchSize); - internal static readonly int BatchSizeMask = BatchSize - 1; + internal const int BatchSize = 8; + internal const int BatchSizeBits = 3; // int.Log2(BatchSize); + internal const int BatchSizeMask = BatchSize - 1; - internal static readonly int BatchCount = MaxSize / BatchSize; + internal const int BatchCount = MaxSize / BatchSize; } // Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs diff --git a/Solver/Intrinsics.cs b/Solver/Intrinsics.cs index 6eca3b9..f4ae781 100644 --- a/Solver/Intrinsics.cs +++ b/Solver/Intrinsics.cs @@ -23,7 +23,7 @@ internal static class Intrinsics [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int HMaxIndexScalar(Vector v, int len) + private static int HMaxIndexScalar(Vector256 v, int len) { var m = 0; for (var i = 1; i < len; ++i) @@ -46,10 +46,10 @@ internal static class Intrinsics [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] // https://stackoverflow.com/a/23592221 - private static int HMaxIndexAVX2(Vector v, int len) + private static int HMaxIndexAVX2(Vector256 v, int len) { // Remove NaNs - var vfilt = ClearLastN(v.AsVector256(), len); + var vfilt = ClearLastN(v, len); // Find max value and broadcast to all lanes var vmax128 = HMax(vfilt); @@ -66,41 +66,11 @@ internal static class Intrinsics [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static int HMaxIndex(Vector v, int len) => + public static int HMaxIndex(Vector256 v, int len) => Avx2.IsSupported ? HMaxIndexAVX2(v, len) : HMaxIndexScalar(v, len); - [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int NthBitSetScalar(uint value, int n) - { - var mask = 0x0000FFFFu; - var size = 16; - var _base = 0; - - if (n++ >= BitOperations.PopCount(value)) - return 32; - - while (size > 0) - { - var count = BitOperations.PopCount(value & mask); - if (n > count) - { - _base += size; - size >>= 1; - mask |= mask << size; - } - else - { - size >>= 1; - mask >>= size; - } - } - - return _base; - } - [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] private static int NthBitSetScalar(ulong value, int n) @@ -131,28 +101,11 @@ internal static class Intrinsics return _base; } - [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int NthBitSetBMI2(uint value, int n) => - BitOperations.TrailingZeroCount(Bmi2.ParallelBitDeposit(1u << n, value)); - [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] private static int NthBitSetBMI2(ulong value, int n) => BitOperations.TrailingZeroCount(Bmi2.X64.ParallelBitDeposit(1ul << n, value)); - [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static int NthBitSet(uint value, int n) - { - if (n >= BitOperations.PopCount(value)) - return 32; - - return Bmi2.IsSupported ? - NthBitSetBMI2(value, n) : - NthBitSetScalar(value, n); - } - [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] public static int NthBitSet(ulong value, int n) @@ -168,17 +121,15 @@ internal static class Intrinsics [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] [SkipLocalsInit] - public static Vector ReciprocalSqrt(Vector data) + public static Vector256 ReciprocalSqrt(Vector256 data) { - if (Avx.IsSupported && Vector.Count >= Vector256.Count) - return Avx.ReciprocalSqrt(data.AsVector256()).AsVector(); + if (Avx.IsSupported && Vector256.Count >= Vector256.Count) + return Avx.ReciprocalSqrt(data); - if (Sse.IsSupported && Vector.Count >= Vector128.Count) - return Sse.ReciprocalSqrt(data.AsVector128()).AsVector(); - - Span result = stackalloc float[Vector.Count]; - for (var i = 0; i < Vector.Count; ++i) - result[i] = MathF.ReciprocalSqrtEstimate(data[i]); - return new(result); + Unsafe.SkipInit(out Vector256 ret); + ref var result = ref Unsafe.As, float>(ref Unsafe.AsRef(in ret)); + for (var i = 0; i < Vector256.Count; ++i) + Unsafe.Add(ref result, i) = MathF.ReciprocalSqrtEstimate(data[i]); + return ret; } } diff --git a/Solver/MCTS.cs b/Solver/MCTS.cs index a22a5ad..b2030c4 100644 --- a/Solver/MCTS.cs +++ b/Solver/MCTS.cs @@ -4,6 +4,7 @@ using System.Diagnostics.Contracts; using System.Numerics; using System.Runtime.CompilerServices; using Node = Craftimizer.Solver.ArenaNode; +using System.Runtime.Intrinsics; namespace Craftimizer.Solver; @@ -66,7 +67,7 @@ public sealed class MCTS private static (int arrayIdx, int subIdx) ChildMaxScore(in NodeScoresBuffer scores) { var length = scores.Count; - var vecLength = Vector.Count; + var vecLength = Vector256.Count; var max = (0, 0); var maxScore = 0f; @@ -74,8 +75,7 @@ public sealed class MCTS { var iterCount = Math.Min(vecLength, length); - ref var chunk = ref scores.Data[i]; - var m = new Vector(chunk.MaxScore.Span); + var m = scores.Data![i].MaxScore; var idx = Intrinsics.HMaxIndex(m, iterCount); @@ -116,12 +116,12 @@ public sealed class MCTS in NodeScoresBuffer scores) { var length = scores.Count; - var vecLength = Vector.Count; + var vecLength = Vector256.Count; var C = MathF.Sqrt(explorationConstant * MathF.Log(parentVisits)); var w = maxScoreWeightingConstant; var W = 1f - w; - var CVector = new Vector(C); + var CVector = Vector256.Create(C); var max = (0, 0); var maxScore = 0f; @@ -129,13 +129,13 @@ public sealed class MCTS { var iterCount = Math.Min(vecLength, length); - ref var chunk = ref scores.Data[i]; - var s = new Vector(chunk.ScoreSum.Span); - var vInt = new Vector(chunk.Visits.Span); - var m = new Vector(chunk.MaxScore.Span); + ref var chunk = ref scores.Data![i]; + var s = chunk.ScoreSum; + var vInt = chunk.Visits; + var m = chunk.MaxScore; - vInt = Vector.Max(vInt, Vector.One); - var v = Vector.ConvertToSingle(vInt); + vInt = Vector256.Max(vInt, Vector256.One); + var v = Vector256.ConvertToSingle(vInt); var exploitation = W * (s / v) + w * m; var exploration = CVector * Intrinsics.ReciprocalSqrt(v); diff --git a/Solver/NodeScoresBuffer.cs b/Solver/NodeScoresBuffer.cs index 1178267..87cc276 100644 --- a/Solver/NodeScoresBuffer.cs +++ b/Solver/NodeScoresBuffer.cs @@ -1,54 +1,45 @@ -using System.Diagnostics.Contracts; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; namespace Craftimizer.Solver; -// Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs public struct NodeScoresBuffer { [StructLayout(LayoutKind.Auto)] - public readonly struct ScoresBatch + public struct ScoresBatch { - public readonly Memory ScoreSum; - public readonly Memory MaxScore; - public readonly Memory Visits; - - public ScoresBatch() - { - ScoreSum = new float[ArenaBuffer.BatchSize]; - MaxScore = new float[ArenaBuffer.BatchSize]; - Visits = new int[ArenaBuffer.BatchSize]; - } + public Vector256 ScoreSum; + public Vector256 MaxScore; + public Vector256 Visits; } - public ScoresBatch[] Data; + public ScoresBatch[]? Data; public int Count { get; private set; } public void Add() { - Data ??= new ScoresBatch[ArenaBuffer.BatchCount]; - - var idx = Count++; - - var (arrayIdx, subIdx) = GetArrayIndex(idx); - - if (subIdx == 0) - Data[arrayIdx] = new(); + Data ??= GC.AllocateUninitializedArray(ArenaBuffer.BatchCount); + var count = Count++; + if ((count & ArenaBuffer.BatchSizeMask) == 0) + Data[count >> ArenaBuffer.BatchSizeBits] = new(); } public readonly void Visit((int arrayIdx, int subIdx) at, float score) { - Data[at.arrayIdx].ScoreSum.Span[at.subIdx] += score; - Data[at.arrayIdx].MaxScore.Span[at.subIdx] = Math.Max(Data[at.arrayIdx].MaxScore.Span[at.subIdx], score); - Data[at.arrayIdx].Visits.Span[at.subIdx]++; + ref var batch = ref Data![at.arrayIdx]; + batch.ScoreSum.At(at.subIdx) += score; + ref var maxScore = ref batch.MaxScore.At(at.subIdx); + maxScore = Math.Max(maxScore, score); + batch.Visits.At(at.subIdx)++; } public readonly int GetVisits((int arrayIdx, int subIdx) at) => - Data[at.arrayIdx].Visits.Span[at.subIdx]; - - [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static (int arrayIdx, int subIdx) GetArrayIndex(int idx) => - (idx >> ArenaBuffer.BatchSizeBits, idx & ArenaBuffer.BatchSizeMask); + Data![at.arrayIdx].Visits[at.subIdx]; +} + +internal static class VectorUtils +{ + public static ref T At(this ref Vector256 me, int idx) => + ref Unsafe.Add(ref Unsafe.As, T>(ref me), idx); }