From 636501ab8659d50d8360e3c2b2f3c61515d0fb4c Mon Sep 17 00:00:00 2001 From: Asriel Camora Date: Fri, 7 Jul 2023 20:17:35 +0200 Subject: [PATCH] Remove all concurrency code Muddled the code too much, and only gave a marginal performance improvement in the grand scheme of things. Other ways to parallelize MCTS will be nicer to implement and could yield better results. --- Benchmark/Program.cs | 12 +- Solver/Crafty/ActionSet.cs | 69 +------ Solver/Crafty/ArenaBuffer.cs | 33 +--- Solver/Crafty/ArenaNode.cs | 9 - Solver/Crafty/ISolver.cs | 10 - Solver/Crafty/Intrinsics.cs | 24 --- Solver/Crafty/NodeScoresBuffer.cs | 40 +--- Solver/Crafty/RootScores.cs | 7 - Solver/Crafty/{SolverUtils.cs => Solver.cs} | 206 ++++++++++++-------- Solver/Crafty/SolverConcurrent.cs | 103 ---------- Solver/Crafty/SolverSingle.cs | 71 ------- 11 files changed, 153 insertions(+), 431 deletions(-) delete mode 100644 Solver/Crafty/ISolver.cs rename Solver/Crafty/{SolverUtils.cs => Solver.cs} (61%) delete mode 100644 Solver/Crafty/SolverConcurrent.cs delete mode 100644 Solver/Crafty/SolverSingle.cs diff --git a/Benchmark/Program.cs b/Benchmark/Program.cs index 69f0147..6ebdbdc 100644 --- a/Benchmark/Program.cs +++ b/Benchmark/Program.cs @@ -46,22 +46,22 @@ internal static class Program var config = new SolverConfig() { - Iterations = 100_000, + Iterations = 30_000, ThreadCount = 8, }; Debugger.Break(); var s = Stopwatch.StartNew(); if (true) { - (_, var state) = SolverUtils.SearchStepwise(config, input, a => Console.WriteLine(a)); + (_, var state) = Solver.Crafty.Solver.SearchStepwise(config, input, a => Console.WriteLine(a)); Console.WriteLine($"Qual: {state.Quality}/{state.Input.Recipe.MaxQuality}"); } else { - (var actions, _) = SolverUtils.SearchOneshot(config, input); - foreach (var action in actions) - Console.Write($">{action.IntName()}"); - Console.WriteLine(); + //(var actions, _) = SolverUtils.SearchOneshot(config, input); + //foreach (var action in actions) + // Console.Write($">{action.IntName()}"); + //Console.WriteLine(); } s.Stop(); Console.WriteLine($"{s.Elapsed.TotalMilliseconds:0.00}"); diff --git a/Solver/Crafty/ActionSet.cs b/Solver/Crafty/ActionSet.cs index 8d0a49c..cd1d067 100644 --- a/Solver/Crafty/ActionSet.cs +++ b/Solver/Crafty/ActionSet.cs @@ -7,6 +7,8 @@ namespace Craftimizer.Solver.Crafty; public struct ActionSet { + private const bool IsDeterministic = true; + private uint bits; [Pure] @@ -19,24 +21,6 @@ public struct ActionSet [MethodImpl(MethodImplOptions.AggressiveInlining)] private static uint ToMask(ActionType action) => 1u << FromAction(action) + 1; - // Return true if action was newly added and not there before. - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public bool AddActionConcurrent(ActionType action) - { - var mask = ToMask(action); - var old = Interlocked.Or(ref bits, mask); - return (old & mask) == 0; - } - - // Return true if action was newly removed and not already gone. - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public bool RemoveActionConcurrent(ActionType action) - { - var mask = ToMask(action); - var old = Interlocked.And(ref bits, ~mask); - return (old & mask) != 0; - } - // Return true if action was newly added and not there before. [MethodImpl(MethodImplOptions.AggressiveInlining)] public bool AddAction(ActionType action) @@ -71,52 +55,17 @@ public struct ActionSet public readonly bool IsEmpty => bits == 0; [MethodImpl(MethodImplOptions.AggressiveInlining)] - public readonly ActionType SelectRandom(Random random) => ElementAt(random.Next(Count)); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ActionType? PopRandomConcurrent(Random random) - { - uint snapshot; - uint newValue; - ActionType action; - do - { - snapshot = bits; - if (snapshot == 0) - return null; - - var count = BitOperations.PopCount(snapshot); - var index = random.Next(count); - - action = ToAction(Intrinsics.NthBitSet(snapshot, index) - 1); - newValue = snapshot & ~ToMask(action); - } - while (Interlocked.CompareExchange(ref bits, newValue, snapshot) != snapshot); - return action; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ActionType? PopFirstConcurrent() - { - uint snapshot; - uint newValue; - ActionType action; - do - { - snapshot = bits; - if (snapshot == 0) - return null; - - action = ToAction(Intrinsics.NthBitSet(snapshot, 0) - 1); - newValue = snapshot & ~ToMask(action); - } - while (Interlocked.CompareExchange(ref bits, newValue, snapshot) != snapshot); - return action; - } + public readonly ActionType SelectRandom(Random random) => + IsDeterministic ? + First() : + ElementAt(random.Next(Count)); [MethodImpl(MethodImplOptions.AggressiveInlining)] public ActionType PopRandom(Random random) { + if (IsDeterministic) + return PopFirst(); + var action = ElementAt(random.Next(Count)); RemoveAction(action); return action; diff --git a/Solver/Crafty/ArenaBuffer.cs b/Solver/Crafty/ArenaBuffer.cs index 3ec0f9c..c3775c2 100644 --- a/Solver/Crafty/ArenaBuffer.cs +++ b/Solver/Crafty/ArenaBuffer.cs @@ -11,41 +11,20 @@ public struct ArenaBuffer where T : struct // The benchmark reaches 20 at most, but here we have a little leeway just in case. private const int MaxSize = 24; - private static int BatchSize = Vector.Count; - private static int BatchSizeBits = int.Log2(BatchSize); - private static int BatchSizeMask = BatchSize - 1; + private static readonly int BatchSize = Vector.Count; + private static readonly int BatchSizeBits = int.Log2(BatchSize); + private static readonly int BatchSizeMask = BatchSize - 1; - private static int BatchCount = MaxSize / BatchSize; + private static readonly int BatchCount = MaxSize / BatchSize; public ArenaNode[][] Data; - private int index; // Unused in single threaded workload - private int count; - - public readonly int Count => count; - - public void AddConcurrent(ArenaNode node) - { - if (Data == null) - Interlocked.CompareExchange(ref Data, new ArenaNode[BatchCount][], null); - - var idx = Interlocked.Increment(ref index) - 1; - - var (arrayIdx, subIdx) = GetArrayIndex(idx); - - if (Data[arrayIdx] == null) - Interlocked.CompareExchange(ref Data[arrayIdx], new ArenaNode[BatchSize], null); - - node.ChildIdx = (arrayIdx, subIdx); - Data[arrayIdx][subIdx] = node; - - Interlocked.Increment(ref count); - } + public int Count { get; private set; } public void Add(ArenaNode node) { Data ??= new ArenaNode[BatchCount][]; - var idx = count++; + var idx = Count++; var (arrayIdx, subIdx) = GetArrayIndex(idx); diff --git a/Solver/Crafty/ArenaNode.cs b/Solver/Crafty/ArenaNode.cs index 89fba01..5ac88ea 100644 --- a/Solver/Crafty/ArenaNode.cs +++ b/Solver/Crafty/ArenaNode.cs @@ -23,15 +23,6 @@ public sealed class ArenaNode where T : struct public ArenaNode? ChildAt((int arrayIdx, int subIdx) at) => Children.Data?[at.arrayIdx]?[at.subIdx]; - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ArenaNode AddConcurrent(T state) - { - var node = new ArenaNode(state, this); - ChildScores.AddConcurrent(); - Children.AddConcurrent(node); - return node; - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] public ArenaNode Add(T state) { diff --git a/Solver/Crafty/ISolver.cs b/Solver/Crafty/ISolver.cs deleted file mode 100644 index 7c760f2..0000000 --- a/Solver/Crafty/ISolver.cs +++ /dev/null @@ -1,10 +0,0 @@ -using Node = Craftimizer.Solver.Crafty.ArenaNode; - -namespace Craftimizer.Solver.Crafty; - -public interface ISolver -{ - abstract static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator); - - abstract static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token); -} diff --git a/Solver/Crafty/Intrinsics.cs b/Solver/Crafty/Intrinsics.cs index 33e95f0..1380ae9 100644 --- a/Solver/Crafty/Intrinsics.cs +++ b/Solver/Crafty/Intrinsics.cs @@ -124,28 +124,4 @@ internal static class Intrinsics result[i] = MathF.ReciprocalSqrtEstimate(data[i]); return new(result); } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void CASMax(ref float location, float newValue) - { - float snapshot; - do - { - snapshot = location; - if (snapshot >= newValue) return; - } while (Interlocked.CompareExchange(ref location, newValue, snapshot) != snapshot); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void CASAdd(ref float location, float value) - { - float snapshot; - float newValue; - do - { - snapshot = location; - newValue = snapshot + value; - } - while (Interlocked.CompareExchange(ref location, newValue, snapshot) != snapshot); - } } diff --git a/Solver/Crafty/NodeScoresBuffer.cs b/Solver/Crafty/NodeScoresBuffer.cs index 1e93ee9..fa394bb 100644 --- a/Solver/Crafty/NodeScoresBuffer.cs +++ b/Solver/Crafty/NodeScoresBuffer.cs @@ -1,9 +1,6 @@ -using System; -using System.ComponentModel; using System.Diagnostics.Contracts; using System.Numerics; using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; namespace Craftimizer.Solver.Crafty; @@ -28,51 +25,26 @@ public struct NodeScoresBuffer // The benchmark reaches 20 at most, but here we have a little leeway just in case. private const int MaxSize = 24; - private static int BatchSize = Vector.Count; - private static int BatchSizeBits = int.Log2(BatchSize); - private static int BatchSizeMask = BatchSize - 1; + private static readonly int BatchSize = Vector.Count; + private static readonly int BatchSizeBits = int.Log2(BatchSize); + private static readonly int BatchSizeMask = BatchSize - 1; - private static int BatchCount = MaxSize / BatchSize; + private static readonly int BatchCount = MaxSize / BatchSize; public ScoresBatch[] Data; - private int index; - private int count; - - public readonly int Count => count; - - public void AddConcurrent() - { - if (Data == null) - Interlocked.CompareExchange(ref Data, new ScoresBatch[BatchCount], null); - - var idx = Interlocked.Increment(ref index) - 1; - - var (arrayIdx, _) = GetArrayIndex(idx); - - if (Data[arrayIdx] == null) - Interlocked.CompareExchange(ref Data[arrayIdx], new ScoresBatch(), null); - - Interlocked.Increment(ref count); - } + public int Count { get; private set; } public void Add() { Data ??= new ScoresBatch[BatchCount]; - var idx = count++; + var idx = Count++; var (arrayIdx, _) = GetArrayIndex(idx); Data[arrayIdx] ??= new(); } - public readonly void VisitConcurrent((int arrayIdx, int subIdx) at, float score) - { - Intrinsics.CASAdd(ref Data[at.arrayIdx].ScoreSum.Span[at.subIdx], score); - Intrinsics.CASMax(ref Data[at.arrayIdx].MaxScore.Span[at.subIdx], score); - Interlocked.Increment(ref Data[at.arrayIdx].Visits.Span[at.subIdx]); - } - public readonly void Visit((int arrayIdx, int subIdx) at, float score) { Data[at.arrayIdx].ScoreSum.Span[at.subIdx] += score; diff --git a/Solver/Crafty/RootScores.cs b/Solver/Crafty/RootScores.cs index fa7350d..2ac7c2f 100644 --- a/Solver/Crafty/RootScores.cs +++ b/Solver/Crafty/RootScores.cs @@ -9,13 +9,6 @@ public sealed class RootScores public float MaxScore; public int Visits; - public void VisitConcurrent(float score) - { - Intrinsics.CASAdd(ref ScoreSum, score); - Intrinsics.CASMax(ref MaxScore, score); - Interlocked.Increment(ref Visits); - } - public void Visit(float score) { ScoreSum += score; diff --git a/Solver/Crafty/SolverUtils.cs b/Solver/Crafty/Solver.cs similarity index 61% rename from Solver/Crafty/SolverUtils.cs rename to Solver/Crafty/Solver.cs index 959a8c8..a6f4016 100644 --- a/Solver/Crafty/SolverUtils.cs +++ b/Solver/Crafty/Solver.cs @@ -1,14 +1,35 @@ -using Craftimizer.Simulator; using Craftimizer.Simulator.Actions; +using Craftimizer.Simulator; using System.Diagnostics.Contracts; using System.Numerics; using System.Runtime.CompilerServices; using Node = Craftimizer.Solver.Crafty.ArenaNode; namespace Craftimizer.Solver.Crafty; -public static class SolverUtils + +// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs +public sealed class Solver { - public static SimulationNode Execute(Simulator simulator, SimulationState state, ActionType action, bool strict) + private SolverConfig config; + private Node rootNode; + private RootScores rootScores; + + public float MaxScore => rootScores.MaxScore; + + public Solver(SolverConfig config, SimulationState state, bool strict) + { + this.config = config; + var sim = new Simulator(state, config.MaxStepCount); + rootNode = new(new( + state, + null, + sim.CompletionState, + sim.AvailableActionsHeuristic(strict) + )); + rootScores = new(); + } + + private static SimulationNode Execute(Simulator simulator, SimulationState state, ActionType action, bool strict) { (_, var newState) = simulator.Execute(state, action); return new( @@ -19,27 +40,50 @@ public static class SolverUtils ); } - public static (Node EndNode, CompletionState State) ExecuteActions(Simulator simulator, Node startNode, ReadOnlySpan actions, bool strict = false) + private static Node ExecuteActions(Simulator simulator, Node startNode, ReadOnlySpan actions, bool strict) { foreach (var action in actions) { var state = startNode.State; if (state.IsComplete) - return (startNode, state.CompletionState); + return startNode; if (!state.AvailableActions.HasAction(action)) - return (startNode, CompletionState.InvalidAction); + return startNode; state.AvailableActions.RemoveAction(action); startNode = startNode.Add(Execute(simulator, state.State, action, strict)); } - return (startNode, startNode.State.CompletionState); + return startNode; + } + + [Pure] + private (List Actions, SimulationNode Node) Solution() + { + var actions = new List(); + var node = rootNode; + + while (node.Children.Count != 0) + { + node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!; + + if (node.State.Action != null) + actions.Add(node.State.Action.Value); + } + + var at = node.ChildIdx; + ref var sum = ref node.ParentScores!.Value.Data[at.arrayIdx].ScoreSum.Span[at.subIdx]; + ref var max = ref node.ParentScores!.Value.Data[at.arrayIdx].MaxScore.Span[at.subIdx]; + ref var visits = ref node.ParentScores!.Value.Data[at.arrayIdx].Visits.Span[at.subIdx]; + //Console.WriteLine($"{sum} {max} {visits}"); + + return (actions, node.State); } [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores) + private static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores) { var length = scores.Count; var vecLength = Vector.Count; @@ -67,21 +111,6 @@ public static class SolverUtils return max; } - [Pure] - public static (List Actions, SimulationNode Node) Solution(Node node) - { - var actions = new List(); - while (node.Children.Count != 0) - { - node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!; - - if (node.State.Action != null) - actions.Add(node.State.Action.Value); - } - - return (actions, node.State); - } - // Calculates the best child node to explore next // Exploitation: ((1 - w) * (s / v)) + (w * m) // Exploration: sqrt(c * ln(V) / v) @@ -98,10 +127,9 @@ public static class SolverUtils // max score is tacked onto it // N_i = parent visits // c = exploration constant (but crafty places it inside the sqrt..?) - [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] - public static (int arrayIdx, int subIdx) EvalBestChild(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer scores) where S : ISolver + private (int arrayIdx, int subIdx) EvalBestChild(int parentVisits, ref NodeScoresBuffer scores) { var length = scores.Count; var vecLength = Vector.Count; @@ -143,9 +171,36 @@ public static class SolverUtils return max; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static (Node ExpandedNode, float Score) Rollout(ref SolverConfig config, float maxScore, Node rootNode, Node expandedNode, Random random, Simulator simulator) + [Pure] + public Node Select() { + var node = rootNode; + var nodeVisits = rootScores.Visits; + + while (true) + { + var expandable = !node.State.AvailableActions.IsEmpty; + var likelyTerminal = node.Children.Count == 0; + if (expandable || likelyTerminal) + return node; + + // select the node with the highest score + var at = EvalBestChild(nodeVisits, ref node.ChildScores); + nodeVisits = node.ChildScores.GetVisits(at); + node = node.ChildAt(at)!; + } + } + + public (Node ExpandedNode, float Score) ExpandAndRollout(Random random, Simulator simulator, Node initialNode) + { + ref var initialState = ref initialNode.State; + // expand once + if (initialState.IsComplete) + return (initialNode, initialState.CalculateScore(config.MaxStepCount) ?? 0); + + var poppedAction = initialState.AvailableActions.PopRandom(random); + var expandedNode = initialNode.Add(Execute(simulator, initialState.State, poppedAction, true)); + // playout to a terminal state var currentState = expandedNode.State.State; var currentCompletionState = expandedNode.State.SimulationCompletionState; @@ -168,56 +223,51 @@ public static class SolverUtils var score = SimulationNode.CalculateScoreForState(currentState, currentCompletionState, config.MaxStepCount) ?? 0; if (currentCompletionState == CompletionState.ProgressComplete) { - if (score >= config.ScoreStorageThreshold && score >= maxScore) + if (score >= config.ScoreStorageThreshold && score >= MaxScore) { - (var terminalNode, _) = ExecuteActions(simulator, expandedNode, actions[..actionCount], true); + var terminalNode = ExecuteActions(simulator, expandedNode, actions[..actionCount], true); return (terminalNode, score); } } return (expandedNode, score); } + public void Backpropagate(Node startNode, float score) + { + while (true) + { + if (startNode == rootNode) + { + rootScores.Visit(score); + break; + } + startNode.ParentScores!.Value.Visit(startNode.ChildIdx, score); + + startNode = startNode.Parent!; + } + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void Search(ref SolverConfig config, int iterations, RootScores rootScores, Node rootNode, CancellationToken token) where S : ISolver + private void Search(CancellationToken token) { Simulator simulator = new(rootNode.State.State, config.MaxStepCount); var random = rootNode.State.State.Input.Random; - for (var i = 0; i < iterations; i++) + for (var i = 0; i < config.Iterations; i++) { if (token.IsCancellationRequested) break; - if (!S.SearchIter(ref config, rootScores, rootNode, random, simulator)) - { - // Retry, count this iteration as moot - i--; - continue; - } + var selectedNode = Select(); + var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode); + + Backpropagate(endNode, score); } } - [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Node CreateRootNode(SolverConfig config, SimulationInput input, bool strict) => - CreateRootNode(config, new SimulationState(input), strict); + public static (List Actions, SimulationState State) SearchStepwiseForked(SolverConfig config, int forkCount, SimulationInput input, Action? actionCallback, CancellationToken token = default) => + SearchStepwiseForked(config, forkCount, new SimulationState(input), actionCallback, token); - [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Node CreateRootNode(SolverConfig config, SimulationState state, bool strict) - { - var sim = new Simulator(state, config.MaxStepCount); - return new(new( - state, - null, - sim.CompletionState, - sim.AvailableActionsHeuristic(strict) - )); - } - - public static (List Actions, SimulationState State) SearchStepwiseForked(SolverConfig config, int forkCount, SimulationInput input, Action? actionCallback, CancellationToken token = default) where S : ISolver => - SearchStepwiseForked(config, forkCount, new SimulationState(input), actionCallback, token); - - public static (List Actions, SimulationState State) SearchStepwiseForked(SolverConfig config, int forkCount, SimulationState state, Action? actionCallback, CancellationToken token = default) where S : ISolver + public static (List Actions, SimulationState State) SearchStepwiseForked(SolverConfig config, int forkCount, SimulationState state, Action? actionCallback, CancellationToken token = default) { var actions = new List(); var sim = new Simulator(state, config.MaxStepCount); @@ -230,13 +280,11 @@ public static class SolverUtils for (var i = 0; i < forkCount; ++i) tasks[i] = Task.Run(() => { - var rootNode = CreateRootNode(config, state, true); - RootScores rootScores = new(); + var solver = new Solver(config, state, true); + solver.Search(token); + var (solution_actions, solution_node) = solver.Solution(); - S.Search(ref config, rootScores, rootNode, token); - var (solution_actions, solution_node) = Solution(rootNode); - - return (rootScores.MaxScore, solution_actions, solution_node.State); + return (solver.MaxScore, solution_actions, solution_node.State); }, token); Task.WaitAll(tasks, CancellationToken.None); @@ -258,24 +306,23 @@ public static class SolverUtils return (actions, state); } - public static (List Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationInput input, Action? actionCallback, CancellationToken token = default) where S : ISolver => - SearchStepwise(config, new SimulationState(input), actionCallback, token); + public static (List Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationInput input, Action? actionCallback, CancellationToken token = default) => + SearchStepwise(config, new SimulationState(input), actionCallback, token); - public static (List Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token = default) where S : ISolver + public static (List Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token = default) { var actions = new List(); var sim = new Simulator(state, config.MaxStepCount); - var rootNode = CreateRootNode(config, state, true); - RootScores rootScores = new(); + var solver = new Solver(config, state, true); while (!sim.IsComplete) { if (token.IsCancellationRequested) break; - S.Search(ref config, rootScores, rootNode, token); - var (solution_actions, solution_node) = Solution(rootNode); + solver.Search(token); + var (solution_actions, solution_node) = solver.Solution(); - if (rootScores.MaxScore >= 1.0) + if (solver.MaxScore >= 1.0) { actions.AddRange(solution_actions); return (actions, solution_node.State); @@ -287,21 +334,20 @@ public static class SolverUtils actionCallback?.Invoke(chosen_action); - rootNode = CreateRootNode(config, state, true); + solver = new Solver(config, state, true); } return (actions, state); } - public static (List Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationInput input, CancellationToken token = default) where S : ISolver => - SearchOneshot(config, new SimulationState(input), token); + public static (List Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationInput input, CancellationToken token = default) => + SearchOneshot(config, new SimulationState(input), token); - public static (List Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationState state, CancellationToken token = default) where S : ISolver + public static (List Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationState state, CancellationToken token = default) { - var rootNode = CreateRootNode(config, state, false); - RootScores rootScores = new(); - S.Search(ref config, rootScores, rootNode, token); - var (solution_actions, solution_node) = Solution(rootNode); + var solver = new Solver(config, state, false); + solver.Search(token); + var (solution_actions, solution_node) = solver.Solution(); return (solution_actions, solution_node.State); } } diff --git a/Solver/Crafty/SolverConcurrent.cs b/Solver/Crafty/SolverConcurrent.cs deleted file mode 100644 index 3bb6032..0000000 --- a/Solver/Crafty/SolverConcurrent.cs +++ /dev/null @@ -1,103 +0,0 @@ -using System.Diagnostics.Contracts; -using System.Runtime.CompilerServices; -using Node = Craftimizer.Solver.Crafty.ArenaNode; - -namespace Craftimizer.Solver.Crafty; - -// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs -public sealed class SolverConcurrent : ISolver -{ - [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static (int arrayIdx, int subIdx)? EvalBestChild(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer children) => - parentVisits == 0 ? - null : - SolverUtils.EvalBestChild(ref config, parentVisits, ref children); - - [Pure] - public static Node Select(ref SolverConfig config, int rootNodeVisits, Node rootNode) - { - var node = rootNode; - var nodeVisits = rootNodeVisits; - while (true) - { - var expandable = !node.State.AvailableActions.IsEmpty; - var likelyTerminal = node.Children.Count == 0; - if (expandable || likelyTerminal) - return node; - - // select the node with the highest score - // if null (current node is invalid & not backpropagated just yet), try again from root - var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores); - if (at.HasValue) - { - nodeVisits = node.ChildScores.GetVisits(at.Value); - node = node.ChildAt(at.Value); - if (node == null) - { - node = rootNode; - nodeVisits = rootNodeVisits; - } - } - else - { - node = rootNode; - nodeVisits = rootNodeVisits; - } - } - } - - public static (Node ExpandedNode, float Score)? ExpandAndRollout(ref SolverConfig config, float maxScore, Node rootNode, Random random, Simulator simulator, Node initialNode) - { - ref var initialState = ref initialNode.State; - // expand once - if (initialState.IsComplete) - return (initialNode, initialState.CalculateScore(config.MaxStepCount) ?? 0); - - var poppedAction = initialState.AvailableActions.PopRandomConcurrent(random); - if (!poppedAction.HasValue) - return null; - var expandedNode = initialNode.AddConcurrent(SolverUtils.Execute(simulator, initialState.State, poppedAction.Value, true)); - - return SolverUtils.Rollout(ref config, maxScore, rootNode, expandedNode, random, simulator); - } - - public static void Backpropagate(RootScores rootScores, Node rootNode, Node startNode, float score) - { - while (true) - { - if (startNode == rootNode) - { - rootScores.VisitConcurrent(score); - break; - } - startNode.ParentScores!.Value.VisitConcurrent(startNode.ChildIdx, score); - - startNode = startNode.Parent!; - } - } - - public static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator) - { - var selectedNode = Select(ref config, rootScores.Visits, rootNode); - var rolledOut = ExpandAndRollout(ref config, rootScores.MaxScore, rootNode, random, simulator, selectedNode); - if (!rolledOut.HasValue) - return false; - - var (endNode, score) = rolledOut.Value; - Backpropagate(rootScores, rootNode, endNode, score); - return true; - } - - public static void SearchThread(SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token) => - SolverUtils.Search(ref config, config.Iterations / config.ThreadCount, rootScores, rootNode, token); - - public static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token) - { - var configP = config; - var tasks = new Task[config.ThreadCount]; - for (var i = 0; i < config.ThreadCount; ++i) - tasks[i] = Task.Run(() => SearchThread(configP, rootScores, rootNode, token), token); - Task.WaitAll(tasks, CancellationToken.None); - } -} diff --git a/Solver/Crafty/SolverSingle.cs b/Solver/Crafty/SolverSingle.cs deleted file mode 100644 index 0ecc287..0000000 --- a/Solver/Crafty/SolverSingle.cs +++ /dev/null @@ -1,71 +0,0 @@ -using System.Diagnostics.Contracts; -using System.Runtime.CompilerServices; -using Node = Craftimizer.Solver.Crafty.ArenaNode; - -namespace Craftimizer.Solver.Crafty; - -// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs -public sealed class SolverSingle : ISolver -{ - [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static (int arrayIdx, int subIdx) EvalBestChild(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer children) => - SolverUtils.EvalBestChild(ref config, parentVisits, ref children); - - [Pure] - public static Node Select(ref SolverConfig config, int nodeVisits, Node node) - { - while (true) - { - var expandable = !node.State.AvailableActions.IsEmpty; - var likelyTerminal = node.Children.Count == 0; - if (expandable || likelyTerminal) - return node; - - // select the node with the highest score - var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores); - nodeVisits = node.ChildScores.GetVisits(at); - node = node.ChildAt(at)!; - } - } - - public static (Node ExpandedNode, float Score) ExpandAndRollout(ref SolverConfig config, float maxScore, Node rootNode, Random random, Simulator simulator, Node initialNode) - { - ref var initialState = ref initialNode.State; - // expand once - if (initialState.IsComplete) - return (initialNode, initialState.CalculateScore(config.MaxStepCount) ?? 0); - - var poppedAction = initialState.AvailableActions.PopRandom(random); - var expandedNode = initialNode.Add(SolverUtils.Execute(simulator, initialState.State, poppedAction, true)); - - return SolverUtils.Rollout(ref config, maxScore, rootNode, expandedNode, random, simulator); - } - - public static void Backpropagate(RootScores rootScores, Node rootNode, Node startNode, float score) - { - while (true) - { - if (startNode == rootNode) - { - rootScores.Visit(score); - break; - } - startNode.ParentScores!.Value.Visit(startNode.ChildIdx, score); - - startNode = startNode.Parent!; - } - } - - public static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator) - { - var selectedNode = Select(ref config, rootScores.Visits, rootNode); - var (endNode, score) = ExpandAndRollout(ref config, rootScores.MaxScore, rootNode, random, simulator, selectedNode); - - Backpropagate(rootScores, rootNode, endNode, score); - return true; - } - - public static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token) => - SolverUtils.Search(ref config, config.Iterations, rootScores, rootNode, token); -}