Implement max thread count setting using semaphores
This commit is contained in:
@@ -240,6 +240,22 @@ public class Settings : Window
|
||||
ref isDirty
|
||||
);
|
||||
|
||||
ImGui.BeginDisabled(config.Algorithm is not (SolverAlgorithm.OneshotForked or SolverAlgorithm.StepwiseForked or SolverAlgorithm.StepwiseFurcated));
|
||||
DrawOption(
|
||||
"Max Core Count",
|
||||
"The number of cores to use when solving. You should use as many\n" +
|
||||
"as you can. If it's too high, it will have an effect on your gameplay\n" +
|
||||
$"experience. A good estimate would be 1 or 2 cores less than your\n" +
|
||||
$"system (FYI, you have {Environment.ProcessorCount} cores,) but\n" +
|
||||
$"make sure to accomodate for any other tasks you have in the\n" +
|
||||
$"background, if you have any.\n" +
|
||||
"(Only used in the Forked and Furcated algorithms)",
|
||||
config.MaxThreadCount,
|
||||
v => config = config with { MaxThreadCount = v },
|
||||
ref isDirty
|
||||
);
|
||||
ImGui.EndDisabled();
|
||||
|
||||
ImGui.BeginDisabled(config.Algorithm is not (SolverAlgorithm.OneshotForked or SolverAlgorithm.StepwiseForked or SolverAlgorithm.StepwiseFurcated));
|
||||
DrawOption(
|
||||
"Fork Count",
|
||||
|
||||
+1
-2
@@ -285,8 +285,7 @@ public sealed class MCTS
|
||||
var n = 0;
|
||||
for (var i = 0; i < iterations || MaxScore == 0; i++)
|
||||
{
|
||||
if (token.IsCancellationRequested)
|
||||
break;
|
||||
token.ThrowIfCancellationRequested();
|
||||
|
||||
var selectedNode = Select();
|
||||
var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode);
|
||||
|
||||
@@ -5,6 +5,8 @@ namespace Craftimizer.Solver;
|
||||
[StructLayout(LayoutKind.Auto)]
|
||||
public readonly record struct MCTSConfig
|
||||
{
|
||||
public int MaxThreadCount { get; init; }
|
||||
|
||||
public int MaxStepCount { get; init; }
|
||||
public int MaxRolloutStepCount { get; init; }
|
||||
public bool StrictActions { get; init; }
|
||||
|
||||
+39
-15
@@ -64,8 +64,7 @@ public sealed class Solver : IDisposable
|
||||
|
||||
private async Task RunTask()
|
||||
{
|
||||
if (Token.IsCancellationRequested)
|
||||
return;
|
||||
Token.ThrowIfCancellationRequested();
|
||||
|
||||
Solution = await SearchFunc().ConfigureAwait(false);
|
||||
}
|
||||
@@ -120,30 +119,38 @@ public sealed class Solver : IDisposable
|
||||
|
||||
while (activeStates.Count != 0)
|
||||
{
|
||||
if (Token.IsCancellationRequested)
|
||||
break;
|
||||
Token.ThrowIfCancellationRequested();
|
||||
|
||||
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
|
||||
var s = Stopwatch.StartNew();
|
||||
var tasks = new Task<(float MaxScore, int FurcatedActionIdx, SolverSolution Solution)>[Config.ForkCount];
|
||||
for (var i = 0; i < Config.ForkCount; i++)
|
||||
{
|
||||
var stateIdx = (int)((float)i / Config.ForkCount * activeStates.Count);
|
||||
tasks[i] = Task.Run(() =>
|
||||
tasks[i] = Task.Run(async () =>
|
||||
{
|
||||
var solver = new MCTS(MCTSConfig, activeStates[stateIdx].State);
|
||||
await semaphore.WaitAsync(Token).ConfigureAwait(false);
|
||||
try
|
||||
{
|
||||
solver.Search(Config.Iterations / Config.ForkCount, Token);
|
||||
}
|
||||
finally
|
||||
{
|
||||
semaphore.Release();
|
||||
}
|
||||
var solution = solver.Solution();
|
||||
var progressActions = activeStates[stateIdx].Actions.Concat(solution.Actions).Skip(definiteActionCount).ToList();
|
||||
OnWorkerProgress?.Invoke(solution with { Actions = progressActions }, solver.MaxScore);
|
||||
return (solver.MaxScore, stateIdx, solution);
|
||||
}, Token);
|
||||
}
|
||||
semaphore.Release(Config.MaxThreadCount);
|
||||
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
|
||||
s.Stop();
|
||||
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
|
||||
|
||||
if (Token.IsCancellationRequested)
|
||||
break;
|
||||
Token.ThrowIfCancellationRequested();
|
||||
|
||||
var bestActions = tasks.Select(t => t.Result).OrderByDescending(r => r.MaxScore).Take(Config.FurcatedActionCount).ToArray();
|
||||
|
||||
@@ -229,29 +236,37 @@ public sealed class Solver : IDisposable
|
||||
var sim = new Simulator(state, Config.MaxStepCount);
|
||||
while (true)
|
||||
{
|
||||
if (Token.IsCancellationRequested)
|
||||
break;
|
||||
Token.ThrowIfCancellationRequested();
|
||||
|
||||
if (sim.IsComplete)
|
||||
break;
|
||||
|
||||
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
|
||||
var s = Stopwatch.StartNew();
|
||||
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount];
|
||||
for (var i = 0; i < Config.ForkCount; ++i)
|
||||
tasks[i] = Task.Run(() =>
|
||||
tasks[i] = Task.Run(async () =>
|
||||
{
|
||||
var solver = new MCTS(MCTSConfig, state);
|
||||
await semaphore.WaitAsync(Token).ConfigureAwait(false);
|
||||
try
|
||||
{
|
||||
solver.Search(Config.Iterations / Config.ForkCount, Token);
|
||||
}
|
||||
finally
|
||||
{
|
||||
semaphore.Release();
|
||||
}
|
||||
var solution = solver.Solution();
|
||||
OnWorkerProgress?.Invoke(solution, solver.MaxScore);
|
||||
return (solver.MaxScore, solution);
|
||||
}, Token);
|
||||
semaphore.Release(Config.MaxThreadCount);
|
||||
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
|
||||
s.Stop();
|
||||
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
|
||||
|
||||
if (Token.IsCancellationRequested)
|
||||
break;
|
||||
Token.ThrowIfCancellationRequested();
|
||||
|
||||
var (maxScore, solution) = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore);
|
||||
|
||||
@@ -280,8 +295,7 @@ public sealed class Solver : IDisposable
|
||||
var sim = new Simulator(state, Config.MaxStepCount);
|
||||
while (true)
|
||||
{
|
||||
if (Token.IsCancellationRequested)
|
||||
break;
|
||||
Token.ThrowIfCancellationRequested();
|
||||
|
||||
if (sim.IsComplete)
|
||||
break;
|
||||
@@ -315,16 +329,26 @@ public sealed class Solver : IDisposable
|
||||
|
||||
private async Task<SolverSolution> SearchOneshotForked()
|
||||
{
|
||||
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
|
||||
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount];
|
||||
for (var i = 0; i < Config.ForkCount; ++i)
|
||||
tasks[i] = Task.Run(() =>
|
||||
tasks[i] = Task.Run(async () =>
|
||||
{
|
||||
var solver = new MCTS(MCTSConfig, State);
|
||||
await semaphore.WaitAsync(Token).ConfigureAwait(false);
|
||||
try
|
||||
{
|
||||
solver.Search(Config.Iterations / Config.ForkCount, Token);
|
||||
}
|
||||
finally
|
||||
{
|
||||
semaphore.Release();
|
||||
}
|
||||
var solution = solver.Solution();
|
||||
OnWorkerProgress?.Invoke(solution, solver.MaxScore);
|
||||
return (solver.MaxScore, solution);
|
||||
}, Token);
|
||||
semaphore.Release(Config.MaxThreadCount);
|
||||
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
|
||||
|
||||
var solution = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore).Solution;
|
||||
|
||||
@@ -22,6 +22,7 @@ public readonly record struct SolverConfig
|
||||
public float ExplorationConstant { get; init; }
|
||||
public int MaxStepCount { get; init; }
|
||||
public int MaxRolloutStepCount { get; init; }
|
||||
public int MaxThreadCount { get; init; }
|
||||
public int ForkCount { get; init; }
|
||||
public int FurcatedActionCount { get; init; }
|
||||
public bool StrictActions { get; init; }
|
||||
|
||||
Reference in New Issue
Block a user