Offload node score buffers

This commit is contained in:
Asriel Camora
2023-07-07 15:45:42 +02:00
parent 1386f9150c
commit e4d9e3a52e
10 changed files with 188 additions and 97 deletions
+26 -26
View File
@@ -39,35 +39,32 @@ public static class SolverUtils
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Node ChildMaxScore(ref Node.ChildBuffer children)
public static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores)
{
var length = children.Count;
var length = scores.Count;
var vecLength = Vector<float>.Count;
Span<float> scores = stackalloc float[vecLength];
var max = (0, 0);
var maxScore = 0f;
for (var i = 0; length > 0; ++i)
{
var iterCount = Math.Min(vecLength, length);
ref var chunk = ref children.Data[i];
for (var j = 0; j < iterCount; ++j)
scores[j] = chunk[j].State.Scores.MaxScore;
ref var chunk = ref scores.Data[i];
var m = new Vector<float>(chunk.MaxScore.Span);
var idx = Intrinsics.HMaxIndex(new Vector<float>(scores), iterCount);
var idx = Intrinsics.HMaxIndex(m, iterCount);
if (scores[idx] >= maxScore)
if (m[idx] >= maxScore)
{
max = (i, idx);
maxScore = scores[idx];
maxScore = m[idx];
}
length -= iterCount;
}
return children.Data[max.Item1][max.Item2];
return max;
}
[Pure]
@@ -76,7 +73,7 @@ public static class SolverUtils
var actions = new List<ActionType>();
while (node.Children.Count != 0)
{
node = ChildMaxScore(ref node.Children);
node = node.ChildAt(ChildMaxScore(ref node.ChildScores));
if (node.State.Action != null)
actions.Add(node.State.Action.Value);
@@ -87,9 +84,9 @@ public static class SolverUtils
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
public static Node EvalBestChild<S>(ref SolverConfig config, int parentVisits, ref Node.ChildBuffer children) where S : ISolver
public static (int arrayIdx, int subIdx) EvalBestChild<S>(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer scores) where S : ISolver
{
var length = children.Count;
var length = scores.Count;
var vecLength = Vector<float>.Count;
var C = MathF.Sqrt(config.ExplorationConstant * MathF.Log(parentVisits));
@@ -107,13 +104,14 @@ public static class SolverUtils
{
var iterCount = Math.Min(vecLength, length);
S.LoadChildData(scoreSums, visits, maxScores, ref children.Data[i], iterCount);
ref var chunk = ref scores.Data[i];
var s = new Vector<float>(chunk.ScoreSum.Span);
var vInt = new Vector<int>(chunk.Visits.Span);
var m = new Vector<float>(chunk.MaxScore.Span);
var s = new Vector<float>(scoreSums);
var m = new Vector<float>(maxScores);
var vInt = new Vector<int>(visits);
vInt = Vector.Max(vInt, Vector<int>.One);
var v = Vector.ConvertToSingle(vInt);
var exploitation = (W * (s / v)) + (w * m);
var exploration = CVector * Intrinsics.ReciprocalSqrt(v);
var evalScores = exploitation + exploration;
@@ -129,11 +127,11 @@ public static class SolverUtils
length -= iterCount;
}
return children.Data[max.Item1][max.Item2];
return max;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static (Node ExpandedNode, float Score) Rollout(ref SolverConfig config, Node rootNode, Node expandedNode, Random random, Simulator simulator)
public static (Node ExpandedNode, float Score) Rollout(ref SolverConfig config, float maxScore, Node rootNode, Node expandedNode, Random random, Simulator simulator)
{
// playout to a terminal state
var currentState = expandedNode.State.State;
@@ -157,7 +155,7 @@ public static class SolverUtils
var score = SimulationNode.CalculateScoreForState(currentState, currentCompletionState, config.MaxStepCount) ?? 0;
if (currentCompletionState == CompletionState.ProgressComplete)
{
if (score >= config.ScoreStorageThreshold && score >= rootNode.State.Scores.MaxScore)
if (score >= config.ScoreStorageThreshold && score >= maxScore)
{
(var terminalNode, _) = ExecuteActions(simulator, expandedNode, actions[..actionCount], true);
return (terminalNode, score);
@@ -167,7 +165,7 @@ public static class SolverUtils
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Search<S>(ref SolverConfig config, int iterations, Node rootNode, CancellationToken token) where S : ISolver
public static void Search<S>(ref SolverConfig config, int iterations, RootScores rootScores, Node rootNode, CancellationToken token) where S : ISolver
{
Simulator simulator = new(rootNode.State.State, config.MaxStepCount);
var random = rootNode.State.State.Input.Random;
@@ -176,7 +174,7 @@ public static class SolverUtils
if (token.IsCancellationRequested)
break;
if (!S.SearchIter(ref config, rootNode, random, simulator))
if (!S.SearchIter(ref config, rootScores, rootNode, random, simulator))
{
// Retry, count this iteration as moot
i--;
@@ -211,15 +209,16 @@ public static class SolverUtils
var actions = new List<ActionType>();
var sim = new Simulator(state, config.MaxStepCount);
var rootNode = CreateRootNode(config, state, true);
RootScores rootScores = new();
while (!sim.IsComplete)
{
if (token.IsCancellationRequested)
break;
S.Search(ref config, rootNode, token);
S.Search(ref config, rootScores, rootNode, token);
var (solution_actions, solution_node) = Solution(rootNode);
if (solution_node.Scores.MaxScore >= 1.0)
if (rootScores.MaxScore >= 1.0)
{
actions.AddRange(solution_actions);
return (actions, solution_node.State);
@@ -243,7 +242,8 @@ public static class SolverUtils
public static (List<ActionType> Actions, SimulationState State) SearchOneshot<S>(SolverConfig config, SimulationState state, CancellationToken token = default) where S : ISolver
{
var rootNode = CreateRootNode(config, state, false);
S.Search(ref config, rootNode, token);
RootScores rootScores = new();
S.Search(ref config, rootScores, rootNode, token);
var (solution_actions, solution_node) = Solution(rootNode);
return (solution_actions, solution_node.State);
}