Remove deadlock issue, use raw segmented array types
This commit is contained in:
@@ -45,7 +45,7 @@ internal static class Program
|
|||||||
0
|
0
|
||||||
);
|
);
|
||||||
|
|
||||||
var threads = 1;
|
var threads = 8;
|
||||||
var config = new SolverConfig()
|
var config = new SolverConfig()
|
||||||
{
|
{
|
||||||
Iterations = 30_000 / threads,
|
Iterations = 30_000 / threads,
|
||||||
|
|||||||
@@ -56,8 +56,8 @@ public struct ActionSet
|
|||||||
public readonly ActionType SelectRandom(Random random) => First();
|
public readonly ActionType SelectRandom(Random random) => First();
|
||||||
|
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
public ActionType? PopRandom(Random random) => PopFirst();
|
//public ActionType? PopRandom(Random random) => PopFirst();
|
||||||
/*public ActionType? PopRandom(Random random)
|
public ActionType? PopRandom(Random random)
|
||||||
{
|
{
|
||||||
uint snapshot;
|
uint snapshot;
|
||||||
uint newValue;
|
uint newValue;
|
||||||
@@ -76,7 +76,7 @@ public struct ActionSet
|
|||||||
}
|
}
|
||||||
while (Interlocked.CompareExchange(ref bits, newValue, snapshot) != snapshot);
|
while (Interlocked.CompareExchange(ref bits, newValue, snapshot) != snapshot);
|
||||||
return action;
|
return action;
|
||||||
}*/
|
}
|
||||||
|
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
public ActionType? PopFirst()
|
public ActionType? PopFirst()
|
||||||
|
|||||||
@@ -1,11 +1,55 @@
|
|||||||
|
using System.Diagnostics.Contracts;
|
||||||
|
using System.Numerics;
|
||||||
using System.Runtime.CompilerServices;
|
using System.Runtime.CompilerServices;
|
||||||
|
|
||||||
namespace Craftimizer.Solver.Crafty;
|
namespace Craftimizer.Solver.Crafty;
|
||||||
|
|
||||||
public class ArenaNode<T> where T : struct
|
public sealed class ArenaNode<T> where T : struct
|
||||||
{
|
{
|
||||||
|
// Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs
|
||||||
|
public struct ChildBuffer
|
||||||
|
{
|
||||||
|
// Technically 25, but it's very unlikely to actually get to there.
|
||||||
|
// The benchmark reaches 20 at most, but here we have a little leeway just in case.
|
||||||
|
private const int MaxSize = 24;
|
||||||
|
|
||||||
|
private static int BatchSize = Vector<float>.Count;
|
||||||
|
private static int BatchSizeBits = int.Log2(BatchSize);
|
||||||
|
private static int BatchSizeMask = BatchSize - 1;
|
||||||
|
|
||||||
|
private static int BatchCount = MaxSize / BatchSize;
|
||||||
|
|
||||||
|
public ArenaNode<T>[][] Data;
|
||||||
|
private int index;
|
||||||
|
private int count;
|
||||||
|
|
||||||
|
public readonly int Count => count;
|
||||||
|
|
||||||
|
public void Add(ArenaNode<T> node)
|
||||||
|
{
|
||||||
|
if (Data == null)
|
||||||
|
Interlocked.CompareExchange(ref Data, new ArenaNode<T>[BatchCount][], null);
|
||||||
|
|
||||||
|
var index = Interlocked.Increment(ref this.index) - 1;
|
||||||
|
|
||||||
|
var (arrayIdx, subIdx) = GetArrayIndex(index);
|
||||||
|
|
||||||
|
if (Data[arrayIdx] == null)
|
||||||
|
Interlocked.CompareExchange(ref Data[arrayIdx], new ArenaNode<T>[BatchSize], null);
|
||||||
|
|
||||||
|
Data[arrayIdx][subIdx] = node;
|
||||||
|
|
||||||
|
Interlocked.Increment(ref count);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Pure]
|
||||||
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
|
private static (int arrayIdx, int subIdx) GetArrayIndex(int idx) =>
|
||||||
|
(idx >> BatchSizeBits, idx & BatchSizeMask);
|
||||||
|
}
|
||||||
|
|
||||||
public T State;
|
public T State;
|
||||||
public readonly List<ArenaNode<T>> Children;
|
public ChildBuffer Children;
|
||||||
public readonly ArenaNode<T>? Parent;
|
public readonly ArenaNode<T>? Parent;
|
||||||
|
|
||||||
public ArenaNode(T state, ArenaNode<T>? parent = null)
|
public ArenaNode(T state, ArenaNode<T>? parent = null)
|
||||||
|
|||||||
+30
-33
@@ -3,13 +3,12 @@ using Craftimizer.Simulator.Actions;
|
|||||||
using System.Diagnostics.Contracts;
|
using System.Diagnostics.Contracts;
|
||||||
using System.Numerics;
|
using System.Numerics;
|
||||||
using System.Runtime.CompilerServices;
|
using System.Runtime.CompilerServices;
|
||||||
using System.Runtime.InteropServices;
|
|
||||||
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
|
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
|
||||||
|
|
||||||
namespace Craftimizer.Solver.Crafty;
|
namespace Craftimizer.Solver.Crafty;
|
||||||
|
|
||||||
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
|
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
|
||||||
public class Solver
|
public sealed class Solver
|
||||||
{
|
{
|
||||||
public SolverConfig Config;
|
public SolverConfig Config;
|
||||||
public Node RootNode;
|
public Node RootNode;
|
||||||
@@ -64,46 +63,45 @@ public class Solver
|
|||||||
|
|
||||||
[Pure]
|
[Pure]
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
private static Node ChildMaxScore(ReadOnlySpan<Node> children)
|
private static Node ChildMaxScore(ref Node.ChildBuffer children)
|
||||||
{
|
{
|
||||||
var length = children.Length;
|
var length = children.Count;
|
||||||
var vecLength = Vector<float>.Count;
|
var vecLength = Vector<float>.Count;
|
||||||
|
|
||||||
Span<float> scores = stackalloc float[vecLength];
|
Span<float> scores = stackalloc float[vecLength];
|
||||||
|
|
||||||
var max = 0;
|
var max = (0, 0);
|
||||||
var maxScore = 0f;
|
var maxScore = 0f;
|
||||||
for (var i = 0; i < length; i += vecLength)
|
for (var i = 0; length > 0; ++i)
|
||||||
{
|
{
|
||||||
var iterCount = i + vecLength > length ?
|
var iterCount = Math.Min(vecLength, length);
|
||||||
length - i :
|
|
||||||
vecLength;
|
|
||||||
|
|
||||||
|
ref var chunk = ref children.Data[i];
|
||||||
for (var j = 0; j < iterCount; ++j)
|
for (var j = 0; j < iterCount; ++j)
|
||||||
scores[j] = children[i + j].State.Scores.MaxScore;
|
scores[j] = chunk[j].State.Scores.MaxScore;
|
||||||
|
|
||||||
var idx = Intrinsics.HMaxIndex(new Vector<float>(scores), iterCount);
|
var idx = Intrinsics.HMaxIndex(new Vector<float>(scores), iterCount);
|
||||||
|
|
||||||
if (scores[idx] >= maxScore)
|
if (scores[idx] >= maxScore)
|
||||||
{
|
{
|
||||||
max = i + idx;
|
max = (i, idx);
|
||||||
maxScore = scores[idx];
|
maxScore = scores[idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
length -= iterCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
return children[max];
|
return children.Data[max.Item1][max.Item2];
|
||||||
}
|
}
|
||||||
|
|
||||||
[Pure]
|
[Pure]
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
private Node EvalBestChild(float parentVisits, ReadOnlySpan<Node> children)
|
private Node? EvalBestChild(float parentVisits, ref Node.ChildBuffer children)
|
||||||
{
|
{
|
||||||
if (parentVisits == 0)
|
if (parentVisits == 0)
|
||||||
{
|
|
||||||
Console.WriteLine("no visits");
|
|
||||||
return null;
|
return null;
|
||||||
}
|
|
||||||
|
|
||||||
|
var length = children.Count;
|
||||||
var vecLength = Vector<float>.Count;
|
var vecLength = Vector<float>.Count;
|
||||||
|
|
||||||
var C = MathF.Sqrt(Config.ExplorationConstant * MathF.Log(parentVisits));
|
var C = MathF.Sqrt(Config.ExplorationConstant * MathF.Log(parentVisits));
|
||||||
@@ -114,18 +112,17 @@ public class Solver
|
|||||||
Span<float> scoreSums = stackalloc float[vecLength];
|
Span<float> scoreSums = stackalloc float[vecLength];
|
||||||
Span<float> visits = stackalloc float[vecLength];
|
Span<float> visits = stackalloc float[vecLength];
|
||||||
Span<float> maxScores = stackalloc float[vecLength];
|
Span<float> maxScores = stackalloc float[vecLength];
|
||||||
|
|
||||||
var max = 0;
|
|
||||||
var maxScore = 0f;
|
|
||||||
for (var i = 0; i < length; i += vecLength)
|
|
||||||
{
|
|
||||||
var iterCount = i + vecLength > length ?
|
|
||||||
length - i :
|
|
||||||
vecLength;
|
|
||||||
|
|
||||||
|
var max = (0, 0);
|
||||||
|
var maxScore = 0f;
|
||||||
|
for (var i = 0; length > 0; ++i)
|
||||||
|
{
|
||||||
|
var iterCount = Math.Min(vecLength, length);
|
||||||
|
|
||||||
|
ref var chunk = ref children.Data[i];
|
||||||
for (var j = 0; j < iterCount; ++j)
|
for (var j = 0; j < iterCount; ++j)
|
||||||
{
|
{
|
||||||
var node = children[i + j].State.Scores;
|
var node = chunk[j]?.State.Scores ?? new();
|
||||||
scoreSums[j] = node.ScoreSum;
|
scoreSums[j] = node.ScoreSum;
|
||||||
visits[j] = node.Visits;
|
visits[j] = node.Visits;
|
||||||
maxScores[j] = node.MaxScore;
|
maxScores[j] = node.MaxScore;
|
||||||
@@ -140,15 +137,17 @@ public class Solver
|
|||||||
var evalScores = exploitation + exploration;
|
var evalScores = exploitation + exploration;
|
||||||
|
|
||||||
var idx = Intrinsics.HMaxIndex(evalScores, iterCount);
|
var idx = Intrinsics.HMaxIndex(evalScores, iterCount);
|
||||||
|
|
||||||
if (evalScores[idx] >= maxScore)
|
if (evalScores[idx] >= maxScore)
|
||||||
{
|
{
|
||||||
max = i + idx;
|
max = (i, idx);
|
||||||
maxScore = evalScores[idx];
|
maxScore = evalScores[idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
length -= iterCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
return children[max];
|
return children.Data[max.Item1][max.Item2];
|
||||||
}
|
}
|
||||||
|
|
||||||
[Pure]
|
[Pure]
|
||||||
@@ -164,7 +163,7 @@ public class Solver
|
|||||||
|
|
||||||
// select the node with the highest score
|
// select the node with the highest score
|
||||||
// if null (current node is invalid & not backpropagated just yet), try again from root
|
// if null (current node is invalid & not backpropagated just yet), try again from root
|
||||||
node = EvalBestChild(node.State.Scores.Visits, node.Children) ?? RootNode;
|
node = EvalBestChild(node.State.Scores.Visits, ref node.Children) ?? RootNode;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -234,10 +233,8 @@ public class Solver
|
|||||||
|
|
||||||
var selectedNode = Select();
|
var selectedNode = Select();
|
||||||
var rolledOut = ExpandAndRollout(simulator, selectedNode);
|
var rolledOut = ExpandAndRollout(simulator, selectedNode);
|
||||||
//Monitor.Exit(selectedNode);
|
|
||||||
if (!rolledOut.HasValue)
|
if (!rolledOut.HasValue)
|
||||||
{
|
{
|
||||||
Console.WriteLine("Retry");
|
|
||||||
// Retry, count this iteration as moot
|
// Retry, count this iteration as moot
|
||||||
i--;
|
i--;
|
||||||
continue;
|
continue;
|
||||||
@@ -253,7 +250,7 @@ public class Solver
|
|||||||
var tasks = new Task[Config.ThreadCount];
|
var tasks = new Task[Config.ThreadCount];
|
||||||
for (var i = 0; i < Config.ThreadCount; ++i)
|
for (var i = 0; i < Config.ThreadCount; ++i)
|
||||||
tasks[i] = Task.Run(() => SearchThread(token), token);
|
tasks[i] = Task.Run(() => SearchThread(token), token);
|
||||||
Task.WaitAll(tasks, token);
|
Task.WaitAll(tasks, CancellationToken.None);
|
||||||
}
|
}
|
||||||
|
|
||||||
[Pure]
|
[Pure]
|
||||||
@@ -263,7 +260,7 @@ public class Solver
|
|||||||
var node = RootNode;
|
var node = RootNode;
|
||||||
while (node.Children.Count != 0)
|
while (node.Children.Count != 0)
|
||||||
{
|
{
|
||||||
node = ChildMaxScore(CollectionsMarshal.AsSpan(node.Children));
|
node = ChildMaxScore(ref node.Children);
|
||||||
|
|
||||||
if (node.State.Action != null)
|
if (node.State.Action != null)
|
||||||
actions.Add(node.State.Action.Value);
|
actions.Add(node.State.Action.Value);
|
||||||
|
|||||||
Reference in New Issue
Block a user