Implement max thread count setting using semaphores

This commit is contained in:
Asriel Camora
2023-10-02 23:02:11 -07:00
parent 42a1bc1117
commit 6bfff06ac4
5 changed files with 62 additions and 20 deletions
+16
View File
@@ -240,6 +240,22 @@ public class Settings : Window
ref isDirty ref isDirty
); );
ImGui.BeginDisabled(config.Algorithm is not (SolverAlgorithm.OneshotForked or SolverAlgorithm.StepwiseForked or SolverAlgorithm.StepwiseFurcated));
DrawOption(
"Max Core Count",
"The number of cores to use when solving. You should use as many\n" +
"as you can. If it's too high, it will have an effect on your gameplay\n" +
$"experience. A good estimate would be 1 or 2 cores less than your\n" +
$"system (FYI, you have {Environment.ProcessorCount} cores,) but\n" +
$"make sure to accomodate for any other tasks you have in the\n" +
$"background, if you have any.\n" +
"(Only used in the Forked and Furcated algorithms)",
config.MaxThreadCount,
v => config = config with { MaxThreadCount = v },
ref isDirty
);
ImGui.EndDisabled();
ImGui.BeginDisabled(config.Algorithm is not (SolverAlgorithm.OneshotForked or SolverAlgorithm.StepwiseForked or SolverAlgorithm.StepwiseFurcated)); ImGui.BeginDisabled(config.Algorithm is not (SolverAlgorithm.OneshotForked or SolverAlgorithm.StepwiseForked or SolverAlgorithm.StepwiseFurcated));
DrawOption( DrawOption(
"Fork Count", "Fork Count",
+1 -2
View File
@@ -285,8 +285,7 @@ public sealed class MCTS
var n = 0; var n = 0;
for (var i = 0; i < iterations || MaxScore == 0; i++) for (var i = 0; i < iterations || MaxScore == 0; i++)
{ {
if (token.IsCancellationRequested) token.ThrowIfCancellationRequested();
break;
var selectedNode = Select(); var selectedNode = Select();
var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode); var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode);
+2
View File
@@ -5,6 +5,8 @@ namespace Craftimizer.Solver;
[StructLayout(LayoutKind.Auto)] [StructLayout(LayoutKind.Auto)]
public readonly record struct MCTSConfig public readonly record struct MCTSConfig
{ {
public int MaxThreadCount { get; init; }
public int MaxStepCount { get; init; } public int MaxStepCount { get; init; }
public int MaxRolloutStepCount { get; init; } public int MaxRolloutStepCount { get; init; }
public bool StrictActions { get; init; } public bool StrictActions { get; init; }
+39 -15
View File
@@ -64,8 +64,7 @@ public sealed class Solver : IDisposable
private async Task RunTask() private async Task RunTask()
{ {
if (Token.IsCancellationRequested) Token.ThrowIfCancellationRequested();
return;
Solution = await SearchFunc().ConfigureAwait(false); Solution = await SearchFunc().ConfigureAwait(false);
} }
@@ -120,30 +119,38 @@ public sealed class Solver : IDisposable
while (activeStates.Count != 0) while (activeStates.Count != 0)
{ {
if (Token.IsCancellationRequested) Token.ThrowIfCancellationRequested();
break;
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
var s = Stopwatch.StartNew(); var s = Stopwatch.StartNew();
var tasks = new Task<(float MaxScore, int FurcatedActionIdx, SolverSolution Solution)>[Config.ForkCount]; var tasks = new Task<(float MaxScore, int FurcatedActionIdx, SolverSolution Solution)>[Config.ForkCount];
for (var i = 0; i < Config.ForkCount; i++) for (var i = 0; i < Config.ForkCount; i++)
{ {
var stateIdx = (int)((float)i / Config.ForkCount * activeStates.Count); var stateIdx = (int)((float)i / Config.ForkCount * activeStates.Count);
tasks[i] = Task.Run(() => tasks[i] = Task.Run(async () =>
{ {
var solver = new MCTS(MCTSConfig, activeStates[stateIdx].State); var solver = new MCTS(MCTSConfig, activeStates[stateIdx].State);
await semaphore.WaitAsync(Token).ConfigureAwait(false);
try
{
solver.Search(Config.Iterations / Config.ForkCount, Token); solver.Search(Config.Iterations / Config.ForkCount, Token);
}
finally
{
semaphore.Release();
}
var solution = solver.Solution(); var solution = solver.Solution();
var progressActions = activeStates[stateIdx].Actions.Concat(solution.Actions).Skip(definiteActionCount).ToList(); var progressActions = activeStates[stateIdx].Actions.Concat(solution.Actions).Skip(definiteActionCount).ToList();
OnWorkerProgress?.Invoke(solution with { Actions = progressActions }, solver.MaxScore); OnWorkerProgress?.Invoke(solution with { Actions = progressActions }, solver.MaxScore);
return (solver.MaxScore, stateIdx, solution); return (solver.MaxScore, stateIdx, solution);
}, Token); }, Token);
} }
semaphore.Release(Config.MaxThreadCount);
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false); await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
s.Stop(); s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t"); OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
if (Token.IsCancellationRequested) Token.ThrowIfCancellationRequested();
break;
var bestActions = tasks.Select(t => t.Result).OrderByDescending(r => r.MaxScore).Take(Config.FurcatedActionCount).ToArray(); var bestActions = tasks.Select(t => t.Result).OrderByDescending(r => r.MaxScore).Take(Config.FurcatedActionCount).ToArray();
@@ -229,29 +236,37 @@ public sealed class Solver : IDisposable
var sim = new Simulator(state, Config.MaxStepCount); var sim = new Simulator(state, Config.MaxStepCount);
while (true) while (true)
{ {
if (Token.IsCancellationRequested) Token.ThrowIfCancellationRequested();
break;
if (sim.IsComplete) if (sim.IsComplete)
break; break;
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
var s = Stopwatch.StartNew(); var s = Stopwatch.StartNew();
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount]; var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount];
for (var i = 0; i < Config.ForkCount; ++i) for (var i = 0; i < Config.ForkCount; ++i)
tasks[i] = Task.Run(() => tasks[i] = Task.Run(async () =>
{ {
var solver = new MCTS(MCTSConfig, state); var solver = new MCTS(MCTSConfig, state);
await semaphore.WaitAsync(Token).ConfigureAwait(false);
try
{
solver.Search(Config.Iterations / Config.ForkCount, Token); solver.Search(Config.Iterations / Config.ForkCount, Token);
}
finally
{
semaphore.Release();
}
var solution = solver.Solution(); var solution = solver.Solution();
OnWorkerProgress?.Invoke(solution, solver.MaxScore); OnWorkerProgress?.Invoke(solution, solver.MaxScore);
return (solver.MaxScore, solution); return (solver.MaxScore, solution);
}, Token); }, Token);
semaphore.Release(Config.MaxThreadCount);
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false); await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
s.Stop(); s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t"); OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
if (Token.IsCancellationRequested) Token.ThrowIfCancellationRequested();
break;
var (maxScore, solution) = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore); var (maxScore, solution) = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore);
@@ -280,8 +295,7 @@ public sealed class Solver : IDisposable
var sim = new Simulator(state, Config.MaxStepCount); var sim = new Simulator(state, Config.MaxStepCount);
while (true) while (true)
{ {
if (Token.IsCancellationRequested) Token.ThrowIfCancellationRequested();
break;
if (sim.IsComplete) if (sim.IsComplete)
break; break;
@@ -315,16 +329,26 @@ public sealed class Solver : IDisposable
private async Task<SolverSolution> SearchOneshotForked() private async Task<SolverSolution> SearchOneshotForked()
{ {
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount]; var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount];
for (var i = 0; i < Config.ForkCount; ++i) for (var i = 0; i < Config.ForkCount; ++i)
tasks[i] = Task.Run(() => tasks[i] = Task.Run(async () =>
{ {
var solver = new MCTS(MCTSConfig, State); var solver = new MCTS(MCTSConfig, State);
await semaphore.WaitAsync(Token).ConfigureAwait(false);
try
{
solver.Search(Config.Iterations / Config.ForkCount, Token); solver.Search(Config.Iterations / Config.ForkCount, Token);
}
finally
{
semaphore.Release();
}
var solution = solver.Solution(); var solution = solver.Solution();
OnWorkerProgress?.Invoke(solution, solver.MaxScore); OnWorkerProgress?.Invoke(solution, solver.MaxScore);
return (solver.MaxScore, solution); return (solver.MaxScore, solution);
}, Token); }, Token);
semaphore.Release(Config.MaxThreadCount);
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false); await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
var solution = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore).Solution; var solution = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore).Solution;
+1
View File
@@ -22,6 +22,7 @@ public readonly record struct SolverConfig
public float ExplorationConstant { get; init; } public float ExplorationConstant { get; init; }
public int MaxStepCount { get; init; } public int MaxStepCount { get; init; }
public int MaxRolloutStepCount { get; init; } public int MaxRolloutStepCount { get; init; }
public int MaxThreadCount { get; init; }
public int ForkCount { get; init; } public int ForkCount { get; init; }
public int FurcatedActionCount { get; init; } public int FurcatedActionCount { get; init; }
public bool StrictActions { get; init; } public bool StrictActions { get; init; }