From 1d0d4cf8ce0d4e23114f600803e9f1718f04c42d Mon Sep 17 00:00:00 2001 From: Asriel Camora Date: Mon, 19 Jun 2023 17:02:09 -0700 Subject: [PATCH] Vectorize evaluation --- Benchmark/Program.cs | 4 ++ Solver/Crafty/Arena.cs | 16 ++++-- Solver/Crafty/Solver.cs | 122 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 130 insertions(+), 12 deletions(-) diff --git a/Benchmark/Program.cs b/Benchmark/Program.cs index 32eb7cc..688eec1 100644 --- a/Benchmark/Program.cs +++ b/Benchmark/Program.cs @@ -2,6 +2,7 @@ using Craftimizer.Simulator; using Craftimizer.Simulator.Actions; using Craftimizer.Solver.Crafty; using ObjectLayoutInspector; +using System.Diagnostics; namespace Craftimizer.Benchmark; @@ -32,6 +33,7 @@ internal static class Program }; var actions = new List(); + var s = Stopwatch.StartNew(); if (true) (actions, _) = Solver.Crafty.Solver.SearchStepwise(input, actions, a => Console.WriteLine(a)); else @@ -41,5 +43,7 @@ internal static class Program Console.Write($">{action.IntName()}"); Console.WriteLine(); } + s.Stop(); + Console.WriteLine($"{s.Elapsed.TotalMilliseconds:0.00}"); } } diff --git a/Solver/Crafty/Arena.cs b/Solver/Crafty/Arena.cs index 26537f3..8661b78 100644 --- a/Solver/Crafty/Arena.cs +++ b/Solver/Crafty/Arena.cs @@ -1,3 +1,5 @@ +using System.Runtime.CompilerServices; + namespace Craftimizer.Solver.Crafty; public class Arena where T : struct @@ -10,20 +12,22 @@ public class Arena where T : struct public T State { get; init; } } - public List Nodes { get; } = new(); + private readonly List nodes = new(); public Arena(T initialState = default) { - Nodes.Add(new() { Parent = null, Index = 0, Children = new(), State = initialState }); + nodes.Add(new() { Parent = null, Index = 0, Children = new(), State = initialState }); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] public int Insert(int parentIndex, T state) { - var index = Nodes.Count; - Nodes.Add(new() { Parent = parentIndex, Index = index, Children = new(), State = state }); - Nodes[parentIndex].Children.Add(index); + var index = nodes.Count; + nodes.Add(new() { Parent = parentIndex, Index = index, Children = new(), State = state }); + nodes[parentIndex].Children.Add(index); return index; } - public Node Get(int index) => Nodes[index]; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Node Get(int index) => nodes[index]; } diff --git a/Solver/Crafty/Solver.cs b/Solver/Crafty/Solver.cs index f05738b..7f23b8d 100644 --- a/Solver/Crafty/Solver.cs +++ b/Solver/Crafty/Solver.cs @@ -1,7 +1,10 @@ using Craftimizer.Simulator; using Craftimizer.Simulator.Actions; +using System.ComponentModel; using System.Diagnostics; +using System.Numerics; using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; namespace Craftimizer.Solver.Crafty; @@ -100,9 +103,116 @@ public class Solver return source[max]; } - public int Select(int currentIndex) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector EvalBestChildVectorized(float w, float W, Vector C, Vector scoreSums, Vector visits, Vector maxScores) + { + var exploitation = W * (scoreSums / visits) + w * maxScores; + var exploration = Vector.SquareRoot(C / visits); + return exploitation + exploration; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + // Requires a multiple of Vector.Count + private static float[] EvalBestChildMultiple(float parentVisits, float[] scoreSums, float[] visits, float[] maxScores) + { + var C = ExplorationConstant * MathF.Log(parentVisits); + var w = MaxScoreWeightingConstant; + var W = 1f - w; + var CVector = new Vector(C); + + var length = scoreSums.Length; + var result = new float[length]; + + for (var i = 0; i < length; i += Vector.Count) + { + var scoreSumsVector = new Vector(scoreSums, i); + var visitsVector = new Vector(visits, i); + var maxScoresVector = new Vector(maxScores, i); + var evalVector = EvalBestChildVectorized(w, W, CVector, scoreSumsVector, visitsVector, maxScoresVector); + evalVector.CopyTo(result, i); + } + + return result; + } + + private float[] EvalAllChildrenDbg(float parentVisits, List children) + { + var length = children.Count; + var alignedLength = (length + (Vector.Count - 1)) & ~(Vector.Count - 1); + var scoreSums = new float[alignedLength]; + var visits = new float[alignedLength]; + var maxScores = new float[alignedLength]; + + + for (var i = 0; i < length; ++i) + { + var node = Tree.Get(children[i]).State.Scores; + scoreSums[i] = node.ScoreSum; + visits[i] = node.Visits; + maxScores[i] = node.MaxScore; + } + + return EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private int EvalBestChild(float parentVisits, List children) + { + var length = children.Count; + var alignedLength = (length + (Vector.Count - 1)) & ~(Vector.Count - 1); + var scoreSums = new float[alignedLength]; + var visits = new float[alignedLength]; + var maxScores = new float[alignedLength]; + + + for (var i = 0; i < length; ++i) + { + var node = Tree.Get(children[i]).State.Scores; + scoreSums[i] = node.ScoreSum; + visits[i] = node.Visits; + maxScores[i] = node.MaxScore; + } + + var evalScores = EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores); + var maxIdx = 0; + var max = evalScores[0]; + for(var i = 1; i < length; ++i) + { + if (evalScores[i] >= max) + { + maxIdx = i; + max = evalScores[i]; + } + } + return children[maxIdx]; + } + + private int EvalBestChildScalar(List children, NodeScores parent) + { + Console.WriteLine(children.Count); + var C = ExplorationConstant * MathF.Log(parent.Visits); + var w = MaxScoreWeightingConstant; + var W = 1f - w; + + var ret = -1; + var maxV = float.MinValue; + foreach (var childNode in children) + { + var child = Tree.Get(childNode).State.Scores; + var exploitation = (W * (child.ScoreSum / child.Visits)) + (w * child.MaxScore); + var exploration = MathF.Sqrt(C / child.Visits); + var score = exploitation + exploration; + if (score >= maxV) + { + ret = childNode; + maxV = score; + } + } + return ret; + } + + public int Select(int selectedIndex) { - var selectedIndex = currentIndex; while (true) { var selectedNode = Tree.Get(selectedIndex); @@ -110,13 +220,12 @@ public class Solver var expandable = selectedNode.State.AvailableActions.Count != 0; var likelyTerminal = selectedNode.Children.Count == 0; if (expandable || likelyTerminal) { - break; + return selectedIndex; } // select the node with the highest score - selectedIndex = RustMaxBy(selectedNode.Children, n => Eval(Tree.Get(n).State.Scores, selectedNode.State.Scores)); + selectedIndex = EvalBestChild(selectedNode.State.Scores.Visits, selectedNode.Children); } - return selectedIndex; } public (int Index, CompletionState State, float Score) ExpandAndRollout(int initialIndex) @@ -164,7 +273,8 @@ public class Solver var currentScores = currentNode.State.Scores; currentScores.Visits++; currentScores.ScoreSum += score; - currentScores.MaxScore = Math.Max(currentScores.MaxScore, score); + if (currentScores.MaxScore < score) + currentScores.MaxScore = score; if (currentIndex == targetIndex) break;