Fix ConstantExpected simd call

This commit is contained in:
Asriel Camora
2024-03-15 00:55:12 -07:00
parent 41e6722c43
commit 836e983eb2
+12 -2
View File
@@ -34,13 +34,22 @@ internal static class Intrinsics
return m; return m;
} }
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector256<float> ClearLastN(Vector256<float> data, int len)
{
var threshold = Vector256.Create<int>(len);
var index = Vector256.Create(0, 1, 2, 3, 4, 5, 6, 7);
return Avx.And(Avx2.CompareGreaterThan(threshold, index).AsSingle(), data);
}
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
// https://stackoverflow.com/a/23592221 // https://stackoverflow.com/a/23592221
private static int HMaxIndexAVX2(Vector<float> v, int len) private static int HMaxIndexAVX2(Vector<float> v, int len)
{ {
// Remove NaNs // Remove NaNs
var vfilt = Avx.Blend(v.AsVector256(), Vector256<float>.Zero, (byte)~((1 << len) - 1)); var vfilt = ClearLastN(v.AsVector256(), len);
// Find max value and broadcast to all lanes // Find max value and broadcast to all lanes
var vmax128 = HMax(vfilt); var vmax128 = HMax(vfilt);
@@ -50,7 +59,7 @@ internal static class Intrinsics
var vcmp = Avx.CompareEqual(vfilt, vmax); var vcmp = Avx.CompareEqual(vfilt, vmax);
var mask = unchecked((uint)Avx2.MoveMask(vcmp.AsByte())); var mask = unchecked((uint)Avx2.MoveMask(vcmp.AsByte()));
var inverseIdx = BitOperations.LeadingZeroCount(mask << (8 - len << 2)) >> 2; var inverseIdx = BitOperations.LeadingZeroCount(mask << ((8 - len) << 2)) >> 2;
return len - 1 - inverseIdx; return len - 1 - inverseIdx;
} }
@@ -158,6 +167,7 @@ internal static class Intrinsics
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
[SkipLocalsInit]
public static Vector<float> ReciprocalSqrt(Vector<float> data) public static Vector<float> ReciprocalSqrt(Vector<float> data)
{ {
if (Avx.IsSupported && Vector<float>.Count >= Vector256<float>.Count) if (Avx.IsSupported && Vector<float>.Count >= Vector256<float>.Count)