Offload node score buffers

This commit is contained in:
Asriel Camora
2023-07-07 15:45:42 +02:00
parent 1386f9150c
commit e4d9e3a52e
10 changed files with 188 additions and 97 deletions
+29 -28
View File
@@ -7,29 +7,18 @@ namespace Craftimizer.Solver.Crafty;
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
public sealed class SolverConcurrent : ISolver
{
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
public static void LoadChildData(Span<float> scoreSums, Span<int> visits, Span<float> maxScores, ref Node[] chunk, int iterCount)
{
for (var j = 0; j < iterCount; ++j)
{
var node = chunk[j]?.State.Scores ?? new();
scoreSums[j] = node.ScoreSum;
visits[j] = node.Visits;
maxScores[j] = node.MaxScore;
}
}
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Node? EvalBestChild(ref SolverConfig config, int parentVisits, ref Node.ChildBuffer children) =>
public static (int arrayIdx, int subIdx)? EvalBestChild(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer children) =>
parentVisits == 0 ?
null :
SolverUtils.EvalBestChild<SolverConcurrent>(ref config, parentVisits, ref children);
[Pure]
public static Node Select(ref SolverConfig config, Node rootNode)
public static Node Select(ref SolverConfig config, int rootNodeVisits, Node rootNode)
{
var node = rootNode;
var nodeVisits = rootNodeVisits;
while (true)
{
var expandable = !node.State.AvailableActions.IsEmpty;
@@ -39,11 +28,21 @@ public sealed class SolverConcurrent : ISolver
// select the node with the highest score
// if null (current node is invalid & not backpropagated just yet), try again from root
node = EvalBestChild(ref config, node.State.Scores.Visits, ref node.Children) ?? rootNode;
var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores);
if (at.HasValue)
{
nodeVisits = node.ChildScores.GetVisits(at.Value);
node = node.ChildAt(at.Value);
}
else
{
node = rootNode;
nodeVisits = rootNodeVisits;
}
}
}
public static (Node ExpandedNode, float Score)? ExpandAndRollout(ref SolverConfig config, Node rootNode, Random random, Simulator simulator, Node initialNode)
public static (Node ExpandedNode, float Score)? ExpandAndRollout(ref SolverConfig config, float maxScore, Node rootNode, Random random, Simulator simulator, Node initialNode)
{
ref var initialState = ref initialNode.State;
// expand once
@@ -55,43 +54,45 @@ public sealed class SolverConcurrent : ISolver
return null;
var expandedNode = initialNode.AddConcurrent(SolverUtils.Execute(simulator, initialState.State, poppedAction.Value, true));
return SolverUtils.Rollout(ref config, rootNode, expandedNode, random, simulator);
return SolverUtils.Rollout(ref config, maxScore, rootNode, expandedNode, random, simulator);
}
public static void Backpropagate(Node rootNode, Node startNode, float score)
public static void Backpropagate(RootScores rootScores, Node rootNode, Node startNode, float score)
{
while (true)
{
startNode.State.Scores.VisitConcurrent(score);
if (startNode == rootNode)
{
rootScores.VisitConcurrent(score);
break;
}
startNode.ParentScores!.Value.VisitConcurrent(startNode.ChildIdx, score);
startNode = startNode.Parent!;
}
}
public static bool SearchIter(ref SolverConfig config, Node rootNode, Random random, Simulator simulator)
public static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator)
{
var selectedNode = Select(ref config, rootNode);
var rolledOut = ExpandAndRollout(ref config, rootNode, random, simulator, selectedNode);
var selectedNode = Select(ref config, rootScores.Visits, rootNode);
var rolledOut = ExpandAndRollout(ref config, rootScores.MaxScore, rootNode, random, simulator, selectedNode);
if (!rolledOut.HasValue)
return false;
var (endNode, score) = rolledOut.Value;
Backpropagate(rootNode, endNode, score);
Backpropagate(rootScores, rootNode, endNode, score);
return true;
}
public static void SearchThread(SolverConfig config, Node rootNode, CancellationToken token) =>
SolverUtils.Search<SolverConcurrent>(ref config, config.Iterations / config.ThreadCount, rootNode, token);
public static void SearchThread(SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token) =>
SolverUtils.Search<SolverConcurrent>(ref config, config.Iterations / config.ThreadCount, rootScores, rootNode, token);
public static void Search(ref SolverConfig config, Node rootNode, CancellationToken token)
public static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token)
{
var configP = config;
var tasks = new Task[config.ThreadCount];
for (var i = 0; i < config.ThreadCount; ++i)
tasks[i] = Task.Run(() => SearchThread(configP, rootNode, token), token);
tasks[i] = Task.Run(() => SearchThread(configP, rootScores, rootNode, token), token);
Task.WaitAll(tasks, CancellationToken.None);
}
}