diff --git a/.editorconfig b/.editorconfig index 5af9930..0d8e909 100644 --- a/.editorconfig +++ b/.editorconfig @@ -158,6 +158,7 @@ dotnet_diagnostic.CA1841.severity = suggestion dotnet_diagnostic.CA1845.severity = suggestion dotnet_diagnostic.MA0011.severity = silent dotnet_diagnostic.MA0076.severity = silent +dotnet_diagnostic.MA0046.severity = silent dotnet_diagnostic.MA0002.severity = silent csharp_style_prefer_switch_expression = true:suggestion csharp_style_prefer_pattern_matching = true:silent diff --git a/Benchmark/Program.cs b/Benchmark/Program.cs index 8b4b06f..f5acf7b 100644 --- a/Benchmark/Program.cs +++ b/Benchmark/Program.cs @@ -77,11 +77,11 @@ internal static class Program Console.WriteLine($"{state.Quality} {state.CP} {state.Progress} {state.Durability}"); //return; - var (_, s) = Solver.Solver.SearchStepwiseFurcated(config, state, a => Console.WriteLine(a), default); + var (_, s) = config.Invoke(state, a => Console.WriteLine(a))!.Value; Console.WriteLine($"Qual: {s.Quality}/{s.Input.Recipe.MaxQuality}"); return; - Solver.Solver.SearchStepwiseFurcated(config, new(input), null, default); + config.Invoke(new(input)); //Benchmark(() => ); } diff --git a/Craftimizer/Configuration.cs b/Craftimizer/Configuration.cs index f90d4f4..efddaa8 100644 --- a/Craftimizer/Configuration.cs +++ b/Craftimizer/Configuration.cs @@ -15,25 +15,6 @@ public class Macro public List Actions { get; set; } = new(); } -public static class AlgorithmUtils -{ - public static void Invoke(this SolverConfig me, SimulationState state, Action? actionCallback = null, CancellationToken token = default) - { - try - { - Solver.Solver.Search(me, state, actionCallback, token); - } - catch (AggregateException e) - { - e.Handle(ex => ex is OperationCanceledException); - } - catch (OperationCanceledException) - { - - } - } -} - [Serializable] public class Configuration : IPluginConfiguration { diff --git a/Simulator/CompletionState.cs b/Simulator/CompletionState.cs index 88856d0..1cb429f 100644 --- a/Simulator/CompletionState.cs +++ b/Simulator/CompletionState.cs @@ -6,5 +6,7 @@ public enum CompletionState : byte ProgressComplete, NoMoreDurability, - Other + InvalidAction, + MaxActionCountReached, + NoMoreActions } diff --git a/Simulator/Simulator.cs b/Simulator/Simulator.cs index 2d8beb0..b049eb2 100644 --- a/Simulator/Simulator.cs +++ b/Simulator/Simulator.cs @@ -21,8 +21,17 @@ public class Simulator public bool IsFirstStep => State.StepCount == 0; - public CompletionState CompletionState => CalculateCompletionState(State); - public virtual bool IsComplete => CompletionState != CompletionState.Incomplete; + public virtual CompletionState CompletionState { + get + { + if (Progress >= Input.Recipe.MaxProgress) + return CompletionState.ProgressComplete; + if (Durability <= 0) + return CompletionState.NoMoreDurability; + return CompletionState.Incomplete; + } + } + public bool IsComplete => CompletionState != CompletionState.Incomplete; public IEnumerable AvailableActions => ActionUtils.AvailableActions(this); @@ -278,13 +287,4 @@ public class Simulator [MethodImpl(MethodImplOptions.AggressiveInlining)] public void IncreaseQuality(float efficiency) => IncreaseQualityRaw(CalculateQualityGain(efficiency, false)); - - public static CompletionState CalculateCompletionState(SimulationState state) - { - if (state.Progress >= state.Input.Recipe.MaxProgress) - return CompletionState.ProgressComplete; - if (state.Durability <= 0) - return CompletionState.NoMoreDurability; - return CompletionState.Incomplete; - } } diff --git a/Solver/CompletionState.cs b/Solver/CompletionState.cs deleted file mode 100644 index 47d3297..0000000 --- a/Solver/CompletionState.cs +++ /dev/null @@ -1,20 +0,0 @@ -using CompState = Craftimizer.Simulator.CompletionState; - -namespace Craftimizer.Solver; - -public enum CompletionState : byte -{ - Incomplete, - ProgressComplete, - NoMoreDurability, - - InvalidAction, - MaxActionCountReached, - NoMoreActions -} - -internal static class CompletionStateUtils -{ - public static CompState IntoBase(this CompletionState me) => - (CompState)me >= CompState.Other ? CompState.Other : (CompState)me; -} diff --git a/Solver/MCTS.cs b/Solver/MCTS.cs new file mode 100644 index 0000000..4f976fb --- /dev/null +++ b/Solver/MCTS.cs @@ -0,0 +1,338 @@ +using Craftimizer.Simulator.Actions; +using Craftimizer.Simulator; +using System.Diagnostics.Contracts; +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Text; +using Node = Craftimizer.Solver.ArenaNode; + +namespace Craftimizer.Solver; + +// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs +public sealed class MCTS +{ + private readonly MCTSConfig config; + private readonly Node rootNode; + private readonly RootScores rootScores; + + public float MaxScore => rootScores.MaxScore; + + public MCTS(MCTSConfig config, SimulationState state) + { + this.config = config; + var sim = new Simulator(state, config.MaxStepCount); + rootNode = new(new( + state, + null, + sim.CompletionState, + sim.AvailableActionsHeuristic(config.StrictActions) + )); + rootScores = new(); + } + + private static SimulationNode Execute(Simulator simulator, SimulationState state, ActionType action, bool strict) + { + (_, var newState) = simulator.Execute(state, action); + return new( + newState, + action, + simulator.CompletionState, + simulator.AvailableActionsHeuristic(strict) + ); + } + + private static Node ExecuteActions(Simulator simulator, Node startNode, ReadOnlySpan actions, bool strict) + { + foreach (var action in actions) + { + var state = startNode.State; + if (state.IsComplete) + return startNode; + + if (!state.AvailableActions.HasAction(action)) + return startNode; + state.AvailableActions.RemoveAction(action); + + startNode = startNode.Add(Execute(simulator, state.State, action, strict)); + } + + return startNode; + } + + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores) + { + var length = scores.Count; + var vecLength = Vector.Count; + + var max = (0, 0); + var maxScore = 0f; + for (var i = 0; length > 0; ++i) + { + var iterCount = Math.Min(vecLength, length); + + ref var chunk = ref scores.Data[i]; + var m = new Vector(chunk.MaxScore.Span); + + var idx = Intrinsics.HMaxIndex(m, iterCount); + + if (m[idx] >= maxScore) + { + max = (i, idx); + maxScore = m[idx]; + } + + length -= iterCount; + } + + return max; + } + + // Calculates the best child node to explore next + // Exploitation: ((1 - w) * (s / v)) + (w * m) + // Exploration: sqrt(c * ln(V) / v) + // w = maxScoreWeightingConstant + // s = score sum + // m = max score + // v = visits + // V = parentVisits + // c = explorationConstant + + // Somewhat based off of https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation + // Here, w_i = (1-w)*score sum + // n_i = visits + // max score is tacked onto it + // N_i = parent visits + // c = exploration constant (but crafty places it inside the sqrt..?) + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + private static (int arrayIdx, int subIdx) EvalBestChild( + float explorationConstant, + float maxScoreWeightingConstant, + int parentVisits, + ref NodeScoresBuffer scores) + { + var length = scores.Count; + var vecLength = Vector.Count; + + var C = MathF.Sqrt(explorationConstant * MathF.Log(parentVisits)); + var w = maxScoreWeightingConstant; + var W = 1f - w; + var CVector = new Vector(C); + + var max = (0, 0); + var maxScore = 0f; + for (var i = 0; length > 0; ++i) + { + var iterCount = Math.Min(vecLength, length); + + ref var chunk = ref scores.Data[i]; + var s = new Vector(chunk.ScoreSum.Span); + var vInt = new Vector(chunk.Visits.Span); + var m = new Vector(chunk.MaxScore.Span); + + vInt = Vector.Max(vInt, Vector.One); + var v = Vector.ConvertToSingle(vInt); + + var exploitation = W * (s / v) + w * m; + var exploration = CVector * Intrinsics.ReciprocalSqrt(v); + var evalScores = exploitation + exploration; + + var idx = Intrinsics.HMaxIndex(evalScores, iterCount); + + if (evalScores[idx] >= maxScore) + { + max = (i, idx); + maxScore = evalScores[idx]; + } + + length -= iterCount; + } + + return max; + } + + [Pure] + private Node Select() + { + var node = rootNode; + var nodeVisits = rootScores.Visits; + + float explorationConstant = config.ExplorationConstant, maxScoreWeightingConstant = config.MaxScoreWeightingConstant; + while (true) + { + var expandable = !node.State.AvailableActions.IsEmpty; + var likelyTerminal = node.Children.Count == 0; + if (expandable || likelyTerminal) + return node; + + // select the node with the highest score + var at = EvalBestChild(explorationConstant, maxScoreWeightingConstant, nodeVisits, ref node.ChildScores); + nodeVisits = node.ChildScores.GetVisits(at); + node = node.ChildAt(at)!; + } + } + + private (Node ExpandedNode, float Score) ExpandAndRollout(Random random, Simulator simulator, Node initialNode) + { + ref var initialState = ref initialNode.State; + // expand once + if (initialState.IsComplete) + return (initialNode, initialState.CalculateScore(config) ?? 0); + + var poppedAction = initialState.AvailableActions.PopRandom(random); + var expandedNode = initialNode.Add(Execute(simulator, initialState.State, poppedAction, true)); + + // playout to a terminal state + var currentState = expandedNode.State.State; + var currentCompletionState = expandedNode.State.SimulationCompletionState; + var currentActions = expandedNode.State.AvailableActions; + + + byte actionCount = 0; + Span actions = stackalloc ActionType[Math.Min(config.MaxStepCount - currentState.ActionCount, config.MaxRolloutStepCount)]; + while (SimulationNode.GetCompletionState(currentCompletionState, currentActions) == CompletionState.Incomplete && + actionCount < actions.Length) + { + var nextAction = currentActions.SelectRandom(random); + actions[actionCount++] = nextAction; + (_, currentState) = simulator.Execute(currentState, nextAction); + currentCompletionState = simulator.CompletionState; + if (currentCompletionState != CompletionState.Incomplete) + break; + currentActions = simulator.AvailableActionsHeuristic(true); + } + + // store the result if a max score was reached + var score = SimulationNode.CalculateScoreForState(currentState, currentCompletionState, config) ?? 0; + if (currentCompletionState == CompletionState.ProgressComplete) + { + if (score >= config.ScoreStorageThreshold && score >= MaxScore) + { + var terminalNode = ExecuteActions(simulator, expandedNode, actions[..actionCount], true); + return (terminalNode, score); + } + } + return (expandedNode, score); + } + + private void Backpropagate(Node startNode, float score) + { + while (true) + { + if (startNode == rootNode) + { + rootScores.Visit(score); + break; + } + startNode.ParentScores!.Value.Visit(startNode.ChildIdx, score); + + startNode = startNode.Parent!; + } + } + + private void ShowAllNodes() + { + static void ShowNodes(StringBuilder b, Node node, Stack path) + { + path.Push(node); + b.AppendLine($"{new string(' ', path.Count)}{node.State.Action}"); + { + for (var i = 0; i < node.Children.Count; ++i) + { + var n = node.ChildAt((i >> 3, i & 7))!; + ShowNodes(b, n, path); + } + path.Pop(); + } + } + var b = new StringBuilder(); + ShowNodes(b, rootNode, new()); + Console.WriteLine(b.ToString()); + } + + private bool AllNodesComplete() + { + static bool NodesIncomplete(Node node, Stack path) + { + path.Push(node); + if (node.Children.Count == 0) + { + if (!node.State.AvailableActions.IsEmpty) + return true; + } + else + { + for (var i = 0; i < node.Children.Count; ++i) + { + var n = node.ChildAt((i >> 3, i & 7))!; + if (NodesIncomplete(n, path)) + return true; + } + path.Pop(); + } + return false; + } + return !NodesIncomplete(rootNode, new()); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Search(int iterations, CancellationToken token) + { + Simulator simulator = new(rootNode.State.State, config.MaxStepCount); + var random = rootNode.State.State.Input.Random; + var n = 0; + for (var i = 0; i < iterations || MaxScore == 0; i++) + { + if (token.IsCancellationRequested) + break; + + var selectedNode = Select(); + var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode); + if (MaxScore == 0) + { + if (endNode == selectedNode) + { + if (n++ > 5000) + { + n = 0; + if (AllNodesComplete()) + { + //Console.WriteLine("All nodes solved for. Can't find a valid solution."); + //ShowAllNodes(); + return; + } + } + } + else + n = 0; + } + + Backpropagate(endNode, score); + } + } + + [Pure] + public SolverSolution Solution() + { + var actions = new List(); + var node = rootNode; + + while (node.Children.Count != 0) + { + node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!; + + if (node.State.Action != null) + actions.Add(node.State.Action.Value); + } + + //var at = node.ChildIdx; + //ref var sum = ref node.ParentScores!.Value.Data[at.arrayIdx].ScoreSum.Span[at.subIdx]; + //ref var max = ref node.ParentScores!.Value.Data[at.arrayIdx].MaxScore.Span[at.subIdx]; + //ref var visits = ref node.ParentScores!.Value.Data[at.arrayIdx].Visits.Span[at.subIdx]; + //Console.WriteLine($"{sum} {max} {visits}"); + + return new(actions, node.State.State); + } +} diff --git a/Solver/MCTSConfig.cs b/Solver/MCTSConfig.cs new file mode 100644 index 0000000..71ae268 --- /dev/null +++ b/Solver/MCTSConfig.cs @@ -0,0 +1,38 @@ +using System.Runtime.InteropServices; + +namespace Craftimizer.Solver; + +[StructLayout(LayoutKind.Auto)] +public readonly record struct MCTSConfig +{ + public int MaxStepCount { get; init; } + public int MaxRolloutStepCount { get; init; } + public bool StrictActions { get; init; } + + public float MaxScoreWeightingConstant { get; init; } + public float ExplorationConstant { get; init; } + public float ScoreStorageThreshold { get; init; } + + public float ScoreProgress { get; init; } + public float ScoreQuality { get; init; } + public float ScoreDurability { get; init; } + public float ScoreCP { get; init; } + public float ScoreSteps { get; init; } + + public MCTSConfig(SolverConfig config) + { + MaxStepCount= config.MaxStepCount; + MaxRolloutStepCount = config.MaxRolloutStepCount; + StrictActions = config.StrictActions; + + MaxScoreWeightingConstant = config.MaxScoreWeightingConstant; + ExplorationConstant = config.ExplorationConstant; + ScoreStorageThreshold = config.ScoreStorageThreshold; + + ScoreProgress = config.ScoreProgressBonus; + ScoreQuality = config.ScoreQualityBonus; + ScoreDurability = config.ScoreDurabilityBonus; + ScoreCP = config.ScoreCPBonus; + ScoreSteps = config.ScoreFewerStepsBonus; + } +} diff --git a/Solver/SimulationNode.cs b/Solver/SimulationNode.cs index af37ac5..e17c35f 100644 --- a/Solver/SimulationNode.cs +++ b/Solver/SimulationNode.cs @@ -30,7 +30,7 @@ public struct SimulationNode CompletionState.NoMoreActions : simCompletionState; - public readonly float? CalculateScore(SolverConfig config) => + public readonly float? CalculateScore(MCTSConfig config) => CalculateScoreForState(State, SimulationCompletionState, config); private static bool CanByregot(SimulationState state) @@ -41,7 +41,7 @@ public struct SimulationNode return BaseComboAction.VerifyDurability2(state, 10); } - public static float? CalculateScoreForState(SimulationState state, CompletionState completionState, SolverConfig config) + public static float? CalculateScoreForState(SimulationState state, CompletionState completionState, MCTSConfig config) { if (completionState != CompletionState.ProgressComplete) return null; @@ -50,32 +50,32 @@ public struct SimulationNode bonus * Math.Min(1f, value / target); var progressScore = Apply( - config.ScoreProgressBonus, + config.ScoreProgress, state.Progress, state.Input.Recipe.MaxProgress ); var byregotBonus = CanByregot(state) ? (state.ActiveEffects.InnerQuiet * .2f + 1) * state.Input.BaseQualityGain : 0; var qualityScore = Apply( - config.ScoreQualityBonus, + config.ScoreQuality, state.Quality + byregotBonus, state.Input.Recipe.MaxQuality ); var durabilityScore = Apply( - config.ScoreDurabilityBonus, + config.ScoreDurability, state.Durability, state.Input.Recipe.MaxDurability ); var cpScore = Apply( - config.ScoreCPBonus, + config.ScoreCP, state.CP, state.Input.Stats.CP ); var fewerStepsScore = - config.ScoreFewerStepsBonus * (1f - (float)(state.ActionCount + 1) / config.MaxStepCount); + config.ScoreSteps * (1f - (float)(state.ActionCount + 1) / config.MaxStepCount); return progressScore + qualityScore + durabilityScore + cpScore + fewerStepsScore; } diff --git a/Solver/Simulator.cs b/Solver/Simulator.cs index 9b610e5..b3f6f60 100644 --- a/Solver/Simulator.cs +++ b/Solver/Simulator.cs @@ -9,8 +9,16 @@ public sealed class Simulator : SimulatorNoRandom { private readonly int maxStepCount; - public new CompletionState CompletionState => CalculateCompletionState(State, maxStepCount); - public override bool IsComplete => CompletionState != CompletionState.Incomplete; + public override CompletionState CompletionState + { + get + { + var b = base.CompletionState; + if (b == CompletionState.Incomplete && (ActionCount + 1) >= maxStepCount) + return CompletionState.MaxActionCountReached; + return b; + } + } public Simulator(SimulationState state, int maxStepCount) : base(state) { @@ -187,8 +195,4 @@ public sealed class Simulator : SimulatorNoRandom return ret; } - public static CompletionState CalculateCompletionState(SimulationState state, int maxStepCount) => - state.ActionCount + 1 >= maxStepCount ? - CompletionState.MaxActionCountReached : - (CompletionState)CalculateCompletionState(state); } diff --git a/Solver/Solver.cs b/Solver/Solver.cs index 9f19d67..d4b324f 100644 --- a/Solver/Solver.cs +++ b/Solver/Solver.cs @@ -1,338 +1,12 @@ using Craftimizer.Simulator; using Craftimizer.Simulator.Actions; using System.Diagnostics; -using System.Diagnostics.Contracts; -using System.Numerics; -using System.Runtime.CompilerServices; -using System.Text; -using Node = Craftimizer.Solver.ArenaNode; namespace Craftimizer.Solver; -// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs -public sealed class Solver +public static class Solver { - private SolverConfig config; - private Node rootNode; - private RootScores rootScores; - - public float MaxScore => rootScores.MaxScore; - - public Solver(SolverConfig config, SimulationState state) - { - this.config = config; - var sim = new Simulator(state, config.MaxStepCount); - rootNode = new(new( - state, - null, - sim.CompletionState, - sim.AvailableActionsHeuristic(config.StrictActions) - )); - rootScores = new(); - } - - private static SimulationNode Execute(Simulator simulator, SimulationState state, ActionType action, bool strict) - { - (_, var newState) = simulator.Execute(state, action); - return new( - newState, - action, - simulator.CompletionState, - simulator.AvailableActionsHeuristic(strict) - ); - } - - private static Node ExecuteActions(Simulator simulator, Node startNode, ReadOnlySpan actions, bool strict) - { - foreach (var action in actions) - { - var state = startNode.State; - if (state.IsComplete) - return startNode; - - if (!state.AvailableActions.HasAction(action)) - return startNode; - state.AvailableActions.RemoveAction(action); - - startNode = startNode.Add(Execute(simulator, state.State, action, strict)); - } - - return startNode; - } - - [Pure] - private SolverSolution Solution() - { - var actions = new List(); - var node = rootNode; - - while (node.Children.Count != 0) - { - node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!; - - if (node.State.Action != null) - actions.Add(node.State.Action.Value); - } - - //var at = node.ChildIdx; - //ref var sum = ref node.ParentScores!.Value.Data[at.arrayIdx].ScoreSum.Span[at.subIdx]; - //ref var max = ref node.ParentScores!.Value.Data[at.arrayIdx].MaxScore.Span[at.subIdx]; - //ref var visits = ref node.ParentScores!.Value.Data[at.arrayIdx].Visits.Span[at.subIdx]; - //Console.WriteLine($"{sum} {max} {visits}"); - - return new(actions, node.State.State); - } - - [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores) - { - var length = scores.Count; - var vecLength = Vector.Count; - - var max = (0, 0); - var maxScore = 0f; - for (var i = 0; length > 0; ++i) - { - var iterCount = Math.Min(vecLength, length); - - ref var chunk = ref scores.Data[i]; - var m = new Vector(chunk.MaxScore.Span); - - var idx = Intrinsics.HMaxIndex(m, iterCount); - - if (m[idx] >= maxScore) - { - max = (i, idx); - maxScore = m[idx]; - } - - length -= iterCount; - } - - return max; - } - - // Calculates the best child node to explore next - // Exploitation: ((1 - w) * (s / v)) + (w * m) - // Exploration: sqrt(c * ln(V) / v) - // w = maxScoreWeightingConstant - // s = score sum - // m = max score - // v = visits - // V = parentVisits - // c = explorationConstant - - // Somewhat based off of https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation - // Here, w_i = (1-w)*score sum - // n_i = visits - // max score is tacked onto it - // N_i = parent visits - // c = exploration constant (but crafty places it inside the sqrt..?) - [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] - private (int arrayIdx, int subIdx) EvalBestChild(int parentVisits, ref NodeScoresBuffer scores) - { - var length = scores.Count; - var vecLength = Vector.Count; - - var C = MathF.Sqrt(config.ExplorationConstant * MathF.Log(parentVisits)); - var w = config.MaxScoreWeightingConstant; - var W = 1f - w; - var CVector = new Vector(C); - - var max = (0, 0); - var maxScore = 0f; - for (var i = 0; length > 0; ++i) - { - var iterCount = Math.Min(vecLength, length); - - ref var chunk = ref scores.Data[i]; - var s = new Vector(chunk.ScoreSum.Span); - var vInt = new Vector(chunk.Visits.Span); - var m = new Vector(chunk.MaxScore.Span); - - vInt = Vector.Max(vInt, Vector.One); - var v = Vector.ConvertToSingle(vInt); - - var exploitation = W * (s / v) + w * m; - var exploration = CVector * Intrinsics.ReciprocalSqrt(v); - var evalScores = exploitation + exploration; - - var idx = Intrinsics.HMaxIndex(evalScores, iterCount); - - if (evalScores[idx] >= maxScore) - { - max = (i, idx); - maxScore = evalScores[idx]; - } - - length -= iterCount; - } - - return max; - } - - [Pure] - public Node Select() - { - var node = rootNode; - var nodeVisits = rootScores.Visits; - - while (true) - { - var expandable = !node.State.AvailableActions.IsEmpty; - var likelyTerminal = node.Children.Count == 0; - if (expandable || likelyTerminal) - return node; - - // select the node with the highest score - var at = EvalBestChild(nodeVisits, ref node.ChildScores); - nodeVisits = node.ChildScores.GetVisits(at); - node = node.ChildAt(at)!; - } - } - - public (Node ExpandedNode, float Score) ExpandAndRollout(Random random, Simulator simulator, Node initialNode) - { - ref var initialState = ref initialNode.State; - // expand once - if (initialState.IsComplete) - return (initialNode, initialState.CalculateScore(config) ?? 0); - - var poppedAction = initialState.AvailableActions.PopRandom(random); - var expandedNode = initialNode.Add(Execute(simulator, initialState.State, poppedAction, true)); - - // playout to a terminal state - var currentState = expandedNode.State.State; - var currentCompletionState = expandedNode.State.SimulationCompletionState; - var currentActions = expandedNode.State.AvailableActions; - - - byte actionCount = 0; - Span actions = stackalloc ActionType[Math.Min(config.MaxStepCount - currentState.ActionCount, config.MaxRolloutStepCount)]; - while (SimulationNode.GetCompletionState(currentCompletionState, currentActions) == CompletionState.Incomplete && - actionCount < actions.Length) - { - var nextAction = currentActions.SelectRandom(random); - actions[actionCount++] = nextAction; - (_, currentState) = simulator.Execute(currentState, nextAction); - currentCompletionState = simulator.CompletionState; - if (currentCompletionState != CompletionState.Incomplete) - break; - currentActions = simulator.AvailableActionsHeuristic(true); - } - - // store the result if a max score was reached - var score = SimulationNode.CalculateScoreForState(currentState, currentCompletionState, config) ?? 0; - if (currentCompletionState == CompletionState.ProgressComplete) - { - if (score >= config.ScoreStorageThreshold && score >= MaxScore) - { - var terminalNode = ExecuteActions(simulator, expandedNode, actions[..actionCount], true); - return (terminalNode, score); - } - } - return (expandedNode, score); - } - - public void Backpropagate(Node startNode, float score) - { - while (true) - { - if (startNode == rootNode) - { - rootScores.Visit(score); - break; - } - startNode.ParentScores!.Value.Visit(startNode.ChildIdx, score); - - startNode = startNode.Parent!; - } - } - - private void ShowAllNodes() - { - static void ShowNodes(StringBuilder b, Node node, Stack path) - { - path.Push(node); - b.AppendLine($"{new string(' ', path.Count)}{node.State.Action}"); - { - for (var i = 0; i < node.Children.Count; ++i) - { - var n = node.ChildAt((i >> 3, i & 7))!; - ShowNodes(b, n, path); - } - path.Pop(); - } - } - var b = new StringBuilder(); - ShowNodes(b, rootNode, new()); - Console.WriteLine(b.ToString()); - } - - private bool AllNodesComplete() - { - static bool NodesIncomplete(Node node, Stack path) - { - path.Push(node); - if (node.Children.Count == 0) - { - if (!node.State.AvailableActions.IsEmpty) - return true; - } - else - { - for (var i = 0; i < node.Children.Count; ++i) - { - var n = node.ChildAt((i >> 3, i & 7))!; - if (NodesIncomplete(n, path)) - return true; - } - path.Pop(); - } - return false; - } - return !NodesIncomplete(rootNode, new()); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void Search(int iterations, CancellationToken token) - { - Simulator simulator = new(rootNode.State.State, config.MaxStepCount); - var random = rootNode.State.State.Input.Random; - var n = 0; - for (var i = 0; i < iterations || MaxScore == 0; i++) - { - if (token.IsCancellationRequested) - break; - - var selectedNode = Select(); - var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode); - if (MaxScore == 0) - { - if (endNode == selectedNode) - { - if (n++ > 5000) - { - n = 0; - if (AllNodesComplete()) - { - //Console.WriteLine("All nodes solved for. Can't find a valid solution."); - //ShowAllNodes(); - return; - } - } - } - else - n = 0; - } - - Backpropagate(endNode, score); - } - } - - public static SolverSolution SearchStepwiseFurcated(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token) + private static SolverSolution SearchStepwiseFurcated(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token) { var definiteActionCount = 0; var bestSims = new List<(float Score, SolverSolution Result)>(); @@ -354,7 +28,7 @@ public sealed class Solver var st = activeStates[stateIdx]; tasks[i] = Task.Run(() => { - var solver = new Solver(config, activeStates[stateIdx].State); + var solver = new MCTS(new(config), activeStates[stateIdx].State); solver.Search(config.Iterations / config.ForkCount, token); return (solver.MaxScore, stateIdx, solver.Solution()); }, token); @@ -442,7 +116,7 @@ public sealed class Solver return result; } - public static SolverSolution SearchStepwiseForked(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token) + private static SolverSolution SearchStepwiseForked(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token) { var actions = new List(); var sim = new Simulator(state, config.MaxStepCount); @@ -460,7 +134,7 @@ public sealed class Solver for (var i = 0; i < config.ForkCount; ++i) tasks[i] = Task.Run(() => { - var solver = new Solver(config, state); + var solver = new MCTS(new(config), state); solver.Search(config.Iterations / config.ForkCount, token); return (solver.MaxScore, solver.Solution()); }, token); @@ -489,7 +163,7 @@ public sealed class Solver return new(actions, state); } - public static SolverSolution SearchStepwise(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token) + private static SolverSolution SearchStepwise(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token) { var actions = new List(); var sim = new Simulator(state, config.MaxStepCount); @@ -501,7 +175,7 @@ public sealed class Solver if (sim.IsComplete) break; - var solver = new Solver(config, state); + var solver = new MCTS(new(config), state); var s = Stopwatch.StartNew(); solver.Search(config.Iterations, token); @@ -526,13 +200,13 @@ public sealed class Solver return new(actions, state); } - public static SolverSolution SearchOneshotForked(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token) + private static SolverSolution SearchOneshotForked(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token) { var tasks = new Task<(float MaxScore, SolverSolution Solution)>[config.ForkCount]; for (var i = 0; i < config.ForkCount; ++i) tasks[i] = Task.Run(() => { - var solver = new Solver(config, state); + var solver = new MCTS(new(config), state); solver.Search(config.Iterations / config.ForkCount, token); return (solver.MaxScore, solver.Solution()); }, token); @@ -545,9 +219,9 @@ public sealed class Solver return solution; } - public static SolverSolution SearchOneshot(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token) + private static SolverSolution SearchOneshot(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token) { - var solver = new Solver(config, state); + var solver = new MCTS(new(config), state); solver.Search(config.Iterations, token); var solution = solver.Solution(); foreach (var action in solution.Actions) @@ -556,7 +230,7 @@ public sealed class Solver return solution; } - public static SolverSolution Search(SolverConfig config, SimulationState state, Action? actionCallback = null, CancellationToken token = default) + public static SolverSolution Search(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token) { Func?, CancellationToken, SolverSolution> func = config.Algorithm switch { @@ -564,7 +238,8 @@ public sealed class Solver SolverAlgorithm.OneshotForked => SearchOneshotForked, SolverAlgorithm.Stepwise => SearchStepwise, SolverAlgorithm.StepwiseForked => SearchStepwiseForked, - SolverAlgorithm.StepwiseFurcated or _ => SearchStepwiseFurcated, + SolverAlgorithm.StepwiseFurcated => SearchStepwiseFurcated, + _ => throw new ArgumentOutOfRangeException(nameof(config), config, $"Invalid algorithm: {config.Algorithm}") }; return func(config, state, actionCallback, token); } diff --git a/Solver/SolverConfig.cs b/Solver/SolverConfig.cs index 15be1db..fffcff3 100644 --- a/Solver/SolverConfig.cs +++ b/Solver/SolverConfig.cs @@ -1,3 +1,5 @@ +using Craftimizer.Simulator.Actions; +using Craftimizer.Simulator; using System.Runtime.InteropServices; namespace Craftimizer.Solver; @@ -65,4 +67,21 @@ public readonly record struct SolverConfig FurcatedActionCount = Environment.ProcessorCount / 2, Algorithm = SolverAlgorithm.StepwiseForked }; -} + + public SolverSolution? Invoke(SimulationState state, Action? actionCallback = null, CancellationToken token = default) + { + try + { + return Solver.Search(this, state, actionCallback, token); + } + catch (AggregateException e) + { + e.Handle(ex => ex is OperationCanceledException); + } + catch (OperationCanceledException) + { + + } + return null; + } +} \ No newline at end of file