Remove deadlock issue, use raw segmented array types

This commit is contained in:
Asriel Camora
2023-07-05 10:22:09 +02:00
parent 4d96fd173f
commit 1f5da66bc6
4 changed files with 80 additions and 39 deletions
+1 -1
View File
@@ -45,7 +45,7 @@ internal static class Program
0 0
); );
var threads = 1; var threads = 8;
var config = new SolverConfig() var config = new SolverConfig()
{ {
Iterations = 30_000 / threads, Iterations = 30_000 / threads,
+3 -3
View File
@@ -56,8 +56,8 @@ public struct ActionSet
public readonly ActionType SelectRandom(Random random) => First(); public readonly ActionType SelectRandom(Random random) => First();
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public ActionType? PopRandom(Random random) => PopFirst(); //public ActionType? PopRandom(Random random) => PopFirst();
/*public ActionType? PopRandom(Random random) public ActionType? PopRandom(Random random)
{ {
uint snapshot; uint snapshot;
uint newValue; uint newValue;
@@ -76,7 +76,7 @@ public struct ActionSet
} }
while (Interlocked.CompareExchange(ref bits, newValue, snapshot) != snapshot); while (Interlocked.CompareExchange(ref bits, newValue, snapshot) != snapshot);
return action; return action;
}*/ }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public ActionType? PopFirst() public ActionType? PopFirst()
+46 -2
View File
@@ -1,11 +1,55 @@
using System.Diagnostics.Contracts;
using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Craftimizer.Solver.Crafty; namespace Craftimizer.Solver.Crafty;
public class ArenaNode<T> where T : struct public sealed class ArenaNode<T> where T : struct
{ {
// Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs
public struct ChildBuffer
{
// Technically 25, but it's very unlikely to actually get to there.
// The benchmark reaches 20 at most, but here we have a little leeway just in case.
private const int MaxSize = 24;
private static int BatchSize = Vector<float>.Count;
private static int BatchSizeBits = int.Log2(BatchSize);
private static int BatchSizeMask = BatchSize - 1;
private static int BatchCount = MaxSize / BatchSize;
public ArenaNode<T>[][] Data;
private int index;
private int count;
public readonly int Count => count;
public void Add(ArenaNode<T> node)
{
if (Data == null)
Interlocked.CompareExchange(ref Data, new ArenaNode<T>[BatchCount][], null);
var index = Interlocked.Increment(ref this.index) - 1;
var (arrayIdx, subIdx) = GetArrayIndex(index);
if (Data[arrayIdx] == null)
Interlocked.CompareExchange(ref Data[arrayIdx], new ArenaNode<T>[BatchSize], null);
Data[arrayIdx][subIdx] = node;
Interlocked.Increment(ref count);
}
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static (int arrayIdx, int subIdx) GetArrayIndex(int idx) =>
(idx >> BatchSizeBits, idx & BatchSizeMask);
}
public T State; public T State;
public readonly List<ArenaNode<T>> Children; public ChildBuffer Children;
public readonly ArenaNode<T>? Parent; public readonly ArenaNode<T>? Parent;
public ArenaNode(T state, ArenaNode<T>? parent = null) public ArenaNode(T state, ArenaNode<T>? parent = null)
+30 -33
View File
@@ -3,13 +3,12 @@ using Craftimizer.Simulator.Actions;
using System.Diagnostics.Contracts; using System.Diagnostics.Contracts;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>; using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
namespace Craftimizer.Solver.Crafty; namespace Craftimizer.Solver.Crafty;
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs // https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
public class Solver public sealed class Solver
{ {
public SolverConfig Config; public SolverConfig Config;
public Node RootNode; public Node RootNode;
@@ -64,46 +63,45 @@ public class Solver
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Node ChildMaxScore(ReadOnlySpan<Node> children) private static Node ChildMaxScore(ref Node.ChildBuffer children)
{ {
var length = children.Length; var length = children.Count;
var vecLength = Vector<float>.Count; var vecLength = Vector<float>.Count;
Span<float> scores = stackalloc float[vecLength]; Span<float> scores = stackalloc float[vecLength];
var max = 0; var max = (0, 0);
var maxScore = 0f; var maxScore = 0f;
for (var i = 0; i < length; i += vecLength) for (var i = 0; length > 0; ++i)
{ {
var iterCount = i + vecLength > length ? var iterCount = Math.Min(vecLength, length);
length - i :
vecLength;
ref var chunk = ref children.Data[i];
for (var j = 0; j < iterCount; ++j) for (var j = 0; j < iterCount; ++j)
scores[j] = children[i + j].State.Scores.MaxScore; scores[j] = chunk[j].State.Scores.MaxScore;
var idx = Intrinsics.HMaxIndex(new Vector<float>(scores), iterCount); var idx = Intrinsics.HMaxIndex(new Vector<float>(scores), iterCount);
if (scores[idx] >= maxScore) if (scores[idx] >= maxScore)
{ {
max = i + idx; max = (i, idx);
maxScore = scores[idx]; maxScore = scores[idx];
} }
length -= iterCount;
} }
return children[max]; return children.Data[max.Item1][max.Item2];
} }
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private Node EvalBestChild(float parentVisits, ReadOnlySpan<Node> children) private Node? EvalBestChild(float parentVisits, ref Node.ChildBuffer children)
{ {
if (parentVisits == 0) if (parentVisits == 0)
{
Console.WriteLine("no visits");
return null; return null;
}
var length = children.Count;
var vecLength = Vector<float>.Count; var vecLength = Vector<float>.Count;
var C = MathF.Sqrt(Config.ExplorationConstant * MathF.Log(parentVisits)); var C = MathF.Sqrt(Config.ExplorationConstant * MathF.Log(parentVisits));
@@ -114,18 +112,17 @@ public class Solver
Span<float> scoreSums = stackalloc float[vecLength]; Span<float> scoreSums = stackalloc float[vecLength];
Span<float> visits = stackalloc float[vecLength]; Span<float> visits = stackalloc float[vecLength];
Span<float> maxScores = stackalloc float[vecLength]; Span<float> maxScores = stackalloc float[vecLength];
var max = 0;
var maxScore = 0f;
for (var i = 0; i < length; i += vecLength)
{
var iterCount = i + vecLength > length ?
length - i :
vecLength;
var max = (0, 0);
var maxScore = 0f;
for (var i = 0; length > 0; ++i)
{
var iterCount = Math.Min(vecLength, length);
ref var chunk = ref children.Data[i];
for (var j = 0; j < iterCount; ++j) for (var j = 0; j < iterCount; ++j)
{ {
var node = children[i + j].State.Scores; var node = chunk[j]?.State.Scores ?? new();
scoreSums[j] = node.ScoreSum; scoreSums[j] = node.ScoreSum;
visits[j] = node.Visits; visits[j] = node.Visits;
maxScores[j] = node.MaxScore; maxScores[j] = node.MaxScore;
@@ -140,15 +137,17 @@ public class Solver
var evalScores = exploitation + exploration; var evalScores = exploitation + exploration;
var idx = Intrinsics.HMaxIndex(evalScores, iterCount); var idx = Intrinsics.HMaxIndex(evalScores, iterCount);
if (evalScores[idx] >= maxScore) if (evalScores[idx] >= maxScore)
{ {
max = i + idx; max = (i, idx);
maxScore = evalScores[idx]; maxScore = evalScores[idx];
} }
length -= iterCount;
} }
return children[max]; return children.Data[max.Item1][max.Item2];
} }
[Pure] [Pure]
@@ -164,7 +163,7 @@ public class Solver
// select the node with the highest score // select the node with the highest score
// if null (current node is invalid & not backpropagated just yet), try again from root // if null (current node is invalid & not backpropagated just yet), try again from root
node = EvalBestChild(node.State.Scores.Visits, node.Children) ?? RootNode; node = EvalBestChild(node.State.Scores.Visits, ref node.Children) ?? RootNode;
} }
} }
@@ -234,10 +233,8 @@ public class Solver
var selectedNode = Select(); var selectedNode = Select();
var rolledOut = ExpandAndRollout(simulator, selectedNode); var rolledOut = ExpandAndRollout(simulator, selectedNode);
//Monitor.Exit(selectedNode);
if (!rolledOut.HasValue) if (!rolledOut.HasValue)
{ {
Console.WriteLine("Retry");
// Retry, count this iteration as moot // Retry, count this iteration as moot
i--; i--;
continue; continue;
@@ -253,7 +250,7 @@ public class Solver
var tasks = new Task[Config.ThreadCount]; var tasks = new Task[Config.ThreadCount];
for (var i = 0; i < Config.ThreadCount; ++i) for (var i = 0; i < Config.ThreadCount; ++i)
tasks[i] = Task.Run(() => SearchThread(token), token); tasks[i] = Task.Run(() => SearchThread(token), token);
Task.WaitAll(tasks, token); Task.WaitAll(tasks, CancellationToken.None);
} }
[Pure] [Pure]
@@ -263,7 +260,7 @@ public class Solver
var node = RootNode; var node = RootNode;
while (node.Children.Count != 0) while (node.Children.Count != 0)
{ {
node = ChildMaxScore(CollectionsMarshal.AsSpan(node.Children)); node = ChildMaxScore(ref node.Children);
if (node.State.Action != null) if (node.State.Action != null)
actions.Add(node.State.Action.Value); actions.Add(node.State.Action.Value);