From d5a8288439bc38160fc419eff16e9ebf7c2d8644 Mon Sep 17 00:00:00 2001 From: Asriel Camora Date: Fri, 7 Jul 2023 09:58:47 +0200 Subject: [PATCH] Merge solver code with static interface - Breaks backwards compat solver code with last version. Concurrent broke backwards compat because of race conditions with rng, but single is now broken too, despite it being 2x faster (!!!!) - Literally twice as fast as Rust now in single thread - Concurrent doesn't work yet, deadlocks somewhere..? --- Benchmark/Program.cs | 4 +- Craftimizer/Windows/SimulatorWindowSolver.cs | 3 +- Solver/Crafty/ActionSet.cs | 49 ++++- Solver/Crafty/ArenaNode.cs | 27 ++- Solver/Crafty/ISolver.cs | 12 ++ Solver/Crafty/NodeScores.cs | 9 +- Solver/Crafty/SolverConcurrent.cs | 97 +++++++++ Solver/Crafty/SolverSingle.cs | 79 ++++++++ Solver/Crafty/{Solver.cs => SolverUtils.cs} | 197 ++++++------------- 9 files changed, 332 insertions(+), 145 deletions(-) create mode 100644 Solver/Crafty/ISolver.cs create mode 100644 Solver/Crafty/SolverConcurrent.cs create mode 100644 Solver/Crafty/SolverSingle.cs rename Solver/Crafty/{Solver.cs => SolverUtils.cs} (56%) diff --git a/Benchmark/Program.cs b/Benchmark/Program.cs index d8bdbef..aea78aa 100644 --- a/Benchmark/Program.cs +++ b/Benchmark/Program.cs @@ -55,10 +55,10 @@ internal static class Program Debugger.Break(); var s = Stopwatch.StartNew(); if (true) - _ = Solver.Crafty.Solver.SearchStepwise(config, input, a => Console.WriteLine(a)); + _ = SolverUtils.SearchStepwise(config, input, a => Console.WriteLine(a)); else { - (var actions, _) = Solver.Crafty.Solver.SearchOneshot(config, input); + (var actions, _) = SolverUtils.SearchOneshot(config, input); foreach (var action in actions) Console.Write($">{action.IntName()}"); Console.WriteLine(); diff --git a/Craftimizer/Windows/SimulatorWindowSolver.cs b/Craftimizer/Windows/SimulatorWindowSolver.cs index 9420e77..5f40c97 100644 --- a/Craftimizer/Windows/SimulatorWindowSolver.cs +++ b/Craftimizer/Windows/SimulatorWindowSolver.cs @@ -1,5 +1,6 @@ using Craftimizer.Simulator; using Craftimizer.Simulator.Actions; +using Craftimizer.Solver.Crafty; using Dalamud.Interface.Windowing; using System; using System.Collections.Concurrent; @@ -71,7 +72,7 @@ public sealed partial class SimulatorWindow : Window, IDisposable SolverInitialActionCount = Actions.Count; SolverTaskToken = new(); - SolverTask = Task.Run(() => Solver.Crafty.Solver.SearchStepwise(Service.Configuration.SolverConfig, solverState, SolverActionQueue.Enqueue, SolverTaskToken.Token)); + SolverTask = Task.Run(() => SolverUtils.SearchStepwise(Service.Configuration.SolverConfig, solverState, SolverActionQueue.Enqueue, SolverTaskToken.Token)); } public void Dispose() diff --git a/Solver/Crafty/ActionSet.cs b/Solver/Crafty/ActionSet.cs index 9ef9614..30632c2 100644 --- a/Solver/Crafty/ActionSet.cs +++ b/Solver/Crafty/ActionSet.cs @@ -17,11 +17,11 @@ public struct ActionSet private static ActionType ToAction(int index) => Simulator.AcceptedActions[index]; [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static uint ToMask(ActionType action) => 1u << (FromAction(action) + 1); + private static uint ToMask(ActionType action) => 1u << FromAction(action) + 1; // Return true if action was newly added and not there before. [MethodImpl(MethodImplOptions.AggressiveInlining)] - public bool AddAction(ActionType action) + public bool AddActionConcurrent(ActionType action) { var mask = ToMask(action); var old = Interlocked.Or(ref bits, mask); @@ -30,13 +30,33 @@ public struct ActionSet // Return true if action was newly removed and not already gone. [MethodImpl(MethodImplOptions.AggressiveInlining)] - public bool RemoveAction(ActionType action) + public bool RemoveActionConcurrent(ActionType action) { var mask = ToMask(action); var old = Interlocked.And(ref bits, ~mask); return (old & mask) != 0; } + // Return true if action was newly added and not there before. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool AddAction(ActionType action) + { + var mask = ToMask(action); + var old = bits; + bits |= mask; + return (old & mask) == 0; + } + + // Return true if action was newly removed and not already gone. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool RemoveAction(ActionType action) + { + var mask = ToMask(action); + var old = bits; + bits &= ~mask; + return (old & mask) != 0; + } + [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] public readonly bool HasAction(ActionType action) => (bits & ToMask(action)) != 0; @@ -51,10 +71,10 @@ public struct ActionSet public readonly bool IsEmpty => bits == 0; [MethodImpl(MethodImplOptions.AggressiveInlining)] - public readonly ActionType SelectRandom(Random random) => ElementAt(random.Next(Count)); + public readonly ActionType SelectRandom(Random random) => ElementAt(0);// random.Next(Count)); [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ActionType? PopRandom(Random random) + public ActionType? PopRandomConcurrent(Random random) { uint snapshot; uint newValue; @@ -76,7 +96,7 @@ public struct ActionSet } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ActionType? PopFirst() + public ActionType? PopFirstConcurrent() { uint snapshot; uint newValue; @@ -94,6 +114,23 @@ public struct ActionSet return action; } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ActionType PopRandom(Random random) + { + return PopFirst(); + var action = ElementAt(random.Next(Count)); + RemoveAction(action); + return action; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ActionType PopFirst() + { + var action = First(); + RemoveAction(action); + return action; + } + [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] public readonly ActionType First() => ElementAt(0); diff --git a/Solver/Crafty/ArenaNode.cs b/Solver/Crafty/ArenaNode.cs index 7b21639..79d4b36 100644 --- a/Solver/Crafty/ArenaNode.cs +++ b/Solver/Crafty/ArenaNode.cs @@ -20,17 +20,17 @@ public sealed class ArenaNode where T : struct private static int BatchCount = MaxSize / BatchSize; public ArenaNode[][] Data; - private int index; + private int index; // Unused in single threaded workload private int count; public readonly int Count => count; - public void Add(ArenaNode node) + public void AddConcurrent(ArenaNode node) { if (Data == null) Interlocked.CompareExchange(ref Data, new ArenaNode[BatchCount][], null); - var idx = Interlocked.Increment(ref this.index) - 1; + var idx = Interlocked.Increment(ref index) - 1; var (arrayIdx, subIdx) = GetArrayIndex(idx); @@ -42,6 +42,19 @@ public sealed class ArenaNode where T : struct Interlocked.Increment(ref count); } + public void Add(ArenaNode node) + { + Data ??= new ArenaNode[BatchCount][]; + + var idx = count++; + + var (arrayIdx, subIdx) = GetArrayIndex(idx); + + Data[arrayIdx] ??= new ArenaNode[BatchSize]; + + Data[arrayIdx][subIdx] = node; + } + [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] private static (int arrayIdx, int subIdx) GetArrayIndex(int idx) => @@ -59,6 +72,14 @@ public sealed class ArenaNode where T : struct Parent = parent; } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ArenaNode AddConcurrent(T state) + { + var node = new ArenaNode(state, this); + Children.AddConcurrent(node); + return node; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] public ArenaNode Add(T state) { diff --git a/Solver/Crafty/ISolver.cs b/Solver/Crafty/ISolver.cs new file mode 100644 index 0000000..050b3f7 --- /dev/null +++ b/Solver/Crafty/ISolver.cs @@ -0,0 +1,12 @@ +using Node = Craftimizer.Solver.Crafty.ArenaNode; + +namespace Craftimizer.Solver.Crafty; + +public interface ISolver +{ + abstract static void LoadChildData(Span scoreSums, Span visits, Span maxScores, ref Node[] chunk, int iterCount); + + abstract static bool SearchIter(ref SolverConfig config, Node rootNode, Random random, Simulator simulator); + + abstract static void Search(ref SolverConfig config, Node rootNode, CancellationToken token); +} diff --git a/Solver/Crafty/NodeScores.cs b/Solver/Crafty/NodeScores.cs index b09dc03..d0a2044 100644 --- a/Solver/Crafty/NodeScores.cs +++ b/Solver/Crafty/NodeScores.cs @@ -9,10 +9,17 @@ public struct NodeScores public float MaxScore; public int Visits; - public void Visit(float score) + public void VisitConcurrent(float score) { Intrinsics.CASAdd(ref ScoreSum, score); Intrinsics.CASMax(ref MaxScore, score); Interlocked.Increment(ref Visits); } + + public void Visit(float score) + { + ScoreSum += score; + MaxScore = Math.Max(MaxScore, score); + Visits++; + } } diff --git a/Solver/Crafty/SolverConcurrent.cs b/Solver/Crafty/SolverConcurrent.cs new file mode 100644 index 0000000..2a78649 --- /dev/null +++ b/Solver/Crafty/SolverConcurrent.cs @@ -0,0 +1,97 @@ +using System.Diagnostics.Contracts; +using System.Runtime.CompilerServices; +using Node = Craftimizer.Solver.Crafty.ArenaNode; + +namespace Craftimizer.Solver.Crafty; + +// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs +public sealed class SolverConcurrent : ISolver +{ + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public static void LoadChildData(Span scoreSums, Span visits, Span maxScores, ref Node[] chunk, int iterCount) + { + for (var j = 0; j < iterCount; ++j) + { + var node = chunk[j]?.State.Scores ?? new(); + scoreSums[j] = node.ScoreSum; + visits[j] = node.Visits; + maxScores[j] = node.MaxScore; + } + } + + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Node? EvalBestChild(ref SolverConfig config, int parentVisits, ref Node.ChildBuffer children) => + parentVisits == 0 ? + null : + SolverUtils.EvalBestChild(ref config, parentVisits, ref children); + + [Pure] + public static Node Select(ref SolverConfig config, Node rootNode) + { + var node = rootNode; + while (true) + { + var expandable = !node.State.AvailableActions.IsEmpty; + var likelyTerminal = node.Children.Count == 0; + if (expandable || likelyTerminal) + return node; + + // select the node with the highest score + // if null (current node is invalid & not backpropagated just yet), try again from root + node = EvalBestChild(ref config, node.State.Scores.Visits, ref node.Children) ?? rootNode; + } + } + + public static (Node ExpandedNode, float Score)? ExpandAndRollout(ref SolverConfig config, Node rootNode, Random random, Simulator simulator, Node initialNode) + { + ref var initialState = ref initialNode.State; + // expand once + if (initialState.IsComplete) + return (initialNode, initialState.CalculateScore(config.MaxStepCount) ?? 0); + + var poppedAction = initialState.AvailableActions.PopRandomConcurrent(random); + if (!poppedAction.HasValue) + return null; + var expandedNode = initialNode.Add(SolverUtils.Execute(simulator, initialState.State, poppedAction.Value, true)); + + return SolverUtils.Rollout(ref config, rootNode, expandedNode, random, simulator); + } + + public static void Backpropagate(Node rootNode, Node startNode, float score) + { + while (true) + { + startNode.State.Scores.VisitConcurrent(score); + + if (startNode == rootNode) + break; + + startNode = startNode.Parent!; + } + } + + public static bool SearchIter(ref SolverConfig config, Node rootNode, Random random, Simulator simulator) + { + var selectedNode = Select(ref config, rootNode); + var rolledOut = ExpandAndRollout(ref config, rootNode, random, simulator, selectedNode); + if (!rolledOut.HasValue) + return false; + + var (endNode, score) = rolledOut.Value; + Backpropagate(rootNode, endNode, score); + return true; + } + + public static void SearchThread(SolverConfig config, Node rootNode, CancellationToken token) => + SolverUtils.Search(ref config, rootNode, token); + + public static void Search(ref SolverConfig config, Node rootNode, CancellationToken token) + { + var configP = config; + var tasks = new Task[config.ThreadCount]; + for (var i = 0; i < config.ThreadCount; ++i) + tasks[i] = Task.Run(() => SearchThread(configP, rootNode, token), token); + Task.WaitAll(tasks, CancellationToken.None); + } +} diff --git a/Solver/Crafty/SolverSingle.cs b/Solver/Crafty/SolverSingle.cs new file mode 100644 index 0000000..b7c9aed --- /dev/null +++ b/Solver/Crafty/SolverSingle.cs @@ -0,0 +1,79 @@ +using System.Diagnostics.Contracts; +using System.Runtime.CompilerServices; +using Node = Craftimizer.Solver.Crafty.ArenaNode; + +namespace Craftimizer.Solver.Crafty; + +// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs +public sealed class SolverSingle : ISolver +{ + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public static void LoadChildData(Span scoreSums, Span visits, Span maxScores, ref Node[] chunk, int iterCount) + { + for (var j = 0; j < iterCount; ++j) + { + ref var node = ref chunk[j].State.Scores; + scoreSums[j] = node.ScoreSum; + visits[j] = node.Visits; + maxScores[j] = node.MaxScore; + } + } + + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Node EvalBestChild(ref SolverConfig config, int parentVisits, ref Node.ChildBuffer children) => + SolverUtils.EvalBestChild(ref config, parentVisits, ref children); + + [Pure] + public static Node Select(ref SolverConfig config, Node node) + { + while (true) + { + var expandable = !node.State.AvailableActions.IsEmpty; + var likelyTerminal = node.Children.Count == 0; + if (expandable || likelyTerminal) + return node; + + // select the node with the highest score + node = EvalBestChild(ref config, node.State.Scores.Visits, ref node.Children); + } + } + + public static (Node ExpandedNode, float Score) ExpandAndRollout(ref SolverConfig config, Node rootNode, Random random, Simulator simulator, Node initialNode) + { + ref var initialState = ref initialNode.State; + // expand once + if (initialState.IsComplete) + return (initialNode, initialState.CalculateScore(config.MaxStepCount) ?? 0); + + var poppedAction = initialState.AvailableActions.PopRandom(random); + var expandedNode = initialNode.Add(SolverUtils.Execute(simulator, initialState.State, poppedAction, true)); + + return SolverUtils.Rollout(ref config, rootNode, expandedNode, random, simulator); + } + + public static void Backpropagate(Node rootNode, Node startNode, float score) + { + while (true) + { + startNode.State.Scores.Visit(score); + + if (startNode == rootNode) + break; + + startNode = startNode.Parent!; + } + } + + public static bool SearchIter(ref SolverConfig config, Node rootNode, Random random, Simulator simulator) + { + var selectedNode = Select(ref config, rootNode); + var (endNode, score) = ExpandAndRollout(ref config, rootNode, random, simulator, selectedNode); + + Backpropagate(rootNode, endNode, score); + return true; + } + + public static void Search(ref SolverConfig config, Node rootNode, CancellationToken token) => + SolverUtils.Search(ref config, rootNode, token); +} diff --git a/Solver/Crafty/Solver.cs b/Solver/Crafty/SolverUtils.cs similarity index 56% rename from Solver/Crafty/Solver.cs rename to Solver/Crafty/SolverUtils.cs index 24c2584..6b7cbb4 100644 --- a/Solver/Crafty/Solver.cs +++ b/Solver/Crafty/SolverUtils.cs @@ -1,38 +1,14 @@ -using Craftimizer.Simulator; using Craftimizer.Simulator.Actions; +using Craftimizer.Simulator; +using Node = Craftimizer.Solver.Crafty.ArenaNode; using System.Diagnostics.Contracts; using System.Numerics; using System.Runtime.CompilerServices; -using Node = Craftimizer.Solver.Crafty.ArenaNode; namespace Craftimizer.Solver.Crafty; - -// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs -public sealed class Solver +public static class SolverUtils { - public SolverConfig Config; - public Node RootNode; - - public Random Random; - - public Solver(SolverConfig config, SimulationState state, bool strict) - { - Config = config; - var sim = new Simulator(state, config.MaxStepCount); - RootNode = new(new( - state, - null, - sim.CompletionState, - sim.AvailableActionsHeuristic(strict) - )); - Random = state.Input.Random; - } - - public Solver(SolverConfig config, SimulationInput input, bool strict) : this(config, new SimulationState(input), strict) - { - } - - private static SimulationNode Execute(Simulator simulator, SimulationState state, ActionType action, bool strict) + public static SimulationNode Execute(Simulator simulator, SimulationState state, ActionType action, bool strict) { (_, var newState) = simulator.Execute(state, action); return new( @@ -53,7 +29,7 @@ public sealed class Solver if (!state.AvailableActions.HasAction(action)) return (startNode, CompletionState.InvalidAction); - state.AvailableActions.RemoveAction(action); + state.AvailableActions.RemoveActionConcurrent(action); startNode = startNode.Add(Execute(simulator, state.State, action, strict)); } @@ -63,7 +39,7 @@ public sealed class Solver [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Node ChildMaxScore(ref Node.ChildBuffer children) + public static Node ChildMaxScore(ref Node.ChildBuffer children) { var length = children.Count; var vecLength = Vector.Count; @@ -95,17 +71,29 @@ public sealed class Solver } [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private Node? EvalBestChild(int parentVisits, ref Node.ChildBuffer children) + public static (List Actions, SimulationNode Node) Solution(Node node) { - if (parentVisits == 0) - return null; + var actions = new List(); + while (node.Children.Count != 0) + { + node = ChildMaxScore(ref node.Children); + if (node.State.Action != null) + actions.Add(node.State.Action.Value); + } + + return (actions, node.State); + } + + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public static Node EvalBestChild(ref SolverConfig config, int parentVisits, ref Node.ChildBuffer children) where S : ISolver + { var length = children.Count; var vecLength = Vector.Count; - var C = MathF.Sqrt(Config.ExplorationConstant * MathF.Log(parentVisits)); - var w = Config.MaxScoreWeightingConstant; + var C = MathF.Sqrt(config.ExplorationConstant * MathF.Log(parentVisits)); + var w = config.MaxScoreWeightingConstant; var W = 1f - w; var CVector = new Vector(C); @@ -119,14 +107,7 @@ public sealed class Solver { var iterCount = Math.Min(vecLength, length); - ref var chunk = ref children.Data[i]; - for (var j = 0; j < iterCount; ++j) - { - var node = chunk[j]?.State.Scores ?? new(); - scoreSums[j] = node.ScoreSum; - visits[j] = node.Visits; - maxScores[j] = node.MaxScore; - } + S.LoadChildData(scoreSums, visits, maxScores, ref children.Data[i], iterCount); var s = new Vector(scoreSums); var m = new Vector(maxScores); @@ -151,47 +132,21 @@ public sealed class Solver return children.Data[max.Item1][max.Item2]; } - [Pure] - public Node Select() + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static (Node ExpandedNode, float Score) Rollout(ref SolverConfig config, Node rootNode, Node expandedNode, Random random, Simulator simulator) { - var node = RootNode; - while (true) - { - var expandable = !node.State.AvailableActions.IsEmpty; - var likelyTerminal = node.Children.Count == 0; - if (expandable || likelyTerminal) - return node; - - // select the node with the highest score - // if null (current node is invalid & not backpropagated just yet), try again from root - node = EvalBestChild(node.State.Scores.Visits, ref node.Children) ?? RootNode; - } - } - - public (Node ExpandedNode, float Score)? ExpandAndRollout(Simulator simulator, Node initialNode) - { - ref var initialState = ref initialNode.State; - // expand once - if (initialState.IsComplete) - return (initialNode, initialState.CalculateScore(Config.MaxStepCount) ?? 0); - - var poppedAction = initialState.AvailableActions.PopRandom(Random); - if (!poppedAction.HasValue) - return null; - var expandedNode = initialNode.Add(Execute(simulator, initialState.State, poppedAction.Value, true)); - // playout to a terminal state var currentState = expandedNode.State.State; var currentCompletionState = expandedNode.State.SimulationCompletionState; var currentActions = expandedNode.State.AvailableActions; byte actionCount = 0; - Span actions = stackalloc ActionType[Config.MaxStepCount]; + Span actions = stackalloc ActionType[config.MaxStepCount - currentState.ActionCount]; while (true) { if (SimulationNode.GetCompletionState(currentCompletionState, currentActions) != CompletionState.Incomplete) break; - var nextAction = currentActions.SelectRandom(Random); + var nextAction = currentActions.SelectRandom(random); actions[actionCount++] = nextAction; (_, currentState) = simulator.Execute(currentState, nextAction); currentCompletionState = simulator.CompletionState; @@ -199,10 +154,10 @@ public sealed class Solver } // store the result if a max score was reached - var score = SimulationNode.CalculateScoreForState(currentState, currentCompletionState, Config.MaxStepCount) ?? 0; + var score = SimulationNode.CalculateScoreForState(currentState, currentCompletionState, config.MaxStepCount) ?? 0; if (currentCompletionState == CompletionState.ProgressComplete) { - if (score >= Config.ScoreStorageThreshold && score >= RootNode.State.Scores.MaxScore) + if (score >= config.ScoreStorageThreshold && score >= rootNode.State.Scores.MaxScore) { (var terminalNode, _) = ExecuteActions(simulator, expandedNode, actions[..actionCount], true); return (terminalNode, score); @@ -211,80 +166,58 @@ public sealed class Solver return (expandedNode, score); } - public void Backpropagate(Node startNode, float score) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Search(ref SolverConfig config, Node rootNode, CancellationToken token) where S : ISolver { - while (true) - { - startNode.State.Scores.Visit(score); - - if (startNode == RootNode) - break; - - startNode = startNode.Parent!; - } - } - - public void SearchThread(CancellationToken token) - { - Simulator simulator = new(RootNode.State.State, Config.MaxStepCount); - for (var i = 0; i < Config.Iterations; i++) + Simulator simulator = new(rootNode.State.State, config.MaxStepCount); + var random = rootNode.State.State.Input.Random; + for (var i = 0; i < config.Iterations; i++) { if (token.IsCancellationRequested) break; - var selectedNode = Select(); - var rolledOut = ExpandAndRollout(simulator, selectedNode); - if (!rolledOut.HasValue) + if (!S.SearchIter(ref config, rootNode, random, simulator)) { // Retry, count this iteration as moot i--; continue; } - - var (endNode, score) = rolledOut.Value; - Backpropagate(endNode, score); } } - public void Search(CancellationToken token) - { - var tasks = new Task[Config.ThreadCount]; - for (var i = 0; i < Config.ThreadCount; ++i) - tasks[i] = Task.Run(() => SearchThread(token), token); - Task.WaitAll(tasks, CancellationToken.None); - } - [Pure] - public (List Actions, SimulationNode Node) Solution() + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Node CreateRootNode(SolverConfig config, SimulationInput input, bool strict) => + CreateRootNode(config, new SimulationState(input), strict); + + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Node CreateRootNode(SolverConfig config, SimulationState state, bool strict) { - var actions = new List(); - var node = RootNode; - while (node.Children.Count != 0) - { - node = ChildMaxScore(ref node.Children); - - if (node.State.Action != null) - actions.Add(node.State.Action.Value); - } - - return (actions, node.State); + var sim = new Simulator(state, config.MaxStepCount); + return new(new( + state, + null, + sim.CompletionState, + sim.AvailableActionsHeuristic(strict) + )); } - public static (List Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationInput input, Action? actionCallback, CancellationToken token = default) => - SearchStepwise(config, new SimulationState(input), actionCallback, token); + public static (List Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationInput input, Action? actionCallback, CancellationToken token = default) where S : ISolver => + SearchStepwise(config, new SimulationState(input), actionCallback, token); - public static (List Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token = default) + public static (List Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token = default) where S : ISolver { var actions = new List(); - Simulator sim = new(state, config.MaxStepCount); - var solver = new Solver(config, state, true); + var sim = new Simulator(state, config.MaxStepCount); + var rootNode = CreateRootNode(config, state, true); while (!sim.IsComplete) { if (token.IsCancellationRequested) break; - solver.Search(token); - var (solution_actions, solution_node) = solver.Solution(); + S.Search(ref config, rootNode, token); + var (solution_actions, solution_node) = Solution(rootNode); if (solution_node.Scores.MaxScore >= 1.0) { @@ -298,20 +231,20 @@ public sealed class Solver actionCallback?.Invoke(chosen_action); - solver = new Solver(config, state, true); + rootNode = CreateRootNode(config, state, true); } return (actions, state); } - public static (List Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationInput input, CancellationToken token = default) => - SearchOneshot(config, new SimulationState(input), token); + public static (List Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationInput input, CancellationToken token = default) where S : ISolver => + SearchOneshot(config, new SimulationState(input), token); - public static (List Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationState state, CancellationToken token = default) + public static (List Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationState state, CancellationToken token = default) where S : ISolver { - var solver = new Solver(config, state, false); - solver.Search(token); - var (solution_actions, solution_node) = solver.Solution(); + var rootNode = CreateRootNode(config, state, false); + S.Search(ref config, rootNode, token); + var (solution_actions, solution_node) = Solution(rootNode); return (solution_actions, solution_node.State); } }