Implement ActionPool (backend only)

This commit is contained in:
Asriel Camora
2024-02-29 00:01:55 -08:00
parent 44ae3791f1
commit ecabc24517
8 changed files with 204 additions and 134 deletions
+6 -6
View File
@@ -23,7 +23,7 @@ public sealed class MCTS
public MCTS(in MCTSConfig config, in SimulationState state)
{
this.config = config;
var sim = new Simulator(config.MaxStepCount) { State = state };
var sim = new Simulator(config.ActionPool, config.MaxStepCount) { State = state };
rootNode = new(new(
state,
null,
@@ -52,9 +52,9 @@ public sealed class MCTS
if (state.IsComplete)
return startNode;
if (!state.AvailableActions.HasAction(action))
if (!state.AvailableActions.HasAction(in simulator.Pool, action))
return startNode;
state.AvailableActions.RemoveAction(action);
state.AvailableActions.RemoveAction(in simulator.Pool, action);
startNode = startNode.Add(Execute(simulator, state.State, action, strict));
}
@@ -184,7 +184,7 @@ public sealed class MCTS
if (initialState.IsComplete)
return (initialNode, initialState.CalculateScore(config) ?? 0);
var poppedAction = initialState.AvailableActions.PopRandom(random);
var poppedAction = initialState.AvailableActions.PopRandom(in simulator.Pool, random);
var expandedNode = initialNode.Add(Execute(simulator, initialState.State, poppedAction, true));
// playout to a terminal state
@@ -198,7 +198,7 @@ public sealed class MCTS
while (SimulationNode.GetCompletionState(currentCompletionState, currentActions) == CompletionState.Incomplete &&
actionCount < actions.Length)
{
var nextAction = currentActions.SelectRandom(random);
var nextAction = currentActions.SelectRandom(in simulator.Pool, random);
actions[actionCount++] = nextAction;
(_, currentState) = simulator.Execute(currentState, nextAction);
currentCompletionState = simulator.CompletionState;
@@ -283,7 +283,7 @@ public sealed class MCTS
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Search(int iterations, ref int progress, CancellationToken token)
{
Simulator simulator = new(config.MaxStepCount);
Simulator simulator = new(config.ActionPool, config.MaxStepCount);
var random = rootNode.State.State.Input.Random;
var staleCounter = 0;
var i = 0;