Fix concurrency, add forked simulations
This commit is contained in:
@@ -46,14 +46,16 @@ internal static class Program
|
|||||||
|
|
||||||
var config = new SolverConfig()
|
var config = new SolverConfig()
|
||||||
{
|
{
|
||||||
Iterations = 1_000_000,
|
Iterations = 100_000,
|
||||||
ThreadCount = 8,
|
ThreadCount = 8,
|
||||||
};
|
};
|
||||||
|
|
||||||
Debugger.Break();
|
Debugger.Break();
|
||||||
var s = Stopwatch.StartNew();
|
var s = Stopwatch.StartNew();
|
||||||
if (true)
|
if (true) {
|
||||||
_ = SolverUtils.SearchStepwise<SolverSingle>(config, input, a => Console.WriteLine(a));
|
(_, var state) = SolverUtils.SearchStepwise<SolverSingle>(config, input, a => Console.WriteLine(a));
|
||||||
|
Console.WriteLine($"Qual: {state.Quality}/{state.Input.Recipe.MaxQuality}");
|
||||||
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
(var actions, _) = SolverUtils.SearchOneshot<SolverConcurrent>(config, input);
|
(var actions, _) = SolverUtils.SearchOneshot<SolverConcurrent>(config, input);
|
||||||
|
|||||||
@@ -20,15 +20,15 @@ public sealed class ArenaNode<T> where T : struct
|
|||||||
Parent = parent;
|
Parent = parent;
|
||||||
}
|
}
|
||||||
|
|
||||||
public ArenaNode<T> ChildAt((int arrayIdx, int subIdx) at) =>
|
public ArenaNode<T>? ChildAt((int arrayIdx, int subIdx) at) =>
|
||||||
Children.Data[at.arrayIdx][at.subIdx];
|
Children.Data?[at.arrayIdx]?[at.subIdx];
|
||||||
|
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
public ArenaNode<T> AddConcurrent(T state)
|
public ArenaNode<T> AddConcurrent(T state)
|
||||||
{
|
{
|
||||||
var node = new ArenaNode<T>(state, this);
|
var node = new ArenaNode<T>(state, this);
|
||||||
Children.AddConcurrent(node);
|
|
||||||
ChildScores.AddConcurrent();
|
ChildScores.AddConcurrent();
|
||||||
|
Children.AddConcurrent(node);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,8 +36,8 @@ public sealed class ArenaNode<T> where T : struct
|
|||||||
public ArenaNode<T> Add(T state)
|
public ArenaNode<T> Add(T state)
|
||||||
{
|
{
|
||||||
var node = new ArenaNode<T>(state, this);
|
var node = new ArenaNode<T>(state, this);
|
||||||
Children.Add(node);
|
|
||||||
ChildScores.Add();
|
ChildScores.Add();
|
||||||
|
Children.Add(node);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,6 +33,11 @@ public sealed class SolverConcurrent : ISolver
|
|||||||
{
|
{
|
||||||
nodeVisits = node.ChildScores.GetVisits(at.Value);
|
nodeVisits = node.ChildScores.GetVisits(at.Value);
|
||||||
node = node.ChildAt(at.Value);
|
node = node.ChildAt(at.Value);
|
||||||
|
if (node == null)
|
||||||
|
{
|
||||||
|
node = rootNode;
|
||||||
|
nodeVisits = rootNodeVisits;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ public readonly record struct SolverConfig
|
|||||||
public float MaxScoreWeightingConstant { get; init; }
|
public float MaxScoreWeightingConstant { get; init; }
|
||||||
public float ExplorationConstant { get; init; }
|
public float ExplorationConstant { get; init; }
|
||||||
public int MaxStepCount { get; init; }
|
public int MaxStepCount { get; init; }
|
||||||
|
public int MaxRolloutStepCount { get; init; }
|
||||||
public int ThreadCount { get; init; }
|
public int ThreadCount { get; init; }
|
||||||
|
|
||||||
public SolverConfig()
|
public SolverConfig()
|
||||||
@@ -17,8 +18,9 @@ public readonly record struct SolverConfig
|
|||||||
Iterations = 300000;
|
Iterations = 300000;
|
||||||
ScoreStorageThreshold = 1f;
|
ScoreStorageThreshold = 1f;
|
||||||
MaxScoreWeightingConstant = 0.1f;
|
MaxScoreWeightingConstant = 0.1f;
|
||||||
ExplorationConstant = 4f;
|
ExplorationConstant = 2;
|
||||||
MaxStepCount = 25;
|
MaxStepCount = 25;
|
||||||
|
MaxRolloutStepCount = 99;
|
||||||
ThreadCount = Environment.ProcessorCount;
|
ThreadCount = Environment.ProcessorCount;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ public sealed class SolverSingle : ISolver
|
|||||||
// select the node with the highest score
|
// select the node with the highest score
|
||||||
var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores);
|
var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores);
|
||||||
nodeVisits = node.ChildScores.GetVisits(at);
|
nodeVisits = node.ChildScores.GetVisits(at);
|
||||||
node = node.ChildAt(at);
|
node = node.ChildAt(at)!;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
using Craftimizer.Simulator.Actions;
|
|
||||||
using Craftimizer.Simulator;
|
using Craftimizer.Simulator;
|
||||||
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
|
using Craftimizer.Simulator.Actions;
|
||||||
using System.Diagnostics.Contracts;
|
using System.Diagnostics.Contracts;
|
||||||
using System.Numerics;
|
using System.Numerics;
|
||||||
using System.Runtime.CompilerServices;
|
using System.Runtime.CompilerServices;
|
||||||
|
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
|
||||||
|
|
||||||
namespace Craftimizer.Solver.Crafty;
|
namespace Craftimizer.Solver.Crafty;
|
||||||
public static class SolverUtils
|
public static class SolverUtils
|
||||||
@@ -73,7 +73,7 @@ public static class SolverUtils
|
|||||||
var actions = new List<ActionType>();
|
var actions = new List<ActionType>();
|
||||||
while (node.Children.Count != 0)
|
while (node.Children.Count != 0)
|
||||||
{
|
{
|
||||||
node = node.ChildAt(ChildMaxScore(ref node.ChildScores));
|
node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!;
|
||||||
|
|
||||||
if (node.State.Action != null)
|
if (node.State.Action != null)
|
||||||
actions.Add(node.State.Action.Value);
|
actions.Add(node.State.Action.Value);
|
||||||
@@ -82,6 +82,23 @@ public static class SolverUtils
|
|||||||
return (actions, node.State);
|
return (actions, node.State);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Calculates the best child node to explore next
|
||||||
|
// Exploitation: ((1 - w) * (s / v)) + (w * m)
|
||||||
|
// Exploration: sqrt(c * ln(V) / v)
|
||||||
|
// w = maxScoreWeightingConstant
|
||||||
|
// s = score sum
|
||||||
|
// m = max score
|
||||||
|
// v = visits
|
||||||
|
// V = parentVisits
|
||||||
|
// c = explorationConstant
|
||||||
|
|
||||||
|
// Somewhat based off of https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
|
||||||
|
// Here, w_i = (1-w)*score sum
|
||||||
|
// n_i = visits
|
||||||
|
// max score is tacked onto it
|
||||||
|
// N_i = parent visits
|
||||||
|
// c = exploration constant (but crafty places it inside the sqrt..?)
|
||||||
|
|
||||||
[Pure]
|
[Pure]
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
|
||||||
public static (int arrayIdx, int subIdx) EvalBestChild<S>(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer scores) where S : ISolver
|
public static (int arrayIdx, int subIdx) EvalBestChild<S>(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer scores) where S : ISolver
|
||||||
@@ -135,8 +152,8 @@ public static class SolverUtils
|
|||||||
var currentActions = expandedNode.State.AvailableActions;
|
var currentActions = expandedNode.State.AvailableActions;
|
||||||
|
|
||||||
byte actionCount = 0;
|
byte actionCount = 0;
|
||||||
Span<ActionType> actions = stackalloc ActionType[config.MaxStepCount - currentState.ActionCount];
|
Span<ActionType> actions = stackalloc ActionType[Math.Min(config.MaxStepCount - currentState.ActionCount, config.MaxRolloutStepCount)];
|
||||||
while (true)
|
while (actionCount < actions.Length)
|
||||||
{
|
{
|
||||||
if (SimulationNode.GetCompletionState(currentCompletionState, currentActions) != CompletionState.Incomplete)
|
if (SimulationNode.GetCompletionState(currentCompletionState, currentActions) != CompletionState.Incomplete)
|
||||||
break;
|
break;
|
||||||
@@ -197,6 +214,50 @@ public static class SolverUtils
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseForked<S>(SolverConfig config, int forkCount, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver =>
|
||||||
|
SearchStepwiseForked<S>(config, forkCount, new SimulationState(input), actionCallback, token);
|
||||||
|
|
||||||
|
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseForked<S>(SolverConfig config, int forkCount, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver
|
||||||
|
{
|
||||||
|
var actions = new List<ActionType>();
|
||||||
|
var sim = new Simulator(state, config.MaxStepCount);
|
||||||
|
while (!sim.IsComplete)
|
||||||
|
{
|
||||||
|
if (token.IsCancellationRequested)
|
||||||
|
break;
|
||||||
|
|
||||||
|
var tasks = new Task<(float score, List<ActionType> actions, SimulationState state)>[forkCount];
|
||||||
|
for (var i = 0; i < forkCount; ++i)
|
||||||
|
tasks[i] = Task.Run(() =>
|
||||||
|
{
|
||||||
|
var rootNode = CreateRootNode(config, state, true);
|
||||||
|
RootScores rootScores = new();
|
||||||
|
|
||||||
|
S.Search(ref config, rootScores, rootNode, token);
|
||||||
|
var (solution_actions, solution_node) = Solution(rootNode);
|
||||||
|
|
||||||
|
return (rootScores.MaxScore, solution_actions, solution_node.State);
|
||||||
|
}, token);
|
||||||
|
Task.WaitAll(tasks, CancellationToken.None);
|
||||||
|
|
||||||
|
var (score, solution_actions, solution_state) = tasks.Select(t => t.Result).MaxBy(r => r.score);
|
||||||
|
|
||||||
|
if (score >= 1.0)
|
||||||
|
{
|
||||||
|
actions.AddRange(solution_actions);
|
||||||
|
return (actions, solution_state);
|
||||||
|
}
|
||||||
|
|
||||||
|
var chosen_action = solution_actions[0];
|
||||||
|
(_, state) = sim.Execute(state, chosen_action);
|
||||||
|
actions.Add(chosen_action);
|
||||||
|
|
||||||
|
actionCallback?.Invoke(chosen_action);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (actions, state);
|
||||||
|
}
|
||||||
|
|
||||||
public static (List<ActionType> Actions, SimulationState State) SearchStepwise<S>(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver =>
|
public static (List<ActionType> Actions, SimulationState State) SearchStepwise<S>(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver =>
|
||||||
SearchStepwise<S>(config, new SimulationState(input), actionCallback, token);
|
SearchStepwise<S>(config, new SimulationState(input), actionCallback, token);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user