From 6bfff06ac4c56b339acb2ca1808c8e25ec5b1ce2 Mon Sep 17 00:00:00 2001 From: Asriel Camora Date: Mon, 2 Oct 2023 23:02:11 -0700 Subject: [PATCH] Implement max thread count setting using semaphores --- Craftimizer/Windows/Settings.cs | 16 +++++++++ Solver/MCTS.cs | 3 +- Solver/MCTSConfig.cs | 2 ++ Solver/Solver.cs | 60 +++++++++++++++++++++++---------- Solver/SolverConfig.cs | 1 + 5 files changed, 62 insertions(+), 20 deletions(-) diff --git a/Craftimizer/Windows/Settings.cs b/Craftimizer/Windows/Settings.cs index 23bdd9c..87d7e7a 100644 --- a/Craftimizer/Windows/Settings.cs +++ b/Craftimizer/Windows/Settings.cs @@ -240,6 +240,22 @@ public class Settings : Window ref isDirty ); + ImGui.BeginDisabled(config.Algorithm is not (SolverAlgorithm.OneshotForked or SolverAlgorithm.StepwiseForked or SolverAlgorithm.StepwiseFurcated)); + DrawOption( + "Max Core Count", + "The number of cores to use when solving. You should use as many\n" + + "as you can. If it's too high, it will have an effect on your gameplay\n" + + $"experience. A good estimate would be 1 or 2 cores less than your\n" + + $"system (FYI, you have {Environment.ProcessorCount} cores,) but\n" + + $"make sure to accomodate for any other tasks you have in the\n" + + $"background, if you have any.\n" + + "(Only used in the Forked and Furcated algorithms)", + config.MaxThreadCount, + v => config = config with { MaxThreadCount = v }, + ref isDirty + ); + ImGui.EndDisabled(); + ImGui.BeginDisabled(config.Algorithm is not (SolverAlgorithm.OneshotForked or SolverAlgorithm.StepwiseForked or SolverAlgorithm.StepwiseFurcated)); DrawOption( "Fork Count", diff --git a/Solver/MCTS.cs b/Solver/MCTS.cs index 4f976fb..8e40a6b 100644 --- a/Solver/MCTS.cs +++ b/Solver/MCTS.cs @@ -285,8 +285,7 @@ public sealed class MCTS var n = 0; for (var i = 0; i < iterations || MaxScore == 0; i++) { - if (token.IsCancellationRequested) - break; + token.ThrowIfCancellationRequested(); var selectedNode = Select(); var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode); diff --git a/Solver/MCTSConfig.cs b/Solver/MCTSConfig.cs index c0d057d..a0e554d 100644 --- a/Solver/MCTSConfig.cs +++ b/Solver/MCTSConfig.cs @@ -5,6 +5,8 @@ namespace Craftimizer.Solver; [StructLayout(LayoutKind.Auto)] public readonly record struct MCTSConfig { + public int MaxThreadCount { get; init; } + public int MaxStepCount { get; init; } public int MaxRolloutStepCount { get; init; } public bool StrictActions { get; init; } diff --git a/Solver/Solver.cs b/Solver/Solver.cs index a34abe9..080c5bd 100644 --- a/Solver/Solver.cs +++ b/Solver/Solver.cs @@ -64,8 +64,7 @@ public sealed class Solver : IDisposable private async Task RunTask() { - if (Token.IsCancellationRequested) - return; + Token.ThrowIfCancellationRequested(); Solution = await SearchFunc().ConfigureAwait(false); } @@ -120,30 +119,38 @@ public sealed class Solver : IDisposable while (activeStates.Count != 0) { - if (Token.IsCancellationRequested) - break; + Token.ThrowIfCancellationRequested(); + using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount); var s = Stopwatch.StartNew(); var tasks = new Task<(float MaxScore, int FurcatedActionIdx, SolverSolution Solution)>[Config.ForkCount]; for (var i = 0; i < Config.ForkCount; i++) { var stateIdx = (int)((float)i / Config.ForkCount * activeStates.Count); - tasks[i] = Task.Run(() => + tasks[i] = Task.Run(async () => { var solver = new MCTS(MCTSConfig, activeStates[stateIdx].State); - solver.Search(Config.Iterations / Config.ForkCount, Token); + await semaphore.WaitAsync(Token).ConfigureAwait(false); + try + { + solver.Search(Config.Iterations / Config.ForkCount, Token); + } + finally + { + semaphore.Release(); + } var solution = solver.Solution(); var progressActions = activeStates[stateIdx].Actions.Concat(solution.Actions).Skip(definiteActionCount).ToList(); OnWorkerProgress?.Invoke(solution with { Actions = progressActions }, solver.MaxScore); return (solver.MaxScore, stateIdx, solution); }, Token); } + semaphore.Release(Config.MaxThreadCount); await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false); s.Stop(); OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t"); - if (Token.IsCancellationRequested) - break; + Token.ThrowIfCancellationRequested(); var bestActions = tasks.Select(t => t.Result).OrderByDescending(r => r.MaxScore).Take(Config.FurcatedActionCount).ToArray(); @@ -229,29 +236,37 @@ public sealed class Solver : IDisposable var sim = new Simulator(state, Config.MaxStepCount); while (true) { - if (Token.IsCancellationRequested) - break; + Token.ThrowIfCancellationRequested(); if (sim.IsComplete) break; + using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount); var s = Stopwatch.StartNew(); var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount]; for (var i = 0; i < Config.ForkCount; ++i) - tasks[i] = Task.Run(() => + tasks[i] = Task.Run(async () => { var solver = new MCTS(MCTSConfig, state); - solver.Search(Config.Iterations / Config.ForkCount, Token); + await semaphore.WaitAsync(Token).ConfigureAwait(false); + try + { + solver.Search(Config.Iterations / Config.ForkCount, Token); + } + finally + { + semaphore.Release(); + } var solution = solver.Solution(); OnWorkerProgress?.Invoke(solution, solver.MaxScore); return (solver.MaxScore, solution); }, Token); + semaphore.Release(Config.MaxThreadCount); await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false); s.Stop(); OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t"); - if (Token.IsCancellationRequested) - break; + Token.ThrowIfCancellationRequested(); var (maxScore, solution) = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore); @@ -280,8 +295,7 @@ public sealed class Solver : IDisposable var sim = new Simulator(state, Config.MaxStepCount); while (true) { - if (Token.IsCancellationRequested) - break; + Token.ThrowIfCancellationRequested(); if (sim.IsComplete) break; @@ -315,16 +329,26 @@ public sealed class Solver : IDisposable private async Task SearchOneshotForked() { + using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount); var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount]; for (var i = 0; i < Config.ForkCount; ++i) - tasks[i] = Task.Run(() => + tasks[i] = Task.Run(async () => { var solver = new MCTS(MCTSConfig, State); - solver.Search(Config.Iterations / Config.ForkCount, Token); + await semaphore.WaitAsync(Token).ConfigureAwait(false); + try + { + solver.Search(Config.Iterations / Config.ForkCount, Token); + } + finally + { + semaphore.Release(); + } var solution = solver.Solution(); OnWorkerProgress?.Invoke(solution, solver.MaxScore); return (solver.MaxScore, solution); }, Token); + semaphore.Release(Config.MaxThreadCount); await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false); var solution = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore).Solution; diff --git a/Solver/SolverConfig.cs b/Solver/SolverConfig.cs index cbb7c5a..5122f51 100644 --- a/Solver/SolverConfig.cs +++ b/Solver/SolverConfig.cs @@ -22,6 +22,7 @@ public readonly record struct SolverConfig public float ExplorationConstant { get; init; } public int MaxStepCount { get; init; } public int MaxRolloutStepCount { get; init; } + public int MaxThreadCount { get; init; } public int ForkCount { get; init; } public int FurcatedActionCount { get; init; } public bool StrictActions { get; init; }