diff --git a/Solver/Crafty/Solver.cs b/Solver/Crafty/Solver.cs index 0d3c5d0..69da71d 100644 --- a/Solver/Crafty/Solver.cs +++ b/Solver/Crafty/Solver.cs @@ -205,7 +205,7 @@ public class Solver if (initialState.IsComplete) return (initialNode, initialState.CompletionState, initialState.CalculateScore(Config.MaxStepCount) ?? 0); - var randomAction = initialState.AvailableActions.First();//.SelectRandom(Random); + var randomAction = initialState.AvailableActions.SelectRandom(Random); initialState.AvailableActions.RemoveAction(randomAction); var expandedNode = initialNode.Add(Execute(initialState.State, randomAction, true)); @@ -220,7 +220,7 @@ public class Solver { if (SimulationNode.GetCompletionState(currentCompletionState, currentActions) != CompletionState.Incomplete) break; - randomAction = currentActions.First();//.SelectRandom(Random); + randomAction = currentActions.SelectRandom(Random); actions[actionCount++] = randomAction; (_, currentState) = Simulator.Execute(currentState, randomAction); currentCompletionState = Simulator.CompletionState; @@ -253,10 +253,13 @@ public class Solver } } - public void Search(Node startNode) + public void Search(Node startNode, CancellationToken token) { for (var i = 0; i < Config.Iterations; i++) { + if (token.IsCancellationRequested) + break; + var selectedNode = Select(startNode); var (endNode, _, score) = ExpandAndRollout(selectedNode); @@ -278,14 +281,19 @@ public class Solver return (actions, node.State); } - public static (List Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationInput input, Action? actionCallback) + public static (List Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationInput input, Action? actionCallback, CancellationToken token = default) => + SearchStepwise(config, new SimulationState(input), actionCallback, token); + + public static (List Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationState state, Action? actionCallback, CancellationToken token = default) { - var state = new SimulationState(input); var actions = new List(); var solver = new Solver(config, state, true); while (!solver.Simulator.IsComplete) { - solver.Search(solver.RootNode); + if (token.IsCancellationRequested) + break; + + solver.Search(solver.RootNode, token); var (solution_actions, solution_node) = solver.Solution(); if (solution_node.Scores.MaxScore >= 1.0) @@ -306,10 +314,13 @@ public class Solver return (actions, state); } - public static (List Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationInput input) + public static (List Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationInput input, CancellationToken token = default) => + SearchOneshot(config, new SimulationState(input), token); + + public static (List Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationState state, CancellationToken token = default) { - var solver = new Solver(config, input, false); - solver.Search(solver.RootNode); + var solver = new Solver(config, state, false); + solver.Search(solver.RootNode, token); var (solution_actions, solution_node) = solver.Solution(); return (solution_actions, solution_node.State); } diff --git a/Solver/Crafty/SolverConfig.cs b/Solver/Crafty/SolverConfig.cs index f4a4aa4..5fe72c4 100644 --- a/Solver/Crafty/SolverConfig.cs +++ b/Solver/Crafty/SolverConfig.cs @@ -13,7 +13,7 @@ public readonly record struct SolverConfig public SolverConfig() { - Iterations = 30000; + Iterations = 300000; ScoreStorageThreshold = 1f; MaxScoreWeightingConstant = 0.1f; ExplorationConstant = 4f;