Use ref int instead of callback for progress

This commit is contained in:
Asriel Camora
2023-11-14 02:50:36 -08:00
parent 930cbcce25
commit 1bdf4412c8
4 changed files with 43 additions and 57 deletions
+25 -37
View File
@@ -22,20 +22,18 @@ public sealed class Solver : IDisposable
private int progress;
private int maxProgress;
// In iterative algorithms, the value can be reset back to 0.
// In other algorithms, the value increases monotonically.
public int ProgressValue => progress;
public int ProgressMax => maxProgress;
public delegate void LogDelegate(string text);
public delegate void WorkerProgressDelegate(SolverSolution solution, float score);
public delegate void ProgressDelegate(int value, int maxValue);
public delegate void NewActionDelegate(ActionType action);
public delegate void SolutionDelegate(SolverSolution solution);
// Print to console or plugin log.
public event LogDelegate? OnLog;
// Always called in some form in every algorithm.
// In iterative algorithms, the value can be reset back to 0.
// In other algorithms, the value increases monotonically.
public event ProgressDelegate? OnProgress;
// Always called when a new step is generated.
public event NewActionDelegate? OnNewAction;
@@ -115,27 +113,9 @@ public sealed class Solver : IDisposable
OnNewAction?.Invoke(sanitizedAction);
}
private void IncrementProgress() =>
IncrementProgressBy(MCTS.ProgressUpdateFrequency);
private void IncrementRemainingProgress(int iterations) =>
IncrementProgressBy(iterations & MCTS.ProgressUpdateFrequency);
private void IncrementProgressBy(int value)
{
OnProgress?.Invoke(Interlocked.Add(ref progress, value), maxProgress);
}
private void IncrementProgressSequence()
private void ResetProgress()
{
Interlocked.Exchange(ref progress, 0);
OnProgress?.Invoke(0, maxProgress);
}
private void SearchWithIncrement(MCTS mcts, int iterations)
{
mcts.Search(iterations, Token, IncrementProgress);
IncrementRemainingProgress(iterations);
}
private async Task<SolverSolution> SearchStepwiseFurcated()
@@ -167,7 +147,7 @@ public sealed class Solver : IDisposable
await semaphore.WaitAsync(Token).ConfigureAwait(false);
try
{
SearchWithIncrement(solver, iterCount);
solver.Search(iterCount, ref progress, Token);
}
finally
{
@@ -184,7 +164,7 @@ public sealed class Solver : IDisposable
semaphore.Release(Config.MaxThreadCount);
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {(float)progress / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
Token.ThrowIfCancellationRequested();
@@ -252,7 +232,7 @@ public sealed class Solver : IDisposable
}
}
IncrementProgressSequence();
ResetProgress();
activeStates = newStates;
}
@@ -292,7 +272,7 @@ public sealed class Solver : IDisposable
await semaphore.WaitAsync(Token).ConfigureAwait(false);
try
{
SearchWithIncrement(solver, iterCount);
solver.Search(iterCount, ref progress, Token);
}
finally
{
@@ -308,7 +288,7 @@ public sealed class Solver : IDisposable
semaphore.Release(Config.MaxThreadCount);
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {(float)progress / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
Token.ThrowIfCancellationRequested();
@@ -328,7 +308,7 @@ public sealed class Solver : IDisposable
(_, state) = sim.Execute(state, chosenAction);
actions.Add(chosenAction);
IncrementProgressSequence();
ResetProgress();
}
return new(actions, state);
@@ -351,9 +331,9 @@ public sealed class Solver : IDisposable
var solver = new MCTS(MCTSConfig, state);
var s = Stopwatch.StartNew();
SearchWithIncrement(solver, Config.Iterations);
solver.Search(Config.Iterations, ref progress, Token);
s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / s.Elapsed.TotalSeconds / 1000:0.00} kI/s");
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {progress / s.Elapsed.TotalSeconds / 1000:0.00} kI/s");
var solution = solver.Solution();
@@ -371,7 +351,7 @@ public sealed class Solver : IDisposable
(_, state) = sim.Execute(state, chosenAction);
actions.Add(chosenAction);
IncrementProgressSequence();
ResetProgress();
}
return Task.FromResult(new SolverSolution(actions, state));
@@ -383,6 +363,7 @@ public sealed class Solver : IDisposable
maxProgress = iterCount * Config.ForkCount;
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
var s = Stopwatch.StartNew();
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount];
for (var i = 0; i < Config.ForkCount; ++i)
tasks[i] = Task.Run(async () =>
@@ -391,7 +372,7 @@ public sealed class Solver : IDisposable
await semaphore.WaitAsync(Token).ConfigureAwait(false);
try
{
SearchWithIncrement(solver, iterCount);
solver.Search(iterCount, ref progress, Token);
}
finally
{
@@ -406,6 +387,8 @@ public sealed class Solver : IDisposable
}, Token);
semaphore.Release(Config.MaxThreadCount);
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {(float)progress / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
var solution = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore).Solution;
foreach (var action in solution.Actions)
@@ -419,7 +402,12 @@ public sealed class Solver : IDisposable
maxProgress = Config.Iterations;
var solver = new MCTS(MCTSConfig, State);
SearchWithIncrement(solver, Config.Iterations);
var s = Stopwatch.StartNew();
solver.Search(Config.Iterations, ref progress, Token);
s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {progress / s.Elapsed.TotalSeconds / 1000:0.00} kI/s");
var solution = solver.Solution();
foreach (var action in solution.Actions)
InvokeNewAction(action);