Offload node score buffers
This commit is contained in:
@@ -39,35 +39,32 @@ public static class SolverUtils
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static Node ChildMaxScore(ref Node.ChildBuffer children)
|
||||
public static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores)
|
||||
{
|
||||
var length = children.Count;
|
||||
var length = scores.Count;
|
||||
var vecLength = Vector<float>.Count;
|
||||
|
||||
Span<float> scores = stackalloc float[vecLength];
|
||||
|
||||
var max = (0, 0);
|
||||
var maxScore = 0f;
|
||||
for (var i = 0; length > 0; ++i)
|
||||
{
|
||||
var iterCount = Math.Min(vecLength, length);
|
||||
|
||||
ref var chunk = ref children.Data[i];
|
||||
for (var j = 0; j < iterCount; ++j)
|
||||
scores[j] = chunk[j].State.Scores.MaxScore;
|
||||
ref var chunk = ref scores.Data[i];
|
||||
var m = new Vector<float>(chunk.MaxScore.Span);
|
||||
|
||||
var idx = Intrinsics.HMaxIndex(new Vector<float>(scores), iterCount);
|
||||
var idx = Intrinsics.HMaxIndex(m, iterCount);
|
||||
|
||||
if (scores[idx] >= maxScore)
|
||||
if (m[idx] >= maxScore)
|
||||
{
|
||||
max = (i, idx);
|
||||
maxScore = scores[idx];
|
||||
maxScore = m[idx];
|
||||
}
|
||||
|
||||
length -= iterCount;
|
||||
}
|
||||
|
||||
return children.Data[max.Item1][max.Item2];
|
||||
return max;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
@@ -76,7 +73,7 @@ public static class SolverUtils
|
||||
var actions = new List<ActionType>();
|
||||
while (node.Children.Count != 0)
|
||||
{
|
||||
node = ChildMaxScore(ref node.Children);
|
||||
node = node.ChildAt(ChildMaxScore(ref node.ChildScores));
|
||||
|
||||
if (node.State.Action != null)
|
||||
actions.Add(node.State.Action.Value);
|
||||
@@ -87,9 +84,9 @@ public static class SolverUtils
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
|
||||
public static Node EvalBestChild<S>(ref SolverConfig config, int parentVisits, ref Node.ChildBuffer children) where S : ISolver
|
||||
public static (int arrayIdx, int subIdx) EvalBestChild<S>(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer scores) where S : ISolver
|
||||
{
|
||||
var length = children.Count;
|
||||
var length = scores.Count;
|
||||
var vecLength = Vector<float>.Count;
|
||||
|
||||
var C = MathF.Sqrt(config.ExplorationConstant * MathF.Log(parentVisits));
|
||||
@@ -107,13 +104,14 @@ public static class SolverUtils
|
||||
{
|
||||
var iterCount = Math.Min(vecLength, length);
|
||||
|
||||
S.LoadChildData(scoreSums, visits, maxScores, ref children.Data[i], iterCount);
|
||||
ref var chunk = ref scores.Data[i];
|
||||
var s = new Vector<float>(chunk.ScoreSum.Span);
|
||||
var vInt = new Vector<int>(chunk.Visits.Span);
|
||||
var m = new Vector<float>(chunk.MaxScore.Span);
|
||||
|
||||
var s = new Vector<float>(scoreSums);
|
||||
var m = new Vector<float>(maxScores);
|
||||
var vInt = new Vector<int>(visits);
|
||||
vInt = Vector.Max(vInt, Vector<int>.One);
|
||||
var v = Vector.ConvertToSingle(vInt);
|
||||
|
||||
var exploitation = (W * (s / v)) + (w * m);
|
||||
var exploration = CVector * Intrinsics.ReciprocalSqrt(v);
|
||||
var evalScores = exploitation + exploration;
|
||||
@@ -129,11 +127,11 @@ public static class SolverUtils
|
||||
length -= iterCount;
|
||||
}
|
||||
|
||||
return children.Data[max.Item1][max.Item2];
|
||||
return max;
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static (Node ExpandedNode, float Score) Rollout(ref SolverConfig config, Node rootNode, Node expandedNode, Random random, Simulator simulator)
|
||||
public static (Node ExpandedNode, float Score) Rollout(ref SolverConfig config, float maxScore, Node rootNode, Node expandedNode, Random random, Simulator simulator)
|
||||
{
|
||||
// playout to a terminal state
|
||||
var currentState = expandedNode.State.State;
|
||||
@@ -157,7 +155,7 @@ public static class SolverUtils
|
||||
var score = SimulationNode.CalculateScoreForState(currentState, currentCompletionState, config.MaxStepCount) ?? 0;
|
||||
if (currentCompletionState == CompletionState.ProgressComplete)
|
||||
{
|
||||
if (score >= config.ScoreStorageThreshold && score >= rootNode.State.Scores.MaxScore)
|
||||
if (score >= config.ScoreStorageThreshold && score >= maxScore)
|
||||
{
|
||||
(var terminalNode, _) = ExecuteActions(simulator, expandedNode, actions[..actionCount], true);
|
||||
return (terminalNode, score);
|
||||
@@ -167,7 +165,7 @@ public static class SolverUtils
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static void Search<S>(ref SolverConfig config, int iterations, Node rootNode, CancellationToken token) where S : ISolver
|
||||
public static void Search<S>(ref SolverConfig config, int iterations, RootScores rootScores, Node rootNode, CancellationToken token) where S : ISolver
|
||||
{
|
||||
Simulator simulator = new(rootNode.State.State, config.MaxStepCount);
|
||||
var random = rootNode.State.State.Input.Random;
|
||||
@@ -176,7 +174,7 @@ public static class SolverUtils
|
||||
if (token.IsCancellationRequested)
|
||||
break;
|
||||
|
||||
if (!S.SearchIter(ref config, rootNode, random, simulator))
|
||||
if (!S.SearchIter(ref config, rootScores, rootNode, random, simulator))
|
||||
{
|
||||
// Retry, count this iteration as moot
|
||||
i--;
|
||||
@@ -211,15 +209,16 @@ public static class SolverUtils
|
||||
var actions = new List<ActionType>();
|
||||
var sim = new Simulator(state, config.MaxStepCount);
|
||||
var rootNode = CreateRootNode(config, state, true);
|
||||
RootScores rootScores = new();
|
||||
while (!sim.IsComplete)
|
||||
{
|
||||
if (token.IsCancellationRequested)
|
||||
break;
|
||||
|
||||
S.Search(ref config, rootNode, token);
|
||||
S.Search(ref config, rootScores, rootNode, token);
|
||||
var (solution_actions, solution_node) = Solution(rootNode);
|
||||
|
||||
if (solution_node.Scores.MaxScore >= 1.0)
|
||||
if (rootScores.MaxScore >= 1.0)
|
||||
{
|
||||
actions.AddRange(solution_actions);
|
||||
return (actions, solution_node.State);
|
||||
@@ -243,7 +242,8 @@ public static class SolverUtils
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchOneshot<S>(SolverConfig config, SimulationState state, CancellationToken token = default) where S : ISolver
|
||||
{
|
||||
var rootNode = CreateRootNode(config, state, false);
|
||||
S.Search(ref config, rootNode, token);
|
||||
RootScores rootScores = new();
|
||||
S.Search(ref config, rootScores, rootNode, token);
|
||||
var (solution_actions, solution_node) = Solution(rootNode);
|
||||
return (solution_actions, solution_node.State);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user