diff --git a/Benchmark/Program.cs b/Benchmark/Program.cs index dc352e6..a788cd8 100644 --- a/Benchmark/Program.cs +++ b/Benchmark/Program.cs @@ -10,7 +10,7 @@ internal static class Program { private static void Main() { - //TypeLayout.PrintLayout.Node>(true); + //TypeLayout.PrintLayout>(true); //return; var input = new SimulationInput( diff --git a/Solver/Crafty/Arena.cs b/Solver/Crafty/Arena.cs deleted file mode 100644 index 7c1927b..0000000 --- a/Solver/Crafty/Arena.cs +++ /dev/null @@ -1,39 +0,0 @@ -using System.Runtime.CompilerServices; - -namespace Craftimizer.Solver.Crafty; - -public class Arena where T : struct -{ - public readonly struct Node - { - public readonly T State; - public readonly List Children; - public readonly int Parent; - - public Node(T state, int parent) - { - State = state; - Children = new(); - Parent = parent; - } - } - - private readonly List nodes = new(); - - public Arena(T initialState = default) - { - nodes.Add(new(initialState, -1)); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public int Insert(int parentIndex, T state) - { - var index = nodes.Count; - nodes.Add(new(state, parentIndex)); - nodes[parentIndex].Children.Add(index); - return index; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public Node Get(int index) => nodes[index]; -} diff --git a/Solver/Crafty/ArenaNode.cs b/Solver/Crafty/ArenaNode.cs new file mode 100644 index 0000000..8dbf0c9 --- /dev/null +++ b/Solver/Crafty/ArenaNode.cs @@ -0,0 +1,25 @@ +using System.Runtime.CompilerServices; + +namespace Craftimizer.Solver.Crafty; + +public class ArenaNode where T : struct +{ + public readonly T State; + public readonly List> Children; + public readonly ArenaNode? Parent; + + public ArenaNode(T state, ArenaNode? parent = null) + { + State = state; + Children = new(); + Parent = parent; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ArenaNode Add(T state) + { + var node = new ArenaNode(state, this); + Children.Add(node); + return node; + } +} diff --git a/Solver/Crafty/SimulationNode.cs b/Solver/Crafty/SimulationNode.cs index 27e4330..11f0a70 100644 --- a/Solver/Crafty/SimulationNode.cs +++ b/Solver/Crafty/SimulationNode.cs @@ -5,18 +5,26 @@ namespace Craftimizer.Solver.Crafty; public readonly struct SimulationNode { - public SimulationState State { get; init; } - public ActionType? Action { get; init; } - public CompletionState SimulationCompletionState { get; init; } + public readonly SimulationState State; + public readonly ActionType? Action; + public readonly CompletionState SimulationCompletionState; + public readonly NodeData Data; + public CompletionState CompletionState => Data.AvailableActions.Count == 0 && SimulationCompletionState == CompletionState.Incomplete ? CompletionState.NoMoreActions : SimulationCompletionState; - public NodeData Data { get; init; } - public bool IsComplete => CompletionState != CompletionState.Incomplete; + public SimulationNode(SimulationState state, ActionType? action, CompletionState completionState, NodeData data) + { + State = state; + Action = action; + SimulationCompletionState = completionState; + Data = data; + } + public float? CalculateScore() { if (CompletionState != CompletionState.ProgressComplete) diff --git a/Solver/Crafty/Solver.cs b/Solver/Crafty/Solver.cs index 8f62fad..5292295 100644 --- a/Solver/Crafty/Solver.cs +++ b/Solver/Crafty/Solver.cs @@ -3,6 +3,7 @@ using Craftimizer.Simulator.Actions; using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using Node = Craftimizer.Solver.Crafty.ArenaNode; namespace Craftimizer.Solver.Crafty; @@ -10,11 +11,11 @@ namespace Craftimizer.Solver.Crafty; public class Solver { public Simulator Simulator; - public Arena Tree; + public Node RootNode; - public Random Random => Simulator.Input.Random; + // public Random Random => Simulator.Input.Random; - public const int Iterations = 50000; + public const int Iterations = 30000; public const float ScoreStorageThreshold = 1f; public const float MaxScoreWeightingConstant = 0.1f; public const float ExplorationConstant = 4f; @@ -23,13 +24,12 @@ public class Solver public Solver(SimulationState state, bool strict) { Simulator = new(state); - Tree = new(new() - { - State = state, - Action = null, - SimulationCompletionState = Simulator.CompletionState, - Data = new() { AvailableActions = Simulator.AvailableActionsHeuristic(strict) } - }); + RootNode = new(new( + state, + null, + Simulator.CompletionState, + new() { AvailableActions = Simulator.AvailableActionsHeuristic(strict) } + )); } public Solver(SimulationInput input, bool strict) : this(new SimulationState(input), strict) @@ -39,37 +39,34 @@ public class Solver private SimulationNode Execute(SimulationState state, ActionType action, bool strict) { (_, var newState) = Simulator.Execute(state, action); - return new() - { - State = newState, - Action = action, - SimulationCompletionState = Simulator.CompletionState, - Data = new() { AvailableActions = Simulator.AvailableActionsHeuristic(strict) } - }; + return new( + newState, + action, + Simulator.CompletionState, + new() { AvailableActions = Simulator.AvailableActionsHeuristic(strict) } + ); } - public (int Index, CompletionState State) ExecuteActions(int startIndex, ReadOnlySpan actions, bool strict = false) + public (Node EndNode, CompletionState State) ExecuteActions(Node startNode, ReadOnlySpan actions, bool strict = false) { - var currentIndex = startIndex; foreach (var action in actions) { - var node = Tree.Get(currentIndex).State; - if (node.IsComplete) - return (currentIndex, node.CompletionState); + var state = startNode.State; + if (state.IsComplete) + return (startNode, state.CompletionState); - if (!node.Data.AvailableActions.HasAction(action)) - return (currentIndex, CompletionState.InvalidAction); - node.Data.AvailableActions.RemoveAction(action); + if (!state.Data.AvailableActions.HasAction(action)) + return (startNode, CompletionState.InvalidAction); + state.Data.AvailableActions.RemoveAction(action); - currentIndex = Tree.Insert(currentIndex, Execute(node.State, action, strict)); + startNode = startNode.Add(Execute(state.State, action, strict)); } - var currentNode = Tree.Get(currentIndex).State; - return (currentIndex, currentNode.CompletionState); + return (startNode, startNode.State.CompletionState); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int RustMaxBy(ReadOnlySpan source, Func into) + private static T RustMaxBy(ReadOnlySpan source, Func into) { var max = 0; var maxV = into(source[0]); @@ -97,7 +94,7 @@ public class Solver (length + (Vector.Count - 1)) & ~(Vector.Count - 1); [MethodImpl(MethodImplOptions.AggressiveInlining)] - private int EvalBestChild(float parentVisits, ReadOnlySpan children) + private Node EvalBestChild(float parentVisits, ReadOnlySpan children) { var length = children.Length; @@ -120,7 +117,7 @@ public class Solver for (var j = 0; j < iterCount; ++j) { - var node = Tree.Get(children[i + j]).State.Data.Scores; + var node = children[i + j].State.Data.Scores; scoreSums[j] = node.ScoreSum; visits[j] = node.Visits; maxScores[j] = node.MaxScore; @@ -140,46 +137,42 @@ public class Solver return children[max]; } - public int Select(int selectedIndex) + public Node Select(Node selectedNode) { while (true) { - var selectedNode = Tree.Get(selectedIndex); - var expandable = selectedNode.State.Data.AvailableActions.Count != 0; var likelyTerminal = selectedNode.Children.Count == 0; if (expandable || likelyTerminal) - { - return selectedIndex; - } + return selectedNode; // select the node with the highest score - selectedIndex = EvalBestChild(selectedNode.State.Data.Scores.Visits, CollectionsMarshal.AsSpan(selectedNode.Children)); + selectedNode = EvalBestChild(selectedNode.State.Data.Scores.Visits, CollectionsMarshal.AsSpan(selectedNode.Children)); } } - public (int Index, CompletionState State, float Score) ExpandAndRollout(int initialIndex) + public (Node ExpandedNode, CompletionState State, float Score) ExpandAndRollout(Node initialNode) { + var initialState = initialNode.State; // expand once - var initialNode = Tree.Get(initialIndex).State; - if (initialNode.IsComplete) - return (initialIndex, initialNode.CompletionState, initialNode.CalculateScore() ?? 0); + if (initialState.IsComplete) + return (initialNode, initialState.CompletionState, initialState.CalculateScore() ?? 0); - var randomIdx = Random.Next(initialNode.Data.AvailableActions.Count); - var randomAction = initialNode.Data.AvailableActions.ElementAt(randomIdx); - initialNode.Data.AvailableActions.RemoveAction(randomAction); - var expandedState = Execute(initialNode.State, randomAction, true); - var expandedIndex = Tree.Insert(initialIndex, expandedState); + var randomIdx = 0;// Random.Next(initialState.Data.AvailableActions.Count); + var randomAction = initialState.Data.AvailableActions.ElementAt(randomIdx); + initialState.Data.AvailableActions.RemoveAction(randomAction); + var expandedState = Execute(initialState.State, randomAction, true); + var expandedNode = initialNode.Add(expandedState); // playout to a terminal state - var currentState = Tree.Get(expandedIndex).State; + var currentState = expandedNode.State; byte actionCount = 0; Span actions = stackalloc ActionType[MaxStepCount]; while (true) { if (currentState.IsComplete) break; - randomIdx = Random.Next(currentState.Data.AvailableActions.Count); + randomIdx = 0;// Random.Next(currentState.Data.AvailableActions.Count); randomAction = currentState.Data.AvailableActions.ElementAt(randomIdx); actions[actionCount++] = randomAction; currentState = Execute(currentState.State, randomAction, true); @@ -189,50 +182,46 @@ public class Solver var score = currentState.CalculateScore() ?? 0; if (currentState.CompletionState == CompletionState.ProgressComplete) { - if (score >= ScoreStorageThreshold && score >= Tree.Get(0).State.Data.Scores.MaxScore) + if (score >= ScoreStorageThreshold && score >= RootNode.State.Data.Scores.MaxScore) { - Console.WriteLine("DONE!"); - (var terminalIndex, _) = ExecuteActions(expandedIndex, actions[..actionCount], true); - return (terminalIndex, currentState.CompletionState, score); + (var terminalNode, _) = ExecuteActions(expandedNode, actions[..actionCount], true); + return (terminalNode, currentState.CompletionState, score); } } - return (expandedIndex, currentState.CompletionState, score); + return (expandedNode, currentState.CompletionState, score); } - public void Backpropagate(int startIndex, int targetIndex, float score) + public static void Backpropagate(Node startNode, Node targetNode, float score) { - var currentIndex = startIndex; while (true) { - var currentNode = Tree.Get(currentIndex); - currentNode.State.Data.Scores.Visit(score); + startNode.State.Data.Scores.Visit(score); - if (currentIndex == targetIndex) + if (startNode == targetNode) break; - currentIndex = currentNode.Parent; + startNode = startNode.Parent!; } } - public void Search(int startIndex) + public void Search(Node startNode) { for (var i = 0; i < Iterations; i++) { - var selectedIndex = Select(startIndex); - var (endIndex, _, score) = ExpandAndRollout(selectedIndex); + var selectedNode = Select(startNode); + var (endNode, _, score) = ExpandAndRollout(selectedNode); - Backpropagate(endIndex, startIndex, score); + Backpropagate(endNode, startNode, score); } } public (List Actions, SimulationNode Node) Solution() { var actions = new List(); - var node = Tree.Get(0); + var node = RootNode; while (node.Children.Count != 0) { - var next_index = RustMaxBy(CollectionsMarshal.AsSpan(node.Children), n => Tree.Get(n).State.Data.Scores.MaxScore); - node = Tree.Get(next_index); + node = RustMaxBy(CollectionsMarshal.AsSpan(node.Children), n => n.State.Data.Scores.MaxScore); if (node.State.Action != null) actions.Add(node.State.Action.Value); } @@ -247,7 +236,7 @@ public class Solver var solver = new Solver(state, true); while (!solver.Simulator.IsComplete) { - solver.Search(0); + solver.Search(solver.RootNode); var (solution_actions, solution_node) = solver.Solution(); if (solution_node.Data.Scores.MaxScore >= 1.0) @@ -271,7 +260,7 @@ public class Solver public static (List Actions, SimulationState State) SearchOneshot(SimulationInput input) { var solver = new Solver(input, false); - solver.Search(0); + solver.Search(solver.RootNode); var (solution_actions, solution_node) = solver.Solution(); return (solution_actions, solution_node.State); }