Vectorize evaluation

This commit is contained in:
Asriel Camora
2023-06-19 17:02:09 -07:00
parent 05ead22448
commit 1d0d4cf8ce
3 changed files with 130 additions and 12 deletions
+4
View File
@@ -2,6 +2,7 @@ using Craftimizer.Simulator;
using Craftimizer.Simulator.Actions; using Craftimizer.Simulator.Actions;
using Craftimizer.Solver.Crafty; using Craftimizer.Solver.Crafty;
using ObjectLayoutInspector; using ObjectLayoutInspector;
using System.Diagnostics;
namespace Craftimizer.Benchmark; namespace Craftimizer.Benchmark;
@@ -32,6 +33,7 @@ internal static class Program
}; };
var actions = new List<ActionType>(); var actions = new List<ActionType>();
var s = Stopwatch.StartNew();
if (true) if (true)
(actions, _) = Solver.Crafty.Solver.SearchStepwise(input, actions, a => Console.WriteLine(a)); (actions, _) = Solver.Crafty.Solver.SearchStepwise(input, actions, a => Console.WriteLine(a));
else else
@@ -41,5 +43,7 @@ internal static class Program
Console.Write($">{action.IntName()}"); Console.Write($">{action.IntName()}");
Console.WriteLine(); Console.WriteLine();
} }
s.Stop();
Console.WriteLine($"{s.Elapsed.TotalMilliseconds:0.00}");
} }
} }
+10 -6
View File
@@ -1,3 +1,5 @@
using System.Runtime.CompilerServices;
namespace Craftimizer.Solver.Crafty; namespace Craftimizer.Solver.Crafty;
public class Arena<T> where T : struct public class Arena<T> where T : struct
@@ -10,20 +12,22 @@ public class Arena<T> where T : struct
public T State { get; init; } public T State { get; init; }
} }
public List<Node> Nodes { get; } = new(); private readonly List<Node> nodes = new();
public Arena(T initialState = default) public Arena(T initialState = default)
{ {
Nodes.Add(new() { Parent = null, Index = 0, Children = new(), State = initialState }); nodes.Add(new() { Parent = null, Index = 0, Children = new(), State = initialState });
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int Insert(int parentIndex, T state) public int Insert(int parentIndex, T state)
{ {
var index = Nodes.Count; var index = nodes.Count;
Nodes.Add(new() { Parent = parentIndex, Index = index, Children = new(), State = state }); nodes.Add(new() { Parent = parentIndex, Index = index, Children = new(), State = state });
Nodes[parentIndex].Children.Add(index); nodes[parentIndex].Children.Add(index);
return index; return index;
} }
public Node Get(int index) => Nodes[index]; [MethodImpl(MethodImplOptions.AggressiveInlining)]
public Node Get(int index) => nodes[index];
} }
+116 -6
View File
@@ -1,7 +1,10 @@
using Craftimizer.Simulator; using Craftimizer.Simulator;
using Craftimizer.Simulator.Actions; using Craftimizer.Simulator.Actions;
using System.ComponentModel;
using System.Diagnostics; using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
namespace Craftimizer.Solver.Crafty; namespace Craftimizer.Solver.Crafty;
@@ -100,9 +103,116 @@ public class Solver
return source[max]; return source[max];
} }
public int Select(int currentIndex) [MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector<float> EvalBestChildVectorized(float w, float W, Vector<float> C, Vector<float> scoreSums, Vector<float> visits, Vector<float> maxScores)
{
var exploitation = W * (scoreSums / visits) + w * maxScores;
var exploration = Vector.SquareRoot(C / visits);
return exploitation + exploration;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
// Requires a multiple of Vector<float>.Count
private static float[] EvalBestChildMultiple(float parentVisits, float[] scoreSums, float[] visits, float[] maxScores)
{
var C = ExplorationConstant * MathF.Log(parentVisits);
var w = MaxScoreWeightingConstant;
var W = 1f - w;
var CVector = new Vector<float>(C);
var length = scoreSums.Length;
var result = new float[length];
for (var i = 0; i < length; i += Vector<float>.Count)
{
var scoreSumsVector = new Vector<float>(scoreSums, i);
var visitsVector = new Vector<float>(visits, i);
var maxScoresVector = new Vector<float>(maxScores, i);
var evalVector = EvalBestChildVectorized(w, W, CVector, scoreSumsVector, visitsVector, maxScoresVector);
evalVector.CopyTo(result, i);
}
return result;
}
private float[] EvalAllChildrenDbg(float parentVisits, List<int> children)
{
var length = children.Count;
var alignedLength = (length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1);
var scoreSums = new float[alignedLength];
var visits = new float[alignedLength];
var maxScores = new float[alignedLength];
for (var i = 0; i < length; ++i)
{
var node = Tree.Get(children[i]).State.Scores;
scoreSums[i] = node.ScoreSum;
visits[i] = node.Visits;
maxScores[i] = node.MaxScore;
}
return EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private int EvalBestChild(float parentVisits, List<int> children)
{
var length = children.Count;
var alignedLength = (length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1);
var scoreSums = new float[alignedLength];
var visits = new float[alignedLength];
var maxScores = new float[alignedLength];
for (var i = 0; i < length; ++i)
{
var node = Tree.Get(children[i]).State.Scores;
scoreSums[i] = node.ScoreSum;
visits[i] = node.Visits;
maxScores[i] = node.MaxScore;
}
var evalScores = EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores);
var maxIdx = 0;
var max = evalScores[0];
for(var i = 1; i < length; ++i)
{
if (evalScores[i] >= max)
{
maxIdx = i;
max = evalScores[i];
}
}
return children[maxIdx];
}
private int EvalBestChildScalar(List<int> children, NodeScores parent)
{
Console.WriteLine(children.Count);
var C = ExplorationConstant * MathF.Log(parent.Visits);
var w = MaxScoreWeightingConstant;
var W = 1f - w;
var ret = -1;
var maxV = float.MinValue;
foreach (var childNode in children)
{
var child = Tree.Get(childNode).State.Scores;
var exploitation = (W * (child.ScoreSum / child.Visits)) + (w * child.MaxScore);
var exploration = MathF.Sqrt(C / child.Visits);
var score = exploitation + exploration;
if (score >= maxV)
{
ret = childNode;
maxV = score;
}
}
return ret;
}
public int Select(int selectedIndex)
{ {
var selectedIndex = currentIndex;
while (true) while (true)
{ {
var selectedNode = Tree.Get(selectedIndex); var selectedNode = Tree.Get(selectedIndex);
@@ -110,13 +220,12 @@ public class Solver
var expandable = selectedNode.State.AvailableActions.Count != 0; var expandable = selectedNode.State.AvailableActions.Count != 0;
var likelyTerminal = selectedNode.Children.Count == 0; var likelyTerminal = selectedNode.Children.Count == 0;
if (expandable || likelyTerminal) { if (expandable || likelyTerminal) {
break; return selectedIndex;
} }
// select the node with the highest score // select the node with the highest score
selectedIndex = RustMaxBy(selectedNode.Children, n => Eval(Tree.Get(n).State.Scores, selectedNode.State.Scores)); selectedIndex = EvalBestChild(selectedNode.State.Scores.Visits, selectedNode.Children);
} }
return selectedIndex;
} }
public (int Index, CompletionState State, float Score) ExpandAndRollout(int initialIndex) public (int Index, CompletionState State, float Score) ExpandAndRollout(int initialIndex)
@@ -164,7 +273,8 @@ public class Solver
var currentScores = currentNode.State.Scores; var currentScores = currentNode.State.Scores;
currentScores.Visits++; currentScores.Visits++;
currentScores.ScoreSum += score; currentScores.ScoreSum += score;
currentScores.MaxScore = Math.Max(currentScores.MaxScore, score); if (currentScores.MaxScore < score)
currentScores.MaxScore = score;
if (currentIndex == targetIndex) if (currentIndex == targetIndex)
break; break;