From 1f5da66bc610bf68b9f0b18a35fe6839c0296820 Mon Sep 17 00:00:00 2001 From: Asriel Camora Date: Wed, 5 Jul 2023 10:22:09 +0200 Subject: [PATCH] Remove deadlock issue, use raw segmented array types --- Benchmark/Program.cs | 2 +- Solver/Crafty/ActionSet.cs | 6 ++-- Solver/Crafty/ArenaNode.cs | 48 +++++++++++++++++++++++++++-- Solver/Crafty/Solver.cs | 63 ++++++++++++++++++-------------------- 4 files changed, 80 insertions(+), 39 deletions(-) diff --git a/Benchmark/Program.cs b/Benchmark/Program.cs index c4442d0..d8bdbef 100644 --- a/Benchmark/Program.cs +++ b/Benchmark/Program.cs @@ -45,7 +45,7 @@ internal static class Program 0 ); - var threads = 1; + var threads = 8; var config = new SolverConfig() { Iterations = 30_000 / threads, diff --git a/Solver/Crafty/ActionSet.cs b/Solver/Crafty/ActionSet.cs index b558030..28941e5 100644 --- a/Solver/Crafty/ActionSet.cs +++ b/Solver/Crafty/ActionSet.cs @@ -56,8 +56,8 @@ public struct ActionSet public readonly ActionType SelectRandom(Random random) => First(); [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ActionType? PopRandom(Random random) => PopFirst(); - /*public ActionType? PopRandom(Random random) + //public ActionType? PopRandom(Random random) => PopFirst(); + public ActionType? PopRandom(Random random) { uint snapshot; uint newValue; @@ -76,7 +76,7 @@ public struct ActionSet } while (Interlocked.CompareExchange(ref bits, newValue, snapshot) != snapshot); return action; - }*/ + } [MethodImpl(MethodImplOptions.AggressiveInlining)] public ActionType? PopFirst() diff --git a/Solver/Crafty/ArenaNode.cs b/Solver/Crafty/ArenaNode.cs index ca15bb2..68e0f0e 100644 --- a/Solver/Crafty/ArenaNode.cs +++ b/Solver/Crafty/ArenaNode.cs @@ -1,11 +1,55 @@ +using System.Diagnostics.Contracts; +using System.Numerics; using System.Runtime.CompilerServices; namespace Craftimizer.Solver.Crafty; -public class ArenaNode where T : struct +public sealed class ArenaNode where T : struct { + // Adapted from https://github.com/dtao/ConcurrentList/blob/4fcf1c76e93021a41af5abb2d61a63caeba2adad/ConcurrentList/ConcurrentList.cs + public struct ChildBuffer + { + // Technically 25, but it's very unlikely to actually get to there. + // The benchmark reaches 20 at most, but here we have a little leeway just in case. + private const int MaxSize = 24; + + private static int BatchSize = Vector.Count; + private static int BatchSizeBits = int.Log2(BatchSize); + private static int BatchSizeMask = BatchSize - 1; + + private static int BatchCount = MaxSize / BatchSize; + + public ArenaNode[][] Data; + private int index; + private int count; + + public readonly int Count => count; + + public void Add(ArenaNode node) + { + if (Data == null) + Interlocked.CompareExchange(ref Data, new ArenaNode[BatchCount][], null); + + var index = Interlocked.Increment(ref this.index) - 1; + + var (arrayIdx, subIdx) = GetArrayIndex(index); + + if (Data[arrayIdx] == null) + Interlocked.CompareExchange(ref Data[arrayIdx], new ArenaNode[BatchSize], null); + + Data[arrayIdx][subIdx] = node; + + Interlocked.Increment(ref count); + } + + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static (int arrayIdx, int subIdx) GetArrayIndex(int idx) => + (idx >> BatchSizeBits, idx & BatchSizeMask); + } + public T State; - public readonly List> Children; + public ChildBuffer Children; public readonly ArenaNode? Parent; public ArenaNode(T state, ArenaNode? parent = null) diff --git a/Solver/Crafty/Solver.cs b/Solver/Crafty/Solver.cs index 3bed82d..bbd9846 100644 --- a/Solver/Crafty/Solver.cs +++ b/Solver/Crafty/Solver.cs @@ -3,13 +3,12 @@ using Craftimizer.Simulator.Actions; using System.Diagnostics.Contracts; using System.Numerics; using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; using Node = Craftimizer.Solver.Crafty.ArenaNode; namespace Craftimizer.Solver.Crafty; // https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs -public class Solver +public sealed class Solver { public SolverConfig Config; public Node RootNode; @@ -64,46 +63,45 @@ public class Solver [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Node ChildMaxScore(ReadOnlySpan children) + private static Node ChildMaxScore(ref Node.ChildBuffer children) { - var length = children.Length; + var length = children.Count; var vecLength = Vector.Count; Span scores = stackalloc float[vecLength]; - var max = 0; + var max = (0, 0); var maxScore = 0f; - for (var i = 0; i < length; i += vecLength) + for (var i = 0; length > 0; ++i) { - var iterCount = i + vecLength > length ? - length - i : - vecLength; + var iterCount = Math.Min(vecLength, length); + ref var chunk = ref children.Data[i]; for (var j = 0; j < iterCount; ++j) - scores[j] = children[i + j].State.Scores.MaxScore; + scores[j] = chunk[j].State.Scores.MaxScore; var idx = Intrinsics.HMaxIndex(new Vector(scores), iterCount); if (scores[idx] >= maxScore) { - max = i + idx; + max = (i, idx); maxScore = scores[idx]; } + + length -= iterCount; } - return children[max]; + return children.Data[max.Item1][max.Item2]; } [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] - private Node EvalBestChild(float parentVisits, ReadOnlySpan children) + private Node? EvalBestChild(float parentVisits, ref Node.ChildBuffer children) { if (parentVisits == 0) - { - Console.WriteLine("no visits"); return null; - } + var length = children.Count; var vecLength = Vector.Count; var C = MathF.Sqrt(Config.ExplorationConstant * MathF.Log(parentVisits)); @@ -114,18 +112,17 @@ public class Solver Span scoreSums = stackalloc float[vecLength]; Span visits = stackalloc float[vecLength]; Span maxScores = stackalloc float[vecLength]; - - var max = 0; - var maxScore = 0f; - for (var i = 0; i < length; i += vecLength) - { - var iterCount = i + vecLength > length ? - length - i : - vecLength; + var max = (0, 0); + var maxScore = 0f; + for (var i = 0; length > 0; ++i) + { + var iterCount = Math.Min(vecLength, length); + + ref var chunk = ref children.Data[i]; for (var j = 0; j < iterCount; ++j) { - var node = children[i + j].State.Scores; + var node = chunk[j]?.State.Scores ?? new(); scoreSums[j] = node.ScoreSum; visits[j] = node.Visits; maxScores[j] = node.MaxScore; @@ -140,15 +137,17 @@ public class Solver var evalScores = exploitation + exploration; var idx = Intrinsics.HMaxIndex(evalScores, iterCount); - + if (evalScores[idx] >= maxScore) { - max = i + idx; + max = (i, idx); maxScore = evalScores[idx]; } + + length -= iterCount; } - return children[max]; + return children.Data[max.Item1][max.Item2]; } [Pure] @@ -164,7 +163,7 @@ public class Solver // select the node with the highest score // if null (current node is invalid & not backpropagated just yet), try again from root - node = EvalBestChild(node.State.Scores.Visits, node.Children) ?? RootNode; + node = EvalBestChild(node.State.Scores.Visits, ref node.Children) ?? RootNode; } } @@ -234,10 +233,8 @@ public class Solver var selectedNode = Select(); var rolledOut = ExpandAndRollout(simulator, selectedNode); - //Monitor.Exit(selectedNode); if (!rolledOut.HasValue) { - Console.WriteLine("Retry"); // Retry, count this iteration as moot i--; continue; @@ -253,7 +250,7 @@ public class Solver var tasks = new Task[Config.ThreadCount]; for (var i = 0; i < Config.ThreadCount; ++i) tasks[i] = Task.Run(() => SearchThread(token), token); - Task.WaitAll(tasks, token); + Task.WaitAll(tasks, CancellationToken.None); } [Pure] @@ -263,7 +260,7 @@ public class Solver var node = RootNode; while (node.Children.Count != 0) { - node = ChildMaxScore(CollectionsMarshal.AsSpan(node.Children)); + node = ChildMaxScore(ref node.Children); if (node.State.Action != null) actions.Add(node.State.Action.Value);