Add max iteration cap
This commit is contained in:
+6
-2
@@ -253,14 +253,16 @@ public sealed class MCTS
|
||||
}
|
||||
|
||||
[SkipLocalsInit]
|
||||
public unsafe void Search(int iterations, ref int progress, CancellationToken token)
|
||||
public unsafe void Search(int iterations, int maxIterations, ref int progress, CancellationToken token)
|
||||
{
|
||||
maxIterations = Math.Max(iterations, maxIterations);
|
||||
var simulator = new Simulator(config.ActionPool, config.MaxStepCount, rootNode.State.State);
|
||||
var random = rootNode.State.State.Input.Random;
|
||||
var staleCounter = 0;
|
||||
var i = 0;
|
||||
|
||||
Span<ActionType> actionBuffer = stackalloc ActionType[Math.Min(config.MaxStepCount, config.MaxRolloutStepCount)];
|
||||
for (; i < iterations || MaxScore == 0; i++)
|
||||
for (; (i < iterations || MaxScore == 0); i++)
|
||||
{
|
||||
var selectedNode = Select();
|
||||
var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode, actionBuffer);
|
||||
@@ -268,6 +270,8 @@ public sealed class MCTS
|
||||
{
|
||||
if (endNode == selectedNode)
|
||||
{
|
||||
if (i >= maxIterations)
|
||||
return;
|
||||
if (staleCounter++ >= StaleProgressThreshold)
|
||||
{
|
||||
staleCounter = 0;
|
||||
|
||||
Reference in New Issue
Block a user