Batch allocations, replace all Vector<T> with Vector256
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
using System.Diagnostics.Contracts;
|
using System.Diagnostics.Contracts;
|
||||||
using System.Numerics;
|
|
||||||
using System.Runtime.CompilerServices;
|
using System.Runtime.CompilerServices;
|
||||||
|
|
||||||
namespace Craftimizer.Solver;
|
namespace Craftimizer.Solver;
|
||||||
@@ -10,11 +9,11 @@ public struct ArenaBuffer
|
|||||||
// The benchmark reaches 20 at most, but here we have a little leeway just in case.
|
// The benchmark reaches 20 at most, but here we have a little leeway just in case.
|
||||||
internal const int MaxSize = 32;
|
internal const int MaxSize = 32;
|
||||||
|
|
||||||
internal static readonly int BatchSize = Vector<float>.Count;
|
internal const int BatchSize = 8;
|
||||||
internal static readonly int BatchSizeBits = int.Log2(BatchSize);
|
internal const int BatchSizeBits = 3; // int.Log2(BatchSize);
|
||||||
internal static readonly int BatchSizeMask = BatchSize - 1;
|
internal const int BatchSizeMask = BatchSize - 1;
|
||||||
|
|
||||||
internal static readonly int BatchCount = MaxSize / BatchSize;
|
internal const int BatchCount = MaxSize / BatchSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs
|
// Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs
|
||||||
|
|||||||
+12
-61
@@ -23,7 +23,7 @@ internal static class Intrinsics
|
|||||||
|
|
||||||
[Pure]
|
[Pure]
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
private static int HMaxIndexScalar(Vector<float> v, int len)
|
private static int HMaxIndexScalar(Vector256<float> v, int len)
|
||||||
{
|
{
|
||||||
var m = 0;
|
var m = 0;
|
||||||
for (var i = 1; i < len; ++i)
|
for (var i = 1; i < len; ++i)
|
||||||
@@ -46,10 +46,10 @@ internal static class Intrinsics
|
|||||||
[Pure]
|
[Pure]
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
// https://stackoverflow.com/a/23592221
|
// https://stackoverflow.com/a/23592221
|
||||||
private static int HMaxIndexAVX2(Vector<float> v, int len)
|
private static int HMaxIndexAVX2(Vector256<float> v, int len)
|
||||||
{
|
{
|
||||||
// Remove NaNs
|
// Remove NaNs
|
||||||
var vfilt = ClearLastN(v.AsVector256(), len);
|
var vfilt = ClearLastN(v, len);
|
||||||
|
|
||||||
// Find max value and broadcast to all lanes
|
// Find max value and broadcast to all lanes
|
||||||
var vmax128 = HMax(vfilt);
|
var vmax128 = HMax(vfilt);
|
||||||
@@ -66,41 +66,11 @@ internal static class Intrinsics
|
|||||||
|
|
||||||
[Pure]
|
[Pure]
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
public static int HMaxIndex(Vector<float> v, int len) =>
|
public static int HMaxIndex(Vector256<float> v, int len) =>
|
||||||
Avx2.IsSupported ?
|
Avx2.IsSupported ?
|
||||||
HMaxIndexAVX2(v, len) :
|
HMaxIndexAVX2(v, len) :
|
||||||
HMaxIndexScalar(v, len);
|
HMaxIndexScalar(v, len);
|
||||||
|
|
||||||
[Pure]
|
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
|
||||||
private static int NthBitSetScalar(uint value, int n)
|
|
||||||
{
|
|
||||||
var mask = 0x0000FFFFu;
|
|
||||||
var size = 16;
|
|
||||||
var _base = 0;
|
|
||||||
|
|
||||||
if (n++ >= BitOperations.PopCount(value))
|
|
||||||
return 32;
|
|
||||||
|
|
||||||
while (size > 0)
|
|
||||||
{
|
|
||||||
var count = BitOperations.PopCount(value & mask);
|
|
||||||
if (n > count)
|
|
||||||
{
|
|
||||||
_base += size;
|
|
||||||
size >>= 1;
|
|
||||||
mask |= mask << size;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
size >>= 1;
|
|
||||||
mask >>= size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return _base;
|
|
||||||
}
|
|
||||||
|
|
||||||
[Pure]
|
[Pure]
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
private static int NthBitSetScalar(ulong value, int n)
|
private static int NthBitSetScalar(ulong value, int n)
|
||||||
@@ -131,28 +101,11 @@ internal static class Intrinsics
|
|||||||
return _base;
|
return _base;
|
||||||
}
|
}
|
||||||
|
|
||||||
[Pure]
|
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
|
||||||
private static int NthBitSetBMI2(uint value, int n) =>
|
|
||||||
BitOperations.TrailingZeroCount(Bmi2.ParallelBitDeposit(1u << n, value));
|
|
||||||
|
|
||||||
[Pure]
|
[Pure]
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
private static int NthBitSetBMI2(ulong value, int n) =>
|
private static int NthBitSetBMI2(ulong value, int n) =>
|
||||||
BitOperations.TrailingZeroCount(Bmi2.X64.ParallelBitDeposit(1ul << n, value));
|
BitOperations.TrailingZeroCount(Bmi2.X64.ParallelBitDeposit(1ul << n, value));
|
||||||
|
|
||||||
[Pure]
|
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
|
||||||
public static int NthBitSet(uint value, int n)
|
|
||||||
{
|
|
||||||
if (n >= BitOperations.PopCount(value))
|
|
||||||
return 32;
|
|
||||||
|
|
||||||
return Bmi2.IsSupported ?
|
|
||||||
NthBitSetBMI2(value, n) :
|
|
||||||
NthBitSetScalar(value, n);
|
|
||||||
}
|
|
||||||
|
|
||||||
[Pure]
|
[Pure]
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
public static int NthBitSet(ulong value, int n)
|
public static int NthBitSet(ulong value, int n)
|
||||||
@@ -168,17 +121,15 @@ internal static class Intrinsics
|
|||||||
[Pure]
|
[Pure]
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
[SkipLocalsInit]
|
[SkipLocalsInit]
|
||||||
public static Vector<float> ReciprocalSqrt(Vector<float> data)
|
public static Vector256<float> ReciprocalSqrt(Vector256<float> data)
|
||||||
{
|
{
|
||||||
if (Avx.IsSupported && Vector<float>.Count >= Vector256<float>.Count)
|
if (Avx.IsSupported && Vector256<float>.Count >= Vector256<float>.Count)
|
||||||
return Avx.ReciprocalSqrt(data.AsVector256()).AsVector();
|
return Avx.ReciprocalSqrt(data);
|
||||||
|
|
||||||
if (Sse.IsSupported && Vector<float>.Count >= Vector128<float>.Count)
|
Unsafe.SkipInit(out Vector256<float> ret);
|
||||||
return Sse.ReciprocalSqrt(data.AsVector128()).AsVector();
|
ref var result = ref Unsafe.As<Vector256<float>, float>(ref Unsafe.AsRef(in ret));
|
||||||
|
for (var i = 0; i < Vector256<float>.Count; ++i)
|
||||||
Span<float> result = stackalloc float[Vector<float>.Count];
|
Unsafe.Add(ref result, i) = MathF.ReciprocalSqrtEstimate(data[i]);
|
||||||
for (var i = 0; i < Vector<float>.Count; ++i)
|
return ret;
|
||||||
result[i] = MathF.ReciprocalSqrtEstimate(data[i]);
|
|
||||||
return new(result);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+11
-11
@@ -4,6 +4,7 @@ using System.Diagnostics.Contracts;
|
|||||||
using System.Numerics;
|
using System.Numerics;
|
||||||
using System.Runtime.CompilerServices;
|
using System.Runtime.CompilerServices;
|
||||||
using Node = Craftimizer.Solver.ArenaNode<Craftimizer.Solver.SimulationNode>;
|
using Node = Craftimizer.Solver.ArenaNode<Craftimizer.Solver.SimulationNode>;
|
||||||
|
using System.Runtime.Intrinsics;
|
||||||
|
|
||||||
namespace Craftimizer.Solver;
|
namespace Craftimizer.Solver;
|
||||||
|
|
||||||
@@ -66,7 +67,7 @@ public sealed class MCTS
|
|||||||
private static (int arrayIdx, int subIdx) ChildMaxScore(in NodeScoresBuffer scores)
|
private static (int arrayIdx, int subIdx) ChildMaxScore(in NodeScoresBuffer scores)
|
||||||
{
|
{
|
||||||
var length = scores.Count;
|
var length = scores.Count;
|
||||||
var vecLength = Vector<float>.Count;
|
var vecLength = Vector256<float>.Count;
|
||||||
|
|
||||||
var max = (0, 0);
|
var max = (0, 0);
|
||||||
var maxScore = 0f;
|
var maxScore = 0f;
|
||||||
@@ -74,8 +75,7 @@ public sealed class MCTS
|
|||||||
{
|
{
|
||||||
var iterCount = Math.Min(vecLength, length);
|
var iterCount = Math.Min(vecLength, length);
|
||||||
|
|
||||||
ref var chunk = ref scores.Data[i];
|
var m = scores.Data![i].MaxScore;
|
||||||
var m = new Vector<float>(chunk.MaxScore.Span);
|
|
||||||
|
|
||||||
var idx = Intrinsics.HMaxIndex(m, iterCount);
|
var idx = Intrinsics.HMaxIndex(m, iterCount);
|
||||||
|
|
||||||
@@ -116,12 +116,12 @@ public sealed class MCTS
|
|||||||
in NodeScoresBuffer scores)
|
in NodeScoresBuffer scores)
|
||||||
{
|
{
|
||||||
var length = scores.Count;
|
var length = scores.Count;
|
||||||
var vecLength = Vector<float>.Count;
|
var vecLength = Vector256<float>.Count;
|
||||||
|
|
||||||
var C = MathF.Sqrt(explorationConstant * MathF.Log(parentVisits));
|
var C = MathF.Sqrt(explorationConstant * MathF.Log(parentVisits));
|
||||||
var w = maxScoreWeightingConstant;
|
var w = maxScoreWeightingConstant;
|
||||||
var W = 1f - w;
|
var W = 1f - w;
|
||||||
var CVector = new Vector<float>(C);
|
var CVector = Vector256.Create(C);
|
||||||
|
|
||||||
var max = (0, 0);
|
var max = (0, 0);
|
||||||
var maxScore = 0f;
|
var maxScore = 0f;
|
||||||
@@ -129,13 +129,13 @@ public sealed class MCTS
|
|||||||
{
|
{
|
||||||
var iterCount = Math.Min(vecLength, length);
|
var iterCount = Math.Min(vecLength, length);
|
||||||
|
|
||||||
ref var chunk = ref scores.Data[i];
|
ref var chunk = ref scores.Data![i];
|
||||||
var s = new Vector<float>(chunk.ScoreSum.Span);
|
var s = chunk.ScoreSum;
|
||||||
var vInt = new Vector<int>(chunk.Visits.Span);
|
var vInt = chunk.Visits;
|
||||||
var m = new Vector<float>(chunk.MaxScore.Span);
|
var m = chunk.MaxScore;
|
||||||
|
|
||||||
vInt = Vector.Max(vInt, Vector<int>.One);
|
vInt = Vector256.Max(vInt, Vector256<int>.One);
|
||||||
var v = Vector.ConvertToSingle(vInt);
|
var v = Vector256.ConvertToSingle(vInt);
|
||||||
|
|
||||||
var exploitation = W * (s / v) + w * m;
|
var exploitation = W * (s / v) + w * m;
|
||||||
var exploration = CVector * Intrinsics.ReciprocalSqrt(v);
|
var exploration = CVector * Intrinsics.ReciprocalSqrt(v);
|
||||||
|
|||||||
+22
-31
@@ -1,54 +1,45 @@
|
|||||||
using System.Diagnostics.Contracts;
|
|
||||||
using System.Runtime.CompilerServices;
|
using System.Runtime.CompilerServices;
|
||||||
using System.Runtime.InteropServices;
|
using System.Runtime.InteropServices;
|
||||||
|
using System.Runtime.Intrinsics;
|
||||||
|
|
||||||
namespace Craftimizer.Solver;
|
namespace Craftimizer.Solver;
|
||||||
|
|
||||||
// Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs
|
|
||||||
public struct NodeScoresBuffer
|
public struct NodeScoresBuffer
|
||||||
{
|
{
|
||||||
[StructLayout(LayoutKind.Auto)]
|
[StructLayout(LayoutKind.Auto)]
|
||||||
public readonly struct ScoresBatch
|
public struct ScoresBatch
|
||||||
{
|
{
|
||||||
public readonly Memory<float> ScoreSum;
|
public Vector256<float> ScoreSum;
|
||||||
public readonly Memory<float> MaxScore;
|
public Vector256<float> MaxScore;
|
||||||
public readonly Memory<int> Visits;
|
public Vector256<int> Visits;
|
||||||
|
|
||||||
public ScoresBatch()
|
|
||||||
{
|
|
||||||
ScoreSum = new float[ArenaBuffer.BatchSize];
|
|
||||||
MaxScore = new float[ArenaBuffer.BatchSize];
|
|
||||||
Visits = new int[ArenaBuffer.BatchSize];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScoresBatch[] Data;
|
public ScoresBatch[]? Data;
|
||||||
public int Count { get; private set; }
|
public int Count { get; private set; }
|
||||||
|
|
||||||
public void Add()
|
public void Add()
|
||||||
{
|
{
|
||||||
Data ??= new ScoresBatch[ArenaBuffer.BatchCount];
|
Data ??= GC.AllocateUninitializedArray<ScoresBatch>(ArenaBuffer.BatchCount);
|
||||||
|
var count = Count++;
|
||||||
var idx = Count++;
|
if ((count & ArenaBuffer.BatchSizeMask) == 0)
|
||||||
|
Data[count >> ArenaBuffer.BatchSizeBits] = new();
|
||||||
var (arrayIdx, subIdx) = GetArrayIndex(idx);
|
|
||||||
|
|
||||||
if (subIdx == 0)
|
|
||||||
Data[arrayIdx] = new();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public readonly void Visit((int arrayIdx, int subIdx) at, float score)
|
public readonly void Visit((int arrayIdx, int subIdx) at, float score)
|
||||||
{
|
{
|
||||||
Data[at.arrayIdx].ScoreSum.Span[at.subIdx] += score;
|
ref var batch = ref Data![at.arrayIdx];
|
||||||
Data[at.arrayIdx].MaxScore.Span[at.subIdx] = Math.Max(Data[at.arrayIdx].MaxScore.Span[at.subIdx], score);
|
batch.ScoreSum.At(at.subIdx) += score;
|
||||||
Data[at.arrayIdx].Visits.Span[at.subIdx]++;
|
ref var maxScore = ref batch.MaxScore.At(at.subIdx);
|
||||||
|
maxScore = Math.Max(maxScore, score);
|
||||||
|
batch.Visits.At(at.subIdx)++;
|
||||||
}
|
}
|
||||||
|
|
||||||
public readonly int GetVisits((int arrayIdx, int subIdx) at) =>
|
public readonly int GetVisits((int arrayIdx, int subIdx) at) =>
|
||||||
Data[at.arrayIdx].Visits.Span[at.subIdx];
|
Data![at.arrayIdx].Visits[at.subIdx];
|
||||||
|
}
|
||||||
[Pure]
|
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
internal static class VectorUtils
|
||||||
private static (int arrayIdx, int subIdx) GetArrayIndex(int idx) =>
|
{
|
||||||
(idx >> ArenaBuffer.BatchSizeBits, idx & ArenaBuffer.BatchSizeMask);
|
public static ref T At<T>(this ref Vector256<T> me, int idx) =>
|
||||||
|
ref Unsafe.Add(ref Unsafe.As<Vector256<T>, T>(ref me), idx);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user