Some multithreaded fixes, deadlocks due to List reads during adds
This commit is contained in:
+31
-20
@@ -19,7 +19,7 @@ public class Solver
|
||||
public Solver(SolverConfig config, SimulationState state, bool strict)
|
||||
{
|
||||
Config = config;
|
||||
Simulator sim = new(state, config.MaxStepCount);
|
||||
var sim = new Simulator(state, config.MaxStepCount);
|
||||
RootNode = new(new(
|
||||
state,
|
||||
null,
|
||||
@@ -98,7 +98,12 @@ public class Solver
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private Node EvalBestChild(float parentVisits, ReadOnlySpan<Node> children)
|
||||
{
|
||||
var length = children.Length;
|
||||
if (parentVisits == 0)
|
||||
{
|
||||
Console.WriteLine("no visits");
|
||||
return null;
|
||||
}
|
||||
|
||||
var vecLength = Vector<float>.Count;
|
||||
|
||||
var C = MathF.Sqrt(Config.ExplorationConstant * MathF.Log(parentVisits));
|
||||
@@ -152,30 +157,28 @@ public class Solver
|
||||
var node = RootNode;
|
||||
while (true)
|
||||
{
|
||||
if (!Monitor.TryEnter(node, 5))
|
||||
return Select();
|
||||
var expandable = node.State.AvailableActions.Count != 0;
|
||||
var expandable = !node.State.AvailableActions.IsEmpty;
|
||||
var likelyTerminal = node.Children.Count == 0;
|
||||
if (expandable || likelyTerminal)
|
||||
return node;
|
||||
|
||||
// select the node with the highest score
|
||||
var n = EvalBestChild(node.State.Scores.Visits, CollectionsMarshal.AsSpan(node.Children));
|
||||
Monitor.Exit(node);
|
||||
node = n;
|
||||
// if null (current node is invalid & not backpropagated just yet), try again from root
|
||||
node = EvalBestChild(node.State.Scores.Visits, node.Children) ?? RootNode;
|
||||
}
|
||||
}
|
||||
|
||||
public (Node ExpandedNode, CompletionState State, float Score) ExpandAndRollout(Simulator simulator, Node initialNode)
|
||||
public (Node ExpandedNode, float Score)? ExpandAndRollout(Simulator simulator, Node initialNode)
|
||||
{
|
||||
ref var initialState = ref initialNode.State;
|
||||
// expand once
|
||||
if (initialState.IsComplete)
|
||||
return (initialNode, initialState.CompletionState, initialState.CalculateScore(Config.MaxStepCount) ?? 0);
|
||||
return (initialNode, initialState.CalculateScore(Config.MaxStepCount) ?? 0);
|
||||
|
||||
var randomAction = initialState.AvailableActions.SelectRandom(Random);
|
||||
initialState.AvailableActions.RemoveAction(randomAction);
|
||||
var expandedNode = initialNode.Add(Execute(simulator, initialState.State, randomAction, true));
|
||||
var poppedAction = initialState.AvailableActions.PopRandom(Random);
|
||||
if (!poppedAction.HasValue)
|
||||
return null;
|
||||
var expandedNode = initialNode.Add(Execute(simulator, initialState.State, poppedAction.Value, true));
|
||||
|
||||
// playout to a terminal state
|
||||
var currentState = expandedNode.State.State;
|
||||
@@ -188,9 +191,9 @@ public class Solver
|
||||
{
|
||||
if (SimulationNode.GetCompletionState(currentCompletionState, currentActions) != CompletionState.Incomplete)
|
||||
break;
|
||||
randomAction = currentActions.SelectRandom(Random);
|
||||
actions[actionCount++] = randomAction;
|
||||
(_, currentState) = simulator.Execute(currentState, randomAction);
|
||||
var nextAction = currentActions.SelectRandom(Random);
|
||||
actions[actionCount++] = nextAction;
|
||||
(_, currentState) = simulator.Execute(currentState, nextAction);
|
||||
currentCompletionState = simulator.CompletionState;
|
||||
currentActions = simulator.AvailableActionsHeuristic(true);
|
||||
}
|
||||
@@ -202,10 +205,10 @@ public class Solver
|
||||
if (score >= Config.ScoreStorageThreshold && score >= RootNode.State.Scores.MaxScore)
|
||||
{
|
||||
(var terminalNode, _) = ExecuteActions(simulator, expandedNode, actions[..actionCount], true);
|
||||
return (terminalNode, currentCompletionState, score);
|
||||
return (terminalNode, score);
|
||||
}
|
||||
}
|
||||
return (expandedNode, currentCompletionState, score);
|
||||
return (expandedNode, score);
|
||||
}
|
||||
|
||||
public void Backpropagate(Node startNode, float score)
|
||||
@@ -230,9 +233,17 @@ public class Solver
|
||||
break;
|
||||
|
||||
var selectedNode = Select();
|
||||
var (endNode, _, score) = ExpandAndRollout(simulator, selectedNode);
|
||||
Monitor.Exit(selectedNode);
|
||||
var rolledOut = ExpandAndRollout(simulator, selectedNode);
|
||||
//Monitor.Exit(selectedNode);
|
||||
if (!rolledOut.HasValue)
|
||||
{
|
||||
Console.WriteLine("Retry");
|
||||
// Retry, count this iteration as moot
|
||||
i--;
|
||||
continue;
|
||||
}
|
||||
|
||||
var (endNode, score) = rolledOut.Value;
|
||||
Backpropagate(endNode, score);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user