Batch allocations, replace all Vector<T> with Vector256

This commit is contained in:
Asriel Camora
2024-03-15 03:32:08 -07:00
parent 94ea1f84de
commit 82ac00df7c
4 changed files with 49 additions and 108 deletions
+4 -5
View File
@@ -1,5 +1,4 @@
using System.Diagnostics.Contracts; using System.Diagnostics.Contracts;
using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Craftimizer.Solver; namespace Craftimizer.Solver;
@@ -10,11 +9,11 @@ public struct ArenaBuffer
// The benchmark reaches 20 at most, but here we have a little leeway just in case. // The benchmark reaches 20 at most, but here we have a little leeway just in case.
internal const int MaxSize = 32; internal const int MaxSize = 32;
internal static readonly int BatchSize = Vector<float>.Count; internal const int BatchSize = 8;
internal static readonly int BatchSizeBits = int.Log2(BatchSize); internal const int BatchSizeBits = 3; // int.Log2(BatchSize);
internal static readonly int BatchSizeMask = BatchSize - 1; internal const int BatchSizeMask = BatchSize - 1;
internal static readonly int BatchCount = MaxSize / BatchSize; internal const int BatchCount = MaxSize / BatchSize;
} }
// Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs // Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs
+12 -61
View File
@@ -23,7 +23,7 @@ internal static class Intrinsics
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int HMaxIndexScalar(Vector<float> v, int len) private static int HMaxIndexScalar(Vector256<float> v, int len)
{ {
var m = 0; var m = 0;
for (var i = 1; i < len; ++i) for (var i = 1; i < len; ++i)
@@ -46,10 +46,10 @@ internal static class Intrinsics
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
// https://stackoverflow.com/a/23592221 // https://stackoverflow.com/a/23592221
private static int HMaxIndexAVX2(Vector<float> v, int len) private static int HMaxIndexAVX2(Vector256<float> v, int len)
{ {
// Remove NaNs // Remove NaNs
var vfilt = ClearLastN(v.AsVector256(), len); var vfilt = ClearLastN(v, len);
// Find max value and broadcast to all lanes // Find max value and broadcast to all lanes
var vmax128 = HMax(vfilt); var vmax128 = HMax(vfilt);
@@ -66,41 +66,11 @@ internal static class Intrinsics
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int HMaxIndex(Vector<float> v, int len) => public static int HMaxIndex(Vector256<float> v, int len) =>
Avx2.IsSupported ? Avx2.IsSupported ?
HMaxIndexAVX2(v, len) : HMaxIndexAVX2(v, len) :
HMaxIndexScalar(v, len); HMaxIndexScalar(v, len);
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int NthBitSetScalar(uint value, int n)
{
var mask = 0x0000FFFFu;
var size = 16;
var _base = 0;
if (n++ >= BitOperations.PopCount(value))
return 32;
while (size > 0)
{
var count = BitOperations.PopCount(value & mask);
if (n > count)
{
_base += size;
size >>= 1;
mask |= mask << size;
}
else
{
size >>= 1;
mask >>= size;
}
}
return _base;
}
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int NthBitSetScalar(ulong value, int n) private static int NthBitSetScalar(ulong value, int n)
@@ -131,28 +101,11 @@ internal static class Intrinsics
return _base; return _base;
} }
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int NthBitSetBMI2(uint value, int n) =>
BitOperations.TrailingZeroCount(Bmi2.ParallelBitDeposit(1u << n, value));
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int NthBitSetBMI2(ulong value, int n) => private static int NthBitSetBMI2(ulong value, int n) =>
BitOperations.TrailingZeroCount(Bmi2.X64.ParallelBitDeposit(1ul << n, value)); BitOperations.TrailingZeroCount(Bmi2.X64.ParallelBitDeposit(1ul << n, value));
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int NthBitSet(uint value, int n)
{
if (n >= BitOperations.PopCount(value))
return 32;
return Bmi2.IsSupported ?
NthBitSetBMI2(value, n) :
NthBitSetScalar(value, n);
}
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int NthBitSet(ulong value, int n) public static int NthBitSet(ulong value, int n)
@@ -168,17 +121,15 @@ internal static class Intrinsics
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
[SkipLocalsInit] [SkipLocalsInit]
public static Vector<float> ReciprocalSqrt(Vector<float> data) public static Vector256<float> ReciprocalSqrt(Vector256<float> data)
{ {
if (Avx.IsSupported && Vector<float>.Count >= Vector256<float>.Count) if (Avx.IsSupported && Vector256<float>.Count >= Vector256<float>.Count)
return Avx.ReciprocalSqrt(data.AsVector256()).AsVector(); return Avx.ReciprocalSqrt(data);
if (Sse.IsSupported && Vector<float>.Count >= Vector128<float>.Count) Unsafe.SkipInit(out Vector256<float> ret);
return Sse.ReciprocalSqrt(data.AsVector128()).AsVector(); ref var result = ref Unsafe.As<Vector256<float>, float>(ref Unsafe.AsRef(in ret));
for (var i = 0; i < Vector256<float>.Count; ++i)
Span<float> result = stackalloc float[Vector<float>.Count]; Unsafe.Add(ref result, i) = MathF.ReciprocalSqrtEstimate(data[i]);
for (var i = 0; i < Vector<float>.Count; ++i) return ret;
result[i] = MathF.ReciprocalSqrtEstimate(data[i]);
return new(result);
} }
} }
+11 -11
View File
@@ -4,6 +4,7 @@ using System.Diagnostics.Contracts;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using Node = Craftimizer.Solver.ArenaNode<Craftimizer.Solver.SimulationNode>; using Node = Craftimizer.Solver.ArenaNode<Craftimizer.Solver.SimulationNode>;
using System.Runtime.Intrinsics;
namespace Craftimizer.Solver; namespace Craftimizer.Solver;
@@ -66,7 +67,7 @@ public sealed class MCTS
private static (int arrayIdx, int subIdx) ChildMaxScore(in NodeScoresBuffer scores) private static (int arrayIdx, int subIdx) ChildMaxScore(in NodeScoresBuffer scores)
{ {
var length = scores.Count; var length = scores.Count;
var vecLength = Vector<float>.Count; var vecLength = Vector256<float>.Count;
var max = (0, 0); var max = (0, 0);
var maxScore = 0f; var maxScore = 0f;
@@ -74,8 +75,7 @@ public sealed class MCTS
{ {
var iterCount = Math.Min(vecLength, length); var iterCount = Math.Min(vecLength, length);
ref var chunk = ref scores.Data[i]; var m = scores.Data![i].MaxScore;
var m = new Vector<float>(chunk.MaxScore.Span);
var idx = Intrinsics.HMaxIndex(m, iterCount); var idx = Intrinsics.HMaxIndex(m, iterCount);
@@ -116,12 +116,12 @@ public sealed class MCTS
in NodeScoresBuffer scores) in NodeScoresBuffer scores)
{ {
var length = scores.Count; var length = scores.Count;
var vecLength = Vector<float>.Count; var vecLength = Vector256<float>.Count;
var C = MathF.Sqrt(explorationConstant * MathF.Log(parentVisits)); var C = MathF.Sqrt(explorationConstant * MathF.Log(parentVisits));
var w = maxScoreWeightingConstant; var w = maxScoreWeightingConstant;
var W = 1f - w; var W = 1f - w;
var CVector = new Vector<float>(C); var CVector = Vector256.Create(C);
var max = (0, 0); var max = (0, 0);
var maxScore = 0f; var maxScore = 0f;
@@ -129,13 +129,13 @@ public sealed class MCTS
{ {
var iterCount = Math.Min(vecLength, length); var iterCount = Math.Min(vecLength, length);
ref var chunk = ref scores.Data[i]; ref var chunk = ref scores.Data![i];
var s = new Vector<float>(chunk.ScoreSum.Span); var s = chunk.ScoreSum;
var vInt = new Vector<int>(chunk.Visits.Span); var vInt = chunk.Visits;
var m = new Vector<float>(chunk.MaxScore.Span); var m = chunk.MaxScore;
vInt = Vector.Max(vInt, Vector<int>.One); vInt = Vector256.Max(vInt, Vector256<int>.One);
var v = Vector.ConvertToSingle(vInt); var v = Vector256.ConvertToSingle(vInt);
var exploitation = W * (s / v) + w * m; var exploitation = W * (s / v) + w * m;
var exploration = CVector * Intrinsics.ReciprocalSqrt(v); var exploration = CVector * Intrinsics.ReciprocalSqrt(v);
+22 -31
View File
@@ -1,54 +1,45 @@
using System.Diagnostics.Contracts;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
namespace Craftimizer.Solver; namespace Craftimizer.Solver;
// Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs
public struct NodeScoresBuffer public struct NodeScoresBuffer
{ {
[StructLayout(LayoutKind.Auto)] [StructLayout(LayoutKind.Auto)]
public readonly struct ScoresBatch public struct ScoresBatch
{ {
public readonly Memory<float> ScoreSum; public Vector256<float> ScoreSum;
public readonly Memory<float> MaxScore; public Vector256<float> MaxScore;
public readonly Memory<int> Visits; public Vector256<int> Visits;
public ScoresBatch()
{
ScoreSum = new float[ArenaBuffer.BatchSize];
MaxScore = new float[ArenaBuffer.BatchSize];
Visits = new int[ArenaBuffer.BatchSize];
}
} }
public ScoresBatch[] Data; public ScoresBatch[]? Data;
public int Count { get; private set; } public int Count { get; private set; }
public void Add() public void Add()
{ {
Data ??= new ScoresBatch[ArenaBuffer.BatchCount]; Data ??= GC.AllocateUninitializedArray<ScoresBatch>(ArenaBuffer.BatchCount);
var count = Count++;
var idx = Count++; if ((count & ArenaBuffer.BatchSizeMask) == 0)
Data[count >> ArenaBuffer.BatchSizeBits] = new();
var (arrayIdx, subIdx) = GetArrayIndex(idx);
if (subIdx == 0)
Data[arrayIdx] = new();
} }
public readonly void Visit((int arrayIdx, int subIdx) at, float score) public readonly void Visit((int arrayIdx, int subIdx) at, float score)
{ {
Data[at.arrayIdx].ScoreSum.Span[at.subIdx] += score; ref var batch = ref Data![at.arrayIdx];
Data[at.arrayIdx].MaxScore.Span[at.subIdx] = Math.Max(Data[at.arrayIdx].MaxScore.Span[at.subIdx], score); batch.ScoreSum.At(at.subIdx) += score;
Data[at.arrayIdx].Visits.Span[at.subIdx]++; ref var maxScore = ref batch.MaxScore.At(at.subIdx);
maxScore = Math.Max(maxScore, score);
batch.Visits.At(at.subIdx)++;
} }
public readonly int GetVisits((int arrayIdx, int subIdx) at) => public readonly int GetVisits((int arrayIdx, int subIdx) at) =>
Data[at.arrayIdx].Visits.Span[at.subIdx]; Data![at.arrayIdx].Visits[at.subIdx];
}
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] internal static class VectorUtils
private static (int arrayIdx, int subIdx) GetArrayIndex(int idx) => {
(idx >> ArenaBuffer.BatchSizeBits, idx & ArenaBuffer.BatchSizeMask); public static ref T At<T>(this ref Vector256<T> me, int idx) =>
ref Unsafe.Add(ref Unsafe.As<Vector256<T>, T>(ref me), idx);
} }