Stop using Arena type since everything can be GC'd easily

This commit is contained in:
Asriel Camora
2023-06-21 09:12:45 -07:00
parent 5189ef5d74
commit 270963dd4a
5 changed files with 97 additions and 114 deletions
+1 -1
View File
@@ -10,7 +10,7 @@ internal static class Program
{ {
private static void Main() private static void Main()
{ {
//TypeLayout.PrintLayout<Arena<SimulationNode>.Node>(true); //TypeLayout.PrintLayout<ArenaNode<SimulationNode>>(true);
//return; //return;
var input = new SimulationInput( var input = new SimulationInput(
-39
View File
@@ -1,39 +0,0 @@
using System.Runtime.CompilerServices;
namespace Craftimizer.Solver.Crafty;
public class Arena<T> where T : struct
{
public readonly struct Node
{
public readonly T State;
public readonly List<int> Children;
public readonly int Parent;
public Node(T state, int parent)
{
State = state;
Children = new();
Parent = parent;
}
}
private readonly List<Node> nodes = new();
public Arena(T initialState = default)
{
nodes.Add(new(initialState, -1));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int Insert(int parentIndex, T state)
{
var index = nodes.Count;
nodes.Add(new(state, parentIndex));
nodes[parentIndex].Children.Add(index);
return index;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Node Get(int index) => nodes[index];
}
+25
View File
@@ -0,0 +1,25 @@
using System.Runtime.CompilerServices;
namespace Craftimizer.Solver.Crafty;
public class ArenaNode<T> where T : struct
{
public readonly T State;
public readonly List<ArenaNode<T>> Children;
public readonly ArenaNode<T>? Parent;
public ArenaNode(T state, ArenaNode<T>? parent = null)
{
State = state;
Children = new();
Parent = parent;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ArenaNode<T> Add(T state)
{
var node = new ArenaNode<T>(state, this);
Children.Add(node);
return node;
}
}
+13 -5
View File
@@ -5,18 +5,26 @@ namespace Craftimizer.Solver.Crafty;
public readonly struct SimulationNode public readonly struct SimulationNode
{ {
public SimulationState State { get; init; } public readonly SimulationState State;
public ActionType? Action { get; init; } public readonly ActionType? Action;
public CompletionState SimulationCompletionState { get; init; } public readonly CompletionState SimulationCompletionState;
public readonly NodeData Data;
public CompletionState CompletionState => public CompletionState CompletionState =>
Data.AvailableActions.Count == 0 && SimulationCompletionState == CompletionState.Incomplete ? Data.AvailableActions.Count == 0 && SimulationCompletionState == CompletionState.Incomplete ?
CompletionState.NoMoreActions : CompletionState.NoMoreActions :
SimulationCompletionState; SimulationCompletionState;
public NodeData Data { get; init; }
public bool IsComplete => CompletionState != CompletionState.Incomplete; public bool IsComplete => CompletionState != CompletionState.Incomplete;
public SimulationNode(SimulationState state, ActionType? action, CompletionState completionState, NodeData data)
{
State = state;
Action = action;
SimulationCompletionState = completionState;
Data = data;
}
public float? CalculateScore() public float? CalculateScore()
{ {
if (CompletionState != CompletionState.ProgressComplete) if (CompletionState != CompletionState.ProgressComplete)
+58 -69
View File
@@ -3,6 +3,7 @@ using Craftimizer.Simulator.Actions;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
namespace Craftimizer.Solver.Crafty; namespace Craftimizer.Solver.Crafty;
@@ -10,11 +11,11 @@ namespace Craftimizer.Solver.Crafty;
public class Solver public class Solver
{ {
public Simulator Simulator; public Simulator Simulator;
public Arena<SimulationNode> Tree; public Node RootNode;
public Random Random => Simulator.Input.Random; // public Random Random => Simulator.Input.Random;
public const int Iterations = 50000; public const int Iterations = 30000;
public const float ScoreStorageThreshold = 1f; public const float ScoreStorageThreshold = 1f;
public const float MaxScoreWeightingConstant = 0.1f; public const float MaxScoreWeightingConstant = 0.1f;
public const float ExplorationConstant = 4f; public const float ExplorationConstant = 4f;
@@ -23,13 +24,12 @@ public class Solver
public Solver(SimulationState state, bool strict) public Solver(SimulationState state, bool strict)
{ {
Simulator = new(state); Simulator = new(state);
Tree = new(new() RootNode = new(new(
{ state,
State = state, null,
Action = null, Simulator.CompletionState,
SimulationCompletionState = Simulator.CompletionState, new() { AvailableActions = Simulator.AvailableActionsHeuristic(strict) }
Data = new() { AvailableActions = Simulator.AvailableActionsHeuristic(strict) } ));
});
} }
public Solver(SimulationInput input, bool strict) : this(new SimulationState(input), strict) public Solver(SimulationInput input, bool strict) : this(new SimulationState(input), strict)
@@ -39,37 +39,34 @@ public class Solver
private SimulationNode Execute(SimulationState state, ActionType action, bool strict) private SimulationNode Execute(SimulationState state, ActionType action, bool strict)
{ {
(_, var newState) = Simulator.Execute(state, action); (_, var newState) = Simulator.Execute(state, action);
return new() return new(
{ newState,
State = newState, action,
Action = action, Simulator.CompletionState,
SimulationCompletionState = Simulator.CompletionState, new() { AvailableActions = Simulator.AvailableActionsHeuristic(strict) }
Data = new() { AvailableActions = Simulator.AvailableActionsHeuristic(strict) } );
};
} }
public (int Index, CompletionState State) ExecuteActions(int startIndex, ReadOnlySpan<ActionType> actions, bool strict = false) public (Node EndNode, CompletionState State) ExecuteActions(Node startNode, ReadOnlySpan<ActionType> actions, bool strict = false)
{ {
var currentIndex = startIndex;
foreach (var action in actions) foreach (var action in actions)
{ {
var node = Tree.Get(currentIndex).State; var state = startNode.State;
if (node.IsComplete) if (state.IsComplete)
return (currentIndex, node.CompletionState); return (startNode, state.CompletionState);
if (!node.Data.AvailableActions.HasAction(action)) if (!state.Data.AvailableActions.HasAction(action))
return (currentIndex, CompletionState.InvalidAction); return (startNode, CompletionState.InvalidAction);
node.Data.AvailableActions.RemoveAction(action); state.Data.AvailableActions.RemoveAction(action);
currentIndex = Tree.Insert(currentIndex, Execute(node.State, action, strict)); startNode = startNode.Add(Execute(state.State, action, strict));
} }
var currentNode = Tree.Get(currentIndex).State; return (startNode, startNode.State.CompletionState);
return (currentIndex, currentNode.CompletionState);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int RustMaxBy(ReadOnlySpan<int> source, Func<int, float> into) private static T RustMaxBy<T>(ReadOnlySpan<T> source, Func<T, float> into)
{ {
var max = 0; var max = 0;
var maxV = into(source[0]); var maxV = into(source[0]);
@@ -97,7 +94,7 @@ public class Solver
(length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1); (length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1);
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private int EvalBestChild(float parentVisits, ReadOnlySpan<int> children) private Node EvalBestChild(float parentVisits, ReadOnlySpan<Node> children)
{ {
var length = children.Length; var length = children.Length;
@@ -120,7 +117,7 @@ public class Solver
for (var j = 0; j < iterCount; ++j) for (var j = 0; j < iterCount; ++j)
{ {
var node = Tree.Get(children[i + j]).State.Data.Scores; var node = children[i + j].State.Data.Scores;
scoreSums[j] = node.ScoreSum; scoreSums[j] = node.ScoreSum;
visits[j] = node.Visits; visits[j] = node.Visits;
maxScores[j] = node.MaxScore; maxScores[j] = node.MaxScore;
@@ -140,46 +137,42 @@ public class Solver
return children[max]; return children[max];
} }
public int Select(int selectedIndex) public Node Select(Node selectedNode)
{ {
while (true) while (true)
{ {
var selectedNode = Tree.Get(selectedIndex);
var expandable = selectedNode.State.Data.AvailableActions.Count != 0; var expandable = selectedNode.State.Data.AvailableActions.Count != 0;
var likelyTerminal = selectedNode.Children.Count == 0; var likelyTerminal = selectedNode.Children.Count == 0;
if (expandable || likelyTerminal) if (expandable || likelyTerminal)
{ return selectedNode;
return selectedIndex;
}
// select the node with the highest score // select the node with the highest score
selectedIndex = EvalBestChild(selectedNode.State.Data.Scores.Visits, CollectionsMarshal.AsSpan(selectedNode.Children)); selectedNode = EvalBestChild(selectedNode.State.Data.Scores.Visits, CollectionsMarshal.AsSpan(selectedNode.Children));
} }
} }
public (int Index, CompletionState State, float Score) ExpandAndRollout(int initialIndex) public (Node ExpandedNode, CompletionState State, float Score) ExpandAndRollout(Node initialNode)
{ {
var initialState = initialNode.State;
// expand once // expand once
var initialNode = Tree.Get(initialIndex).State; if (initialState.IsComplete)
if (initialNode.IsComplete) return (initialNode, initialState.CompletionState, initialState.CalculateScore() ?? 0);
return (initialIndex, initialNode.CompletionState, initialNode.CalculateScore() ?? 0);
var randomIdx = Random.Next(initialNode.Data.AvailableActions.Count); var randomIdx = 0;// Random.Next(initialState.Data.AvailableActions.Count);
var randomAction = initialNode.Data.AvailableActions.ElementAt(randomIdx); var randomAction = initialState.Data.AvailableActions.ElementAt(randomIdx);
initialNode.Data.AvailableActions.RemoveAction(randomAction); initialState.Data.AvailableActions.RemoveAction(randomAction);
var expandedState = Execute(initialNode.State, randomAction, true); var expandedState = Execute(initialState.State, randomAction, true);
var expandedIndex = Tree.Insert(initialIndex, expandedState); var expandedNode = initialNode.Add(expandedState);
// playout to a terminal state // playout to a terminal state
var currentState = Tree.Get(expandedIndex).State; var currentState = expandedNode.State;
byte actionCount = 0; byte actionCount = 0;
Span<ActionType> actions = stackalloc ActionType[MaxStepCount]; Span<ActionType> actions = stackalloc ActionType[MaxStepCount];
while (true) while (true)
{ {
if (currentState.IsComplete) if (currentState.IsComplete)
break; break;
randomIdx = Random.Next(currentState.Data.AvailableActions.Count); randomIdx = 0;// Random.Next(currentState.Data.AvailableActions.Count);
randomAction = currentState.Data.AvailableActions.ElementAt(randomIdx); randomAction = currentState.Data.AvailableActions.ElementAt(randomIdx);
actions[actionCount++] = randomAction; actions[actionCount++] = randomAction;
currentState = Execute(currentState.State, randomAction, true); currentState = Execute(currentState.State, randomAction, true);
@@ -189,50 +182,46 @@ public class Solver
var score = currentState.CalculateScore() ?? 0; var score = currentState.CalculateScore() ?? 0;
if (currentState.CompletionState == CompletionState.ProgressComplete) if (currentState.CompletionState == CompletionState.ProgressComplete)
{ {
if (score >= ScoreStorageThreshold && score >= Tree.Get(0).State.Data.Scores.MaxScore) if (score >= ScoreStorageThreshold && score >= RootNode.State.Data.Scores.MaxScore)
{ {
Console.WriteLine("DONE!"); (var terminalNode, _) = ExecuteActions(expandedNode, actions[..actionCount], true);
(var terminalIndex, _) = ExecuteActions(expandedIndex, actions[..actionCount], true); return (terminalNode, currentState.CompletionState, score);
return (terminalIndex, currentState.CompletionState, score);
} }
} }
return (expandedIndex, currentState.CompletionState, score); return (expandedNode, currentState.CompletionState, score);
} }
public void Backpropagate(int startIndex, int targetIndex, float score) public static void Backpropagate(Node startNode, Node targetNode, float score)
{ {
var currentIndex = startIndex;
while (true) while (true)
{ {
var currentNode = Tree.Get(currentIndex); startNode.State.Data.Scores.Visit(score);
currentNode.State.Data.Scores.Visit(score);
if (currentIndex == targetIndex) if (startNode == targetNode)
break; break;
currentIndex = currentNode.Parent; startNode = startNode.Parent!;
} }
} }
public void Search(int startIndex) public void Search(Node startNode)
{ {
for (var i = 0; i < Iterations; i++) for (var i = 0; i < Iterations; i++)
{ {
var selectedIndex = Select(startIndex); var selectedNode = Select(startNode);
var (endIndex, _, score) = ExpandAndRollout(selectedIndex); var (endNode, _, score) = ExpandAndRollout(selectedNode);
Backpropagate(endIndex, startIndex, score); Backpropagate(endNode, startNode, score);
} }
} }
public (List<ActionType> Actions, SimulationNode Node) Solution() public (List<ActionType> Actions, SimulationNode Node) Solution()
{ {
var actions = new List<ActionType>(); var actions = new List<ActionType>();
var node = Tree.Get(0); var node = RootNode;
while (node.Children.Count != 0) while (node.Children.Count != 0)
{ {
var next_index = RustMaxBy(CollectionsMarshal.AsSpan(node.Children), n => Tree.Get(n).State.Data.Scores.MaxScore); node = RustMaxBy<Node>(CollectionsMarshal.AsSpan(node.Children), n => n.State.Data.Scores.MaxScore);
node = Tree.Get(next_index);
if (node.State.Action != null) if (node.State.Action != null)
actions.Add(node.State.Action.Value); actions.Add(node.State.Action.Value);
} }
@@ -247,7 +236,7 @@ public class Solver
var solver = new Solver(state, true); var solver = new Solver(state, true);
while (!solver.Simulator.IsComplete) while (!solver.Simulator.IsComplete)
{ {
solver.Search(0); solver.Search(solver.RootNode);
var (solution_actions, solution_node) = solver.Solution(); var (solution_actions, solution_node) = solver.Solution();
if (solution_node.Data.Scores.MaxScore >= 1.0) if (solution_node.Data.Scores.MaxScore >= 1.0)
@@ -271,7 +260,7 @@ public class Solver
public static (List<ActionType> Actions, SimulationState State) SearchOneshot(SimulationInput input) public static (List<ActionType> Actions, SimulationState State) SearchOneshot(SimulationInput input)
{ {
var solver = new Solver(input, false); var solver = new Solver(input, false);
solver.Search(0); solver.Search(solver.RootNode);
var (solution_actions, solution_node) = solver.Solution(); var (solution_actions, solution_node) = solver.Solution();
return (solution_actions, solution_node.State); return (solution_actions, solution_node.State);
} }