More minor optimizations

This commit is contained in:
Asriel Camora
2023-06-21 06:16:57 -07:00
parent fc0ffc11f3
commit 0b2b80d6b5
10 changed files with 76 additions and 64 deletions
+14 -20
View File
@@ -28,8 +28,7 @@ public class Solver
State = state,
Action = null,
SimulationCompletionState = Simulator.CompletionState,
AvailableActions = Simulator.AvailableActionsHeuristic(strict),
Scores = new()
Data = new() { AvailableActions = Simulator.AvailableActionsHeuristic(strict) }
});
}
@@ -45,8 +44,7 @@ public class Solver
State = newState,
Action = action,
SimulationCompletionState = Simulator.CompletionState,
AvailableActions = Simulator.AvailableActionsHeuristic(strict),
Scores = new()
Data = new() { AvailableActions = Simulator.AvailableActionsHeuristic(strict) }
};
}
@@ -59,9 +57,9 @@ public class Solver
if (node.IsComplete)
return (currentIndex, node.CompletionState);
if (!node.AvailableActions.HasAction(action))
if (!node.Data.AvailableActions.HasAction(action))
return (currentIndex, CompletionState.InvalidAction);
node.AvailableActions.RemoveAction(action);
node.Data.AvailableActions.RemoveAction(action);
currentIndex = Tree.Insert(currentIndex, Execute(node.State, action, strict));
}
@@ -122,7 +120,7 @@ public class Solver
for (var j = 0; j < iterCount; ++j)
{
var node = Tree.Get(children[i + j]).State.Scores;
var node = Tree.Get(children[i + j]).State.Data.Scores;
scoreSums[j] = node.ScoreSum;
visits[j] = node.Visits;
maxScores[j] = node.MaxScore;
@@ -148,7 +146,7 @@ public class Solver
{
var selectedNode = Tree.Get(selectedIndex);
var expandable = selectedNode.State.AvailableActions.Count != 0;
var expandable = selectedNode.State.Data.AvailableActions.Count != 0;
var likelyTerminal = selectedNode.Children.Count == 0;
if (expandable || likelyTerminal)
{
@@ -156,7 +154,7 @@ public class Solver
}
// select the node with the highest score
selectedIndex = EvalBestChild(selectedNode.State.Scores.Visits, selectedNode.Children);
selectedIndex = EvalBestChild(selectedNode.State.Data.Scores.Visits, selectedNode.Children);
}
}
@@ -167,8 +165,8 @@ public class Solver
if (initialNode.IsComplete)
return (initialIndex, initialNode.CompletionState, initialNode.CalculateScore() ?? 0);
var randomAction = initialNode.AvailableActions.ElementAt(0);
initialNode.AvailableActions.RemoveAction(randomAction);
var randomAction = initialNode.Data.AvailableActions.ElementAt(0);
initialNode.Data.AvailableActions.RemoveAction(randomAction);
var expandedState = Execute(initialNode.State, randomAction, true);
var expandedIndex = Tree.Insert(initialIndex, expandedState);
@@ -179,7 +177,7 @@ public class Solver
{
if (currentState.IsComplete)
break;
randomAction = currentState.AvailableActions.ElementAt(0);
randomAction = currentState.Data.AvailableActions.ElementAt(0);
actions.Add(randomAction);
currentState = Execute(currentState.State, randomAction, true);
}
@@ -188,7 +186,7 @@ public class Solver
var score = currentState.CalculateScore() ?? 0;
if (currentState.CompletionState == CompletionState.ProgressComplete)
{
if (score >= ScoreStorageThreshold && score >= Tree.Get(0).State.Scores.MaxScore)
if (score >= ScoreStorageThreshold && score >= Tree.Get(0).State.Data.Scores.MaxScore)
{
(var terminalIndex, _) = ExecuteActions(expandedIndex, actions, true);
return (terminalIndex, currentState.CompletionState, score);
@@ -203,11 +201,7 @@ public class Solver
while (true)
{
var currentNode = Tree.Get(currentIndex);
var currentScores = currentNode.State.Scores;
currentScores.Visits++;
currentScores.ScoreSum += score;
if (currentScores.MaxScore < score)
currentScores.MaxScore = score;
currentNode.State.Data.Scores.Visit(score);
if (currentIndex == targetIndex)
break;
@@ -233,7 +227,7 @@ public class Solver
var node = Tree.Get(0);
while (node.Children.Count != 0)
{
var next_index = RustMaxBy(node.Children, n => Tree.Get(n).State.Scores.MaxScore);
var next_index = RustMaxBy(node.Children, n => Tree.Get(n).State.Data.Scores.MaxScore);
node = Tree.Get(next_index);
if (node.State.Action != null)
actions.Add(node.State.Action.Value);
@@ -264,7 +258,7 @@ public class Solver
solver.Search(0);
var (solution_actions, solution_node) = solver.Solution();
if (solution_node.Scores.MaxScore >= 1.0)
if (solution_node.Data.Scores.MaxScore >= 1.0)
{
actions.AddRange(solution_actions);
return (actions, solution_node.State);