Use TPL when solving with all algorithms

This commit is contained in:
Asriel Camora
2023-10-02 22:28:12 -07:00
parent c0f579f23b
commit 42a1bc1117
9 changed files with 241 additions and 144 deletions
+3
View File
@@ -0,0 +1,3 @@
# Auto detect text files and perform LF normalization
* text eol=crlf
*.png binary
+6 -6
View File
@@ -7,7 +7,7 @@ namespace Craftimizer.Benchmark;
internal static class Program internal static class Program
{ {
private static void Main() private static async Task Main()
{ {
//var summary = BenchmarkRunner.Run<Bench>(); //var summary = BenchmarkRunner.Run<Bench>();
//return; //return;
@@ -77,12 +77,12 @@ internal static class Program
Console.WriteLine($"{state.Quality} {state.CP} {state.Progress} {state.Durability}"); Console.WriteLine($"{state.Quality} {state.CP} {state.Progress} {state.Durability}");
//return; //return;
var (_, s) = config.Invoke(state, a => Console.WriteLine(a))!.Value; var solver = new Solver.Solver(config, state);
solver.OnLog += s => Console.WriteLine(s);
solver.OnNewAction += s => Console.WriteLine(s);
solver.Start();
var (_, s) = await solver.GetTask().ConfigureAwait(false);
Console.WriteLine($"Qual: {s.Quality}/{s.Input.Recipe.MaxQuality}"); Console.WriteLine($"Qual: {s.Quality}/{s.Input.Recipe.MaxQuality}");
return;
config.Invoke(new(input));
//Benchmark(() => );
} }
private static void Benchmark(Func<SolverSolution> search) private static void Benchmark(Func<SolverSolution> search)
+2 -1
View File
@@ -215,7 +215,8 @@ public sealed unsafe partial class Craft : Window, IDisposable
public void Dispose() public void Dispose()
{ {
StopSolve(); StopSolve();
SolverTask?.Wait(); SolverTaskToken?.Cancel();
SolverTask?.TryWait();
SolverTask?.Dispose(); SolverTask?.Dispose();
SolverTaskToken?.Dispose(); SolverTaskToken?.Dispose();
+6 -2
View File
@@ -1,6 +1,7 @@
using Craftimizer.Simulator; using Craftimizer.Simulator;
using Craftimizer.Simulator.Actions; using Craftimizer.Simulator.Actions;
using Dalamud.Interface.Windowing; using Dalamud.Interface.Windowing;
using Dalamud.Logging;
using System; using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
@@ -12,7 +13,7 @@ namespace Craftimizer.Plugin.Windows;
public sealed unsafe partial class Craft : Window, IDisposable public sealed unsafe partial class Craft : Window, IDisposable
{ {
private SimulationState? SolverState { get; set; } private SimulationState? SolverState { get; set; }
private Task? SolverTask { get; set; } private Solver.Solver? SolverTask { get; set; }
private CancellationTokenSource? SolverTaskToken { get; set; } private CancellationTokenSource? SolverTaskToken { get; set; }
private ConcurrentQueue<ActionType> SolverActionQueue { get; } = new(); private ConcurrentQueue<ActionType> SolverActionQueue { get; } = new();
@@ -48,7 +49,10 @@ public sealed unsafe partial class Craft : Window, IDisposable
SolverSim = new(state); SolverSim = new(state);
SolverTaskToken = new(); SolverTaskToken = new();
SolverTask = Task.Run(() => Config.SynthHelperSolverConfig.Invoke(state, SolverActionQueue.Enqueue, SolverTaskToken.Token)); SolverTask = new(Config.SynthHelperSolverConfig, state) { Token = SolverTaskToken.Token };
SolverTask.OnLog += s => PluginLog.Debug(s);
SolverTask.OnNewAction += SolverActionQueue.Enqueue;
SolverTask.Start();
} }
private void SolveTick() private void SolveTick()
+20 -20
View File
@@ -312,32 +312,32 @@ public class Settings : Window
DrawOption( DrawOption(
"Progress", "Progress",
"Amount of weight to give to the craft's progress.", "Amount of weight to give to the craft's progress.",
config.ScoreProgressBonus, config.ScoreProgress,
v => config = config with { ScoreProgressBonus = v }, v => config = config with { ScoreProgress = v },
ref isDirty ref isDirty
); );
DrawOption( DrawOption(
"Quality", "Quality",
"Amount of weight to give to the craft's quality.", "Amount of weight to give to the craft's quality.",
config.ScoreQualityBonus, config.ScoreQuality,
v => config = config with { ScoreQualityBonus = v }, v => config = config with { ScoreQuality = v },
ref isDirty ref isDirty
); );
DrawOption( DrawOption(
"Durability", "Durability",
"Amount of weight to give to the craft's remaining durability.", "Amount of weight to give to the craft's remaining durability.",
config.ScoreDurabilityBonus, config.ScoreDurability,
v => config = config with { ScoreDurabilityBonus = v }, v => config = config with { ScoreDurability = v },
ref isDirty ref isDirty
); );
DrawOption( DrawOption(
"CP", "CP",
"Amount of weight to give to the craft's remaining CP.", "Amount of weight to give to the craft's remaining CP.",
config.ScoreCPBonus, config.ScoreCP,
v => config = config with { ScoreCPBonus = v }, v => config = config with { ScoreCP = v },
ref isDirty ref isDirty
); );
@@ -345,25 +345,25 @@ public class Settings : Window
"Steps", "Steps",
"Amount of weight to give to the craft's number of steps. The lower\n" + "Amount of weight to give to the craft's number of steps. The lower\n" +
"the step count, the higher the score.", "the step count, the higher the score.",
config.ScoreFewerStepsBonus, config.ScoreSteps,
v => config = config with { ScoreFewerStepsBonus = v }, v => config = config with { ScoreSteps = v },
ref isDirty ref isDirty
); );
if (ImGui.Button("Normalize Weights", OptionButtonSize)) if (ImGui.Button("Normalize Weights", OptionButtonSize))
{ {
var total = config.ScoreProgressBonus + var total = config.ScoreProgress +
config.ScoreQualityBonus + config.ScoreQuality +
config.ScoreDurabilityBonus + config.ScoreDurability +
config.ScoreCPBonus + config.ScoreCP +
config.ScoreFewerStepsBonus; config.ScoreSteps;
config = config with config = config with
{ {
ScoreProgressBonus = config.ScoreProgressBonus / total, ScoreProgress = config.ScoreProgress / total,
ScoreQualityBonus = config.ScoreQualityBonus / total, ScoreQuality = config.ScoreQuality / total,
ScoreDurabilityBonus = config.ScoreDurabilityBonus / total, ScoreDurability = config.ScoreDurability / total,
ScoreCPBonus = config.ScoreCPBonus / total, ScoreCP = config.ScoreCP / total,
ScoreFewerStepsBonus = config.ScoreFewerStepsBonus / total ScoreSteps = config.ScoreSteps / total
}; };
isDirty = true; isDirty = true;
} }
+8 -3
View File
@@ -1,6 +1,7 @@
using Craftimizer.Simulator; using Craftimizer.Simulator;
using Craftimizer.Simulator.Actions; using Craftimizer.Simulator.Actions;
using Dalamud.Interface.Windowing; using Dalamud.Interface.Windowing;
using Dalamud.Logging;
using System; using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Threading; using System.Threading;
@@ -10,7 +11,7 @@ namespace Craftimizer.Plugin.Windows;
public sealed partial class Simulator : Window, IDisposable public sealed partial class Simulator : Window, IDisposable
{ {
private Task? SolverTask { get; set; } private Solver.Solver? SolverTask { get; set; }
private CancellationTokenSource? SolverTaskToken { get; set; } private CancellationTokenSource? SolverTaskToken { get; set; }
private ConcurrentQueue<ActionType> SolverActionQueue { get; } = new(); private ConcurrentQueue<ActionType> SolverActionQueue { get; } = new();
private int SolverInitialActionCount { get; set; } private int SolverInitialActionCount { get; set; }
@@ -83,13 +84,17 @@ public sealed partial class Simulator : Window, IDisposable
SolverInitialActionCount = Actions.Count; SolverInitialActionCount = Actions.Count;
SolverTaskToken = new(); SolverTaskToken = new();
SolverTask = Task.Run(() => Config.SimulatorSolverConfig.Invoke(solverState, SolverActionQueue.Enqueue, SolverTaskToken.Token)); SolverTask = new(Config.SimulatorSolverConfig, solverState) { Token = SolverTaskToken.Token };
SolverTask.OnLog += s => PluginLog.Debug(s);
SolverTask.OnNewAction += SolverActionQueue.Enqueue;
SolverTask.Start();
} }
public void Dispose() public void Dispose()
{ {
StopSolveMacro(); StopSolveMacro();
SolverTask?.Wait(); SolverTaskToken?.Cancel();
SolverTask?.TryWait();
SolverTask?.Dispose(); SolverTask?.Dispose();
SolverTaskToken?.Dispose(); SolverTaskToken?.Dispose();
} }
+5 -5
View File
@@ -29,10 +29,10 @@ public readonly record struct MCTSConfig
ExplorationConstant = config.ExplorationConstant; ExplorationConstant = config.ExplorationConstant;
ScoreStorageThreshold = config.ScoreStorageThreshold; ScoreStorageThreshold = config.ScoreStorageThreshold;
ScoreProgress = config.ScoreProgressBonus; ScoreProgress = config.ScoreProgress;
ScoreQuality = config.ScoreQualityBonus; ScoreQuality = config.ScoreQuality;
ScoreDurability = config.ScoreDurabilityBonus; ScoreDurability = config.ScoreDurability;
ScoreCP = config.ScoreCPBonus; ScoreCP = config.ScoreCP;
ScoreSteps = config.ScoreFewerStepsBonus; ScoreSteps = config.ScoreSteps;
} }
} }
+180 -79
View File
@@ -4,57 +4,165 @@ using System.Diagnostics;
namespace Craftimizer.Solver; namespace Craftimizer.Solver;
public static class Solver public sealed class Solver : IDisposable
{ {
private static SolverSolution SearchStepwiseFurcated(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token) public SolverConfig Config { get; }
public SimulationState State { get; }
public CancellationToken Token { get; init; }
public SolverSolution? Solution { get; private set; }
public bool IsStarted => CompletionTask != null;
public bool IsCompletedSuccessfully => Solution != null;
public bool IsCompleted => CompletionTask?.IsCompleted ?? false;
private Func<Task<SolverSolution>> SearchFunc { get; }
private MCTSConfig MCTSConfig => new(Config);
private Task? CompletionTask { get; set; }
public delegate void LogDelegate(string text);
public delegate void WorkerProgressDelegate(SolverSolution solution, float score);
public delegate void NewActionDelegate(ActionType action);
public delegate void SolutionDelegate(SolverSolution solution);
// Print to console or plugin log.
public event LogDelegate? OnLog;
// Isn't always called. This is just meant to show as an indicator to the user.
// Solution contains the best terminal state, and its actions to get there exclude the ones provided by OnNewAction.
// For example, to get to the terminal state, execute all OnNewAction actions, then execute all Solution actions.
public event WorkerProgressDelegate? OnWorkerProgress;
// Always called when a new step is generated.
public event NewActionDelegate? OnNewAction;
// Always called when the solver is fully complete.
public event SolutionDelegate? OnSolution;
public Solver(SolverConfig config, SimulationState state)
{
Config = config;
State = state;
SearchFunc = Config.Algorithm switch
{
SolverAlgorithm.Oneshot => SearchOneshot,
SolverAlgorithm.OneshotForked => SearchOneshotForked,
SolverAlgorithm.Stepwise => SearchStepwise,
SolverAlgorithm.StepwiseForked => SearchStepwiseForked,
SolverAlgorithm.StepwiseFurcated => SearchStepwiseFurcated,
_ => throw new ArgumentOutOfRangeException(nameof(config), config, $"Invalid algorithm: {config.Algorithm}")
};
}
public void Start()
{
if (IsStarted)
throw new InvalidOperationException("Solver has already started.");
CompletionTask = RunTask();
}
private async Task RunTask()
{
if (Token.IsCancellationRequested)
return;
Solution = await SearchFunc().ConfigureAwait(false);
}
public async Task<SolverSolution> GetTask()
{
if (!IsStarted)
throw new InvalidOperationException("Solver has not started.");
await CompletionTask!.ConfigureAwait(false);
return Solution!.Value;
}
public async Task<SolverSolution?> GetSafeTask()
{
try
{
return await GetTask().ConfigureAwait(false);
}
catch (AggregateException e)
{
e.Handle(ex => ex is OperationCanceledException);
}
catch (OperationCanceledException)
{
}
return null;
}
public void TryWait()
{
if (IsStarted && !IsCompleted)
GetSafeTask().Wait();
}
public void Dispose()
{
CompletionTask?.Dispose();
}
private async Task<SolverSolution> SearchStepwiseFurcated()
{ {
var definiteActionCount = 0; var definiteActionCount = 0;
var bestSims = new List<(float Score, SolverSolution Result)>(); var bestSims = new List<(float Score, SolverSolution Result)>();
var sim = new Simulator(state, config.MaxStepCount); var state = State;
var sim = new Simulator(state, Config.MaxStepCount);
var activeStates = new List<SolverSolution>() { new(new(), state) }; var activeStates = new List<SolverSolution>() { new(new(), state) };
while (activeStates.Count != 0) while (activeStates.Count != 0)
{ {
if (token.IsCancellationRequested) if (Token.IsCancellationRequested)
break; break;
var s = Stopwatch.StartNew(); var s = Stopwatch.StartNew();
var tasks = new Task<(float MaxScore, int FurcatedActionIdx, SolverSolution Solution)>[config.ForkCount]; var tasks = new Task<(float MaxScore, int FurcatedActionIdx, SolverSolution Solution)>[Config.ForkCount];
for (var i = 0; i < config.ForkCount; i++) for (var i = 0; i < Config.ForkCount; i++)
{ {
var stateIdx = (int)((float)i / config.ForkCount * activeStates.Count); var stateIdx = (int)((float)i / Config.ForkCount * activeStates.Count);
var st = activeStates[stateIdx];
tasks[i] = Task.Run(() => tasks[i] = Task.Run(() =>
{ {
var solver = new MCTS(new(config), activeStates[stateIdx].State); var solver = new MCTS(MCTSConfig, activeStates[stateIdx].State);
solver.Search(config.Iterations / config.ForkCount, token); solver.Search(Config.Iterations / Config.ForkCount, Token);
return (solver.MaxScore, stateIdx, solver.Solution()); var solution = solver.Solution();
}, token); var progressActions = activeStates[stateIdx].Actions.Concat(solution.Actions).Skip(definiteActionCount).ToList();
OnWorkerProgress?.Invoke(solution with { Actions = progressActions }, solver.MaxScore);
return (solver.MaxScore, stateIdx, solution);
}, Token);
} }
Task.WaitAll(tasks, token); await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
s.Stop(); s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
if (token.IsCancellationRequested) if (Token.IsCancellationRequested)
break; break;
var bestActions = tasks.Select(t => t.Result).OrderByDescending(r => r.MaxScore).Take(config.FurcatedActionCount).ToArray(); var bestActions = tasks.Select(t => t.Result).OrderByDescending(r => r.MaxScore).Take(Config.FurcatedActionCount).ToArray();
var bestAction = bestActions[0]; var bestAction = bestActions[0];
if (bestAction.MaxScore >= config.ScoreStorageThreshold) if (bestAction.MaxScore >= Config.ScoreStorageThreshold)
{ {
var (maxScore, furcatedActionIdx, solution) = bestAction; var (_, furcatedActionIdx, solution) = bestAction;
var (activeActions, activeState) = activeStates[furcatedActionIdx]; var (activeActions, _) = activeStates[furcatedActionIdx];
activeActions.AddRange(solution.Actions); activeActions.AddRange(solution.Actions);
foreach (var action in activeActions.Skip(definiteActionCount))
OnNewAction?.Invoke(action);
return solution with { Actions = activeActions }; return solution with { Actions = activeActions };
} }
var newStates = new List<SolverSolution>(config.FurcatedActionCount); var newStates = new List<SolverSolution>(Config.FurcatedActionCount);
for (var i = 0; i < bestActions.Length; ++i) for (var i = 0; i < bestActions.Length; ++i)
{ {
var (maxScore, furcatedActionIdx, (solutionActions, solutionNode)) = bestActions[i]; var (maxScore, furcatedActionIdx, (solutionActions, _)) = bestActions[i];
if (solutionActions.Count == 0) if (solutionActions.Count == 0)
continue; continue;
@@ -94,67 +202,69 @@ public static class Solver
} }
if (definiteCount != equalCount) if (definiteCount != equalCount)
{ {
for (var i = definiteCount; i < equalCount; ++i) foreach(var action in refActions.Take(equalCount).Skip(definiteCount))
actionCallback?.Invoke(refActions[i]); OnNewAction?.Invoke(action);
definiteActionCount = equalCount; definiteActionCount = equalCount;
} }
} }
activeStates = newStates; activeStates = newStates;
Console.WriteLine($"{s.Elapsed.TotalMilliseconds:0.00}ms {config.Iterations / config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
} }
if (bestSims.Count == 0) if (bestSims.Count == 0)
return new(new(), state); return new(new(), state);
var result = bestSims.MaxBy(s => s.Score).Result; var result = bestSims.MaxBy(s => s.Score).Result;
for (var i = definiteActionCount; i < result.Actions.Count; ++i) foreach (var action in result.Actions.Skip(definiteActionCount))
actionCallback?.Invoke(result.Actions[i]); OnNewAction?.Invoke(action);
return result; return result;
} }
private static SolverSolution SearchStepwiseForked(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token) private async Task<SolverSolution> SearchStepwiseForked()
{ {
var actions = new List<ActionType>(); var actions = new List<ActionType>();
var sim = new Simulator(state, config.MaxStepCount); var state = State;
var sim = new Simulator(state, Config.MaxStepCount);
while (true) while (true)
{ {
if (token.IsCancellationRequested) if (Token.IsCancellationRequested)
break; break;
if (sim.IsComplete) if (sim.IsComplete)
break; break;
var s = Stopwatch.StartNew(); var s = Stopwatch.StartNew();
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[config.ForkCount]; var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount];
for (var i = 0; i < config.ForkCount; ++i) for (var i = 0; i < Config.ForkCount; ++i)
tasks[i] = Task.Run(() => tasks[i] = Task.Run(() =>
{ {
var solver = new MCTS(new(config), state); var solver = new MCTS(MCTSConfig, state);
solver.Search(config.Iterations / config.ForkCount, token); solver.Search(Config.Iterations / Config.ForkCount, Token);
return (solver.MaxScore, solver.Solution()); var solution = solver.Solution();
}, token); OnWorkerProgress?.Invoke(solution, solver.MaxScore);
Task.WaitAll(tasks, token); return (solver.MaxScore, solution);
}, Token);
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
s.Stop(); s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / Config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
if (token.IsCancellationRequested) if (Token.IsCancellationRequested)
break; break;
var (maxScore, solution) = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore); var (maxScore, solution) = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore);
if (maxScore >= config.ScoreStorageThreshold) if (maxScore >= Config.ScoreStorageThreshold)
{ {
actions.AddRange(solution.Actions); actions.AddRange(solution.Actions);
foreach (var action in solution.Actions)
OnNewAction?.Invoke(action);
return solution with { Actions = actions }; return solution with { Actions = actions };
} }
var chosenAction = solution.Actions[0]; var chosenAction = solution.Actions[0];
actionCallback?.Invoke(chosenAction); OnNewAction?.Invoke(chosenAction);
Console.WriteLine($"{s.Elapsed.TotalMilliseconds:0.00}ms {config.Iterations / config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
(_, state) = sim.Execute(state, chosenAction); (_, state) = sim.Execute(state, chosenAction);
actions.Add(chosenAction); actions.Add(chosenAction);
@@ -163,84 +273,75 @@ public static class Solver
return new(actions, state); return new(actions, state);
} }
private static SolverSolution SearchStepwise(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token) private Task<SolverSolution> SearchStepwise()
{ {
var actions = new List<ActionType>(); var actions = new List<ActionType>();
var sim = new Simulator(state, config.MaxStepCount); var state = State;
var sim = new Simulator(state, Config.MaxStepCount);
while (true) while (true)
{ {
if (token.IsCancellationRequested) if (Token.IsCancellationRequested)
break; break;
if (sim.IsComplete) if (sim.IsComplete)
break; break;
var solver = new MCTS(new(config), state); var solver = new MCTS(MCTSConfig, State);
var s = Stopwatch.StartNew(); var s = Stopwatch.StartNew();
solver.Search(config.Iterations, token); solver.Search(Config.Iterations, Token);
s.Stop(); s.Stop();
OnLog?.Invoke($"{s.Elapsed.TotalMilliseconds:0.00}ms {Config.Iterations / s.Elapsed.TotalSeconds / 1000:0.00} kI/s");
var solution = solver.Solution(); var solution = solver.Solution();
if (solver.MaxScore >= config.ScoreStorageThreshold) if (solver.MaxScore >= Config.ScoreStorageThreshold)
{ {
actions.AddRange(solution.Actions); actions.AddRange(solution.Actions);
return solution with { Actions = actions }; foreach (var action in solution.Actions)
OnNewAction?.Invoke(action);
return Task.FromResult(solution with { Actions = actions });
} }
var chosenAction = solution.Actions[0]; var chosenAction = solution.Actions[0];
actionCallback?.Invoke(chosenAction); OnNewAction?.Invoke(chosenAction);
Console.WriteLine($"{s.Elapsed.TotalMilliseconds:0.00}ms {config.Iterations / s.Elapsed.TotalSeconds / 1000:0.00} kI/s");
(_, state) = sim.Execute(state, chosenAction); (_, state) = sim.Execute(state, chosenAction);
actions.Add(chosenAction); actions.Add(chosenAction);
} }
return new(actions, state); return Task.FromResult(new SolverSolution(actions, state));
} }
private static SolverSolution SearchOneshotForked(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token) private async Task<SolverSolution> SearchOneshotForked()
{ {
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[config.ForkCount]; var tasks = new Task<(float MaxScore, SolverSolution Solution)>[Config.ForkCount];
for (var i = 0; i < config.ForkCount; ++i) for (var i = 0; i < Config.ForkCount; ++i)
tasks[i] = Task.Run(() => tasks[i] = Task.Run(() =>
{ {
var solver = new MCTS(new(config), state); var solver = new MCTS(MCTSConfig, State);
solver.Search(config.Iterations / config.ForkCount, token); solver.Search(Config.Iterations / Config.ForkCount, Token);
return (solver.MaxScore, solver.Solution()); var solution = solver.Solution();
}, token); OnWorkerProgress?.Invoke(solution, solver.MaxScore);
Task.WaitAll(tasks, CancellationToken.None); return (solver.MaxScore, solution);
}, Token);
await Task.WhenAll(tasks).WaitAsync(Token).ConfigureAwait(false);
var solution = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore).Solution; var solution = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore).Solution;
foreach (var action in solution.Actions) foreach (var action in solution.Actions)
actionCallback?.Invoke(action); OnNewAction?.Invoke(action);
return solution; return solution;
} }
private static SolverSolution SearchOneshot(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token) private Task<SolverSolution> SearchOneshot()
{ {
var solver = new MCTS(new(config), state); var solver = new MCTS(MCTSConfig, State);
solver.Search(config.Iterations, token); solver.Search(Config.Iterations, Token);
var solution = solver.Solution(); var solution = solver.Solution();
foreach (var action in solution.Actions) foreach (var action in solution.Actions)
actionCallback?.Invoke(action); OnNewAction?.Invoke(action);
return solution; return Task.FromResult(solution);
}
public static SolverSolution Search(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token)
{
Func<SolverConfig, SimulationState, Action<ActionType>?, CancellationToken, SolverSolution> func = config.Algorithm switch
{
SolverAlgorithm.Oneshot => SearchOneshot,
SolverAlgorithm.OneshotForked => SearchOneshotForked,
SolverAlgorithm.Stepwise => SearchStepwise,
SolverAlgorithm.StepwiseForked => SearchStepwiseForked,
SolverAlgorithm.StepwiseFurcated => SearchStepwiseFurcated,
_ => throw new ArgumentOutOfRangeException(nameof(config), config, $"Invalid algorithm: {config.Algorithm}")
};
return func(config, state, actionCallback, token);
} }
} }
+10 -27
View File
@@ -26,11 +26,11 @@ public readonly record struct SolverConfig
public int FurcatedActionCount { get; init; } public int FurcatedActionCount { get; init; }
public bool StrictActions { get; init; } public bool StrictActions { get; init; }
public float ScoreProgressBonus { get; init; } public float ScoreProgress { get; init; }
public float ScoreQualityBonus { get; init; } public float ScoreQuality { get; init; }
public float ScoreDurabilityBonus { get; init; } public float ScoreDurability { get; init; }
public float ScoreCPBonus { get; init; } public float ScoreCP { get; init; }
public float ScoreFewerStepsBonus { get; init; } public float ScoreSteps { get; init; }
public SolverAlgorithm Algorithm { get; init; } public SolverAlgorithm Algorithm { get; init; }
@@ -46,11 +46,11 @@ public readonly record struct SolverConfig
FurcatedActionCount = ForkCount / 2; FurcatedActionCount = ForkCount / 2;
StrictActions = true; StrictActions = true;
ScoreProgressBonus = .20f; ScoreProgress = .20f;
ScoreQualityBonus = .65f; ScoreQuality = .65f;
ScoreDurabilityBonus = .05f; ScoreDurability = .05f;
ScoreCPBonus = .05f; ScoreCP = .05f;
ScoreFewerStepsBonus = .05f; ScoreSteps = .05f;
Algorithm = SolverAlgorithm.StepwiseFurcated; Algorithm = SolverAlgorithm.StepwiseFurcated;
} }
@@ -67,21 +67,4 @@ public readonly record struct SolverConfig
FurcatedActionCount = Environment.ProcessorCount / 2, FurcatedActionCount = Environment.ProcessorCount / 2,
Algorithm = SolverAlgorithm.StepwiseForked Algorithm = SolverAlgorithm.StepwiseForked
}; };
public SolverSolution? Invoke(SimulationState state, Action<ActionType>? actionCallback = null, CancellationToken token = default)
{
try
{
return Solver.Search(this, state, actionCallback, token);
}
catch (AggregateException e)
{
e.Handle(ex => ex is OperationCanceledException);
}
catch (OperationCanceledException)
{
}
return null;
}
} }