Vectorize evaluation
This commit is contained in:
+10
-6
@@ -1,3 +1,5 @@
|
||||
using System.Runtime.CompilerServices;
|
||||
|
||||
namespace Craftimizer.Solver.Crafty;
|
||||
|
||||
public class Arena<T> where T : struct
|
||||
@@ -10,20 +12,22 @@ public class Arena<T> where T : struct
|
||||
public T State { get; init; }
|
||||
}
|
||||
|
||||
public List<Node> Nodes { get; } = new();
|
||||
private readonly List<Node> nodes = new();
|
||||
|
||||
public Arena(T initialState = default)
|
||||
{
|
||||
Nodes.Add(new() { Parent = null, Index = 0, Children = new(), State = initialState });
|
||||
nodes.Add(new() { Parent = null, Index = 0, Children = new(), State = initialState });
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public int Insert(int parentIndex, T state)
|
||||
{
|
||||
var index = Nodes.Count;
|
||||
Nodes.Add(new() { Parent = parentIndex, Index = index, Children = new(), State = state });
|
||||
Nodes[parentIndex].Children.Add(index);
|
||||
var index = nodes.Count;
|
||||
nodes.Add(new() { Parent = parentIndex, Index = index, Children = new(), State = state });
|
||||
nodes[parentIndex].Children.Add(index);
|
||||
return index;
|
||||
}
|
||||
|
||||
public Node Get(int index) => Nodes[index];
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public Node Get(int index) => nodes[index];
|
||||
}
|
||||
|
||||
+116
-6
@@ -1,7 +1,10 @@
|
||||
using Craftimizer.Simulator;
|
||||
using Craftimizer.Simulator.Actions;
|
||||
using System.ComponentModel;
|
||||
using System.Diagnostics;
|
||||
using System.Numerics;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.Intrinsics;
|
||||
|
||||
namespace Craftimizer.Solver.Crafty;
|
||||
|
||||
@@ -100,9 +103,116 @@ public class Solver
|
||||
return source[max];
|
||||
}
|
||||
|
||||
public int Select(int currentIndex)
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static Vector<float> EvalBestChildVectorized(float w, float W, Vector<float> C, Vector<float> scoreSums, Vector<float> visits, Vector<float> maxScores)
|
||||
{
|
||||
var exploitation = W * (scoreSums / visits) + w * maxScores;
|
||||
var exploration = Vector.SquareRoot(C / visits);
|
||||
return exploitation + exploration;
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
// Requires a multiple of Vector<float>.Count
|
||||
private static float[] EvalBestChildMultiple(float parentVisits, float[] scoreSums, float[] visits, float[] maxScores)
|
||||
{
|
||||
var C = ExplorationConstant * MathF.Log(parentVisits);
|
||||
var w = MaxScoreWeightingConstant;
|
||||
var W = 1f - w;
|
||||
var CVector = new Vector<float>(C);
|
||||
|
||||
var length = scoreSums.Length;
|
||||
var result = new float[length];
|
||||
|
||||
for (var i = 0; i < length; i += Vector<float>.Count)
|
||||
{
|
||||
var scoreSumsVector = new Vector<float>(scoreSums, i);
|
||||
var visitsVector = new Vector<float>(visits, i);
|
||||
var maxScoresVector = new Vector<float>(maxScores, i);
|
||||
var evalVector = EvalBestChildVectorized(w, W, CVector, scoreSumsVector, visitsVector, maxScoresVector);
|
||||
evalVector.CopyTo(result, i);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private float[] EvalAllChildrenDbg(float parentVisits, List<int> children)
|
||||
{
|
||||
var length = children.Count;
|
||||
var alignedLength = (length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1);
|
||||
var scoreSums = new float[alignedLength];
|
||||
var visits = new float[alignedLength];
|
||||
var maxScores = new float[alignedLength];
|
||||
|
||||
|
||||
for (var i = 0; i < length; ++i)
|
||||
{
|
||||
var node = Tree.Get(children[i]).State.Scores;
|
||||
scoreSums[i] = node.ScoreSum;
|
||||
visits[i] = node.Visits;
|
||||
maxScores[i] = node.MaxScore;
|
||||
}
|
||||
|
||||
return EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores);
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private int EvalBestChild(float parentVisits, List<int> children)
|
||||
{
|
||||
var length = children.Count;
|
||||
var alignedLength = (length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1);
|
||||
var scoreSums = new float[alignedLength];
|
||||
var visits = new float[alignedLength];
|
||||
var maxScores = new float[alignedLength];
|
||||
|
||||
|
||||
for (var i = 0; i < length; ++i)
|
||||
{
|
||||
var node = Tree.Get(children[i]).State.Scores;
|
||||
scoreSums[i] = node.ScoreSum;
|
||||
visits[i] = node.Visits;
|
||||
maxScores[i] = node.MaxScore;
|
||||
}
|
||||
|
||||
var evalScores = EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores);
|
||||
var maxIdx = 0;
|
||||
var max = evalScores[0];
|
||||
for(var i = 1; i < length; ++i)
|
||||
{
|
||||
if (evalScores[i] >= max)
|
||||
{
|
||||
maxIdx = i;
|
||||
max = evalScores[i];
|
||||
}
|
||||
}
|
||||
return children[maxIdx];
|
||||
}
|
||||
|
||||
private int EvalBestChildScalar(List<int> children, NodeScores parent)
|
||||
{
|
||||
Console.WriteLine(children.Count);
|
||||
var C = ExplorationConstant * MathF.Log(parent.Visits);
|
||||
var w = MaxScoreWeightingConstant;
|
||||
var W = 1f - w;
|
||||
|
||||
var ret = -1;
|
||||
var maxV = float.MinValue;
|
||||
foreach (var childNode in children)
|
||||
{
|
||||
var child = Tree.Get(childNode).State.Scores;
|
||||
var exploitation = (W * (child.ScoreSum / child.Visits)) + (w * child.MaxScore);
|
||||
var exploration = MathF.Sqrt(C / child.Visits);
|
||||
var score = exploitation + exploration;
|
||||
if (score >= maxV)
|
||||
{
|
||||
ret = childNode;
|
||||
maxV = score;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
public int Select(int selectedIndex)
|
||||
{
|
||||
var selectedIndex = currentIndex;
|
||||
while (true)
|
||||
{
|
||||
var selectedNode = Tree.Get(selectedIndex);
|
||||
@@ -110,13 +220,12 @@ public class Solver
|
||||
var expandable = selectedNode.State.AvailableActions.Count != 0;
|
||||
var likelyTerminal = selectedNode.Children.Count == 0;
|
||||
if (expandable || likelyTerminal) {
|
||||
break;
|
||||
return selectedIndex;
|
||||
}
|
||||
|
||||
// select the node with the highest score
|
||||
selectedIndex = RustMaxBy(selectedNode.Children, n => Eval(Tree.Get(n).State.Scores, selectedNode.State.Scores));
|
||||
selectedIndex = EvalBestChild(selectedNode.State.Scores.Visits, selectedNode.Children);
|
||||
}
|
||||
return selectedIndex;
|
||||
}
|
||||
|
||||
public (int Index, CompletionState State, float Score) ExpandAndRollout(int initialIndex)
|
||||
@@ -164,7 +273,8 @@ public class Solver
|
||||
var currentScores = currentNode.State.Scores;
|
||||
currentScores.Visits++;
|
||||
currentScores.ScoreSum += score;
|
||||
currentScores.MaxScore = Math.Max(currentScores.MaxScore, score);
|
||||
if (currentScores.MaxScore < score)
|
||||
currentScores.MaxScore = score;
|
||||
|
||||
if (currentIndex == targetIndex)
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user