Implement max thread count setting using semaphores

This commit is contained in:
Asriel Camora
2023-10-02 23:02:11 -07:00
parent 42a1bc1117
commit 6bfff06ac4
5 changed files with 62 additions and 20 deletions
+42 -18
View File
@@ -64,8 +64,7 @@ public sealed class Solver : IDisposable
private async Task RunTask()
{
if (Token.IsCancellationRequested)
return;
Token.ThrowIfCancellationRequested();
Solution = await SearchFunc().ConfigureAwait(false);
}
@@ -120,30 +119,38 @@ public sealed class Solver : IDisposable
while (activeStates.Count != 0)
{
if (Token.IsCancellationRequested)
break;
Token.ThrowIfCancellationRequested();
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
var s = Stopwatch.StartNew();
var tasks = new Task<(float MaxScore, int FurcatedActionIdx, SolverSolution Solution)>[Config.ForkCount];
for (var i = 0; i < Config.ForkCount; i++)
{
var stateIdx = (int)((float)i / Config.ForkCount * activeStates.Count);
tasks[i] = Task.Run(() =>
tasks[i] = Task.Run(async () =>
{
var solver = new MCTS(MCTSConfig, activeStates[stateIdx].State);
solver.Search(Config.Iterations / Config.ForkCount, Token);
await semaphore.WaitAsync(Token).ConfigureAwait(false);
try
{
solver.Search(Config.Iterations / Config.ForkCount, Token);
}
finally
{
semaphore.Release();
}
var solution = solver.Solution();
var progressActions = activeStates[stateIdx].Actions.Concat(solution.Actions).Skip(definiteActionCount).ToList();
OnWorkerProgress?.Invoke(solution with { Actions = progressActions }, solver.MaxScore);
return (solver.MaxScore, stateIdx, solution);
}, Token);
}
semaphore.Release(Config.MaxThreadCount);
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
if (Token.IsCancellationRequested)
break;
Token.ThrowIfCancellationRequested();
var bestActions = tasks.Select(t => t.Result).OrderByDescending(r => r.MaxScore).Take(Config.FurcatedActionCount).ToArray();
@@ -229,29 +236,37 @@ public sealed class Solver : IDisposable
var sim = new Simulator(state, Config.MaxStepCount);
while (true)
{
if (Token.IsCancellationRequested)
break;
Token.ThrowIfCancellationRequested();
if (sim.IsComplete)
break;
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
var s = Stopwatch.StartNew();
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount];
for (var i = 0; i < Config.ForkCount; ++i)
tasks[i] = Task.Run(() =>
tasks[i] = Task.Run(async () =>
{
var solver = new MCTS(MCTSConfig, state);
solver.Search(Config.Iterations / Config.ForkCount, Token);
await semaphore.WaitAsync(Token).ConfigureAwait(false);
try
{
solver.Search(Config.Iterations / Config.ForkCount, Token);
}
finally
{
semaphore.Release();
}
var solution = solver.Solution();
OnWorkerProgress?.Invoke(solution, solver.MaxScore);
return (solver.MaxScore, solution);
}, Token);
semaphore.Release(Config.MaxThreadCount);
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
if (Token.IsCancellationRequested)
break;
Token.ThrowIfCancellationRequested();
var (maxScore, solution) = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore);
@@ -280,8 +295,7 @@ public sealed class Solver : IDisposable
var sim = new Simulator(state, Config.MaxStepCount);
while (true)
{
if (Token.IsCancellationRequested)
break;
Token.ThrowIfCancellationRequested();
if (sim.IsComplete)
break;
@@ -315,16 +329,26 @@ public sealed class Solver : IDisposable
private async Task<SolverSolution> SearchOneshotForked()
{
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount];
for (var i = 0; i < Config.ForkCount; ++i)
tasks[i] = Task.Run(() =>
tasks[i] = Task.Run(async () =>
{
var solver = new MCTS(MCTSConfig, State);
solver.Search(Config.Iterations / Config.ForkCount, Token);
await semaphore.WaitAsync(Token).ConfigureAwait(false);
try
{
solver.Search(Config.Iterations / Config.ForkCount, Token);
}
finally
{
semaphore.Release();
}
var solution = solver.Solution();
OnWorkerProgress?.Invoke(solution, solver.MaxScore);
return (solver.MaxScore, solution);
}, Token);
semaphore.Release(Config.MaxThreadCount);
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
var solution = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore).Solution;