Some multithreaded fixes, deadlocks due to List reads during adds

This commit is contained in:
Asriel Camora
2023-07-04 17:19:55 +02:00
parent 75306ca020
commit 4d96fd173f
8 changed files with 90 additions and 29 deletions
+31 -20
View File
@@ -19,7 +19,7 @@ public class Solver
public Solver(SolverConfig config, SimulationState state, bool strict)
{
Config = config;
Simulator sim = new(state, config.MaxStepCount);
var sim = new Simulator(state, config.MaxStepCount);
RootNode = new(new(
state,
null,
@@ -98,7 +98,12 @@ public class Solver
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private Node EvalBestChild(float parentVisits, ReadOnlySpan<Node> children)
{
var length = children.Length;
if (parentVisits == 0)
{
Console.WriteLine("no visits");
return null;
}
var vecLength = Vector<float>.Count;
var C = MathF.Sqrt(Config.ExplorationConstant * MathF.Log(parentVisits));
@@ -152,30 +157,28 @@ public class Solver
var node = RootNode;
while (true)
{
if (!Monitor.TryEnter(node, 5))
return Select();
var expandable = node.State.AvailableActions.Count != 0;
var expandable = !node.State.AvailableActions.IsEmpty;
var likelyTerminal = node.Children.Count == 0;
if (expandable || likelyTerminal)
return node;
// select the node with the highest score
var n = EvalBestChild(node.State.Scores.Visits, CollectionsMarshal.AsSpan(node.Children));
Monitor.Exit(node);
node = n;
// if null (current node is invalid & not backpropagated just yet), try again from root
node = EvalBestChild(node.State.Scores.Visits, node.Children) ?? RootNode;
}
}
public (Node ExpandedNode, CompletionState State, float Score) ExpandAndRollout(Simulator simulator, Node initialNode)
public (Node ExpandedNode, float Score)? ExpandAndRollout(Simulator simulator, Node initialNode)
{
ref var initialState = ref initialNode.State;
// expand once
if (initialState.IsComplete)
return (initialNode, initialState.CompletionState, initialState.CalculateScore(Config.MaxStepCount) ?? 0);
return (initialNode, initialState.CalculateScore(Config.MaxStepCount) ?? 0);
var randomAction = initialState.AvailableActions.SelectRandom(Random);
initialState.AvailableActions.RemoveAction(randomAction);
var expandedNode = initialNode.Add(Execute(simulator, initialState.State, randomAction, true));
var poppedAction = initialState.AvailableActions.PopRandom(Random);
if (!poppedAction.HasValue)
return null;
var expandedNode = initialNode.Add(Execute(simulator, initialState.State, poppedAction.Value, true));
// playout to a terminal state
var currentState = expandedNode.State.State;
@@ -188,9 +191,9 @@ public class Solver
{
if (SimulationNode.GetCompletionState(currentCompletionState, currentActions) != CompletionState.Incomplete)
break;
randomAction = currentActions.SelectRandom(Random);
actions[actionCount++] = randomAction;
(_, currentState) = simulator.Execute(currentState, randomAction);
var nextAction = currentActions.SelectRandom(Random);
actions[actionCount++] = nextAction;
(_, currentState) = simulator.Execute(currentState, nextAction);
currentCompletionState = simulator.CompletionState;
currentActions = simulator.AvailableActionsHeuristic(true);
}
@@ -202,10 +205,10 @@ public class Solver
if (score >= Config.ScoreStorageThreshold && score >= RootNode.State.Scores.MaxScore)
{
(var terminalNode, _) = ExecuteActions(simulator, expandedNode, actions[..actionCount], true);
return (terminalNode, currentCompletionState, score);
return (terminalNode, score);
}
}
return (expandedNode, currentCompletionState, score);
return (expandedNode, score);
}
public void Backpropagate(Node startNode, float score)
@@ -230,9 +233,17 @@ public class Solver
break;
var selectedNode = Select();
var (endNode, _, score) = ExpandAndRollout(simulator, selectedNode);
Monitor.Exit(selectedNode);
var rolledOut = ExpandAndRollout(simulator, selectedNode);
//Monitor.Exit(selectedNode);
if (!rolledOut.HasValue)
{
Console.WriteLine("Retry");
// Retry, count this iteration as moot
i--;
continue;
}
var (endNode, score) = rolledOut.Value;
Backpropagate(endNode, score);
}
}