Use ref int instead of callback for progress

This commit is contained in:
Asriel Camora
2023-11-14 02:50:36 -08:00
parent 930cbcce25
commit 1bdf4412c8
4 changed files with 43 additions and 57 deletions
+13 -15
View File
@@ -196,8 +196,7 @@ public sealed class MacroEditor : Window, IDisposable
private int? SolverStartStepCount { get; set; } private int? SolverStartStepCount { get; set; }
private object? SolverQueueLock { get; set; } private object? SolverQueueLock { get; set; }
private List<SimulatedActionStep>? SolverQueuedSteps { get; set; } private List<SimulatedActionStep>? SolverQueuedSteps { get; set; }
private int solverProgress; private Solver.Solver? SolverObject { get; set; }
private int maxSolverProgress;
private bool SolverRunning => SolverTokenSource != null; private bool SolverRunning => SolverTokenSource != null;
private IDalamudTextureWrap ExpertBadge { get; } private IDalamudTextureWrap ExpertBadge { get; }
@@ -1310,15 +1309,15 @@ public sealed class MacroEditor : Window, IDisposable
ImGui.Dummy(default); ImGui.Dummy(default);
ImGui.GetWindowDrawList().AddLine(pos, pos + new Vector2(availSpace, 0), ImGui.GetColorU32(ImGuiCol.Border)); ImGui.GetWindowDrawList().AddLine(pos, pos + new Vector2(availSpace, 0), ImGui.GetColorU32(ImGuiCol.Border));
ImGui.Dummy(default); ImGui.Dummy(default);
if (SolverRunning) if (SolverRunning && SolverObject is { } solver)
{ {
var percentWidth = ImGui.CalcTextSize("100%").X; var percentWidth = ImGui.CalcTextSize("100%").X;
var progressWidth = availSpace - percentWidth - spacing; var progressWidth = availSpace - percentWidth - spacing;
var fraction = Math.Clamp((float)solverProgress / maxSolverProgress, 0, 1); var fraction = Math.Clamp((float)solver.ProgressValue / solver.ProgressMax, 0, 1);
using (var color = ImRaii.PushColor(ImGuiCol.PlotHistogram, ImGuiColors.DalamudGrey3)) using (var color = ImRaii.PushColor(ImGuiCol.PlotHistogram, ImGuiColors.DalamudGrey3))
ImGui.ProgressBar(fraction, new(progressWidth, ImGui.GetFrameHeight()), string.Empty); ImGui.ProgressBar(fraction, new(progressWidth, ImGui.GetFrameHeight()), string.Empty);
if (ImGui.IsItemHovered()) if (ImGui.IsItemHovered())
ImGui.SetTooltip($"Solver Progress: {solverProgress} / {maxSolverProgress}"); ImGui.SetTooltip($"Solver Progress: {solver.ProgressValue} / {solver.ProgressMax}");
ImGui.SameLine(0, spacing); ImGui.SameLine(0, spacing);
ImGui.AlignTextToFramePadding(); ImGui.AlignTextToFramePadding();
ImGuiUtils.TextRight($"{fraction * 100:0}%", percentWidth); ImGuiUtils.TextRight($"{fraction * 100:0}%", percentWidth);
@@ -1617,7 +1616,6 @@ public sealed class MacroEditor : Window, IDisposable
} }
SolverQueueLock = new(); SolverQueueLock = new();
SolverQueuedSteps ??= new(); SolverQueuedSteps ??= new();
solverProgress = 0;
RevertPreviousMacro(); RevertPreviousMacro();
@@ -1636,7 +1634,10 @@ public sealed class MacroEditor : Window, IDisposable
_ = task.ContinueWith(t => _ = task.ContinueWith(t =>
{ {
if (token == SolverTokenSource.Token) if (token == SolverTokenSource.Token)
{
SolverTokenSource = null; SolverTokenSource = null;
SolverObject = null;
}
}); });
_ = task.ContinueWith(t => _ = task.ContinueWith(t =>
{ {
@@ -1661,16 +1662,13 @@ public sealed class MacroEditor : Window, IDisposable
token.ThrowIfCancellationRequested(); token.ThrowIfCancellationRequested();
var solver = new Solver.Solver(config, state) { Token = token }; using (SolverObject = new Solver.Solver(config, state) { Token = token })
solver.OnLog += Log.Debug;
solver.OnNewAction += QueueSolverStep;
solver.OnProgress += (p, m) =>
{ {
Interlocked.Exchange(ref solverProgress, p); SolverObject.OnLog += Log.Debug;
Interlocked.Exchange(ref maxSolverProgress, m); SolverObject.OnNewAction += QueueSolverStep;
}; SolverObject.Start();
solver.Start(); _ = SolverObject.GetTask().GetAwaiter().GetResult();
_ = solver.GetTask().GetAwaiter().GetResult(); }
token.ThrowIfCancellationRequested(); token.ThrowIfCancellationRequested();
} }
+5 -3
View File
@@ -281,12 +281,13 @@ public sealed class MCTS
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Search(int iterations, CancellationToken token, Action? progressCallback) public void Search(int iterations, ref int progress, CancellationToken token)
{ {
Simulator simulator = new(config.MaxStepCount); Simulator simulator = new(config.MaxStepCount);
var random = rootNode.State.State.Input.Random; var random = rootNode.State.State.Input.Random;
var staleCounter = 0; var staleCounter = 0;
for (var i = 0; i < iterations || MaxScore == 0; i++) var i = 0;
for (; i < iterations || MaxScore == 0; i++)
{ {
token.ThrowIfCancellationRequested(); token.ThrowIfCancellationRequested();
@@ -314,8 +315,9 @@ public sealed class MCTS
Backpropagate(endNode, score); Backpropagate(endNode, score);
if ((i & (ProgressUpdateFrequency - 1)) == ProgressUpdateFrequency - 1) if ((i & (ProgressUpdateFrequency - 1)) == ProgressUpdateFrequency - 1)
progressCallback?.Invoke(); Interlocked.Add(ref progress, ProgressUpdateFrequency);
} }
Interlocked.Add(ref progress, i & (ProgressUpdateFrequency - 1));
} }
[Pure] [Pure]
+25 -37
View File
@@ -22,20 +22,18 @@ public sealed class Solver : IDisposable
private int progress; private int progress;
private int maxProgress; private int maxProgress;
// In iterative algorithms, the value can be reset back to 0.
// In other algorithms, the value increases monotonically.
public int ProgressValue => progress;
public int ProgressMax => maxProgress;
public delegate void LogDelegate(string text); public delegate void LogDelegate(string text);
public delegate void WorkerProgressDelegate(SolverSolution solution, float score);
public delegate void ProgressDelegate(int value, int maxValue);
public delegate void NewActionDelegate(ActionType action); public delegate void NewActionDelegate(ActionType action);
public delegate void SolutionDelegate(SolverSolution solution); public delegate void SolutionDelegate(SolverSolution solution);
// Print to console or plugin log. // Print to console or plugin log.
public event LogDelegate? OnLog; public event LogDelegate? OnLog;
// Always called in some form in every algorithm.
// In iterative algorithms, the value can be reset back to 0.
// In other algorithms, the value increases monotonically.
public event ProgressDelegate? OnProgress;
// Always called when a new step is generated. // Always called when a new step is generated.
public event NewActionDelegate? OnNewAction; public event NewActionDelegate? OnNewAction;
@@ -115,27 +113,9 @@ public sealed class Solver : IDisposable
OnNewAction?.Invoke(sanitizedAction); OnNewAction?.Invoke(sanitizedAction);
} }
private void IncrementProgress() => private void ResetProgress()
IncrementProgressBy(MCTS.ProgressUpdateFrequency);
private void IncrementRemainingProgress(int iterations) =>
IncrementProgressBy(iterations & MCTS.ProgressUpdateFrequency);
private void IncrementProgressBy(int value)
{
OnProgress?.Invoke(Interlocked.Add(ref progress, value), maxProgress);
}
private void IncrementProgressSequence()
{ {
Interlocked.Exchange(ref progress, 0); Interlocked.Exchange(ref progress, 0);
OnProgress?.Invoke(0, maxProgress);
}
private void SearchWithIncrement(MCTS mcts, int iterations)
{
mcts.Search(iterations, Token, IncrementProgress);
IncrementRemainingProgress(iterations);
} }
private async Task<SolverSolution> SearchStepwiseFurcated() private async Task<SolverSolution> SearchStepwiseFurcated()
@@ -167,7 +147,7 @@ public sealed class Solver : IDisposable
await semaphore.WaitAsync(Token).ConfigureAwait(false); await semaphore.WaitAsync(Token).ConfigureAwait(false);
try try
{ {
SearchWithIncrement(solver, iterCount); solver.Search(iterCount, ref progress, Token);
} }
finally finally
{ {
@@ -184,7 +164,7 @@ public sealed class Solver : IDisposable
semaphore.Release(Config.MaxThreadCount); semaphore.Release(Config.MaxThreadCount);
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false); await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
s.Stop(); s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t"); OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {(float)progress / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
Token.ThrowIfCancellationRequested(); Token.ThrowIfCancellationRequested();
@@ -252,7 +232,7 @@ public sealed class Solver : IDisposable
} }
} }
IncrementProgressSequence(); ResetProgress();
activeStates = newStates; activeStates = newStates;
} }
@@ -292,7 +272,7 @@ public sealed class Solver : IDisposable
await semaphore.WaitAsync(Token).ConfigureAwait(false); await semaphore.WaitAsync(Token).ConfigureAwait(false);
try try
{ {
SearchWithIncrement(solver, iterCount); solver.Search(iterCount, ref progress, Token);
} }
finally finally
{ {
@@ -308,7 +288,7 @@ public sealed class Solver : IDisposable
semaphore.Release(Config.MaxThreadCount); semaphore.Release(Config.MaxThreadCount);
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false); await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
s.Stop(); s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t"); OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {(float)progress / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
Token.ThrowIfCancellationRequested(); Token.ThrowIfCancellationRequested();
@@ -328,7 +308,7 @@ public sealed class Solver : IDisposable
(_, state) = sim.Execute(state, chosenAction); (_, state) = sim.Execute(state, chosenAction);
actions.Add(chosenAction); actions.Add(chosenAction);
IncrementProgressSequence(); ResetProgress();
} }
return new(actions, state); return new(actions, state);
@@ -351,9 +331,9 @@ public sealed class Solver : IDisposable
var solver = new MCTS(MCTSConfig, state); var solver = new MCTS(MCTSConfig, state);
var s = Stopwatch.StartNew(); var s = Stopwatch.StartNew();
SearchWithIncrement(solver, Config.Iterations); solver.Search(Config.Iterations, ref progress, Token);
s.Stop(); s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / s.Elapsed.TotalSeconds / 1000:0.00} kI/s"); OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {progress / s.Elapsed.TotalSeconds / 1000:0.00} kI/s");
var solution = solver.Solution(); var solution = solver.Solution();
@@ -371,7 +351,7 @@ public sealed class Solver : IDisposable
(_, state) = sim.Execute(state, chosenAction); (_, state) = sim.Execute(state, chosenAction);
actions.Add(chosenAction); actions.Add(chosenAction);
IncrementProgressSequence(); ResetProgress();
} }
return Task.FromResult(new SolverSolution(actions, state)); return Task.FromResult(new SolverSolution(actions, state));
@@ -383,6 +363,7 @@ public sealed class Solver : IDisposable
maxProgress = iterCount * Config.ForkCount; maxProgress = iterCount * Config.ForkCount;
using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount); using var semaphore = new SemaphoreSlim(0, Config.MaxThreadCount);
var s = Stopwatch.StartNew();
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount]; var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount];
for (var i = 0; i < Config.ForkCount; ++i) for (var i = 0; i < Config.ForkCount; ++i)
tasks[i] = Task.Run(async () => tasks[i] = Task.Run(async () =>
@@ -391,7 +372,7 @@ public sealed class Solver : IDisposable
await semaphore.WaitAsync(Token).ConfigureAwait(false); await semaphore.WaitAsync(Token).ConfigureAwait(false);
try try
{ {
SearchWithIncrement(solver, iterCount); solver.Search(iterCount, ref progress, Token);
} }
finally finally
{ {
@@ -406,6 +387,8 @@ public sealed class Solver : IDisposable
}, Token); }, Token);
semaphore.Release(Config.MaxThreadCount); semaphore.Release(Config.MaxThreadCount);
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false); await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {(float)progress / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
var solution = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore).Solution; var solution = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore).Solution;
foreach (var action in solution.Actions) foreach (var action in solution.Actions)
@@ -419,7 +402,12 @@ public sealed class Solver : IDisposable
maxProgress = Config.Iterations; maxProgress = Config.Iterations;
var solver = new MCTS(MCTSConfig, State); var solver = new MCTS(MCTSConfig, State);
SearchWithIncrement(solver, Config.Iterations);
var s = Stopwatch.StartNew();
solver.Search(Config.Iterations, ref progress, Token);
s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {progress / s.Elapsed.TotalSeconds / 1000:0.00} kI/s");
var solution = solver.Solution(); var solution = solver.Solution();
foreach (var action in solution.Actions) foreach (var action in solution.Actions)
InvokeNewAction(action); InvokeNewAction(action);
-2
View File
@@ -200,8 +200,6 @@ public class SimulatorTests
[TestMethod] [TestMethod]
public void TestCompletedCraft2() public void TestCompletedCraft2()
{ {
Console.WriteLine($"{Input2.BaseProgressGain} {Input2.BaseProgressGain * (3.6f * 2.5f)}");
Console.WriteLine($"{(int)(Input2.BaseProgressGain * (3.6f * 2.5f))} {(int)MathF.Floor(Input2.BaseProgressGain * (3.6f * 2.5f))}");
AssertCraft( AssertCraft(
Input2, Input2,
new[] { new[] {