Offload node score buffers

This commit is contained in:
Asriel Camora
2023-07-07 15:45:42 +02:00
parent 1386f9150c
commit e4d9e3a52e
10 changed files with 188 additions and 97 deletions
+18 -26
View File
@@ -7,25 +7,13 @@ namespace Craftimizer.Solver.Crafty;
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
public sealed class SolverSingle : ISolver
{
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
public static void LoadChildData(Span<float> scoreSums, Span<int> visits, Span<float> maxScores, ref Node[] chunk, int iterCount)
{
for (var j = 0; j < iterCount; ++j)
{
ref var node = ref chunk[j].State.Scores;
scoreSums[j] = node.ScoreSum;
visits[j] = node.Visits;
maxScores[j] = node.MaxScore;
}
}
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Node EvalBestChild(ref SolverConfig config, int parentVisits, ref Node.ChildBuffer children) =>
public static (int arrayIdx, int subIdx) EvalBestChild(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer children) =>
SolverUtils.EvalBestChild<SolverSingle>(ref config, parentVisits, ref children);
[Pure]
public static Node Select(ref SolverConfig config, Node node)
public static Node Select(ref SolverConfig config, int nodeVisits, Node node)
{
while (true)
{
@@ -35,11 +23,13 @@ public sealed class SolverSingle : ISolver
return node;
// select the node with the highest score
node = EvalBestChild(ref config, node.State.Scores.Visits, ref node.Children);
var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores);
nodeVisits = node.ChildScores.GetVisits(at);
node = node.ChildAt(at);
}
}
public static (Node ExpandedNode, float Score) ExpandAndRollout(ref SolverConfig config, Node rootNode, Random random, Simulator simulator, Node initialNode)
public static (Node ExpandedNode, float Score) ExpandAndRollout(ref SolverConfig config, float maxScore, Node rootNode, Random random, Simulator simulator, Node initialNode)
{
ref var initialState = ref initialNode.State;
// expand once
@@ -49,31 +39,33 @@ public sealed class SolverSingle : ISolver
var poppedAction = initialState.AvailableActions.PopRandom(random);
var expandedNode = initialNode.Add(SolverUtils.Execute(simulator, initialState.State, poppedAction, true));
return SolverUtils.Rollout(ref config, rootNode, expandedNode, random, simulator);
return SolverUtils.Rollout(ref config, maxScore, rootNode, expandedNode, random, simulator);
}
public static void Backpropagate(Node rootNode, Node startNode, float score)
public static void Backpropagate(RootScores rootScores, Node rootNode, Node startNode, float score)
{
while (true)
{
startNode.State.Scores.Visit(score);
if (startNode == rootNode)
{
rootScores.Visit(score);
break;
}
startNode.ParentScores!.Value.Visit(startNode.ChildIdx, score);
startNode = startNode.Parent!;
}
}
public static bool SearchIter(ref SolverConfig config, Node rootNode, Random random, Simulator simulator)
public static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator)
{
var selectedNode = Select(ref config, rootNode);
var (endNode, score) = ExpandAndRollout(ref config, rootNode, random, simulator, selectedNode);
var selectedNode = Select(ref config, rootScores.Visits, rootNode);
var (endNode, score) = ExpandAndRollout(ref config, rootScores.MaxScore, rootNode, random, simulator, selectedNode);
Backpropagate(rootNode, endNode, score);
Backpropagate(rootScores, rootNode, endNode, score);
return true;
}
public static void Search(ref SolverConfig config, Node rootNode, CancellationToken token) =>
SolverUtils.Search<SolverSingle>(ref config, config.Iterations, rootNode, token);
public static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token) =>
SolverUtils.Search<SolverSingle>(ref config, config.Iterations, rootScores, rootNode, token);
}