From 836e983eb2f71a821e8a91e713ad9ba8e84b9d44 Mon Sep 17 00:00:00 2001 From: Asriel Camora Date: Fri, 15 Mar 2024 00:55:12 -0700 Subject: [PATCH] Fix ConstantExpected simd call --- Solver/Intrinsics.cs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/Solver/Intrinsics.cs b/Solver/Intrinsics.cs index d901f7f..6eca3b9 100644 --- a/Solver/Intrinsics.cs +++ b/Solver/Intrinsics.cs @@ -34,13 +34,22 @@ internal static class Intrinsics return m; } + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 ClearLastN(Vector256 data, int len) + { + var threshold = Vector256.Create(len); + var index = Vector256.Create(0, 1, 2, 3, 4, 5, 6, 7); + return Avx.And(Avx2.CompareGreaterThan(threshold, index).AsSingle(), data); + } + [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] // https://stackoverflow.com/a/23592221 private static int HMaxIndexAVX2(Vector v, int len) { // Remove NaNs - var vfilt = Avx.Blend(v.AsVector256(), Vector256.Zero, (byte)~((1 << len) - 1)); + var vfilt = ClearLastN(v.AsVector256(), len); // Find max value and broadcast to all lanes var vmax128 = HMax(vfilt); @@ -50,7 +59,7 @@ internal static class Intrinsics var vcmp = Avx.CompareEqual(vfilt, vmax); var mask = unchecked((uint)Avx2.MoveMask(vcmp.AsByte())); - var inverseIdx = BitOperations.LeadingZeroCount(mask << (8 - len << 2)) >> 2; + var inverseIdx = BitOperations.LeadingZeroCount(mask << ((8 - len) << 2)) >> 2; return len - 1 - inverseIdx; } @@ -158,6 +167,7 @@ internal static class Intrinsics [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining)] + [SkipLocalsInit] public static Vector ReciprocalSqrt(Vector data) { if (Avx.IsSupported && Vector.Count >= Vector256.Count)