Offload node score buffers

This commit is contained in:
Asriel Camora
2023-07-07 15:45:42 +02:00
parent 1386f9150c
commit e4d9e3a52e
10 changed files with 188 additions and 97 deletions
+10 -8
View File
@@ -5,7 +5,7 @@ using System.Runtime.CompilerServices;
namespace Craftimizer.Solver.Crafty; namespace Craftimizer.Solver.Crafty;
// Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs // Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs
public struct ArenaBuffer<T> public struct ArenaBuffer<T> where T : struct
{ {
// Technically 25, but it's very unlikely to actually get to there. // Technically 25, but it's very unlikely to actually get to there.
// The benchmark reaches 20 at most, but here we have a little leeway just in case. // The benchmark reaches 20 at most, but here we have a little leeway just in case.
@@ -17,39 +17,41 @@ public struct ArenaBuffer<T>
private static int BatchCount = MaxSize / BatchSize; private static int BatchCount = MaxSize / BatchSize;
public T[][] Data; public ArenaNode<T>[][] Data;
private int index; // Unused in single threaded workload private int index; // Unused in single threaded workload
private int count; private int count;
public readonly int Count => count; public readonly int Count => count;
public void AddConcurrent(T node) public void AddConcurrent(ArenaNode<T> node)
{ {
if (Data == null) if (Data == null)
Interlocked.CompareExchange(ref Data, new T[BatchCount][], null); Interlocked.CompareExchange(ref Data, new ArenaNode<T>[BatchCount][], null);
var idx = Interlocked.Increment(ref index) - 1; var idx = Interlocked.Increment(ref index) - 1;
var (arrayIdx, subIdx) = GetArrayIndex(idx); var (arrayIdx, subIdx) = GetArrayIndex(idx);
if (Data[arrayIdx] == null) if (Data[arrayIdx] == null)
Interlocked.CompareExchange(ref Data[arrayIdx], new T[BatchSize], null); Interlocked.CompareExchange(ref Data[arrayIdx], new ArenaNode<T>[BatchSize], null);
node.ChildIdx = (arrayIdx, subIdx);
Data[arrayIdx][subIdx] = node; Data[arrayIdx][subIdx] = node;
Interlocked.Increment(ref count); Interlocked.Increment(ref count);
} }
public void Add(T node) public void Add(ArenaNode<T> node)
{ {
Data ??= new T[BatchCount][]; Data ??= new ArenaNode<T>[BatchCount][];
var idx = count++; var idx = count++;
var (arrayIdx, subIdx) = GetArrayIndex(idx); var (arrayIdx, subIdx) = GetArrayIndex(idx);
Data[arrayIdx] ??= new T[BatchSize]; Data[arrayIdx] ??= new ArenaNode<T>[BatchSize];
node.ChildIdx = (arrayIdx, subIdx);
Data[arrayIdx][subIdx] = node; Data[arrayIdx][subIdx] = node;
} }
+11 -1
View File
@@ -5,21 +5,30 @@ namespace Craftimizer.Solver.Crafty;
public sealed class ArenaNode<T> where T : struct public sealed class ArenaNode<T> where T : struct
{ {
public T State; public T State;
public ArenaBuffer<ArenaNode<T>> Children; public ArenaBuffer<T> Children;
public NodeScoresBuffer ChildScores;
public (int arrayIdx, int subIdx) ChildIdx;
public readonly ArenaNode<T>? Parent; public readonly ArenaNode<T>? Parent;
public NodeScoresBuffer? ParentScores => Parent?.ChildScores;
public ArenaNode(T state, ArenaNode<T>? parent = null) public ArenaNode(T state, ArenaNode<T>? parent = null)
{ {
State = state; State = state;
Children = new(); Children = new();
ChildScores = new();
Parent = parent; Parent = parent;
} }
public ArenaNode<T> ChildAt((int arrayIdx, int subIdx) at) =>
Children.Data[at.arrayIdx][at.subIdx];
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public ArenaNode<T> AddConcurrent(T state) public ArenaNode<T> AddConcurrent(T state)
{ {
var node = new ArenaNode<T>(state, this); var node = new ArenaNode<T>(state, this);
Children.AddConcurrent(node); Children.AddConcurrent(node);
ChildScores.AddConcurrent();
return node; return node;
} }
@@ -28,6 +37,7 @@ public sealed class ArenaNode<T> where T : struct
{ {
var node = new ArenaNode<T>(state, this); var node = new ArenaNode<T>(state, this);
Children.Add(node); Children.Add(node);
ChildScores.Add();
return node; return node;
} }
} }
+2 -4
View File
@@ -4,9 +4,7 @@ namespace Craftimizer.Solver.Crafty;
public interface ISolver public interface ISolver
{ {
abstract static void LoadChildData(Span<float> scoreSums, Span<int> visits, Span<float> maxScores, ref Node[] chunk, int iterCount); abstract static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator);
abstract static bool SearchIter(ref SolverConfig config, Node rootNode, Random random, Simulator simulator); abstract static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token);
abstract static void Search(ref SolverConfig config, Node rootNode, CancellationToken token);
} }
+1 -2
View File
@@ -101,9 +101,8 @@ internal static class Intrinsics
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int NthBitSet(uint value, int n) public static int NthBitSet(uint value, int n)
{ {
// TODO: debug
if (n >= BitOperations.PopCount(value)) if (n >= BitOperations.PopCount(value))
throw new ArgumentException(null, nameof(value)); return 32;
return Bmi2.IsSupported ? return Bmi2.IsSupported ?
NthBitSetBMI2(value, n) : NthBitSetBMI2(value, n) :
+90
View File
@@ -0,0 +1,90 @@
using System;
using System.ComponentModel;
using System.Diagnostics.Contracts;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
namespace Craftimizer.Solver.Crafty;
// Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs
public struct NodeScoresBuffer
{
public sealed class ScoresBatch
{
public Memory<float> ScoreSum;
public Memory<float> MaxScore;
public Memory<int> Visits;
public ScoresBatch()
{
ScoreSum = new float[BatchSize];
MaxScore = new float[BatchSize];
Visits = new int[BatchSize];
}
}
// Technically 25, but it's very unlikely to actually get to there.
// The benchmark reaches 20 at most, but here we have a little leeway just in case.
private const int MaxSize = 24;
private static int BatchSize = Vector<float>.Count;
private static int BatchSizeBits = int.Log2(BatchSize);
private static int BatchSizeMask = BatchSize - 1;
private static int BatchCount = MaxSize / BatchSize;
public ScoresBatch[] Data;
private int index;
private int count;
public readonly int Count => count;
public void AddConcurrent()
{
if (Data == null)
Interlocked.CompareExchange(ref Data, new ScoresBatch[BatchCount], null);
var idx = Interlocked.Increment(ref index) - 1;
var (arrayIdx, _) = GetArrayIndex(idx);
if (Data[arrayIdx] == null)
Interlocked.CompareExchange(ref Data[arrayIdx], new ScoresBatch(), null);
Interlocked.Increment(ref count);
}
public void Add()
{
Data ??= new ScoresBatch[BatchCount];
var idx = count++;
var (arrayIdx, _) = GetArrayIndex(idx);
Data[arrayIdx] ??= new();
}
public readonly void VisitConcurrent((int arrayIdx, int subIdx) at, float score)
{
Intrinsics.CASAdd(ref Data[at.arrayIdx].ScoreSum.Span[at.subIdx], score);
Intrinsics.CASMax(ref Data[at.arrayIdx].MaxScore.Span[at.subIdx], score);
Interlocked.Increment(ref Data[at.arrayIdx].Visits.Span[at.subIdx]);
}
public readonly void Visit((int arrayIdx, int subIdx) at, float score)
{
Data[at.arrayIdx].ScoreSum.Span[at.subIdx] += score;
Data[at.arrayIdx].MaxScore.Span[at.subIdx] = Math.Max(Data[at.arrayIdx].MaxScore.Span[at.subIdx], score);
Data[at.arrayIdx].Visits.Span[at.subIdx]++;
}
public readonly int GetVisits((int arrayIdx, int subIdx) at) =>
Data[at.arrayIdx].Visits.Span[at.subIdx];
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static (int arrayIdx, int subIdx) GetArrayIndex(int idx) =>
(idx >> BatchSizeBits, idx & BatchSizeMask);
}
@@ -3,7 +3,7 @@ using System.Runtime.InteropServices;
namespace Craftimizer.Solver.Crafty; namespace Craftimizer.Solver.Crafty;
[StructLayout(LayoutKind.Auto)] [StructLayout(LayoutKind.Auto)]
public struct NodeScores public sealed class RootScores
{ {
public float ScoreSum; public float ScoreSum;
public float MaxScore; public float MaxScore;
-1
View File
@@ -12,7 +12,6 @@ public struct SimulationNode
public readonly CompletionState SimulationCompletionState; public readonly CompletionState SimulationCompletionState;
public ActionSet AvailableActions; public ActionSet AvailableActions;
public NodeScores Scores;
public readonly CompletionState CompletionState => GetCompletionState(SimulationCompletionState, AvailableActions); public readonly CompletionState CompletionState => GetCompletionState(SimulationCompletionState, AvailableActions);
+29 -28
View File
@@ -7,29 +7,18 @@ namespace Craftimizer.Solver.Crafty;
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs // https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
public sealed class SolverConcurrent : ISolver public sealed class SolverConcurrent : ISolver
{ {
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
public static void LoadChildData(Span<float> scoreSums, Span<int> visits, Span<float> maxScores, ref Node[] chunk, int iterCount)
{
for (var j = 0; j < iterCount; ++j)
{
var node = chunk[j]?.State.Scores ?? new();
scoreSums[j] = node.ScoreSum;
visits[j] = node.Visits;
maxScores[j] = node.MaxScore;
}
}
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Node? EvalBestChild(ref SolverConfig config, int parentVisits, ref Node.ChildBuffer children) => public static (int arrayIdx, int subIdx)? EvalBestChild(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer children) =>
parentVisits == 0 ? parentVisits == 0 ?
null : null :
SolverUtils.EvalBestChild<SolverConcurrent>(ref config, parentVisits, ref children); SolverUtils.EvalBestChild<SolverConcurrent>(ref config, parentVisits, ref children);
[Pure] [Pure]
public static Node Select(ref SolverConfig config, Node rootNode) public static Node Select(ref SolverConfig config, int rootNodeVisits, Node rootNode)
{ {
var node = rootNode; var node = rootNode;
var nodeVisits = rootNodeVisits;
while (true) while (true)
{ {
var expandable = !node.State.AvailableActions.IsEmpty; var expandable = !node.State.AvailableActions.IsEmpty;
@@ -39,11 +28,21 @@ public sealed class SolverConcurrent : ISolver
// select the node with the highest score // select the node with the highest score
// if null (current node is invalid & not backpropagated just yet), try again from root // if null (current node is invalid & not backpropagated just yet), try again from root
node = EvalBestChild(ref config, node.State.Scores.Visits, ref node.Children) ?? rootNode; var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores);
if (at.HasValue)
{
nodeVisits = node.ChildScores.GetVisits(at.Value);
node = node.ChildAt(at.Value);
}
else
{
node = rootNode;
nodeVisits = rootNodeVisits;
}
} }
} }
public static (Node ExpandedNode, float Score)? ExpandAndRollout(ref SolverConfig config, Node rootNode, Random random, Simulator simulator, Node initialNode) public static (Node ExpandedNode, float Score)? ExpandAndRollout(ref SolverConfig config, float maxScore, Node rootNode, Random random, Simulator simulator, Node initialNode)
{ {
ref var initialState = ref initialNode.State; ref var initialState = ref initialNode.State;
// expand once // expand once
@@ -55,43 +54,45 @@ public sealed class SolverConcurrent : ISolver
return null; return null;
var expandedNode = initialNode.AddConcurrent(SolverUtils.Execute(simulator, initialState.State, poppedAction.Value, true)); var expandedNode = initialNode.AddConcurrent(SolverUtils.Execute(simulator, initialState.State, poppedAction.Value, true));
return SolverUtils.Rollout(ref config, rootNode, expandedNode, random, simulator); return SolverUtils.Rollout(ref config, maxScore, rootNode, expandedNode, random, simulator);
} }
public static void Backpropagate(Node rootNode, Node startNode, float score) public static void Backpropagate(RootScores rootScores, Node rootNode, Node startNode, float score)
{ {
while (true) while (true)
{ {
startNode.State.Scores.VisitConcurrent(score);
if (startNode == rootNode) if (startNode == rootNode)
{
rootScores.VisitConcurrent(score);
break; break;
}
startNode.ParentScores!.Value.VisitConcurrent(startNode.ChildIdx, score);
startNode = startNode.Parent!; startNode = startNode.Parent!;
} }
} }
public static bool SearchIter(ref SolverConfig config, Node rootNode, Random random, Simulator simulator) public static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator)
{ {
var selectedNode = Select(ref config, rootNode); var selectedNode = Select(ref config, rootScores.Visits, rootNode);
var rolledOut = ExpandAndRollout(ref config, rootNode, random, simulator, selectedNode); var rolledOut = ExpandAndRollout(ref config, rootScores.MaxScore, rootNode, random, simulator, selectedNode);
if (!rolledOut.HasValue) if (!rolledOut.HasValue)
return false; return false;
var (endNode, score) = rolledOut.Value; var (endNode, score) = rolledOut.Value;
Backpropagate(rootNode, endNode, score); Backpropagate(rootScores, rootNode, endNode, score);
return true; return true;
} }
public static void SearchThread(SolverConfig config, Node rootNode, CancellationToken token) => public static void SearchThread(SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token) =>
SolverUtils.Search<SolverConcurrent>(ref config, config.Iterations / config.ThreadCount, rootNode, token); SolverUtils.Search<SolverConcurrent>(ref config, config.Iterations / config.ThreadCount, rootScores, rootNode, token);
public static void Search(ref SolverConfig config, Node rootNode, CancellationToken token) public static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token)
{ {
var configP = config; var configP = config;
var tasks = new Task[config.ThreadCount]; var tasks = new Task[config.ThreadCount];
for (var i = 0; i < config.ThreadCount; ++i) for (var i = 0; i < config.ThreadCount; ++i)
tasks[i] = Task.Run(() => SearchThread(configP, rootNode, token), token); tasks[i] = Task.Run(() => SearchThread(configP, rootScores, rootNode, token), token);
Task.WaitAll(tasks, CancellationToken.None); Task.WaitAll(tasks, CancellationToken.None);
} }
} }
+18 -26
View File
@@ -7,25 +7,13 @@ namespace Craftimizer.Solver.Crafty;
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs // https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
public sealed class SolverSingle : ISolver public sealed class SolverSingle : ISolver
{ {
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
public static void LoadChildData(Span<float> scoreSums, Span<int> visits, Span<float> maxScores, ref Node[] chunk, int iterCount)
{
for (var j = 0; j < iterCount; ++j)
{
ref var node = ref chunk[j].State.Scores;
scoreSums[j] = node.ScoreSum;
visits[j] = node.Visits;
maxScores[j] = node.MaxScore;
}
}
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Node EvalBestChild(ref SolverConfig config, int parentVisits, ref Node.ChildBuffer children) => public static (int arrayIdx, int subIdx) EvalBestChild(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer children) =>
SolverUtils.EvalBestChild<SolverSingle>(ref config, parentVisits, ref children); SolverUtils.EvalBestChild<SolverSingle>(ref config, parentVisits, ref children);
[Pure] [Pure]
public static Node Select(ref SolverConfig config, Node node) public static Node Select(ref SolverConfig config, int nodeVisits, Node node)
{ {
while (true) while (true)
{ {
@@ -35,11 +23,13 @@ public sealed class SolverSingle : ISolver
return node; return node;
// select the node with the highest score // select the node with the highest score
node = EvalBestChild(ref config, node.State.Scores.Visits, ref node.Children); var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores);
nodeVisits = node.ChildScores.GetVisits(at);
node = node.ChildAt(at);
} }
} }
public static (Node ExpandedNode, float Score) ExpandAndRollout(ref SolverConfig config, Node rootNode, Random random, Simulator simulator, Node initialNode) public static (Node ExpandedNode, float Score) ExpandAndRollout(ref SolverConfig config, float maxScore, Node rootNode, Random random, Simulator simulator, Node initialNode)
{ {
ref var initialState = ref initialNode.State; ref var initialState = ref initialNode.State;
// expand once // expand once
@@ -49,31 +39,33 @@ public sealed class SolverSingle : ISolver
var poppedAction = initialState.AvailableActions.PopRandom(random); var poppedAction = initialState.AvailableActions.PopRandom(random);
var expandedNode = initialNode.Add(SolverUtils.Execute(simulator, initialState.State, poppedAction, true)); var expandedNode = initialNode.Add(SolverUtils.Execute(simulator, initialState.State, poppedAction, true));
return SolverUtils.Rollout(ref config, rootNode, expandedNode, random, simulator); return SolverUtils.Rollout(ref config, maxScore, rootNode, expandedNode, random, simulator);
} }
public static void Backpropagate(Node rootNode, Node startNode, float score) public static void Backpropagate(RootScores rootScores, Node rootNode, Node startNode, float score)
{ {
while (true) while (true)
{ {
startNode.State.Scores.Visit(score);
if (startNode == rootNode) if (startNode == rootNode)
{
rootScores.Visit(score);
break; break;
}
startNode.ParentScores!.Value.Visit(startNode.ChildIdx, score);
startNode = startNode.Parent!; startNode = startNode.Parent!;
} }
} }
public static bool SearchIter(ref SolverConfig config, Node rootNode, Random random, Simulator simulator) public static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator)
{ {
var selectedNode = Select(ref config, rootNode); var selectedNode = Select(ref config, rootScores.Visits, rootNode);
var (endNode, score) = ExpandAndRollout(ref config, rootNode, random, simulator, selectedNode); var (endNode, score) = ExpandAndRollout(ref config, rootScores.MaxScore, rootNode, random, simulator, selectedNode);
Backpropagate(rootNode, endNode, score); Backpropagate(rootScores, rootNode, endNode, score);
return true; return true;
} }
public static void Search(ref SolverConfig config, Node rootNode, CancellationToken token) => public static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token) =>
SolverUtils.Search<SolverSingle>(ref config, config.Iterations, rootNode, token); SolverUtils.Search<SolverSingle>(ref config, config.Iterations, rootScores, rootNode, token);
} }
+26 -26
View File
@@ -39,35 +39,32 @@ public static class SolverUtils
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Node ChildMaxScore(ref Node.ChildBuffer children) public static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores)
{ {
var length = children.Count; var length = scores.Count;
var vecLength = Vector<float>.Count; var vecLength = Vector<float>.Count;
Span<float> scores = stackalloc float[vecLength];
var max = (0, 0); var max = (0, 0);
var maxScore = 0f; var maxScore = 0f;
for (var i = 0; length > 0; ++i) for (var i = 0; length > 0; ++i)
{ {
var iterCount = Math.Min(vecLength, length); var iterCount = Math.Min(vecLength, length);
ref var chunk = ref children.Data[i]; ref var chunk = ref scores.Data[i];
for (var j = 0; j < iterCount; ++j) var m = new Vector<float>(chunk.MaxScore.Span);
scores[j] = chunk[j].State.Scores.MaxScore;
var idx = Intrinsics.HMaxIndex(new Vector<float>(scores), iterCount); var idx = Intrinsics.HMaxIndex(m, iterCount);
if (scores[idx] >= maxScore) if (m[idx] >= maxScore)
{ {
max = (i, idx); max = (i, idx);
maxScore = scores[idx]; maxScore = m[idx];
} }
length -= iterCount; length -= iterCount;
} }
return children.Data[max.Item1][max.Item2]; return max;
} }
[Pure] [Pure]
@@ -76,7 +73,7 @@ public static class SolverUtils
var actions = new List<ActionType>(); var actions = new List<ActionType>();
while (node.Children.Count != 0) while (node.Children.Count != 0)
{ {
node = ChildMaxScore(ref node.Children); node = node.ChildAt(ChildMaxScore(ref node.ChildScores));
if (node.State.Action != null) if (node.State.Action != null)
actions.Add(node.State.Action.Value); actions.Add(node.State.Action.Value);
@@ -87,9 +84,9 @@ public static class SolverUtils
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
public static Node EvalBestChild<S>(ref SolverConfig config, int parentVisits, ref Node.ChildBuffer children) where S : ISolver public static (int arrayIdx, int subIdx) EvalBestChild<S>(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer scores) where S : ISolver
{ {
var length = children.Count; var length = scores.Count;
var vecLength = Vector<float>.Count; var vecLength = Vector<float>.Count;
var C = MathF.Sqrt(config.ExplorationConstant * MathF.Log(parentVisits)); var C = MathF.Sqrt(config.ExplorationConstant * MathF.Log(parentVisits));
@@ -107,13 +104,14 @@ public static class SolverUtils
{ {
var iterCount = Math.Min(vecLength, length); var iterCount = Math.Min(vecLength, length);
S.LoadChildData(scoreSums, visits, maxScores, ref children.Data[i], iterCount); ref var chunk = ref scores.Data[i];
var s = new Vector<float>(chunk.ScoreSum.Span);
var vInt = new Vector<int>(chunk.Visits.Span);
var m = new Vector<float>(chunk.MaxScore.Span);
var s = new Vector<float>(scoreSums);
var m = new Vector<float>(maxScores);
var vInt = new Vector<int>(visits);
vInt = Vector.Max(vInt, Vector<int>.One); vInt = Vector.Max(vInt, Vector<int>.One);
var v = Vector.ConvertToSingle(vInt); var v = Vector.ConvertToSingle(vInt);
var exploitation = (W * (s / v)) + (w * m); var exploitation = (W * (s / v)) + (w * m);
var exploration = CVector * Intrinsics.ReciprocalSqrt(v); var exploration = CVector * Intrinsics.ReciprocalSqrt(v);
var evalScores = exploitation + exploration; var evalScores = exploitation + exploration;
@@ -129,11 +127,11 @@ public static class SolverUtils
length -= iterCount; length -= iterCount;
} }
return children.Data[max.Item1][max.Item2]; return max;
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static (Node ExpandedNode, float Score) Rollout(ref SolverConfig config, Node rootNode, Node expandedNode, Random random, Simulator simulator) public static (Node ExpandedNode, float Score) Rollout(ref SolverConfig config, float maxScore, Node rootNode, Node expandedNode, Random random, Simulator simulator)
{ {
// playout to a terminal state // playout to a terminal state
var currentState = expandedNode.State.State; var currentState = expandedNode.State.State;
@@ -157,7 +155,7 @@ public static class SolverUtils
var score = SimulationNode.CalculateScoreForState(currentState, currentCompletionState, config.MaxStepCount) ?? 0; var score = SimulationNode.CalculateScoreForState(currentState, currentCompletionState, config.MaxStepCount) ?? 0;
if (currentCompletionState == CompletionState.ProgressComplete) if (currentCompletionState == CompletionState.ProgressComplete)
{ {
if (score >= config.ScoreStorageThreshold && score >= rootNode.State.Scores.MaxScore) if (score >= config.ScoreStorageThreshold && score >= maxScore)
{ {
(var terminalNode, _) = ExecuteActions(simulator, expandedNode, actions[..actionCount], true); (var terminalNode, _) = ExecuteActions(simulator, expandedNode, actions[..actionCount], true);
return (terminalNode, score); return (terminalNode, score);
@@ -167,7 +165,7 @@ public static class SolverUtils
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Search<S>(ref SolverConfig config, int iterations, Node rootNode, CancellationToken token) where S : ISolver public static void Search<S>(ref SolverConfig config, int iterations, RootScores rootScores, Node rootNode, CancellationToken token) where S : ISolver
{ {
Simulator simulator = new(rootNode.State.State, config.MaxStepCount); Simulator simulator = new(rootNode.State.State, config.MaxStepCount);
var random = rootNode.State.State.Input.Random; var random = rootNode.State.State.Input.Random;
@@ -176,7 +174,7 @@ public static class SolverUtils
if (token.IsCancellationRequested) if (token.IsCancellationRequested)
break; break;
if (!S.SearchIter(ref config, rootNode, random, simulator)) if (!S.SearchIter(ref config, rootScores, rootNode, random, simulator))
{ {
// Retry, count this iteration as moot // Retry, count this iteration as moot
i--; i--;
@@ -211,15 +209,16 @@ public static class SolverUtils
var actions = new List<ActionType>(); var actions = new List<ActionType>();
var sim = new Simulator(state, config.MaxStepCount); var sim = new Simulator(state, config.MaxStepCount);
var rootNode = CreateRootNode(config, state, true); var rootNode = CreateRootNode(config, state, true);
RootScores rootScores = new();
while (!sim.IsComplete) while (!sim.IsComplete)
{ {
if (token.IsCancellationRequested) if (token.IsCancellationRequested)
break; break;
S.Search(ref config, rootNode, token); S.Search(ref config, rootScores, rootNode, token);
var (solution_actions, solution_node) = Solution(rootNode); var (solution_actions, solution_node) = Solution(rootNode);
if (solution_node.Scores.MaxScore >= 1.0) if (rootScores.MaxScore >= 1.0)
{ {
actions.AddRange(solution_actions); actions.AddRange(solution_actions);
return (actions, solution_node.State); return (actions, solution_node.State);
@@ -243,7 +242,8 @@ public static class SolverUtils
public static (List<ActionType> Actions, SimulationState State) SearchOneshot<S>(SolverConfig config, SimulationState state, CancellationToken token = default) where S : ISolver public static (List<ActionType> Actions, SimulationState State) SearchOneshot<S>(SolverConfig config, SimulationState state, CancellationToken token = default) where S : ISolver
{ {
var rootNode = CreateRootNode(config, state, false); var rootNode = CreateRootNode(config, state, false);
S.Search(ref config, rootNode, token); RootScores rootScores = new();
S.Search(ref config, rootScores, rootNode, token);
var (solution_actions, solution_node) = Solution(rootNode); var (solution_actions, solution_node) = Solution(rootNode);
return (solution_actions, solution_node.State); return (solution_actions, solution_node.State);
} }