Stop using Arena type since everything can be GC'd easily
This commit is contained in:
@@ -1,39 +0,0 @@
|
||||
using System.Runtime.CompilerServices;
|
||||
|
||||
namespace Craftimizer.Solver.Crafty;
|
||||
|
||||
public class Arena<T> where T : struct
|
||||
{
|
||||
public readonly struct Node
|
||||
{
|
||||
public readonly T State;
|
||||
public readonly List<int> Children;
|
||||
public readonly int Parent;
|
||||
|
||||
public Node(T state, int parent)
|
||||
{
|
||||
State = state;
|
||||
Children = new();
|
||||
Parent = parent;
|
||||
}
|
||||
}
|
||||
|
||||
private readonly List<Node> nodes = new();
|
||||
|
||||
public Arena(T initialState = default)
|
||||
{
|
||||
nodes.Add(new(initialState, -1));
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public int Insert(int parentIndex, T state)
|
||||
{
|
||||
var index = nodes.Count;
|
||||
nodes.Add(new(state, parentIndex));
|
||||
nodes[parentIndex].Children.Add(index);
|
||||
return index;
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public Node Get(int index) => nodes[index];
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
using System.Runtime.CompilerServices;
|
||||
|
||||
namespace Craftimizer.Solver.Crafty;
|
||||
|
||||
public class ArenaNode<T> where T : struct
|
||||
{
|
||||
public readonly T State;
|
||||
public readonly List<ArenaNode<T>> Children;
|
||||
public readonly ArenaNode<T>? Parent;
|
||||
|
||||
public ArenaNode(T state, ArenaNode<T>? parent = null)
|
||||
{
|
||||
State = state;
|
||||
Children = new();
|
||||
Parent = parent;
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public ArenaNode<T> Add(T state)
|
||||
{
|
||||
var node = new ArenaNode<T>(state, this);
|
||||
Children.Add(node);
|
||||
return node;
|
||||
}
|
||||
}
|
||||
@@ -5,18 +5,26 @@ namespace Craftimizer.Solver.Crafty;
|
||||
|
||||
public readonly struct SimulationNode
|
||||
{
|
||||
public SimulationState State { get; init; }
|
||||
public ActionType? Action { get; init; }
|
||||
public CompletionState SimulationCompletionState { get; init; }
|
||||
public readonly SimulationState State;
|
||||
public readonly ActionType? Action;
|
||||
public readonly CompletionState SimulationCompletionState;
|
||||
public readonly NodeData Data;
|
||||
|
||||
public CompletionState CompletionState =>
|
||||
Data.AvailableActions.Count == 0 && SimulationCompletionState == CompletionState.Incomplete ?
|
||||
CompletionState.NoMoreActions :
|
||||
SimulationCompletionState;
|
||||
|
||||
public NodeData Data { get; init; }
|
||||
|
||||
public bool IsComplete => CompletionState != CompletionState.Incomplete;
|
||||
|
||||
public SimulationNode(SimulationState state, ActionType? action, CompletionState completionState, NodeData data)
|
||||
{
|
||||
State = state;
|
||||
Action = action;
|
||||
SimulationCompletionState = completionState;
|
||||
Data = data;
|
||||
}
|
||||
|
||||
public float? CalculateScore()
|
||||
{
|
||||
if (CompletionState != CompletionState.ProgressComplete)
|
||||
|
||||
+58
-69
@@ -3,6 +3,7 @@ using Craftimizer.Simulator.Actions;
|
||||
using System.Numerics;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
|
||||
|
||||
namespace Craftimizer.Solver.Crafty;
|
||||
|
||||
@@ -10,11 +11,11 @@ namespace Craftimizer.Solver.Crafty;
|
||||
public class Solver
|
||||
{
|
||||
public Simulator Simulator;
|
||||
public Arena<SimulationNode> Tree;
|
||||
public Node RootNode;
|
||||
|
||||
public Random Random => Simulator.Input.Random;
|
||||
// public Random Random => Simulator.Input.Random;
|
||||
|
||||
public const int Iterations = 50000;
|
||||
public const int Iterations = 30000;
|
||||
public const float ScoreStorageThreshold = 1f;
|
||||
public const float MaxScoreWeightingConstant = 0.1f;
|
||||
public const float ExplorationConstant = 4f;
|
||||
@@ -23,13 +24,12 @@ public class Solver
|
||||
public Solver(SimulationState state, bool strict)
|
||||
{
|
||||
Simulator = new(state);
|
||||
Tree = new(new()
|
||||
{
|
||||
State = state,
|
||||
Action = null,
|
||||
SimulationCompletionState = Simulator.CompletionState,
|
||||
Data = new() { AvailableActions = Simulator.AvailableActionsHeuristic(strict) }
|
||||
});
|
||||
RootNode = new(new(
|
||||
state,
|
||||
null,
|
||||
Simulator.CompletionState,
|
||||
new() { AvailableActions = Simulator.AvailableActionsHeuristic(strict) }
|
||||
));
|
||||
}
|
||||
|
||||
public Solver(SimulationInput input, bool strict) : this(new SimulationState(input), strict)
|
||||
@@ -39,37 +39,34 @@ public class Solver
|
||||
private SimulationNode Execute(SimulationState state, ActionType action, bool strict)
|
||||
{
|
||||
(_, var newState) = Simulator.Execute(state, action);
|
||||
return new()
|
||||
{
|
||||
State = newState,
|
||||
Action = action,
|
||||
SimulationCompletionState = Simulator.CompletionState,
|
||||
Data = new() { AvailableActions = Simulator.AvailableActionsHeuristic(strict) }
|
||||
};
|
||||
return new(
|
||||
newState,
|
||||
action,
|
||||
Simulator.CompletionState,
|
||||
new() { AvailableActions = Simulator.AvailableActionsHeuristic(strict) }
|
||||
);
|
||||
}
|
||||
|
||||
public (int Index, CompletionState State) ExecuteActions(int startIndex, ReadOnlySpan<ActionType> actions, bool strict = false)
|
||||
public (Node EndNode, CompletionState State) ExecuteActions(Node startNode, ReadOnlySpan<ActionType> actions, bool strict = false)
|
||||
{
|
||||
var currentIndex = startIndex;
|
||||
foreach (var action in actions)
|
||||
{
|
||||
var node = Tree.Get(currentIndex).State;
|
||||
if (node.IsComplete)
|
||||
return (currentIndex, node.CompletionState);
|
||||
var state = startNode.State;
|
||||
if (state.IsComplete)
|
||||
return (startNode, state.CompletionState);
|
||||
|
||||
if (!node.Data.AvailableActions.HasAction(action))
|
||||
return (currentIndex, CompletionState.InvalidAction);
|
||||
node.Data.AvailableActions.RemoveAction(action);
|
||||
if (!state.Data.AvailableActions.HasAction(action))
|
||||
return (startNode, CompletionState.InvalidAction);
|
||||
state.Data.AvailableActions.RemoveAction(action);
|
||||
|
||||
currentIndex = Tree.Insert(currentIndex, Execute(node.State, action, strict));
|
||||
startNode = startNode.Add(Execute(state.State, action, strict));
|
||||
}
|
||||
|
||||
var currentNode = Tree.Get(currentIndex).State;
|
||||
return (currentIndex, currentNode.CompletionState);
|
||||
return (startNode, startNode.State.CompletionState);
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static int RustMaxBy(ReadOnlySpan<int> source, Func<int, float> into)
|
||||
private static T RustMaxBy<T>(ReadOnlySpan<T> source, Func<T, float> into)
|
||||
{
|
||||
var max = 0;
|
||||
var maxV = into(source[0]);
|
||||
@@ -97,7 +94,7 @@ public class Solver
|
||||
(length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1);
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private int EvalBestChild(float parentVisits, ReadOnlySpan<int> children)
|
||||
private Node EvalBestChild(float parentVisits, ReadOnlySpan<Node> children)
|
||||
{
|
||||
var length = children.Length;
|
||||
|
||||
@@ -120,7 +117,7 @@ public class Solver
|
||||
|
||||
for (var j = 0; j < iterCount; ++j)
|
||||
{
|
||||
var node = Tree.Get(children[i + j]).State.Data.Scores;
|
||||
var node = children[i + j].State.Data.Scores;
|
||||
scoreSums[j] = node.ScoreSum;
|
||||
visits[j] = node.Visits;
|
||||
maxScores[j] = node.MaxScore;
|
||||
@@ -140,46 +137,42 @@ public class Solver
|
||||
return children[max];
|
||||
}
|
||||
|
||||
public int Select(int selectedIndex)
|
||||
public Node Select(Node selectedNode)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
var selectedNode = Tree.Get(selectedIndex);
|
||||
|
||||
var expandable = selectedNode.State.Data.AvailableActions.Count != 0;
|
||||
var likelyTerminal = selectedNode.Children.Count == 0;
|
||||
if (expandable || likelyTerminal)
|
||||
{
|
||||
return selectedIndex;
|
||||
}
|
||||
return selectedNode;
|
||||
|
||||
// select the node with the highest score
|
||||
selectedIndex = EvalBestChild(selectedNode.State.Data.Scores.Visits, CollectionsMarshal.AsSpan(selectedNode.Children));
|
||||
selectedNode = EvalBestChild(selectedNode.State.Data.Scores.Visits, CollectionsMarshal.AsSpan(selectedNode.Children));
|
||||
}
|
||||
}
|
||||
|
||||
public (int Index, CompletionState State, float Score) ExpandAndRollout(int initialIndex)
|
||||
public (Node ExpandedNode, CompletionState State, float Score) ExpandAndRollout(Node initialNode)
|
||||
{
|
||||
var initialState = initialNode.State;
|
||||
// expand once
|
||||
var initialNode = Tree.Get(initialIndex).State;
|
||||
if (initialNode.IsComplete)
|
||||
return (initialIndex, initialNode.CompletionState, initialNode.CalculateScore() ?? 0);
|
||||
if (initialState.IsComplete)
|
||||
return (initialNode, initialState.CompletionState, initialState.CalculateScore() ?? 0);
|
||||
|
||||
var randomIdx = Random.Next(initialNode.Data.AvailableActions.Count);
|
||||
var randomAction = initialNode.Data.AvailableActions.ElementAt(randomIdx);
|
||||
initialNode.Data.AvailableActions.RemoveAction(randomAction);
|
||||
var expandedState = Execute(initialNode.State, randomAction, true);
|
||||
var expandedIndex = Tree.Insert(initialIndex, expandedState);
|
||||
var randomIdx = 0;// Random.Next(initialState.Data.AvailableActions.Count);
|
||||
var randomAction = initialState.Data.AvailableActions.ElementAt(randomIdx);
|
||||
initialState.Data.AvailableActions.RemoveAction(randomAction);
|
||||
var expandedState = Execute(initialState.State, randomAction, true);
|
||||
var expandedNode = initialNode.Add(expandedState);
|
||||
|
||||
// playout to a terminal state
|
||||
var currentState = Tree.Get(expandedIndex).State;
|
||||
var currentState = expandedNode.State;
|
||||
byte actionCount = 0;
|
||||
Span<ActionType> actions = stackalloc ActionType[MaxStepCount];
|
||||
while (true)
|
||||
{
|
||||
if (currentState.IsComplete)
|
||||
break;
|
||||
randomIdx = Random.Next(currentState.Data.AvailableActions.Count);
|
||||
randomIdx = 0;// Random.Next(currentState.Data.AvailableActions.Count);
|
||||
randomAction = currentState.Data.AvailableActions.ElementAt(randomIdx);
|
||||
actions[actionCount++] = randomAction;
|
||||
currentState = Execute(currentState.State, randomAction, true);
|
||||
@@ -189,50 +182,46 @@ public class Solver
|
||||
var score = currentState.CalculateScore() ?? 0;
|
||||
if (currentState.CompletionState == CompletionState.ProgressComplete)
|
||||
{
|
||||
if (score >= ScoreStorageThreshold && score >= Tree.Get(0).State.Data.Scores.MaxScore)
|
||||
if (score >= ScoreStorageThreshold && score >= RootNode.State.Data.Scores.MaxScore)
|
||||
{
|
||||
Console.WriteLine("DONE!");
|
||||
(var terminalIndex, _) = ExecuteActions(expandedIndex, actions[..actionCount], true);
|
||||
return (terminalIndex, currentState.CompletionState, score);
|
||||
(var terminalNode, _) = ExecuteActions(expandedNode, actions[..actionCount], true);
|
||||
return (terminalNode, currentState.CompletionState, score);
|
||||
}
|
||||
}
|
||||
return (expandedIndex, currentState.CompletionState, score);
|
||||
return (expandedNode, currentState.CompletionState, score);
|
||||
}
|
||||
|
||||
public void Backpropagate(int startIndex, int targetIndex, float score)
|
||||
public static void Backpropagate(Node startNode, Node targetNode, float score)
|
||||
{
|
||||
var currentIndex = startIndex;
|
||||
while (true)
|
||||
{
|
||||
var currentNode = Tree.Get(currentIndex);
|
||||
currentNode.State.Data.Scores.Visit(score);
|
||||
startNode.State.Data.Scores.Visit(score);
|
||||
|
||||
if (currentIndex == targetIndex)
|
||||
if (startNode == targetNode)
|
||||
break;
|
||||
|
||||
currentIndex = currentNode.Parent;
|
||||
startNode = startNode.Parent!;
|
||||
}
|
||||
}
|
||||
|
||||
public void Search(int startIndex)
|
||||
public void Search(Node startNode)
|
||||
{
|
||||
for (var i = 0; i < Iterations; i++)
|
||||
{
|
||||
var selectedIndex = Select(startIndex);
|
||||
var (endIndex, _, score) = ExpandAndRollout(selectedIndex);
|
||||
var selectedNode = Select(startNode);
|
||||
var (endNode, _, score) = ExpandAndRollout(selectedNode);
|
||||
|
||||
Backpropagate(endIndex, startIndex, score);
|
||||
Backpropagate(endNode, startNode, score);
|
||||
}
|
||||
}
|
||||
|
||||
public (List<ActionType> Actions, SimulationNode Node) Solution()
|
||||
{
|
||||
var actions = new List<ActionType>();
|
||||
var node = Tree.Get(0);
|
||||
var node = RootNode;
|
||||
while (node.Children.Count != 0)
|
||||
{
|
||||
var next_index = RustMaxBy(CollectionsMarshal.AsSpan(node.Children), n => Tree.Get(n).State.Data.Scores.MaxScore);
|
||||
node = Tree.Get(next_index);
|
||||
node = RustMaxBy<Node>(CollectionsMarshal.AsSpan(node.Children), n => n.State.Data.Scores.MaxScore);
|
||||
if (node.State.Action != null)
|
||||
actions.Add(node.State.Action.Value);
|
||||
}
|
||||
@@ -247,7 +236,7 @@ public class Solver
|
||||
var solver = new Solver(state, true);
|
||||
while (!solver.Simulator.IsComplete)
|
||||
{
|
||||
solver.Search(0);
|
||||
solver.Search(solver.RootNode);
|
||||
var (solution_actions, solution_node) = solver.Solution();
|
||||
|
||||
if (solution_node.Data.Scores.MaxScore >= 1.0)
|
||||
@@ -271,7 +260,7 @@ public class Solver
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchOneshot(SimulationInput input)
|
||||
{
|
||||
var solver = new Solver(input, false);
|
||||
solver.Search(0);
|
||||
solver.Search(solver.RootNode);
|
||||
var (solution_actions, solution_node) = solver.Solution();
|
||||
return (solution_actions, solution_node.State);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user