Add benchmarks
This commit is contained in:
+32
-11
@@ -99,23 +99,45 @@ public class Solver
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
// https://stackoverflow.com/a/23592221
|
||||
private static (int, uint) HMaxIndex(Vector256<float> v, int len)
|
||||
public static int HMaxIndexScalar(Vector<float> v, int len)
|
||||
{
|
||||
var vfilt = Avx.Blend(v, Vector256<float>.Zero, (byte)~((1 << len) - 1));
|
||||
var m = 0;
|
||||
for (var i = 1; i < len; ++i)
|
||||
{
|
||||
if (v[i] >= v[m])
|
||||
m = i;
|
||||
}
|
||||
return m;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
// https://stackoverflow.com/a/23592221
|
||||
public static int HMaxIndexAVX2(Vector<float> v, int len)
|
||||
{
|
||||
// Remove NaNs
|
||||
var vfilt = Avx.Blend(v.AsVector256(), Vector256<float>.Zero, (byte)~((1 << len) - 1));
|
||||
|
||||
// Find max value and broadcast to all lanes
|
||||
var vmax128 = HMax(vfilt);
|
||||
var vmax = Vector256.Create(vmax128, vmax128);
|
||||
|
||||
var vcmp = Avx.CompareEqual(v, vmax);
|
||||
// Find the highest index with that value, respecting len
|
||||
var vcmp = Avx.CompareEqual(vfilt, vmax);
|
||||
var mask = unchecked((uint)Avx2.MoveMask(vcmp.AsByte()));
|
||||
mask <<= (8 - len) << 2;
|
||||
|
||||
var inverseIdx = BitOperations.LeadingZeroCount(mask) >> 2;
|
||||
var inverseIdx = BitOperations.LeadingZeroCount(mask << ((8 - len) << 2)) >> 2;
|
||||
|
||||
return (len - 1 - inverseIdx, mask);
|
||||
return len - 1 - inverseIdx;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static int HMaxIndex(Vector<float> v, int len) =>
|
||||
Avx2.IsSupported ?
|
||||
HMaxIndexAVX2(v, len) :
|
||||
HMaxIndexScalar(v, len);
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private Node EvalBestChild(float parentVisits, ReadOnlySpan<Node> children)
|
||||
@@ -134,7 +156,6 @@ public class Solver
|
||||
|
||||
var max = 0;
|
||||
var maxScore = 0f;
|
||||
|
||||
for (var i = 0; i < length; i += vecLength)
|
||||
{
|
||||
var iterCount = i + vecLength > length ?
|
||||
@@ -153,7 +174,7 @@ public class Solver
|
||||
var exploration = Vector.SquareRoot(CVector / new Vector<float>(visits));
|
||||
var evalScores = exploitation + exploration;
|
||||
|
||||
var (idx, mask) = HMaxIndex(evalScores.AsVector256(), iterCount);
|
||||
var idx = HMaxIndex(evalScores, iterCount);
|
||||
|
||||
if (evalScores[idx] >= maxScore)
|
||||
{
|
||||
@@ -186,7 +207,7 @@ public class Solver
|
||||
if (initialState.IsComplete)
|
||||
return (initialNode, initialState.CompletionState, initialState.CalculateScore(Config.MaxStepCount) ?? 0);
|
||||
|
||||
var randomAction = initialState.AvailableActions.SelectRandom(Random);
|
||||
var randomAction = initialState.AvailableActions.First();//.SelectRandom(Random);
|
||||
initialState.AvailableActions.RemoveAction(randomAction);
|
||||
var expandedNode = initialNode.Add(Execute(initialState.State, randomAction, true));
|
||||
|
||||
@@ -201,7 +222,7 @@ public class Solver
|
||||
{
|
||||
if (SimulationNode.GetCompletionState(currentCompletionState, currentActions) != CompletionState.Incomplete)
|
||||
break;
|
||||
randomAction = currentActions.SelectRandom(Random);
|
||||
randomAction = currentActions.First();//.SelectRandom(Random);
|
||||
actions[actionCount++] = randomAction;
|
||||
(_, currentState) = Simulator.Execute(currentState, randomAction);
|
||||
currentCompletionState = Simulator.CompletionState;
|
||||
|
||||
Reference in New Issue
Block a user