Batch allocations, replace all Vector<T> with Vector256

This commit is contained in:
Asriel Camora
2024-03-15 03:32:08 -07:00
parent 94ea1f84de
commit 82ac00df7c
4 changed files with 49 additions and 108 deletions
+11 -11
View File
@@ -4,6 +4,7 @@ using System.Diagnostics.Contracts;
using System.Numerics;
using System.Runtime.CompilerServices;
using Node = Craftimizer.Solver.ArenaNode<Craftimizer.Solver.SimulationNode>;
using System.Runtime.Intrinsics;
namespace Craftimizer.Solver;
@@ -66,7 +67,7 @@ public sealed class MCTS
private static (int arrayIdx, int subIdx) ChildMaxScore(in NodeScoresBuffer scores)
{
var length = scores.Count;
var vecLength = Vector<float>.Count;
var vecLength = Vector256<float>.Count;
var max = (0, 0);
var maxScore = 0f;
@@ -74,8 +75,7 @@ public sealed class MCTS
{
var iterCount = Math.Min(vecLength, length);
ref var chunk = ref scores.Data[i];
var m = new Vector<float>(chunk.MaxScore.Span);
var m = scores.Data![i].MaxScore;
var idx = Intrinsics.HMaxIndex(m, iterCount);
@@ -116,12 +116,12 @@ public sealed class MCTS
in NodeScoresBuffer scores)
{
var length = scores.Count;
var vecLength = Vector<float>.Count;
var vecLength = Vector256<float>.Count;
var C = MathF.Sqrt(explorationConstant * MathF.Log(parentVisits));
var w = maxScoreWeightingConstant;
var W = 1f - w;
var CVector = new Vector<float>(C);
var CVector = Vector256.Create(C);
var max = (0, 0);
var maxScore = 0f;
@@ -129,13 +129,13 @@ public sealed class MCTS
{
var iterCount = Math.Min(vecLength, length);
ref var chunk = ref scores.Data[i];
var s = new Vector<float>(chunk.ScoreSum.Span);
var vInt = new Vector<int>(chunk.Visits.Span);
var m = new Vector<float>(chunk.MaxScore.Span);
ref var chunk = ref scores.Data![i];
var s = chunk.ScoreSum;
var vInt = chunk.Visits;
var m = chunk.MaxScore;
vInt = Vector.Max(vInt, Vector<int>.One);
var v = Vector.ConvertToSingle(vInt);
vInt = Vector256.Max(vInt, Vector256<int>.One);
var v = Vector256.ConvertToSingle(vInt);
var exploitation = W * (s / v) + w * m;
var exploration = CVector * Intrinsics.ReciprocalSqrt(v);