From b7393b5c65b7852470fe29b71e7c6e4c14995043 Mon Sep 17 00:00:00 2001 From: Asriel Camora Date: Tue, 20 Jun 2023 22:38:26 -0700 Subject: [PATCH] Improve simd code --- Solver/Crafty/Solver.cs | 126 +++++++++++++++++++--------------------- 1 file changed, 60 insertions(+), 66 deletions(-) diff --git a/Solver/Crafty/Solver.cs b/Solver/Crafty/Solver.cs index 692f00e..e9551fc 100644 --- a/Solver/Crafty/Solver.cs +++ b/Solver/Crafty/Solver.cs @@ -1,10 +1,8 @@ using Craftimizer.Simulator; using Craftimizer.Simulator.Actions; -using System.ComponentModel; using System.Diagnostics; using System.Numerics; using System.Runtime.CompilerServices; -using System.Runtime.Intrinsics; namespace Craftimizer.Solver.Crafty; @@ -85,7 +83,7 @@ public class Solver return exploitation + exploration; } - + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static int RustMaxBy(List source, Func into) { @@ -111,9 +109,12 @@ public class Solver return exploitation + exploration; } + private static int AlignToVectorLength(int length) => + (length + (Vector.Count - 1)) & ~(Vector.Count - 1); + [MethodImpl(MethodImplOptions.AggressiveInlining)] // Requires a multiple of Vector.Count - private static float[] EvalBestChildMultiple(float parentVisits, float[] scoreSums, float[] visits, float[] maxScores) + private static void EvalBestChildMultiple(float parentVisits, ReadOnlySpan scoreSums, ReadOnlySpan visits, ReadOnlySpan maxScores, Span evalScores) { var C = ExplorationConstant * MathF.Log(parentVisits); var w = MaxScoreWeightingConstant; @@ -121,28 +122,25 @@ public class Solver var CVector = new Vector(C); var length = scoreSums.Length; - var result = new float[length]; - for (var i = 0; i < length; i += Vector.Count) { - var scoreSumsVector = new Vector(scoreSums, i); - var visitsVector = new Vector(visits, i); - var maxScoresVector = new Vector(maxScores, i); + var scoreSumsVector = new Vector(scoreSums[i..(i + Vector.Count)]); + var visitsVector = new Vector(visits[i..(i + Vector.Count)]); + var maxScoresVector = new Vector(maxScores[i..(i + Vector.Count)]); var evalVector = EvalBestChildVectorized(w, W, CVector, scoreSumsVector, visitsVector, maxScoresVector); - evalVector.CopyTo(result, i); + evalVector.CopyTo(evalScores[i..(i + Vector.Count)]); } - - return result; } - private float[] EvalAllChildrenDbg(float parentVisits, List children) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private int EvalBestChildAlternative(float parentVisits, List children) { var length = children.Count; - var alignedLength = (length + (Vector.Count - 1)) & ~(Vector.Count - 1); - var scoreSums = new float[alignedLength]; - var visits = new float[alignedLength]; - var maxScores = new float[alignedLength]; - + var alignedLength = AlignToVectorLength(length); + Span scoreSums = stackalloc float[alignedLength]; + Span visits = stackalloc float[alignedLength]; + Span maxScores = stackalloc float[alignedLength]; + Span evalScores = stackalloc float[alignedLength]; for (var i = 0; i < length; ++i) { @@ -152,63 +150,56 @@ public class Solver maxScores[i] = node.MaxScore; } - return EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores); + EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores, evalScores); + var max = 0; + for (var i = 1; i < length; ++i) + if (evalScores[i] >= evalScores[max]) + max = i; + return children[max]; } [MethodImpl(MethodImplOptions.AggressiveInlining)] private int EvalBestChild(float parentVisits, List children) { var length = children.Count; - var alignedLength = (length + (Vector.Count - 1)) & ~(Vector.Count - 1); - var scoreSums = new float[alignedLength]; - var visits = new float[alignedLength]; - var maxScores = new float[alignedLength]; - - for (var i = 0; i < length; ++i) - { - var node = Tree.Get(children[i]).State.Scores; - scoreSums[i] = node.ScoreSum; - visits[i] = node.Visits; - maxScores[i] = node.MaxScore; - } - - var evalScores = EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores); - var maxIdx = 0; - var max = evalScores[0]; - for(var i = 1; i < length; ++i) - { - if (evalScores[i] >= max) - { - maxIdx = i; - max = evalScores[i]; - } - } - return children[maxIdx]; - } - - private int EvalBestChildScalar(List children, NodeScores parent) - { - Console.WriteLine(children.Count); - var C = ExplorationConstant * MathF.Log(parent.Visits); + var C = ExplorationConstant * MathF.Log(parentVisits); var w = MaxScoreWeightingConstant; var W = 1f - w; + var CVector = new Vector(C); - var ret = -1; - var maxV = float.MinValue; - foreach (var childNode in children) + Span scoreSums = stackalloc float[Vector.Count]; + Span visits = stackalloc float[Vector.Count]; + Span maxScores = stackalloc float[Vector.Count]; + + var max = 0; + var maxScore = 0f; + for (var i = 0; i < length; i += Vector.Count) { - var child = Tree.Get(childNode).State.Scores; - var exploitation = (W * (child.ScoreSum / child.Visits)) + (w * child.MaxScore); - var exploration = MathF.Sqrt(C / child.Visits); - var score = exploitation + exploration; - if (score >= maxV) + var iterCount = i + Vector.Count > length ? + length - i : + Vector.Count; + + for (var j = 0; j < iterCount; ++j) { - ret = childNode; - maxV = score; + var node = Tree.Get(children[i + j]).State.Scores; + scoreSums[j] = node.ScoreSum; + visits[j] = node.Visits; + maxScores[j] = node.MaxScore; + } + var evalScores = EvalBestChildVectorized(w, W, CVector, new(scoreSums), new(visits), new(maxScores)); + + for (var j = 0; j < iterCount; ++j) + { + if (evalScores[j] >= maxScore) + { + max = i + j; + maxScore = evalScores[j]; + } } } - return ret; + + return children[max]; } public int Select(int selectedIndex) @@ -219,7 +210,8 @@ public class Solver var expandable = selectedNode.State.AvailableActions.Count != 0; var likelyTerminal = selectedNode.Children.Count == 0; - if (expandable || likelyTerminal) { + if (expandable || likelyTerminal) + { return selectedIndex; } @@ -280,7 +272,7 @@ public class Solver if (currentIndex == targetIndex) break; - currentIndex = currentNode.Parent!.Value; + currentIndex = currentNode.Parent; } } @@ -320,7 +312,8 @@ public class Solver public static (List Actions, SimulationState State) SearchStepwise(SimulationInput input, List actions, Action? actionCallback) { var (state, result) = Simulate(input, actions); - if (result != CompletionState.Incomplete) { + if (result != CompletionState.Incomplete) + { return (actions, state); } @@ -331,7 +324,8 @@ public class Solver solver.Search(0); var (solution_actions, solution_node) = solver.Solution(); - if (solution_node.Scores.MaxScore >= 1.0) { + if (solution_node.Scores.MaxScore >= 1.0) + { actions.AddRange(solution_actions); return (actions, solution_node.State); } @@ -344,7 +338,7 @@ public class Solver solver = new Solver(state, true); } - //Debugger.Break(); + Debugger.Break(); return (actions, state); }