Add max iteration cap
This commit is contained in:
+6
-2
@@ -253,14 +253,16 @@ public sealed class MCTS
|
||||
}
|
||||
|
||||
[SkipLocalsInit]
|
||||
public unsafe void Search(int iterations, ref int progress, CancellationToken token)
|
||||
public unsafe void Search(int iterations, int maxIterations, ref int progress, CancellationToken token)
|
||||
{
|
||||
maxIterations = Math.Max(iterations, maxIterations);
|
||||
var simulator = new Simulator(config.ActionPool, config.MaxStepCount, rootNode.State.State);
|
||||
var random = rootNode.State.State.Input.Random;
|
||||
var staleCounter = 0;
|
||||
var i = 0;
|
||||
|
||||
Span<ActionType> actionBuffer = stackalloc ActionType[Math.Min(config.MaxStepCount, config.MaxRolloutStepCount)];
|
||||
for (; i < iterations || MaxScore == 0; i++)
|
||||
for (; (i < iterations || MaxScore == 0); i++)
|
||||
{
|
||||
var selectedNode = Select();
|
||||
var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode, actionBuffer);
|
||||
@@ -268,6 +270,8 @@ public sealed class MCTS
|
||||
{
|
||||
if (endNode == selectedNode)
|
||||
{
|
||||
if (i >= maxIterations)
|
||||
return;
|
||||
if (staleCounter++ >= StaleProgressThreshold)
|
||||
{
|
||||
staleCounter = 0;
|
||||
|
||||
+8
-5
@@ -139,6 +139,7 @@ public sealed class Solver : IDisposable
|
||||
private async Task<SolverSolution> SearchStepwiseGenetic()
|
||||
{
|
||||
var iterCount = Config.Iterations / Config.ForkCount;
|
||||
var maxIterCount = Math.Max(Config.Iterations, Config.MaxIterations) / Config.ForkCount;
|
||||
maxProgress = iterCount * Config.ForkCount;
|
||||
|
||||
var definiteActionCount = 0;
|
||||
@@ -165,7 +166,7 @@ public sealed class Solver : IDisposable
|
||||
await semaphore.WaitAsync(Token).ConfigureAwait(false);
|
||||
try
|
||||
{
|
||||
solver.Search(iterCount, ref progress, Token);
|
||||
solver.Search(iterCount, maxIterCount, ref progress, Token);
|
||||
}
|
||||
finally
|
||||
{
|
||||
@@ -256,6 +257,7 @@ public sealed class Solver : IDisposable
|
||||
private async Task<SolverSolution> SearchStepwiseForked()
|
||||
{
|
||||
var iterCount = Config.Iterations / Config.ForkCount;
|
||||
var maxIterCount = Math.Max(Config.Iterations, Config.MaxIterations) / Config.ForkCount;
|
||||
maxProgress = iterCount * Config.ForkCount;
|
||||
|
||||
var actions = new List<ActionType>();
|
||||
@@ -278,7 +280,7 @@ public sealed class Solver : IDisposable
|
||||
await semaphore.WaitAsync(Token).ConfigureAwait(false);
|
||||
try
|
||||
{
|
||||
solver.Search(iterCount, ref progress, Token);
|
||||
solver.Search(iterCount, maxIterCount, ref progress, Token);
|
||||
}
|
||||
finally
|
||||
{
|
||||
@@ -329,7 +331,7 @@ public sealed class Solver : IDisposable
|
||||
var solver = new MCTS(MCTSConfig, state);
|
||||
|
||||
var s = Stopwatch.StartNew();
|
||||
solver.Search(Config.Iterations, ref progress, Token);
|
||||
solver.Search(Config.Iterations, Config.MaxIterations, ref progress, Token);
|
||||
s.Stop();
|
||||
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {progress / s.Elapsed.TotalSeconds / 1000:0.00} kI/s");
|
||||
|
||||
@@ -350,6 +352,7 @@ public sealed class Solver : IDisposable
|
||||
private async Task<SolverSolution> SearchOneshotForked()
|
||||
{
|
||||
var iterCount = Config.Iterations / Config.ForkCount;
|
||||
var maxIterCount = Math.Max(Config.Iterations, Config.MaxIterations) / Config.ForkCount;
|
||||
maxProgress = iterCount * Config.ForkCount;
|
||||
|
||||
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
|
||||
@@ -362,7 +365,7 @@ public sealed class Solver : IDisposable
|
||||
await semaphore.WaitAsync(Token).ConfigureAwait(false);
|
||||
try
|
||||
{
|
||||
solver.Search(iterCount, ref progress, Token);
|
||||
solver.Search(iterCount, maxIterCount, ref progress, Token);
|
||||
}
|
||||
finally
|
||||
{
|
||||
@@ -394,7 +397,7 @@ public sealed class Solver : IDisposable
|
||||
var solver = new MCTS(MCTSConfig, State);
|
||||
|
||||
var s = Stopwatch.StartNew();
|
||||
solver.Search(Config.Iterations, ref progress, Token);
|
||||
solver.Search(Config.Iterations, Config.MaxIterations, ref progress, Token);
|
||||
s.Stop();
|
||||
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {progress / s.Elapsed.TotalSeconds / 1000:0.00} kI/s");
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ public enum SolverAlgorithm
|
||||
public readonly record struct SolverConfig
|
||||
{
|
||||
public int Iterations { get; init; }
|
||||
public int MaxIterations { get; init; }
|
||||
public float MaxScoreWeightingConstant { get; init; }
|
||||
public float ExplorationConstant { get; init; }
|
||||
public int MaxStepCount { get; init; }
|
||||
@@ -38,6 +39,7 @@ public readonly record struct SolverConfig
|
||||
public SolverConfig()
|
||||
{
|
||||
Iterations = 100_000;
|
||||
MaxIterations = 1_500_000;
|
||||
MaxScoreWeightingConstant = 0.1f;
|
||||
ExplorationConstant = 4;
|
||||
MaxStepCount = 30;
|
||||
|
||||
Reference in New Issue
Block a user