Fancy simd horizontal max
This commit is contained in:
+51
-20
@@ -1,9 +1,13 @@
|
|||||||
using Craftimizer.Simulator;
|
using Craftimizer.Simulator;
|
||||||
using Craftimizer.Simulator.Actions;
|
using Craftimizer.Simulator.Actions;
|
||||||
using System;
|
using System;
|
||||||
|
using System.Diagnostics;
|
||||||
|
using System.Diagnostics.Contracts;
|
||||||
using System.Numerics;
|
using System.Numerics;
|
||||||
using System.Runtime.CompilerServices;
|
using System.Runtime.CompilerServices;
|
||||||
using System.Runtime.InteropServices;
|
using System.Runtime.InteropServices;
|
||||||
|
using System.Runtime.Intrinsics;
|
||||||
|
using System.Runtime.Intrinsics.X86;
|
||||||
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
|
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
|
||||||
|
|
||||||
namespace Craftimizer.Solver.Crafty;
|
namespace Craftimizer.Solver.Crafty;
|
||||||
@@ -79,38 +83,63 @@ public class Solver
|
|||||||
return source[max];
|
return source[max];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[Pure]
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
private static Vector<float> EvalBestChildVectorized(float w, float W, Vector<float> C, Vector<float> scoreSums, Vector<float> visits, Vector<float> maxScores)
|
// https://stackoverflow.com/a/73439472
|
||||||
|
private static Vector128<float> HMax(Vector256<float> v1)
|
||||||
{
|
{
|
||||||
var exploitation = W * (scoreSums / visits) + w * maxScores;
|
var v2 = Avx.Permute(v1, 0b10110001);
|
||||||
var exploration = Vector.SquareRoot(C / visits);
|
var v3 = Avx.Max(v1, v2);
|
||||||
return exploitation + exploration;
|
var v4 = Avx.Permute(v3, 0b00001010);
|
||||||
|
var v5 = Avx.Max(v3, v4);
|
||||||
|
var v6 = Avx.ExtractVector128(v5, 1);
|
||||||
|
var v7 = Sse.Max(v5.GetLower(), v6);
|
||||||
|
return v7;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static int AlignToVectorLength(int length) =>
|
[Pure]
|
||||||
(length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1);
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
|
// https://stackoverflow.com/a/23592221
|
||||||
|
private static (int, uint) HMaxIndex(Vector256<float> v, int len)
|
||||||
|
{
|
||||||
|
var vfilt = Avx.Blend(v, Vector256<float>.Zero, (byte)~((1 << len) - 1));
|
||||||
|
|
||||||
|
var vmax128 = HMax(vfilt);
|
||||||
|
var vmax = Vector256.Create(vmax128, vmax128);
|
||||||
|
|
||||||
|
var vcmp = Avx.CompareEqual(v, vmax);
|
||||||
|
var mask = unchecked((uint)Avx2.MoveMask(vcmp.AsByte()));
|
||||||
|
mask <<= (8 - len) << 2;
|
||||||
|
|
||||||
|
var inverseIdx = BitOperations.LeadingZeroCount(mask) >> 2;
|
||||||
|
|
||||||
|
return (len - 1 - inverseIdx, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Pure]
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
private Node EvalBestChild(float parentVisits, ReadOnlySpan<Node> children)
|
private Node EvalBestChild(float parentVisits, ReadOnlySpan<Node> children)
|
||||||
{
|
{
|
||||||
var length = children.Length;
|
var length = children.Length;
|
||||||
|
var vecLength = Vector<float>.Count;
|
||||||
|
|
||||||
var C = Config.ExplorationConstant * MathF.Log(parentVisits);
|
var C = Config.ExplorationConstant * MathF.Log(parentVisits);
|
||||||
var w = Config.MaxScoreWeightingConstant;
|
var w = Config.MaxScoreWeightingConstant;
|
||||||
var W = 1f - w;
|
var W = 1f - w;
|
||||||
var CVector = new Vector<float>(C);
|
var CVector = new Vector<float>(C);
|
||||||
|
|
||||||
Span<float> scoreSums = stackalloc float[Vector<float>.Count];
|
Span<float> scoreSums = stackalloc float[vecLength];
|
||||||
Span<float> visits = stackalloc float[Vector<float>.Count];
|
Span<float> visits = stackalloc float[vecLength];
|
||||||
Span<float> maxScores = stackalloc float[Vector<float>.Count];
|
Span<float> maxScores = stackalloc float[vecLength];
|
||||||
|
|
||||||
var max = 0;
|
var max = 0;
|
||||||
var maxScore = 0f;
|
var maxScore = 0f;
|
||||||
for (var i = 0; i < length; i += Vector<float>.Count)
|
|
||||||
|
for (var i = 0; i < length; i += vecLength)
|
||||||
{
|
{
|
||||||
var iterCount = i + Vector<float>.Count > length ?
|
var iterCount = i + vecLength > length ?
|
||||||
length - i :
|
length - i :
|
||||||
Vector<float>.Count;
|
vecLength;
|
||||||
|
|
||||||
for (var j = 0; j < iterCount; ++j)
|
for (var j = 0; j < iterCount; ++j)
|
||||||
{
|
{
|
||||||
@@ -119,15 +148,17 @@ public class Solver
|
|||||||
visits[j] = node.Visits;
|
visits[j] = node.Visits;
|
||||||
maxScores[j] = node.MaxScore;
|
maxScores[j] = node.MaxScore;
|
||||||
}
|
}
|
||||||
var evalScores = EvalBestChildVectorized(w, W, CVector, new(scoreSums), new(visits), new(maxScores));
|
|
||||||
|
|
||||||
for (var j = 0; j < iterCount; ++j)
|
var exploitation = (W * (new Vector<float>(scoreSums) / new Vector<float>(visits))) + (w * new Vector<float>(maxScores));
|
||||||
|
var exploration = Vector.SquareRoot(CVector / new Vector<float>(visits));
|
||||||
|
var evalScores = exploitation + exploration;
|
||||||
|
|
||||||
|
var (idx, mask) = HMaxIndex(evalScores.AsVector256(), iterCount);
|
||||||
|
|
||||||
|
if (evalScores[idx] >= maxScore)
|
||||||
{
|
{
|
||||||
if (evalScores[j] >= maxScore)
|
max = i + idx;
|
||||||
{
|
maxScore = evalScores[idx];
|
||||||
max = i + j;
|
|
||||||
maxScore = evalScores[j];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user