Implement max thread count setting using semaphores
This commit is contained in:
@@ -240,6 +240,22 @@ public class Settings : Window
|
|||||||
ref isDirty
|
ref isDirty
|
||||||
);
|
);
|
||||||
|
|
||||||
|
ImGui.BeginDisabled(config.Algorithm is not (SolverAlgorithm.OneshotForked or SolverAlgorithm.StepwiseForked or SolverAlgorithm.StepwiseFurcated));
|
||||||
|
DrawOption(
|
||||||
|
"Max Core Count",
|
||||||
|
"The number of cores to use when solving. You should use as many\n" +
|
||||||
|
"as you can. If it's too high, it will have an effect on your gameplay\n" +
|
||||||
|
$"experience. A good estimate would be 1 or 2 cores less than your\n" +
|
||||||
|
$"system (FYI, you have {Environment.ProcessorCount} cores,) but\n" +
|
||||||
|
$"make sure to accomodate for any other tasks you have in the\n" +
|
||||||
|
$"background, if you have any.\n" +
|
||||||
|
"(Only used in the Forked and Furcated algorithms)",
|
||||||
|
config.MaxThreadCount,
|
||||||
|
v => config = config with { MaxThreadCount = v },
|
||||||
|
ref isDirty
|
||||||
|
);
|
||||||
|
ImGui.EndDisabled();
|
||||||
|
|
||||||
ImGui.BeginDisabled(config.Algorithm is not (SolverAlgorithm.OneshotForked or SolverAlgorithm.StepwiseForked or SolverAlgorithm.StepwiseFurcated));
|
ImGui.BeginDisabled(config.Algorithm is not (SolverAlgorithm.OneshotForked or SolverAlgorithm.StepwiseForked or SolverAlgorithm.StepwiseFurcated));
|
||||||
DrawOption(
|
DrawOption(
|
||||||
"Fork Count",
|
"Fork Count",
|
||||||
|
|||||||
+1
-2
@@ -285,8 +285,7 @@ public sealed class MCTS
|
|||||||
var n = 0;
|
var n = 0;
|
||||||
for (var i = 0; i < iterations || MaxScore == 0; i++)
|
for (var i = 0; i < iterations || MaxScore == 0; i++)
|
||||||
{
|
{
|
||||||
if (token.IsCancellationRequested)
|
token.ThrowIfCancellationRequested();
|
||||||
break;
|
|
||||||
|
|
||||||
var selectedNode = Select();
|
var selectedNode = Select();
|
||||||
var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode);
|
var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode);
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ namespace Craftimizer.Solver;
|
|||||||
[StructLayout(LayoutKind.Auto)]
|
[StructLayout(LayoutKind.Auto)]
|
||||||
public readonly record struct MCTSConfig
|
public readonly record struct MCTSConfig
|
||||||
{
|
{
|
||||||
|
public int MaxThreadCount { get; init; }
|
||||||
|
|
||||||
public int MaxStepCount { get; init; }
|
public int MaxStepCount { get; init; }
|
||||||
public int MaxRolloutStepCount { get; init; }
|
public int MaxRolloutStepCount { get; init; }
|
||||||
public bool StrictActions { get; init; }
|
public bool StrictActions { get; init; }
|
||||||
|
|||||||
+39
-15
@@ -64,8 +64,7 @@ public sealed class Solver : IDisposable
|
|||||||
|
|
||||||
private async Task RunTask()
|
private async Task RunTask()
|
||||||
{
|
{
|
||||||
if (Token.IsCancellationRequested)
|
Token.ThrowIfCancellationRequested();
|
||||||
return;
|
|
||||||
|
|
||||||
Solution = await SearchFunc().ConfigureAwait(false);
|
Solution = await SearchFunc().ConfigureAwait(false);
|
||||||
}
|
}
|
||||||
@@ -120,30 +119,38 @@ public sealed class Solver : IDisposable
|
|||||||
|
|
||||||
while (activeStates.Count != 0)
|
while (activeStates.Count != 0)
|
||||||
{
|
{
|
||||||
if (Token.IsCancellationRequested)
|
Token.ThrowIfCancellationRequested();
|
||||||
break;
|
|
||||||
|
|
||||||
|
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
|
||||||
var s = Stopwatch.StartNew();
|
var s = Stopwatch.StartNew();
|
||||||
var tasks = new Task<(float MaxScore, int FurcatedActionIdx, SolverSolution Solution)>[Config.ForkCount];
|
var tasks = new Task<(float MaxScore, int FurcatedActionIdx, SolverSolution Solution)>[Config.ForkCount];
|
||||||
for (var i = 0; i < Config.ForkCount; i++)
|
for (var i = 0; i < Config.ForkCount; i++)
|
||||||
{
|
{
|
||||||
var stateIdx = (int)((float)i / Config.ForkCount * activeStates.Count);
|
var stateIdx = (int)((float)i / Config.ForkCount * activeStates.Count);
|
||||||
tasks[i] = Task.Run(() =>
|
tasks[i] = Task.Run(async () =>
|
||||||
{
|
{
|
||||||
var solver = new MCTS(MCTSConfig, activeStates[stateIdx].State);
|
var solver = new MCTS(MCTSConfig, activeStates[stateIdx].State);
|
||||||
|
await semaphore.WaitAsync(Token).ConfigureAwait(false);
|
||||||
|
try
|
||||||
|
{
|
||||||
solver.Search(Config.Iterations / Config.ForkCount, Token);
|
solver.Search(Config.Iterations / Config.ForkCount, Token);
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
semaphore.Release();
|
||||||
|
}
|
||||||
var solution = solver.Solution();
|
var solution = solver.Solution();
|
||||||
var progressActions = activeStates[stateIdx].Actions.Concat(solution.Actions).Skip(definiteActionCount).ToList();
|
var progressActions = activeStates[stateIdx].Actions.Concat(solution.Actions).Skip(definiteActionCount).ToList();
|
||||||
OnWorkerProgress?.Invoke(solution with { Actions = progressActions }, solver.MaxScore);
|
OnWorkerProgress?.Invoke(solution with { Actions = progressActions }, solver.MaxScore);
|
||||||
return (solver.MaxScore, stateIdx, solution);
|
return (solver.MaxScore, stateIdx, solution);
|
||||||
}, Token);
|
}, Token);
|
||||||
}
|
}
|
||||||
|
semaphore.Release(Config.MaxThreadCount);
|
||||||
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
|
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
|
||||||
s.Stop();
|
s.Stop();
|
||||||
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
|
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
|
||||||
|
|
||||||
if (Token.IsCancellationRequested)
|
Token.ThrowIfCancellationRequested();
|
||||||
break;
|
|
||||||
|
|
||||||
var bestActions = tasks.Select(t => t.Result).OrderByDescending(r => r.MaxScore).Take(Config.FurcatedActionCount).ToArray();
|
var bestActions = tasks.Select(t => t.Result).OrderByDescending(r => r.MaxScore).Take(Config.FurcatedActionCount).ToArray();
|
||||||
|
|
||||||
@@ -229,29 +236,37 @@ public sealed class Solver : IDisposable
|
|||||||
var sim = new Simulator(state, Config.MaxStepCount);
|
var sim = new Simulator(state, Config.MaxStepCount);
|
||||||
while (true)
|
while (true)
|
||||||
{
|
{
|
||||||
if (Token.IsCancellationRequested)
|
Token.ThrowIfCancellationRequested();
|
||||||
break;
|
|
||||||
|
|
||||||
if (sim.IsComplete)
|
if (sim.IsComplete)
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
|
||||||
var s = Stopwatch.StartNew();
|
var s = Stopwatch.StartNew();
|
||||||
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount];
|
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount];
|
||||||
for (var i = 0; i < Config.ForkCount; ++i)
|
for (var i = 0; i < Config.ForkCount; ++i)
|
||||||
tasks[i] = Task.Run(() =>
|
tasks[i] = Task.Run(async () =>
|
||||||
{
|
{
|
||||||
var solver = new MCTS(MCTSConfig, state);
|
var solver = new MCTS(MCTSConfig, state);
|
||||||
|
await semaphore.WaitAsync(Token).ConfigureAwait(false);
|
||||||
|
try
|
||||||
|
{
|
||||||
solver.Search(Config.Iterations / Config.ForkCount, Token);
|
solver.Search(Config.Iterations / Config.ForkCount, Token);
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
semaphore.Release();
|
||||||
|
}
|
||||||
var solution = solver.Solution();
|
var solution = solver.Solution();
|
||||||
OnWorkerProgress?.Invoke(solution, solver.MaxScore);
|
OnWorkerProgress?.Invoke(solution, solver.MaxScore);
|
||||||
return (solver.MaxScore, solution);
|
return (solver.MaxScore, solution);
|
||||||
}, Token);
|
}, Token);
|
||||||
|
semaphore.Release(Config.MaxThreadCount);
|
||||||
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
|
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
|
||||||
s.Stop();
|
s.Stop();
|
||||||
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
|
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
|
||||||
|
|
||||||
if (Token.IsCancellationRequested)
|
Token.ThrowIfCancellationRequested();
|
||||||
break;
|
|
||||||
|
|
||||||
var (maxScore, solution) = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore);
|
var (maxScore, solution) = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore);
|
||||||
|
|
||||||
@@ -280,8 +295,7 @@ public sealed class Solver : IDisposable
|
|||||||
var sim = new Simulator(state, Config.MaxStepCount);
|
var sim = new Simulator(state, Config.MaxStepCount);
|
||||||
while (true)
|
while (true)
|
||||||
{
|
{
|
||||||
if (Token.IsCancellationRequested)
|
Token.ThrowIfCancellationRequested();
|
||||||
break;
|
|
||||||
|
|
||||||
if (sim.IsComplete)
|
if (sim.IsComplete)
|
||||||
break;
|
break;
|
||||||
@@ -315,16 +329,26 @@ public sealed class Solver : IDisposable
|
|||||||
|
|
||||||
private async Task<SolverSolution> SearchOneshotForked()
|
private async Task<SolverSolution> SearchOneshotForked()
|
||||||
{
|
{
|
||||||
|
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
|
||||||
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount];
|
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount];
|
||||||
for (var i = 0; i < Config.ForkCount; ++i)
|
for (var i = 0; i < Config.ForkCount; ++i)
|
||||||
tasks[i] = Task.Run(() =>
|
tasks[i] = Task.Run(async () =>
|
||||||
{
|
{
|
||||||
var solver = new MCTS(MCTSConfig, State);
|
var solver = new MCTS(MCTSConfig, State);
|
||||||
|
await semaphore.WaitAsync(Token).ConfigureAwait(false);
|
||||||
|
try
|
||||||
|
{
|
||||||
solver.Search(Config.Iterations / Config.ForkCount, Token);
|
solver.Search(Config.Iterations / Config.ForkCount, Token);
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
semaphore.Release();
|
||||||
|
}
|
||||||
var solution = solver.Solution();
|
var solution = solver.Solution();
|
||||||
OnWorkerProgress?.Invoke(solution, solver.MaxScore);
|
OnWorkerProgress?.Invoke(solution, solver.MaxScore);
|
||||||
return (solver.MaxScore, solution);
|
return (solver.MaxScore, solution);
|
||||||
}, Token);
|
}, Token);
|
||||||
|
semaphore.Release(Config.MaxThreadCount);
|
||||||
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
|
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
|
||||||
|
|
||||||
var solution = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore).Solution;
|
var solution = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore).Solution;
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ public readonly record struct SolverConfig
|
|||||||
public float ExplorationConstant { get; init; }
|
public float ExplorationConstant { get; init; }
|
||||||
public int MaxStepCount { get; init; }
|
public int MaxStepCount { get; init; }
|
||||||
public int MaxRolloutStepCount { get; init; }
|
public int MaxRolloutStepCount { get; init; }
|
||||||
|
public int MaxThreadCount { get; init; }
|
||||||
public int ForkCount { get; init; }
|
public int ForkCount { get; init; }
|
||||||
public int FurcatedActionCount { get; init; }
|
public int FurcatedActionCount { get; init; }
|
||||||
public bool StrictActions { get; init; }
|
public bool StrictActions { get; init; }
|
||||||
|
|||||||
Reference in New Issue
Block a user