From ec596f400d3feee1d57ceb49bbecf1190dbc6d25 Mon Sep 17 00:00:00 2001 From: Asriel Camora Date: Mon, 3 Jul 2023 18:38:59 +0300 Subject: [PATCH] Clean up solver code a bit --- Benchmark/Bench.cs | 59 ------------------ Solver/Crafty/Intrinsics.cs | 64 +++++++++++++++++++ Solver/Crafty/Solver.cs | 118 +++++++++++++----------------------- 3 files changed, 105 insertions(+), 136 deletions(-) delete mode 100644 Benchmark/Bench.cs create mode 100644 Solver/Crafty/Intrinsics.cs diff --git a/Benchmark/Bench.cs b/Benchmark/Bench.cs deleted file mode 100644 index 1ee12a8..0000000 --- a/Benchmark/Bench.cs +++ /dev/null @@ -1,59 +0,0 @@ -using BenchmarkDotNet.Attributes; -using BenchmarkDotNet.Jobs; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Numerics; -using System.Security.Cryptography; -using System.Text; -using System.Threading.Tasks; - -namespace Craftimizer.Benchmark; - -[SimpleJob(RuntimeMoniker.Net70)] -[SimpleJob(RuntimeMoniker.NativeAot70)] -public class Bench -{ - private float[] data; - private int[] dataLengths; - - [Params(1000, 10000)] - public int N; - - [GlobalSetup] - public void Setup() - { - var rand = new Random(); - data = new float[N * 8]; - dataLengths = new int[N]; - for (var i = 0; i < data.Length; i += 8) - { - var len = rand.NextSingle() > .5 ? 8 : rand.Next(1, 9); - dataLengths[i / 8] = len; - for (var j = 0; j < len; ++j) - data[i + j] = rand.NextSingle(); - for (var j = len; j < 8; ++j) - data[i + j] = float.NaN; - } - } - - [Benchmark] - public int[] Scalar() - { - var d = new int[N]; - var dataSpan = data.AsSpan(); - for (var i = 0; i < N; ++i) - d[i] = Solver.Crafty.Solver.HMaxIndexScalar(new Vector(dataSpan.Slice(i * 8, 8)), dataLengths[i]); - return d; - } - - [Benchmark] - public int[] AVX2() - { - var d = new int[128]; - var dataSpan = data.AsSpan(); - for (var i = 0; i < 128; ++i) - d[i] = Solver.Crafty.Solver.HMaxIndexAVX2(new Vector(dataSpan.Slice(i * 8, 8)), dataLengths[i]); - return d; - } -} diff --git a/Solver/Crafty/Intrinsics.cs b/Solver/Crafty/Intrinsics.cs new file mode 100644 index 0000000..8772854 --- /dev/null +++ b/Solver/Crafty/Intrinsics.cs @@ -0,0 +1,64 @@ +using System.Diagnostics.Contracts; +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; + +namespace Craftimizer.Solver.Crafty; +internal static class Intrinsics +{ + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + // https://stackoverflow.com/a/73439472 + private static Vector128 HMax(Vector256 v1) + { + var v2 = Avx.Permute(v1, 0b10110001); + var v3 = Avx.Max(v1, v2); + var v4 = Avx.Permute(v3, 0b00001010); + var v5 = Avx.Max(v3, v4); + var v6 = Avx.ExtractVector128(v5, 1); + var v7 = Sse.Max(v5.GetLower(), v6); + return v7; + } + + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int HMaxIndexScalar(Vector v, int len) + { + var m = 0; + for (var i = 1; i < len; ++i) + { + if (v[i] >= v[m]) + m = i; + } + return m; + } + + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + // https://stackoverflow.com/a/23592221 + private static int HMaxIndexAVX2(Vector v, int len) + { + // Remove NaNs + var vfilt = Avx.Blend(v.AsVector256(), Vector256.Zero, (byte)~((1 << len) - 1)); + + // Find max value and broadcast to all lanes + var vmax128 = HMax(vfilt); + var vmax = Vector256.Create(vmax128, vmax128); + + // Find the highest index with that value, respecting len + var vcmp = Avx.CompareEqual(vfilt, vmax); + var mask = unchecked((uint)Avx2.MoveMask(vcmp.AsByte())); + + var inverseIdx = BitOperations.LeadingZeroCount(mask << ((8 - len) << 2)) >> 2; + + return len - 1 - inverseIdx; + } + + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int HMaxIndex(Vector v, int len) => + Avx2.IsSupported ? + HMaxIndexAVX2(v, len) : + HMaxIndexScalar(v, len); +} diff --git a/Solver/Crafty/Solver.cs b/Solver/Crafty/Solver.cs index 69da71d..e079f23 100644 --- a/Solver/Crafty/Solver.cs +++ b/Solver/Crafty/Solver.cs @@ -64,78 +64,38 @@ public class Solver return (startNode, startNode.State.CompletionState); } + [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static T RustMaxBy(ReadOnlySpan source, Func into) + private static Node ChildMaxScore(ReadOnlySpan children) { + var length = children.Length; + var vecLength = Vector.Count; + + Span scores = stackalloc float[vecLength]; + var max = 0; - var maxV = into(source[0]); - for (var i = 1; i < source.Length; ++i) + var maxScore = 0f; + for (var i = 0; i < length; i += vecLength) { - var nextV = into(source[i]); - if (maxV <= nextV) + var iterCount = i + vecLength > length ? + length - i : + vecLength; + + for (var j = 0; j < iterCount; ++j) + scores[j] = children[i + j].State.Scores.MaxScore; + + var idx = Intrinsics.HMaxIndex(new Vector(scores), iterCount); + + if (scores[idx] >= maxScore) { - max = i; - maxV = nextV; + max = i + idx; + maxScore = scores[idx]; } } - return source[max]; + + return children[max]; } - [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - // https://stackoverflow.com/a/73439472 - private static Vector128 HMax(Vector256 v1) - { - var v2 = Avx.Permute(v1, 0b10110001); - var v3 = Avx.Max(v1, v2); - var v4 = Avx.Permute(v3, 0b00001010); - var v5 = Avx.Max(v3, v4); - var v6 = Avx.ExtractVector128(v5, 1); - var v7 = Sse.Max(v5.GetLower(), v6); - return v7; - } - - [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static int HMaxIndexScalar(Vector v, int len) - { - var m = 0; - for (var i = 1; i < len; ++i) - { - if (v[i] >= v[m]) - m = i; - } - return m; - } - - [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - // https://stackoverflow.com/a/23592221 - public static int HMaxIndexAVX2(Vector v, int len) - { - // Remove NaNs - var vfilt = Avx.Blend(v.AsVector256(), Vector256.Zero, (byte)~((1 << len) - 1)); - - // Find max value and broadcast to all lanes - var vmax128 = HMax(vfilt); - var vmax = Vector256.Create(vmax128, vmax128); - - // Find the highest index with that value, respecting len - var vcmp = Avx.CompareEqual(vfilt, vmax); - var mask = unchecked((uint)Avx2.MoveMask(vcmp.AsByte())); - - var inverseIdx = BitOperations.LeadingZeroCount(mask << ((8 - len) << 2)) >> 2; - - return len - 1 - inverseIdx; - } - - [Pure] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int HMaxIndex(Vector v, int len) => - Avx2.IsSupported ? - HMaxIndexAVX2(v, len) : - HMaxIndexScalar(v, len); - [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] private Node EvalBestChild(float parentVisits, ReadOnlySpan children) @@ -172,7 +132,7 @@ public class Solver var exploration = Vector.SquareRoot(CVector / new Vector(visits)); var evalScores = exploitation + exploration; - var idx = HMaxIndex(evalScores, iterCount); + var idx = Intrinsics.HMaxIndex(evalScores, iterCount); if (evalScores[idx] >= maxScore) { @@ -184,17 +144,19 @@ public class Solver return children[max]; } - public Node Select(Node selectedNode) + [Pure] + public Node Select() { + var node = RootNode; while (true) { - var expandable = selectedNode.State.AvailableActions.Count != 0; - var likelyTerminal = selectedNode.Children.Count == 0; + var expandable = node.State.AvailableActions.Count != 0; + var likelyTerminal = node.Children.Count == 0; if (expandable || likelyTerminal) - return selectedNode; + return node; // select the node with the highest score - selectedNode = EvalBestChild(selectedNode.State.Scores.Visits, CollectionsMarshal.AsSpan(selectedNode.Children)); + node = EvalBestChild(node.State.Scores.Visits, CollectionsMarshal.AsSpan(node.Children)); } } @@ -240,40 +202,42 @@ public class Solver return (expandedNode, currentCompletionState, score); } - public static void Backpropagate(Node startNode, Node targetNode, float score) + public void Backpropagate(Node startNode, float score) { while (true) { startNode.State.Scores.Visit(score); - if (startNode == targetNode) + if (startNode == RootNode) break; startNode = startNode.Parent!; } } - public void Search(Node startNode, CancellationToken token) + public void Search(CancellationToken token) { for (var i = 0; i < Config.Iterations; i++) { if (token.IsCancellationRequested) break; - var selectedNode = Select(startNode); + var selectedNode = Select(); var (endNode, _, score) = ExpandAndRollout(selectedNode); - Backpropagate(endNode, startNode, score); + Backpropagate(endNode, score); } } + [Pure] public (List Actions, SimulationNode Node) Solution() { var actions = new List(); var node = RootNode; while (node.Children.Count != 0) { - node = RustMaxBy(CollectionsMarshal.AsSpan(node.Children), n => n.State.Scores.MaxScore); + node = ChildMaxScore(CollectionsMarshal.AsSpan(node.Children)); + if (node.State.Action != null) actions.Add(node.State.Action.Value); } @@ -293,7 +257,7 @@ public class Solver if (token.IsCancellationRequested) break; - solver.Search(solver.RootNode, token); + solver.Search(token); var (solution_actions, solution_node) = solver.Solution(); if (solution_node.Scores.MaxScore >= 1.0) @@ -320,7 +284,7 @@ public class Solver public static (List Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationState state, CancellationToken token = default) { var solver = new Solver(config, state, false); - solver.Search(solver.RootNode, token); + solver.Search(token); var (solution_actions, solution_node) = solver.Solution(); return (solution_actions, solution_node.State); }