Batch allocations, replace all Vector<T> with Vector256
This commit is contained in:
+11
-11
@@ -4,6 +4,7 @@ using System.Diagnostics.Contracts;
|
||||
using System.Numerics;
|
||||
using System.Runtime.CompilerServices;
|
||||
using Node = Craftimizer.Solver.ArenaNode<Craftimizer.Solver.SimulationNode>;
|
||||
using System.Runtime.Intrinsics;
|
||||
|
||||
namespace Craftimizer.Solver;
|
||||
|
||||
@@ -66,7 +67,7 @@ public sealed class MCTS
|
||||
private static (int arrayIdx, int subIdx) ChildMaxScore(in NodeScoresBuffer scores)
|
||||
{
|
||||
var length = scores.Count;
|
||||
var vecLength = Vector<float>.Count;
|
||||
var vecLength = Vector256<float>.Count;
|
||||
|
||||
var max = (0, 0);
|
||||
var maxScore = 0f;
|
||||
@@ -74,8 +75,7 @@ public sealed class MCTS
|
||||
{
|
||||
var iterCount = Math.Min(vecLength, length);
|
||||
|
||||
ref var chunk = ref scores.Data[i];
|
||||
var m = new Vector<float>(chunk.MaxScore.Span);
|
||||
var m = scores.Data![i].MaxScore;
|
||||
|
||||
var idx = Intrinsics.HMaxIndex(m, iterCount);
|
||||
|
||||
@@ -116,12 +116,12 @@ public sealed class MCTS
|
||||
in NodeScoresBuffer scores)
|
||||
{
|
||||
var length = scores.Count;
|
||||
var vecLength = Vector<float>.Count;
|
||||
var vecLength = Vector256<float>.Count;
|
||||
|
||||
var C = MathF.Sqrt(explorationConstant * MathF.Log(parentVisits));
|
||||
var w = maxScoreWeightingConstant;
|
||||
var W = 1f - w;
|
||||
var CVector = new Vector<float>(C);
|
||||
var CVector = Vector256.Create(C);
|
||||
|
||||
var max = (0, 0);
|
||||
var maxScore = 0f;
|
||||
@@ -129,13 +129,13 @@ public sealed class MCTS
|
||||
{
|
||||
var iterCount = Math.Min(vecLength, length);
|
||||
|
||||
ref var chunk = ref scores.Data[i];
|
||||
var s = new Vector<float>(chunk.ScoreSum.Span);
|
||||
var vInt = new Vector<int>(chunk.Visits.Span);
|
||||
var m = new Vector<float>(chunk.MaxScore.Span);
|
||||
ref var chunk = ref scores.Data![i];
|
||||
var s = chunk.ScoreSum;
|
||||
var vInt = chunk.Visits;
|
||||
var m = chunk.MaxScore;
|
||||
|
||||
vInt = Vector.Max(vInt, Vector<int>.One);
|
||||
var v = Vector.ConvertToSingle(vInt);
|
||||
vInt = Vector256.Max(vInt, Vector256<int>.One);
|
||||
var v = Vector256.ConvertToSingle(vInt);
|
||||
|
||||
var exploitation = W * (s / v) + w * m;
|
||||
var exploration = CVector * Intrinsics.ReciprocalSqrt(v);
|
||||
|
||||
Reference in New Issue
Block a user