Fix concurrency, add forked simulations

This commit is contained in:
Asriel Camora
2023-07-07 18:13:27 +02:00
parent 7d8fc9ff8f
commit 3ab50d389e
6 changed files with 84 additions and 14 deletions
+5 -3
View File
@@ -46,14 +46,16 @@ internal static class Program
var config = new SolverConfig() var config = new SolverConfig()
{ {
Iterations = 1_000_000, Iterations = 100_000,
ThreadCount = 8, ThreadCount = 8,
}; };
Debugger.Break(); Debugger.Break();
var s = Stopwatch.StartNew(); var s = Stopwatch.StartNew();
if (true) if (true) {
_ = SolverUtils.SearchStepwise<SolverSingle>(config, input, a => Console.WriteLine(a)); (_, var state) = SolverUtils.SearchStepwise<SolverSingle>(config, input, a => Console.WriteLine(a));
Console.WriteLine($"Qual: {state.Quality}/{state.Input.Recipe.MaxQuality}");
}
else else
{ {
(var actions, _) = SolverUtils.SearchOneshot<SolverConcurrent>(config, input); (var actions, _) = SolverUtils.SearchOneshot<SolverConcurrent>(config, input);
+4 -4
View File
@@ -20,15 +20,15 @@ public sealed class ArenaNode<T> where T : struct
Parent = parent; Parent = parent;
} }
public ArenaNode<T> ChildAt((int arrayIdx, int subIdx) at) => public ArenaNode<T>? ChildAt((int arrayIdx, int subIdx) at) =>
Children.Data[at.arrayIdx][at.subIdx]; Children.Data?[at.arrayIdx]?[at.subIdx];
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public ArenaNode<T> AddConcurrent(T state) public ArenaNode<T> AddConcurrent(T state)
{ {
var node = new ArenaNode<T>(state, this); var node = new ArenaNode<T>(state, this);
Children.AddConcurrent(node);
ChildScores.AddConcurrent(); ChildScores.AddConcurrent();
Children.AddConcurrent(node);
return node; return node;
} }
@@ -36,8 +36,8 @@ public sealed class ArenaNode<T> where T : struct
public ArenaNode<T> Add(T state) public ArenaNode<T> Add(T state)
{ {
var node = new ArenaNode<T>(state, this); var node = new ArenaNode<T>(state, this);
Children.Add(node);
ChildScores.Add(); ChildScores.Add();
Children.Add(node);
return node; return node;
} }
} }
+5
View File
@@ -33,6 +33,11 @@ public sealed class SolverConcurrent : ISolver
{ {
nodeVisits = node.ChildScores.GetVisits(at.Value); nodeVisits = node.ChildScores.GetVisits(at.Value);
node = node.ChildAt(at.Value); node = node.ChildAt(at.Value);
if (node == null)
{
node = rootNode;
nodeVisits = rootNodeVisits;
}
} }
else else
{ {
+3 -1
View File
@@ -10,6 +10,7 @@ public readonly record struct SolverConfig
public float MaxScoreWeightingConstant { get; init; } public float MaxScoreWeightingConstant { get; init; }
public float ExplorationConstant { get; init; } public float ExplorationConstant { get; init; }
public int MaxStepCount { get; init; } public int MaxStepCount { get; init; }
public int MaxRolloutStepCount { get; init; }
public int ThreadCount { get; init; } public int ThreadCount { get; init; }
public SolverConfig() public SolverConfig()
@@ -17,8 +18,9 @@ public readonly record struct SolverConfig
Iterations = 300000; Iterations = 300000;
ScoreStorageThreshold = 1f; ScoreStorageThreshold = 1f;
MaxScoreWeightingConstant = 0.1f; MaxScoreWeightingConstant = 0.1f;
ExplorationConstant = 4f; ExplorationConstant = 2;
MaxStepCount = 25; MaxStepCount = 25;
MaxRolloutStepCount = 99;
ThreadCount = Environment.ProcessorCount; ThreadCount = Environment.ProcessorCount;
} }
} }
+1 -1
View File
@@ -25,7 +25,7 @@ public sealed class SolverSingle : ISolver
// select the node with the highest score // select the node with the highest score
var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores); var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores);
nodeVisits = node.ChildScores.GetVisits(at); nodeVisits = node.ChildScores.GetVisits(at);
node = node.ChildAt(at); node = node.ChildAt(at)!;
} }
} }
+66 -5
View File
@@ -1,9 +1,9 @@
using Craftimizer.Simulator.Actions;
using Craftimizer.Simulator; using Craftimizer.Simulator;
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>; using Craftimizer.Simulator.Actions;
using System.Diagnostics.Contracts; using System.Diagnostics.Contracts;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
namespace Craftimizer.Solver.Crafty; namespace Craftimizer.Solver.Crafty;
public static class SolverUtils public static class SolverUtils
@@ -73,7 +73,7 @@ public static class SolverUtils
var actions = new List<ActionType>(); var actions = new List<ActionType>();
while (node.Children.Count != 0) while (node.Children.Count != 0)
{ {
node = node.ChildAt(ChildMaxScore(ref node.ChildScores)); node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!;
if (node.State.Action != null) if (node.State.Action != null)
actions.Add(node.State.Action.Value); actions.Add(node.State.Action.Value);
@@ -82,6 +82,23 @@ public static class SolverUtils
return (actions, node.State); return (actions, node.State);
} }
// Calculates the best child node to explore next
// Exploitation: ((1 - w) * (s / v)) + (w * m)
// Exploration: sqrt(c * ln(V) / v)
// w = maxScoreWeightingConstant
// s = score sum
// m = max score
// v = visits
// V = parentVisits
// c = explorationConstant
// Somewhat based off of https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
// Here, w_i = (1-w)*score sum
// n_i = visits
// max score is tacked onto it
// N_i = parent visits
// c = exploration constant (but crafty places it inside the sqrt..?)
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
public static (int arrayIdx, int subIdx) EvalBestChild<S>(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer scores) where S : ISolver public static (int arrayIdx, int subIdx) EvalBestChild<S>(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer scores) where S : ISolver
@@ -135,8 +152,8 @@ public static class SolverUtils
var currentActions = expandedNode.State.AvailableActions; var currentActions = expandedNode.State.AvailableActions;
byte actionCount = 0; byte actionCount = 0;
Span<ActionType> actions = stackalloc ActionType[config.MaxStepCount - currentState.ActionCount]; Span<ActionType> actions = stackalloc ActionType[Math.Min(config.MaxStepCount - currentState.ActionCount, config.MaxRolloutStepCount)];
while (true) while (actionCount < actions.Length)
{ {
if (SimulationNode.GetCompletionState(currentCompletionState, currentActions) != CompletionState.Incomplete) if (SimulationNode.GetCompletionState(currentCompletionState, currentActions) != CompletionState.Incomplete)
break; break;
@@ -197,6 +214,50 @@ public static class SolverUtils
)); ));
} }
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseForked<S>(SolverConfig config, int forkCount, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver =>
SearchStepwiseForked<S>(config, forkCount, new SimulationState(input), actionCallback, token);
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseForked<S>(SolverConfig config, int forkCount, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver
{
var actions = new List<ActionType>();
var sim = new Simulator(state, config.MaxStepCount);
while (!sim.IsComplete)
{
if (token.IsCancellationRequested)
break;
var tasks = new Task<(float score, List<ActionType> actions, SimulationState state)>[forkCount];
for (var i = 0; i < forkCount; ++i)
tasks[i] = Task.Run(() =>
{
var rootNode = CreateRootNode(config, state, true);
RootScores rootScores = new();
S.Search(ref config, rootScores, rootNode, token);
var (solution_actions, solution_node) = Solution(rootNode);
return (rootScores.MaxScore, solution_actions, solution_node.State);
}, token);
Task.WaitAll(tasks, CancellationToken.None);
var (score, solution_actions, solution_state) = tasks.Select(t => t.Result).MaxBy(r => r.score);
if (score >= 1.0)
{
actions.AddRange(solution_actions);
return (actions, solution_state);
}
var chosen_action = solution_actions[0];
(_, state) = sim.Execute(state, chosen_action);
actions.Add(chosen_action);
actionCallback?.Invoke(chosen_action);
}
return (actions, state);
}
public static (List<ActionType> Actions, SimulationState State) SearchStepwise<S>(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver => public static (List<ActionType> Actions, SimulationState State) SearchStepwise<S>(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver =>
SearchStepwise<S>(config, new SimulationState(input), actionCallback, token); SearchStepwise<S>(config, new SimulationState(input), actionCallback, token);