Offload node score buffers

This commit is contained in:
Asriel Camora
2023-07-07 15:45:42 +02:00
parent 1386f9150c
commit e4d9e3a52e
10 changed files with 188 additions and 97 deletions
+10 -8
View File
@@ -5,7 +5,7 @@ using System.Runtime.CompilerServices;
namespace Craftimizer.Solver.Crafty;
// Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs
public struct ArenaBuffer<T>
public struct ArenaBuffer<T> where T : struct
{
// Technically 25, but it's very unlikely to actually get to there.
// The benchmark reaches 20 at most, but here we have a little leeway just in case.
@@ -17,39 +17,41 @@ public struct ArenaBuffer<T>
private static int BatchCount = MaxSize / BatchSize;
public T[][] Data;
public ArenaNode<T>[][] Data;
private int index; // Unused in single threaded workload
private int count;
public readonly int Count => count;
public void AddConcurrent(T node)
public void AddConcurrent(ArenaNode<T> node)
{
if (Data == null)
Interlocked.CompareExchange(ref Data, new T[BatchCount][], null);
Interlocked.CompareExchange(ref Data, new ArenaNode<T>[BatchCount][], null);
var idx = Interlocked.Increment(ref index) - 1;
var (arrayIdx, subIdx) = GetArrayIndex(idx);
if (Data[arrayIdx] == null)
Interlocked.CompareExchange(ref Data[arrayIdx], new T[BatchSize], null);
Interlocked.CompareExchange(ref Data[arrayIdx], new ArenaNode<T>[BatchSize], null);
node.ChildIdx = (arrayIdx, subIdx);
Data[arrayIdx][subIdx] = node;
Interlocked.Increment(ref count);
}
public void Add(T node)
public void Add(ArenaNode<T> node)
{
Data ??= new T[BatchCount][];
Data ??= new ArenaNode<T>[BatchCount][];
var idx = count++;
var (arrayIdx, subIdx) = GetArrayIndex(idx);
Data[arrayIdx] ??= new T[BatchSize];
Data[arrayIdx] ??= new ArenaNode<T>[BatchSize];
node.ChildIdx = (arrayIdx, subIdx);
Data[arrayIdx][subIdx] = node;
}
+11 -1
View File
@@ -5,21 +5,30 @@ namespace Craftimizer.Solver.Crafty;
public sealed class ArenaNode<T> where T : struct
{
public T State;
public ArenaBuffer<ArenaNode<T>> Children;
public ArenaBuffer<T> Children;
public NodeScoresBuffer ChildScores;
public (int arrayIdx, int subIdx) ChildIdx;
public readonly ArenaNode<T>? Parent;
public NodeScoresBuffer? ParentScores => Parent?.ChildScores;
public ArenaNode(T state, ArenaNode<T>? parent = null)
{
State = state;
Children = new();
ChildScores = new();
Parent = parent;
}
public ArenaNode<T> ChildAt((int arrayIdx, int subIdx) at) =>
Children.Data[at.arrayIdx][at.subIdx];
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ArenaNode<T> AddConcurrent(T state)
{
var node = new ArenaNode<T>(state, this);
Children.AddConcurrent(node);
ChildScores.AddConcurrent();
return node;
}
@@ -28,6 +37,7 @@ public sealed class ArenaNode<T> where T : struct
{
var node = new ArenaNode<T>(state, this);
Children.Add(node);
ChildScores.Add();
return node;
}
}
+2 -4
View File
@@ -4,9 +4,7 @@ namespace Craftimizer.Solver.Crafty;
public interface ISolver
{
abstract static void LoadChildData(Span<float> scoreSums, Span<int> visits, Span<float> maxScores, ref Node[] chunk, int iterCount);
abstract static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator);
abstract static bool SearchIter(ref SolverConfig config, Node rootNode, Random random, Simulator simulator);
abstract static void Search(ref SolverConfig config, Node rootNode, CancellationToken token);
abstract static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token);
}
+1 -2
View File
@@ -101,9 +101,8 @@ internal static class Intrinsics
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int NthBitSet(uint value, int n)
{
// TODO: debug
if (n >= BitOperations.PopCount(value))
throw new ArgumentException(null, nameof(value));
return 32;
return Bmi2.IsSupported ?
NthBitSetBMI2(value, n) :
+90
View File
@@ -0,0 +1,90 @@
using System;
using System.ComponentModel;
using System.Diagnostics.Contracts;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
namespace Craftimizer.Solver.Crafty;
// Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs
public struct NodeScoresBuffer
{
public sealed class ScoresBatch
{
public Memory<float> ScoreSum;
public Memory<float> MaxScore;
public Memory<int> Visits;
public ScoresBatch()
{
ScoreSum = new float[BatchSize];
MaxScore = new float[BatchSize];
Visits = new int[BatchSize];
}
}
// Technically 25, but it's very unlikely to actually get to there.
// The benchmark reaches 20 at most, but here we have a little leeway just in case.
private const int MaxSize = 24;
private static int BatchSize = Vector<float>.Count;
private static int BatchSizeBits = int.Log2(BatchSize);
private static int BatchSizeMask = BatchSize - 1;
private static int BatchCount = MaxSize / BatchSize;
public ScoresBatch[] Data;
private int index;
private int count;
public readonly int Count => count;
public void AddConcurrent()
{
if (Data == null)
Interlocked.CompareExchange(ref Data, new ScoresBatch[BatchCount], null);
var idx = Interlocked.Increment(ref index) - 1;
var (arrayIdx, _) = GetArrayIndex(idx);
if (Data[arrayIdx] == null)
Interlocked.CompareExchange(ref Data[arrayIdx], new ScoresBatch(), null);
Interlocked.Increment(ref count);
}
public void Add()
{
Data ??= new ScoresBatch[BatchCount];
var idx = count++;
var (arrayIdx, _) = GetArrayIndex(idx);
Data[arrayIdx] ??= new();
}
public readonly void VisitConcurrent((int arrayIdx, int subIdx) at, float score)
{
Intrinsics.CASAdd(ref Data[at.arrayIdx].ScoreSum.Span[at.subIdx], score);
Intrinsics.CASMax(ref Data[at.arrayIdx].MaxScore.Span[at.subIdx], score);
Interlocked.Increment(ref Data[at.arrayIdx].Visits.Span[at.subIdx]);
}
public readonly void Visit((int arrayIdx, int subIdx) at, float score)
{
Data[at.arrayIdx].ScoreSum.Span[at.subIdx] += score;
Data[at.arrayIdx].MaxScore.Span[at.subIdx] = Math.Max(Data[at.arrayIdx].MaxScore.Span[at.subIdx], score);
Data[at.arrayIdx].Visits.Span[at.subIdx]++;
}
public readonly int GetVisits((int arrayIdx, int subIdx) at) =>
Data[at.arrayIdx].Visits.Span[at.subIdx];
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static (int arrayIdx, int subIdx) GetArrayIndex(int idx) =>
(idx >> BatchSizeBits, idx & BatchSizeMask);
}
@@ -3,7 +3,7 @@ using System.Runtime.InteropServices;
namespace Craftimizer.Solver.Crafty;
[StructLayout(LayoutKind.Auto)]
public struct NodeScores
public sealed class RootScores
{
public float ScoreSum;
public float MaxScore;
-1
View File
@@ -12,7 +12,6 @@ public struct SimulationNode
public readonly CompletionState SimulationCompletionState;
public ActionSet AvailableActions;
public NodeScores Scores;
public readonly CompletionState CompletionState => GetCompletionState(SimulationCompletionState, AvailableActions);
+29 -28
View File
@@ -7,29 +7,18 @@ namespace Craftimizer.Solver.Crafty;
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
public sealed class SolverConcurrent : ISolver
{
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
public static void LoadChildData(Span<float> scoreSums, Span<int> visits, Span<float> maxScores, ref Node[] chunk, int iterCount)
{
for (var j = 0; j < iterCount; ++j)
{
var node = chunk[j]?.State.Scores ?? new();
scoreSums[j] = node.ScoreSum;
visits[j] = node.Visits;
maxScores[j] = node.MaxScore;
}
}
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Node? EvalBestChild(ref SolverConfig config, int parentVisits, ref Node.ChildBuffer children) =>
public static (int arrayIdx, int subIdx)? EvalBestChild(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer children) =>
parentVisits == 0 ?
null :
SolverUtils.EvalBestChild<SolverConcurrent>(ref config, parentVisits, ref children);
[Pure]
public static Node Select(ref SolverConfig config, Node rootNode)
public static Node Select(ref SolverConfig config, int rootNodeVisits, Node rootNode)
{
var node = rootNode;
var nodeVisits = rootNodeVisits;
while (true)
{
var expandable = !node.State.AvailableActions.IsEmpty;
@@ -39,11 +28,21 @@ public sealed class SolverConcurrent : ISolver
// select the node with the highest score
// if null (current node is invalid & not backpropagated just yet), try again from root
node = EvalBestChild(ref config, node.State.Scores.Visits, ref node.Children) ?? rootNode;
var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores);
if (at.HasValue)
{
nodeVisits = node.ChildScores.GetVisits(at.Value);
node = node.ChildAt(at.Value);
}
else
{
node = rootNode;
nodeVisits = rootNodeVisits;
}
}
}
public static (Node ExpandedNode, float Score)? ExpandAndRollout(ref SolverConfig config, Node rootNode, Random random, Simulator simulator, Node initialNode)
public static (Node ExpandedNode, float Score)? ExpandAndRollout(ref SolverConfig config, float maxScore, Node rootNode, Random random, Simulator simulator, Node initialNode)
{
ref var initialState = ref initialNode.State;
// expand once
@@ -55,43 +54,45 @@ public sealed class SolverConcurrent : ISolver
return null;
var expandedNode = initialNode.AddConcurrent(SolverUtils.Execute(simulator, initialState.State, poppedAction.Value, true));
return SolverUtils.Rollout(ref config, rootNode, expandedNode, random, simulator);
return SolverUtils.Rollout(ref config, maxScore, rootNode, expandedNode, random, simulator);
}
public static void Backpropagate(Node rootNode, Node startNode, float score)
public static void Backpropagate(RootScores rootScores, Node rootNode, Node startNode, float score)
{
while (true)
{
startNode.State.Scores.VisitConcurrent(score);
if (startNode == rootNode)
{
rootScores.VisitConcurrent(score);
break;
}
startNode.ParentScores!.Value.VisitConcurrent(startNode.ChildIdx, score);
startNode = startNode.Parent!;
}
}
public static bool SearchIter(ref SolverConfig config, Node rootNode, Random random, Simulator simulator)
public static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator)
{
var selectedNode = Select(ref config, rootNode);
var rolledOut = ExpandAndRollout(ref config, rootNode, random, simulator, selectedNode);
var selectedNode = Select(ref config, rootScores.Visits, rootNode);
var rolledOut = ExpandAndRollout(ref config, rootScores.MaxScore, rootNode, random, simulator, selectedNode);
if (!rolledOut.HasValue)
return false;
var (endNode, score) = rolledOut.Value;
Backpropagate(rootNode, endNode, score);
Backpropagate(rootScores, rootNode, endNode, score);
return true;
}
public static void SearchThread(SolverConfig config, Node rootNode, CancellationToken token) =>
SolverUtils.Search<SolverConcurrent>(ref config, config.Iterations / config.ThreadCount, rootNode, token);
public static void SearchThread(SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token) =>
SolverUtils.Search<SolverConcurrent>(ref config, config.Iterations / config.ThreadCount, rootScores, rootNode, token);
public static void Search(ref SolverConfig config, Node rootNode, CancellationToken token)
public static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token)
{
var configP = config;
var tasks = new Task[config.ThreadCount];
for (var i = 0; i < config.ThreadCount; ++i)
tasks[i] = Task.Run(() => SearchThread(configP, rootNode, token), token);
tasks[i] = Task.Run(() => SearchThread(configP, rootScores, rootNode, token), token);
Task.WaitAll(tasks, CancellationToken.None);
}
}
+18 -26
View File
@@ -7,25 +7,13 @@ namespace Craftimizer.Solver.Crafty;
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
public sealed class SolverSingle : ISolver
{
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
public static void LoadChildData(Span<float> scoreSums, Span<int> visits, Span<float> maxScores, ref Node[] chunk, int iterCount)
{
for (var j = 0; j < iterCount; ++j)
{
ref var node = ref chunk[j].State.Scores;
scoreSums[j] = node.ScoreSum;
visits[j] = node.Visits;
maxScores[j] = node.MaxScore;
}
}
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Node EvalBestChild(ref SolverConfig config, int parentVisits, ref Node.ChildBuffer children) =>
public static (int arrayIdx, int subIdx) EvalBestChild(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer children) =>
SolverUtils.EvalBestChild<SolverSingle>(ref config, parentVisits, ref children);
[Pure]
public static Node Select(ref SolverConfig config, Node node)
public static Node Select(ref SolverConfig config, int nodeVisits, Node node)
{
while (true)
{
@@ -35,11 +23,13 @@ public sealed class SolverSingle : ISolver
return node;
// select the node with the highest score
node = EvalBestChild(ref config, node.State.Scores.Visits, ref node.Children);
var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores);
nodeVisits = node.ChildScores.GetVisits(at);
node = node.ChildAt(at);
}
}
public static (Node ExpandedNode, float Score) ExpandAndRollout(ref SolverConfig config, Node rootNode, Random random, Simulator simulator, Node initialNode)
public static (Node ExpandedNode, float Score) ExpandAndRollout(ref SolverConfig config, float maxScore, Node rootNode, Random random, Simulator simulator, Node initialNode)
{
ref var initialState = ref initialNode.State;
// expand once
@@ -49,31 +39,33 @@ public sealed class SolverSingle : ISolver
var poppedAction = initialState.AvailableActions.PopRandom(random);
var expandedNode = initialNode.Add(SolverUtils.Execute(simulator, initialState.State, poppedAction, true));
return SolverUtils.Rollout(ref config, rootNode, expandedNode, random, simulator);
return SolverUtils.Rollout(ref config, maxScore, rootNode, expandedNode, random, simulator);
}
public static void Backpropagate(Node rootNode, Node startNode, float score)
public static void Backpropagate(RootScores rootScores, Node rootNode, Node startNode, float score)
{
while (true)
{
startNode.State.Scores.Visit(score);
if (startNode == rootNode)
{
rootScores.Visit(score);
break;
}
startNode.ParentScores!.Value.Visit(startNode.ChildIdx, score);
startNode = startNode.Parent!;
}
}
public static bool SearchIter(ref SolverConfig config, Node rootNode, Random random, Simulator simulator)
public static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator)
{
var selectedNode = Select(ref config, rootNode);
var (endNode, score) = ExpandAndRollout(ref config, rootNode, random, simulator, selectedNode);
var selectedNode = Select(ref config, rootScores.Visits, rootNode);
var (endNode, score) = ExpandAndRollout(ref config, rootScores.MaxScore, rootNode, random, simulator, selectedNode);
Backpropagate(rootNode, endNode, score);
Backpropagate(rootScores, rootNode, endNode, score);
return true;
}
public static void Search(ref SolverConfig config, Node rootNode, CancellationToken token) =>
SolverUtils.Search<SolverSingle>(ref config, config.Iterations, rootNode, token);
public static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token) =>
SolverUtils.Search<SolverSingle>(ref config, config.Iterations, rootScores, rootNode, token);
}
+26 -26
View File
@@ -39,35 +39,32 @@ public static class SolverUtils
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Node ChildMaxScore(ref Node.ChildBuffer children)
public static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores)
{
var length = children.Count;
var length = scores.Count;
var vecLength = Vector<float>.Count;
Span<float> scores = stackalloc float[vecLength];
var max = (0, 0);
var maxScore = 0f;
for (var i = 0; length > 0; ++i)
{
var iterCount = Math.Min(vecLength, length);
ref var chunk = ref children.Data[i];
for (var j = 0; j < iterCount; ++j)
scores[j] = chunk[j].State.Scores.MaxScore;
ref var chunk = ref scores.Data[i];
var m = new Vector<float>(chunk.MaxScore.Span);
var idx = Intrinsics.HMaxIndex(new Vector<float>(scores), iterCount);
var idx = Intrinsics.HMaxIndex(m, iterCount);
if (scores[idx] >= maxScore)
if (m[idx] >= maxScore)
{
max = (i, idx);
maxScore = scores[idx];
maxScore = m[idx];
}
length -= iterCount;
}
return children.Data[max.Item1][max.Item2];
return max;
}
[Pure]
@@ -76,7 +73,7 @@ public static class SolverUtils
var actions = new List<ActionType>();
while (node.Children.Count != 0)
{
node = ChildMaxScore(ref node.Children);
node = node.ChildAt(ChildMaxScore(ref node.ChildScores));
if (node.State.Action != null)
actions.Add(node.State.Action.Value);
@@ -87,9 +84,9 @@ public static class SolverUtils
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
public static Node EvalBestChild<S>(ref SolverConfig config, int parentVisits, ref Node.ChildBuffer children) where S : ISolver
public static (int arrayIdx, int subIdx) EvalBestChild<S>(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer scores) where S : ISolver
{
var length = children.Count;
var length = scores.Count;
var vecLength = Vector<float>.Count;
var C = MathF.Sqrt(config.ExplorationConstant * MathF.Log(parentVisits));
@@ -107,13 +104,14 @@ public static class SolverUtils
{
var iterCount = Math.Min(vecLength, length);
S.LoadChildData(scoreSums, visits, maxScores, ref children.Data[i], iterCount);
ref var chunk = ref scores.Data[i];
var s = new Vector<float>(chunk.ScoreSum.Span);
var vInt = new Vector<int>(chunk.Visits.Span);
var m = new Vector<float>(chunk.MaxScore.Span);
var s = new Vector<float>(scoreSums);
var m = new Vector<float>(maxScores);
var vInt = new Vector<int>(visits);
vInt = Vector.Max(vInt, Vector<int>.One);
var v = Vector.ConvertToSingle(vInt);
var exploitation = (W * (s / v)) + (w * m);
var exploration = CVector * Intrinsics.ReciprocalSqrt(v);
var evalScores = exploitation + exploration;
@@ -129,11 +127,11 @@ public static class SolverUtils
length -= iterCount;
}
return children.Data[max.Item1][max.Item2];
return max;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static (Node ExpandedNode, float Score) Rollout(ref SolverConfig config, Node rootNode, Node expandedNode, Random random, Simulator simulator)
public static (Node ExpandedNode, float Score) Rollout(ref SolverConfig config, float maxScore, Node rootNode, Node expandedNode, Random random, Simulator simulator)
{
// playout to a terminal state
var currentState = expandedNode.State.State;
@@ -157,7 +155,7 @@ public static class SolverUtils
var score = SimulationNode.CalculateScoreForState(currentState, currentCompletionState, config.MaxStepCount) ?? 0;
if (currentCompletionState == CompletionState.ProgressComplete)
{
if (score >= config.ScoreStorageThreshold && score >= rootNode.State.Scores.MaxScore)
if (score >= config.ScoreStorageThreshold && score >= maxScore)
{
(var terminalNode, _) = ExecuteActions(simulator, expandedNode, actions[..actionCount], true);
return (terminalNode, score);
@@ -167,7 +165,7 @@ public static class SolverUtils
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Search<S>(ref SolverConfig config, int iterations, Node rootNode, CancellationToken token) where S : ISolver
public static void Search<S>(ref SolverConfig config, int iterations, RootScores rootScores, Node rootNode, CancellationToken token) where S : ISolver
{
Simulator simulator = new(rootNode.State.State, config.MaxStepCount);
var random = rootNode.State.State.Input.Random;
@@ -176,7 +174,7 @@ public static class SolverUtils
if (token.IsCancellationRequested)
break;
if (!S.SearchIter(ref config, rootNode, random, simulator))
if (!S.SearchIter(ref config, rootScores, rootNode, random, simulator))
{
// Retry, count this iteration as moot
i--;
@@ -211,15 +209,16 @@ public static class SolverUtils
var actions = new List<ActionType>();
var sim = new Simulator(state, config.MaxStepCount);
var rootNode = CreateRootNode(config, state, true);
RootScores rootScores = new();
while (!sim.IsComplete)
{
if (token.IsCancellationRequested)
break;
S.Search(ref config, rootNode, token);
S.Search(ref config, rootScores, rootNode, token);
var (solution_actions, solution_node) = Solution(rootNode);
if (solution_node.Scores.MaxScore >= 1.0)
if (rootScores.MaxScore >= 1.0)
{
actions.AddRange(solution_actions);
return (actions, solution_node.State);
@@ -243,7 +242,8 @@ public static class SolverUtils
public static (List<ActionType> Actions, SimulationState State) SearchOneshot<S>(SolverConfig config, SimulationState state, CancellationToken token = default) where S : ISolver
{
var rootNode = CreateRootNode(config, state, false);
S.Search(ref config, rootNode, token);
RootScores rootScores = new();
S.Search(ref config, rootScores, rootNode, token);
var (solution_actions, solution_node) = Solution(rootNode);
return (solution_actions, solution_node.State);
}