Clean up solver code a bit
This commit is contained in:
@@ -1,59 +0,0 @@
|
||||
using BenchmarkDotNet.Attributes;
|
||||
using BenchmarkDotNet.Jobs;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Numerics;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Craftimizer.Benchmark;
|
||||
|
||||
[SimpleJob(RuntimeMoniker.Net70)]
|
||||
[SimpleJob(RuntimeMoniker.NativeAot70)]
|
||||
public class Bench
|
||||
{
|
||||
private float[] data;
|
||||
private int[] dataLengths;
|
||||
|
||||
[Params(1000, 10000)]
|
||||
public int N;
|
||||
|
||||
[GlobalSetup]
|
||||
public void Setup()
|
||||
{
|
||||
var rand = new Random();
|
||||
data = new float[N * 8];
|
||||
dataLengths = new int[N];
|
||||
for (var i = 0; i < data.Length; i += 8)
|
||||
{
|
||||
var len = rand.NextSingle() > .5 ? 8 : rand.Next(1, 9);
|
||||
dataLengths[i / 8] = len;
|
||||
for (var j = 0; j < len; ++j)
|
||||
data[i + j] = rand.NextSingle();
|
||||
for (var j = len; j < 8; ++j)
|
||||
data[i + j] = float.NaN;
|
||||
}
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public int[] Scalar()
|
||||
{
|
||||
var d = new int[N];
|
||||
var dataSpan = data.AsSpan();
|
||||
for (var i = 0; i < N; ++i)
|
||||
d[i] = Solver.Crafty.Solver.HMaxIndexScalar(new Vector<float>(dataSpan.Slice(i * 8, 8)), dataLengths[i]);
|
||||
return d;
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public int[] AVX2()
|
||||
{
|
||||
var d = new int[128];
|
||||
var dataSpan = data.AsSpan();
|
||||
for (var i = 0; i < 128; ++i)
|
||||
d[i] = Solver.Crafty.Solver.HMaxIndexAVX2(new Vector<float>(dataSpan.Slice(i * 8, 8)), dataLengths[i]);
|
||||
return d;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
using System.Diagnostics.Contracts;
|
||||
using System.Numerics;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.Intrinsics;
|
||||
using System.Runtime.Intrinsics.X86;
|
||||
|
||||
namespace Craftimizer.Solver.Crafty;
|
||||
internal static class Intrinsics
|
||||
{
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
// https://stackoverflow.com/a/73439472
|
||||
private static Vector128<float> HMax(Vector256<float> v1)
|
||||
{
|
||||
var v2 = Avx.Permute(v1, 0b10110001);
|
||||
var v3 = Avx.Max(v1, v2);
|
||||
var v4 = Avx.Permute(v3, 0b00001010);
|
||||
var v5 = Avx.Max(v3, v4);
|
||||
var v6 = Avx.ExtractVector128(v5, 1);
|
||||
var v7 = Sse.Max(v5.GetLower(), v6);
|
||||
return v7;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static int HMaxIndexScalar(Vector<float> v, int len)
|
||||
{
|
||||
var m = 0;
|
||||
for (var i = 1; i < len; ++i)
|
||||
{
|
||||
if (v[i] >= v[m])
|
||||
m = i;
|
||||
}
|
||||
return m;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
// https://stackoverflow.com/a/23592221
|
||||
private static int HMaxIndexAVX2(Vector<float> v, int len)
|
||||
{
|
||||
// Remove NaNs
|
||||
var vfilt = Avx.Blend(v.AsVector256(), Vector256<float>.Zero, (byte)~((1 << len) - 1));
|
||||
|
||||
// Find max value and broadcast to all lanes
|
||||
var vmax128 = HMax(vfilt);
|
||||
var vmax = Vector256.Create(vmax128, vmax128);
|
||||
|
||||
// Find the highest index with that value, respecting len
|
||||
var vcmp = Avx.CompareEqual(vfilt, vmax);
|
||||
var mask = unchecked((uint)Avx2.MoveMask(vcmp.AsByte()));
|
||||
|
||||
var inverseIdx = BitOperations.LeadingZeroCount(mask << ((8 - len) << 2)) >> 2;
|
||||
|
||||
return len - 1 - inverseIdx;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static int HMaxIndex(Vector<float> v, int len) =>
|
||||
Avx2.IsSupported ?
|
||||
HMaxIndexAVX2(v, len) :
|
||||
HMaxIndexScalar(v, len);
|
||||
}
|
||||
+41
-77
@@ -64,78 +64,38 @@ public class Solver
|
||||
return (startNode, startNode.State.CompletionState);
|
||||
}
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static T RustMaxBy<T>(ReadOnlySpan<T> source, Func<T, float> into)
|
||||
private static Node ChildMaxScore(ReadOnlySpan<Node> children)
|
||||
{
|
||||
var length = children.Length;
|
||||
var vecLength = Vector<float>.Count;
|
||||
|
||||
Span<float> scores = stackalloc float[vecLength];
|
||||
|
||||
var max = 0;
|
||||
var maxV = into(source[0]);
|
||||
for (var i = 1; i < source.Length; ++i)
|
||||
var maxScore = 0f;
|
||||
for (var i = 0; i < length; i += vecLength)
|
||||
{
|
||||
var nextV = into(source[i]);
|
||||
if (maxV <= nextV)
|
||||
var iterCount = i + vecLength > length ?
|
||||
length - i :
|
||||
vecLength;
|
||||
|
||||
for (var j = 0; j < iterCount; ++j)
|
||||
scores[j] = children[i + j].State.Scores.MaxScore;
|
||||
|
||||
var idx = Intrinsics.HMaxIndex(new Vector<float>(scores), iterCount);
|
||||
|
||||
if (scores[idx] >= maxScore)
|
||||
{
|
||||
max = i;
|
||||
maxV = nextV;
|
||||
max = i + idx;
|
||||
maxScore = scores[idx];
|
||||
}
|
||||
}
|
||||
return source[max];
|
||||
|
||||
return children[max];
|
||||
}
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
// https://stackoverflow.com/a/73439472
|
||||
private static Vector128<float> HMax(Vector256<float> v1)
|
||||
{
|
||||
var v2 = Avx.Permute(v1, 0b10110001);
|
||||
var v3 = Avx.Max(v1, v2);
|
||||
var v4 = Avx.Permute(v3, 0b00001010);
|
||||
var v5 = Avx.Max(v3, v4);
|
||||
var v6 = Avx.ExtractVector128(v5, 1);
|
||||
var v7 = Sse.Max(v5.GetLower(), v6);
|
||||
return v7;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static int HMaxIndexScalar(Vector<float> v, int len)
|
||||
{
|
||||
var m = 0;
|
||||
for (var i = 1; i < len; ++i)
|
||||
{
|
||||
if (v[i] >= v[m])
|
||||
m = i;
|
||||
}
|
||||
return m;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
// https://stackoverflow.com/a/23592221
|
||||
public static int HMaxIndexAVX2(Vector<float> v, int len)
|
||||
{
|
||||
// Remove NaNs
|
||||
var vfilt = Avx.Blend(v.AsVector256(), Vector256<float>.Zero, (byte)~((1 << len) - 1));
|
||||
|
||||
// Find max value and broadcast to all lanes
|
||||
var vmax128 = HMax(vfilt);
|
||||
var vmax = Vector256.Create(vmax128, vmax128);
|
||||
|
||||
// Find the highest index with that value, respecting len
|
||||
var vcmp = Avx.CompareEqual(vfilt, vmax);
|
||||
var mask = unchecked((uint)Avx2.MoveMask(vcmp.AsByte()));
|
||||
|
||||
var inverseIdx = BitOperations.LeadingZeroCount(mask << ((8 - len) << 2)) >> 2;
|
||||
|
||||
return len - 1 - inverseIdx;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static int HMaxIndex(Vector<float> v, int len) =>
|
||||
Avx2.IsSupported ?
|
||||
HMaxIndexAVX2(v, len) :
|
||||
HMaxIndexScalar(v, len);
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private Node EvalBestChild(float parentVisits, ReadOnlySpan<Node> children)
|
||||
@@ -172,7 +132,7 @@ public class Solver
|
||||
var exploration = Vector.SquareRoot(CVector / new Vector<float>(visits));
|
||||
var evalScores = exploitation + exploration;
|
||||
|
||||
var idx = HMaxIndex(evalScores, iterCount);
|
||||
var idx = Intrinsics.HMaxIndex(evalScores, iterCount);
|
||||
|
||||
if (evalScores[idx] >= maxScore)
|
||||
{
|
||||
@@ -184,17 +144,19 @@ public class Solver
|
||||
return children[max];
|
||||
}
|
||||
|
||||
public Node Select(Node selectedNode)
|
||||
[Pure]
|
||||
public Node Select()
|
||||
{
|
||||
var node = RootNode;
|
||||
while (true)
|
||||
{
|
||||
var expandable = selectedNode.State.AvailableActions.Count != 0;
|
||||
var likelyTerminal = selectedNode.Children.Count == 0;
|
||||
var expandable = node.State.AvailableActions.Count != 0;
|
||||
var likelyTerminal = node.Children.Count == 0;
|
||||
if (expandable || likelyTerminal)
|
||||
return selectedNode;
|
||||
return node;
|
||||
|
||||
// select the node with the highest score
|
||||
selectedNode = EvalBestChild(selectedNode.State.Scores.Visits, CollectionsMarshal.AsSpan(selectedNode.Children));
|
||||
node = EvalBestChild(node.State.Scores.Visits, CollectionsMarshal.AsSpan(node.Children));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,40 +202,42 @@ public class Solver
|
||||
return (expandedNode, currentCompletionState, score);
|
||||
}
|
||||
|
||||
public static void Backpropagate(Node startNode, Node targetNode, float score)
|
||||
public void Backpropagate(Node startNode, float score)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
startNode.State.Scores.Visit(score);
|
||||
|
||||
if (startNode == targetNode)
|
||||
if (startNode == RootNode)
|
||||
break;
|
||||
|
||||
startNode = startNode.Parent!;
|
||||
}
|
||||
}
|
||||
|
||||
public void Search(Node startNode, CancellationToken token)
|
||||
public void Search(CancellationToken token)
|
||||
{
|
||||
for (var i = 0; i < Config.Iterations; i++)
|
||||
{
|
||||
if (token.IsCancellationRequested)
|
||||
break;
|
||||
|
||||
var selectedNode = Select(startNode);
|
||||
var selectedNode = Select();
|
||||
var (endNode, _, score) = ExpandAndRollout(selectedNode);
|
||||
|
||||
Backpropagate(endNode, startNode, score);
|
||||
Backpropagate(endNode, score);
|
||||
}
|
||||
}
|
||||
|
||||
[Pure]
|
||||
public (List<ActionType> Actions, SimulationNode Node) Solution()
|
||||
{
|
||||
var actions = new List<ActionType>();
|
||||
var node = RootNode;
|
||||
while (node.Children.Count != 0)
|
||||
{
|
||||
node = RustMaxBy<Node>(CollectionsMarshal.AsSpan(node.Children), n => n.State.Scores.MaxScore);
|
||||
node = ChildMaxScore(CollectionsMarshal.AsSpan(node.Children));
|
||||
|
||||
if (node.State.Action != null)
|
||||
actions.Add(node.State.Action.Value);
|
||||
}
|
||||
@@ -293,7 +257,7 @@ public class Solver
|
||||
if (token.IsCancellationRequested)
|
||||
break;
|
||||
|
||||
solver.Search(solver.RootNode, token);
|
||||
solver.Search(token);
|
||||
var (solution_actions, solution_node) = solver.Solution();
|
||||
|
||||
if (solution_node.Scores.MaxScore >= 1.0)
|
||||
@@ -320,7 +284,7 @@ public class Solver
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationState state, CancellationToken token = default)
|
||||
{
|
||||
var solver = new Solver(config, state, false);
|
||||
solver.Search(solver.RootNode, token);
|
||||
solver.Search(token);
|
||||
var (solution_actions, solution_node) = solver.Solution();
|
||||
return (solution_actions, solution_node.State);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user