diff --git a/Solver/Crafty/Solver.cs b/Solver/Crafty/Solver.cs index 4be2de3..1893b8b 100644 --- a/Solver/Crafty/Solver.cs +++ b/Solver/Crafty/Solver.cs @@ -1,9 +1,13 @@ using Craftimizer.Simulator; using Craftimizer.Simulator.Actions; using System; +using System.Diagnostics; +using System.Diagnostics.Contracts; using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; using Node = Craftimizer.Solver.Crafty.ArenaNode; namespace Craftimizer.Solver.Crafty; @@ -79,38 +83,63 @@ public class Solver return source[max]; } + [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector EvalBestChildVectorized(float w, float W, Vector C, Vector scoreSums, Vector visits, Vector maxScores) + // https://stackoverflow.com/a/73439472 + private static Vector128 HMax(Vector256 v1) { - var exploitation = W * (scoreSums / visits) + w * maxScores; - var exploration = Vector.SquareRoot(C / visits); - return exploitation + exploration; + var v2 = Avx.Permute(v1, 0b10110001); + var v3 = Avx.Max(v1, v2); + var v4 = Avx.Permute(v3, 0b00001010); + var v5 = Avx.Max(v3, v4); + var v6 = Avx.ExtractVector128(v5, 1); + var v7 = Sse.Max(v5.GetLower(), v6); + return v7; } - private static int AlignToVectorLength(int length) => - (length + (Vector.Count - 1)) & ~(Vector.Count - 1); + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + // https://stackoverflow.com/a/23592221 + private static (int, uint) HMaxIndex(Vector256 v, int len) + { + var vfilt = Avx.Blend(v, Vector256.Zero, (byte)~((1 << len) - 1)); + var vmax128 = HMax(vfilt); + var vmax = Vector256.Create(vmax128, vmax128); + + var vcmp = Avx.CompareEqual(v, vmax); + var mask = unchecked((uint)Avx2.MoveMask(vcmp.AsByte())); + mask <<= (8 - len) << 2; + + var inverseIdx = BitOperations.LeadingZeroCount(mask) >> 2; + + return (len - 1 - inverseIdx, mask); + } + + [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] private Node EvalBestChild(float parentVisits, ReadOnlySpan children) { var length = children.Length; + var vecLength = Vector.Count; var C = Config.ExplorationConstant * MathF.Log(parentVisits); var w = Config.MaxScoreWeightingConstant; var W = 1f - w; var CVector = new Vector(C); - Span scoreSums = stackalloc float[Vector.Count]; - Span visits = stackalloc float[Vector.Count]; - Span maxScores = stackalloc float[Vector.Count]; - + Span scoreSums = stackalloc float[vecLength]; + Span visits = stackalloc float[vecLength]; + Span maxScores = stackalloc float[vecLength]; + var max = 0; var maxScore = 0f; - for (var i = 0; i < length; i += Vector.Count) + + for (var i = 0; i < length; i += vecLength) { - var iterCount = i + Vector.Count > length ? + var iterCount = i + vecLength > length ? length - i : - Vector.Count; + vecLength; for (var j = 0; j < iterCount; ++j) { @@ -119,15 +148,17 @@ public class Solver visits[j] = node.Visits; maxScores[j] = node.MaxScore; } - var evalScores = EvalBestChildVectorized(w, W, CVector, new(scoreSums), new(visits), new(maxScores)); - for (var j = 0; j < iterCount; ++j) + var exploitation = (W * (new Vector(scoreSums) / new Vector(visits))) + (w * new Vector(maxScores)); + var exploration = Vector.SquareRoot(CVector / new Vector(visits)); + var evalScores = exploitation + exploration; + + var (idx, mask) = HMaxIndex(evalScores.AsVector256(), iterCount); + + if (evalScores[idx] >= maxScore) { - if (evalScores[j] >= maxScore) - { - max = i + j; - maxScore = evalScores[j]; - } + max = i + idx; + maxScore = evalScores[idx]; } }