From 2f0e30e4c27f23327cc933ff2799a35f97f7797e Mon Sep 17 00:00:00 2001 From: Asriel Camora Date: Sun, 3 Mar 2024 19:02:44 -0800 Subject: [PATCH] Solver optimizations & allow action pools --- Solver/ArenaNode.cs | 6 +++--- Solver/MCTS.cs | 36 +++++++----------------------------- Solver/Simulator.cs | 19 ++++++++++++------- Solver/Solver.cs | 4 ++-- 4 files changed, 24 insertions(+), 41 deletions(-) diff --git a/Solver/ArenaNode.cs b/Solver/ArenaNode.cs index 5d65e9f..ebafacf 100644 --- a/Solver/ArenaNode.cs +++ b/Solver/ArenaNode.cs @@ -12,7 +12,7 @@ public sealed class ArenaNode where T : struct public NodeScoresBuffer? ParentScores => Parent?.ChildScores; - public ArenaNode(T state, ArenaNode? parent = null) + public ArenaNode(in T state, ArenaNode? parent = null) { State = state; Children = new(); @@ -24,9 +24,9 @@ public sealed class ArenaNode where T : struct Children.Data?[at.arrayIdx]?[at.subIdx]; [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ArenaNode Add(T state) + public ArenaNode Add(in T state) { - var node = new ArenaNode(state, this); + var node = new ArenaNode(in state, this); ChildScores.Add(); Children.Add(node); return node; diff --git a/Solver/MCTS.cs b/Solver/MCTS.cs index afcd89e..20454a4 100644 --- a/Solver/MCTS.cs +++ b/Solver/MCTS.cs @@ -3,7 +3,6 @@ using Craftimizer.Simulator; using System.Diagnostics.Contracts; using System.Numerics; using System.Runtime.CompilerServices; -using System.Text; using Node = Craftimizer.Solver.ArenaNode; namespace Craftimizer.Solver; @@ -23,7 +22,7 @@ public sealed class MCTS public MCTS(in MCTSConfig config, in SimulationState state) { this.config = config; - var sim = new Simulator(config.ActionPool, config.MaxStepCount) { State = state }; + var sim = new Simulator(config.ActionPool, config.MaxStepCount, state); rootNode = new(new( state, null, @@ -35,7 +34,7 @@ public sealed class MCTS private static SimulationNode Execute(Simulator simulator, in SimulationState state, ActionType action, bool strict) { - (_, var newState) = simulator.Execute(state, action); + var newState = simulator.ExecuteUnchecked(state, action); return new( newState, action, @@ -192,7 +191,6 @@ public sealed class MCTS var currentCompletionState = expandedNode.State.SimulationCompletionState; var currentActions = expandedNode.State.AvailableActions; - byte actionCount = 0; Span actions = stackalloc ActionType[Math.Min(config.MaxStepCount - currentState.ActionCount, config.MaxRolloutStepCount)]; while (SimulationNode.GetCompletionState(currentCompletionState, currentActions) == CompletionState.Incomplete && @@ -200,7 +198,7 @@ public sealed class MCTS { var nextAction = currentActions.SelectRandom(random); actions[actionCount++] = nextAction; - (_, currentState) = simulator.Execute(currentState, nextAction); + currentState = simulator.ExecuteUnchecked(currentState, nextAction); currentCompletionState = simulator.CompletionState; if (currentCompletionState != CompletionState.Incomplete) break; @@ -235,26 +233,6 @@ public sealed class MCTS } } - private void ShowAllNodes() - { - static void ShowNodes(StringBuilder b, Node node, Stack path) - { - path.Push(node); - b.AppendLine($"{new string(' ', path.Count)}{node.State.Action}"); - { - for (var i = 0; i < node.Children.Count; ++i) - { - var n = node.ChildAt((i >> 3, i & 7))!; - ShowNodes(b, n, path); - } - path.Pop(); - } - } - var b = new StringBuilder(); - ShowNodes(b, rootNode, new()); - Console.WriteLine(b.ToString()); - } - private bool AllNodesComplete() { static bool NodesIncomplete(Node node, Stack path) @@ -280,17 +258,14 @@ public sealed class MCTS return !NodesIncomplete(rootNode, new()); } - [MethodImpl(MethodImplOptions.AggressiveInlining)] public void Search(int iterations, ref int progress, CancellationToken token) { - Simulator simulator = new(config.ActionPool, config.MaxStepCount); + var simulator = new Simulator(config.ActionPool, config.MaxStepCount, rootNode.State.State); var random = rootNode.State.State.Input.Random; var staleCounter = 0; var i = 0; for (; i < iterations || MaxScore == 0; i++) { - token.ThrowIfCancellationRequested(); - var selectedNode = Select(); var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode); if (MaxScore == 0) @@ -315,7 +290,10 @@ public sealed class MCTS Backpropagate(endNode, score); if ((i & (ProgressUpdateFrequency - 1)) == ProgressUpdateFrequency - 1) + { + token.ThrowIfCancellationRequested(); Interlocked.Add(ref progress, ProgressUpdateFrequency); + } } Interlocked.Add(ref progress, i & (ProgressUpdateFrequency - 1)); } diff --git a/Solver/Simulator.cs b/Solver/Simulator.cs index 8ddea8f..843e86c 100644 --- a/Solver/Simulator.cs +++ b/Solver/Simulator.cs @@ -21,9 +21,15 @@ internal sealed class Simulator : SimulatorNoRandom } } - public Simulator(ActionType[] actionPool, int maxStepCount) + public Simulator(ActionType[] actionPool, int maxStepCount, SimulationState? filteringState = null) { - actionPoolObjects = actionPool.Select(x => (x.Base(), x)).ToArray(); + var pool = actionPool.Select(x => (x.Base(), x)); + if (filteringState is { } state) + { + State = state; + pool = pool.Where(x => x.Item1.IsPossible(this)); + } + actionPoolObjects = pool.OrderBy(x => x.x).ToArray(); this.maxStepCount = maxStepCount; } @@ -32,7 +38,7 @@ internal sealed class Simulator : SimulatorNoRandom [MethodImpl(MethodImplOptions.AggressiveInlining)] // It's just a bunch of if statements, I would assume this is actually quite simple to follow #pragma warning disable MA0051 // Method is too long - private bool CanUseAction(ActionType action, BaseAction baseAction, bool strict) + private bool CouldUseAction(ActionType action, BaseAction baseAction, bool strict) #pragma warning restore MA0051 // Method is too long { if (CalculateSuccessRate(baseAction.SuccessRate(this)) != 1) @@ -46,7 +52,7 @@ internal sealed class Simulator : SimulatorNoRandom { // always use Trained Eye if it's available if (action == ActionType.TrainedEye) - return baseAction.CanUse(this); + return baseAction.CouldUse(this); // don't allow quality moves under Muscle Memory for difficult crafts if (Input.Recipe.ClassJobLevel == 90 && @@ -123,7 +129,7 @@ internal sealed class Simulator : SimulatorNoRandom return false; } - return baseAction.CanUse(this); + return baseAction.CouldUse(this); } // https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/craft_state.rs#L137 @@ -134,9 +140,8 @@ internal sealed class Simulator : SimulatorNoRandom var ret = new ActionSet(); foreach (var (data, action) in actionPoolObjects) - if (CanUseAction(action, data, strict)) + if (CouldUseAction(action, data, strict)) ret.AddAction(action); return ret; } - } diff --git a/Solver/Solver.cs b/Solver/Solver.cs index e60be04..6cc53a5 100644 --- a/Solver/Solver.cs +++ b/Solver/Solver.cs @@ -272,7 +272,7 @@ public sealed class Solver : IDisposable var actions = new List(); var state = State; - var sim = new Simulator(Config.ActionPool, Config.MaxStepCount) { State = state }; + var sim = new Simulator(Config.ActionPool, Config.MaxStepCount, state); while (true) { Token.ThrowIfCancellationRequested(); @@ -338,7 +338,7 @@ public sealed class Solver : IDisposable var actions = new List(); var state = State; - var sim = new Simulator(Config.ActionPool, Config.MaxStepCount) { State = state }; + var sim = new Simulator(Config.ActionPool, Config.MaxStepCount, state); while (true) { Token.ThrowIfCancellationRequested();