Remove deadlock issue, use raw segmented array types

This commit is contained in:
Asriel Camora
2023-07-05 10:22:09 +02:00
parent 4d96fd173f
commit 1f5da66bc6
4 changed files with 80 additions and 39 deletions
+30 -33
View File
@@ -3,13 +3,12 @@ using Craftimizer.Simulator.Actions;
using System.Diagnostics.Contracts;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
namespace Craftimizer.Solver.Crafty;
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
public class Solver
public sealed class Solver
{
public SolverConfig Config;
public Node RootNode;
@@ -64,46 +63,45 @@ public class Solver
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Node ChildMaxScore(ReadOnlySpan<Node> children)
private static Node ChildMaxScore(ref Node.ChildBuffer children)
{
var length = children.Length;
var length = children.Count;
var vecLength = Vector<float>.Count;
Span<float> scores = stackalloc float[vecLength];
var max = 0;
var max = (0, 0);
var maxScore = 0f;
for (var i = 0; i < length; i += vecLength)
for (var i = 0; length > 0; ++i)
{
var iterCount = i + vecLength > length ?
length - i :
vecLength;
var iterCount = Math.Min(vecLength, length);
ref var chunk = ref children.Data[i];
for (var j = 0; j < iterCount; ++j)
scores[j] = children[i + j].State.Scores.MaxScore;
scores[j] = chunk[j].State.Scores.MaxScore;
var idx = Intrinsics.HMaxIndex(new Vector<float>(scores), iterCount);
if (scores[idx] >= maxScore)
{
max = i + idx;
max = (i, idx);
maxScore = scores[idx];
}
length -= iterCount;
}
return children[max];
return children.Data[max.Item1][max.Item2];
}
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private Node EvalBestChild(float parentVisits, ReadOnlySpan<Node> children)
private Node? EvalBestChild(float parentVisits, ref Node.ChildBuffer children)
{
if (parentVisits == 0)
{
Console.WriteLine("no visits");
return null;
}
var length = children.Count;
var vecLength = Vector<float>.Count;
var C = MathF.Sqrt(Config.ExplorationConstant * MathF.Log(parentVisits));
@@ -114,18 +112,17 @@ public class Solver
Span<float> scoreSums = stackalloc float[vecLength];
Span<float> visits = stackalloc float[vecLength];
Span<float> maxScores = stackalloc float[vecLength];
var max = 0;
var maxScore = 0f;
for (var i = 0; i < length; i += vecLength)
{
var iterCount = i + vecLength > length ?
length - i :
vecLength;
var max = (0, 0);
var maxScore = 0f;
for (var i = 0; length > 0; ++i)
{
var iterCount = Math.Min(vecLength, length);
ref var chunk = ref children.Data[i];
for (var j = 0; j < iterCount; ++j)
{
var node = children[i + j].State.Scores;
var node = chunk[j]?.State.Scores ?? new();
scoreSums[j] = node.ScoreSum;
visits[j] = node.Visits;
maxScores[j] = node.MaxScore;
@@ -140,15 +137,17 @@ public class Solver
var evalScores = exploitation + exploration;
var idx = Intrinsics.HMaxIndex(evalScores, iterCount);
if (evalScores[idx] >= maxScore)
{
max = i + idx;
max = (i, idx);
maxScore = evalScores[idx];
}
length -= iterCount;
}
return children[max];
return children.Data[max.Item1][max.Item2];
}
[Pure]
@@ -164,7 +163,7 @@ public class Solver
// select the node with the highest score
// if null (current node is invalid & not backpropagated just yet), try again from root
node = EvalBestChild(node.State.Scores.Visits, node.Children) ?? RootNode;
node = EvalBestChild(node.State.Scores.Visits, ref node.Children) ?? RootNode;
}
}
@@ -234,10 +233,8 @@ public class Solver
var selectedNode = Select();
var rolledOut = ExpandAndRollout(simulator, selectedNode);
//Monitor.Exit(selectedNode);
if (!rolledOut.HasValue)
{
Console.WriteLine("Retry");
// Retry, count this iteration as moot
i--;
continue;
@@ -253,7 +250,7 @@ public class Solver
var tasks = new Task[Config.ThreadCount];
for (var i = 0; i < Config.ThreadCount; ++i)
tasks[i] = Task.Run(() => SearchThread(token), token);
Task.WaitAll(tasks, token);
Task.WaitAll(tasks, CancellationToken.None);
}
[Pure]
@@ -263,7 +260,7 @@ public class Solver
var node = RootNode;
while (node.Children.Count != 0)
{
node = ChildMaxScore(CollectionsMarshal.AsSpan(node.Children));
node = ChildMaxScore(ref node.Children);
if (node.State.Action != null)
actions.Add(node.State.Action.Value);