Cancellable solving, re-added solving randomization

This commit is contained in:
Asriel Camora
2023-06-29 01:46:21 -07:00
parent 1370957fc7
commit 37950b557e
2 changed files with 21 additions and 10 deletions
+20 -9
View File
@@ -205,7 +205,7 @@ public class Solver
if (initialState.IsComplete) if (initialState.IsComplete)
return (initialNode, initialState.CompletionState, initialState.CalculateScore(Config.MaxStepCount) ?? 0); return (initialNode, initialState.CompletionState, initialState.CalculateScore(Config.MaxStepCount) ?? 0);
var randomAction = initialState.AvailableActions.First();//.SelectRandom(Random); var randomAction = initialState.AvailableActions.SelectRandom(Random);
initialState.AvailableActions.RemoveAction(randomAction); initialState.AvailableActions.RemoveAction(randomAction);
var expandedNode = initialNode.Add(Execute(initialState.State, randomAction, true)); var expandedNode = initialNode.Add(Execute(initialState.State, randomAction, true));
@@ -220,7 +220,7 @@ public class Solver
{ {
if (SimulationNode.GetCompletionState(currentCompletionState, currentActions) != CompletionState.Incomplete) if (SimulationNode.GetCompletionState(currentCompletionState, currentActions) != CompletionState.Incomplete)
break; break;
randomAction = currentActions.First();//.SelectRandom(Random); randomAction = currentActions.SelectRandom(Random);
actions[actionCount++] = randomAction; actions[actionCount++] = randomAction;
(_, currentState) = Simulator.Execute(currentState, randomAction); (_, currentState) = Simulator.Execute(currentState, randomAction);
currentCompletionState = Simulator.CompletionState; currentCompletionState = Simulator.CompletionState;
@@ -253,10 +253,13 @@ public class Solver
} }
} }
public void Search(Node startNode) public void Search(Node startNode, CancellationToken token)
{ {
for (var i = 0; i < Config.Iterations; i++) for (var i = 0; i < Config.Iterations; i++)
{ {
if (token.IsCancellationRequested)
break;
var selectedNode = Select(startNode); var selectedNode = Select(startNode);
var (endNode, _, score) = ExpandAndRollout(selectedNode); var (endNode, _, score) = ExpandAndRollout(selectedNode);
@@ -278,14 +281,19 @@ public class Solver
return (actions, node.State); return (actions, node.State);
} }
public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback) public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) =>
SearchStepwise(config, new SimulationState(input), actionCallback, token);
public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token = default)
{ {
var state = new SimulationState(input);
var actions = new List<ActionType>(); var actions = new List<ActionType>();
var solver = new Solver(config, state, true); var solver = new Solver(config, state, true);
while (!solver.Simulator.IsComplete) while (!solver.Simulator.IsComplete)
{ {
solver.Search(solver.RootNode); if (token.IsCancellationRequested)
break;
solver.Search(solver.RootNode, token);
var (solution_actions, solution_node) = solver.Solution(); var (solution_actions, solution_node) = solver.Solution();
if (solution_node.Scores.MaxScore >= 1.0) if (solution_node.Scores.MaxScore >= 1.0)
@@ -306,10 +314,13 @@ public class Solver
return (actions, state); return (actions, state);
} }
public static (List<ActionType> Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationInput input) public static (List<ActionType> Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationInput input, CancellationToken token = default) =>
SearchOneshot(config, new SimulationState(input), token);
public static (List<ActionType> Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationState state, CancellationToken token = default)
{ {
var solver = new Solver(config, input, false); var solver = new Solver(config, state, false);
solver.Search(solver.RootNode); solver.Search(solver.RootNode, token);
var (solution_actions, solution_node) = solver.Solution(); var (solution_actions, solution_node) = solver.Solution();
return (solution_actions, solution_node.State); return (solution_actions, solution_node.State);
} }
+1 -1
View File
@@ -13,7 +13,7 @@ public readonly record struct SolverConfig
public SolverConfig() public SolverConfig()
{ {
Iterations = 30000; Iterations = 300000;
ScoreStorageThreshold = 1f; ScoreStorageThreshold = 1f;
MaxScoreWeightingConstant = 0.1f; MaxScoreWeightingConstant = 0.1f;
ExplorationConstant = 4f; ExplorationConstant = 4f;