Vectorize evaluation
This commit is contained in:
@@ -2,6 +2,7 @@ using Craftimizer.Simulator;
|
|||||||
using Craftimizer.Simulator.Actions;
|
using Craftimizer.Simulator.Actions;
|
||||||
using Craftimizer.Solver.Crafty;
|
using Craftimizer.Solver.Crafty;
|
||||||
using ObjectLayoutInspector;
|
using ObjectLayoutInspector;
|
||||||
|
using System.Diagnostics;
|
||||||
|
|
||||||
namespace Craftimizer.Benchmark;
|
namespace Craftimizer.Benchmark;
|
||||||
|
|
||||||
@@ -32,6 +33,7 @@ internal static class Program
|
|||||||
};
|
};
|
||||||
|
|
||||||
var actions = new List<ActionType>();
|
var actions = new List<ActionType>();
|
||||||
|
var s = Stopwatch.StartNew();
|
||||||
if (true)
|
if (true)
|
||||||
(actions, _) = Solver.Crafty.Solver.SearchStepwise(input, actions, a => Console.WriteLine(a));
|
(actions, _) = Solver.Crafty.Solver.SearchStepwise(input, actions, a => Console.WriteLine(a));
|
||||||
else
|
else
|
||||||
@@ -41,5 +43,7 @@ internal static class Program
|
|||||||
Console.Write($">{action.IntName()}");
|
Console.Write($">{action.IntName()}");
|
||||||
Console.WriteLine();
|
Console.WriteLine();
|
||||||
}
|
}
|
||||||
|
s.Stop();
|
||||||
|
Console.WriteLine($"{s.Elapsed.TotalMilliseconds:0.00}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+10
-6
@@ -1,3 +1,5 @@
|
|||||||
|
using System.Runtime.CompilerServices;
|
||||||
|
|
||||||
namespace Craftimizer.Solver.Crafty;
|
namespace Craftimizer.Solver.Crafty;
|
||||||
|
|
||||||
public class Arena<T> where T : struct
|
public class Arena<T> where T : struct
|
||||||
@@ -10,20 +12,22 @@ public class Arena<T> where T : struct
|
|||||||
public T State { get; init; }
|
public T State { get; init; }
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<Node> Nodes { get; } = new();
|
private readonly List<Node> nodes = new();
|
||||||
|
|
||||||
public Arena(T initialState = default)
|
public Arena(T initialState = default)
|
||||||
{
|
{
|
||||||
Nodes.Add(new() { Parent = null, Index = 0, Children = new(), State = initialState });
|
nodes.Add(new() { Parent = null, Index = 0, Children = new(), State = initialState });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
public int Insert(int parentIndex, T state)
|
public int Insert(int parentIndex, T state)
|
||||||
{
|
{
|
||||||
var index = Nodes.Count;
|
var index = nodes.Count;
|
||||||
Nodes.Add(new() { Parent = parentIndex, Index = index, Children = new(), State = state });
|
nodes.Add(new() { Parent = parentIndex, Index = index, Children = new(), State = state });
|
||||||
Nodes[parentIndex].Children.Add(index);
|
nodes[parentIndex].Children.Add(index);
|
||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Node Get(int index) => Nodes[index];
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
|
public Node Get(int index) => nodes[index];
|
||||||
}
|
}
|
||||||
|
|||||||
+116
-6
@@ -1,7 +1,10 @@
|
|||||||
using Craftimizer.Simulator;
|
using Craftimizer.Simulator;
|
||||||
using Craftimizer.Simulator.Actions;
|
using Craftimizer.Simulator.Actions;
|
||||||
|
using System.ComponentModel;
|
||||||
using System.Diagnostics;
|
using System.Diagnostics;
|
||||||
|
using System.Numerics;
|
||||||
using System.Runtime.CompilerServices;
|
using System.Runtime.CompilerServices;
|
||||||
|
using System.Runtime.Intrinsics;
|
||||||
|
|
||||||
namespace Craftimizer.Solver.Crafty;
|
namespace Craftimizer.Solver.Crafty;
|
||||||
|
|
||||||
@@ -100,9 +103,116 @@ public class Solver
|
|||||||
return source[max];
|
return source[max];
|
||||||
}
|
}
|
||||||
|
|
||||||
public int Select(int currentIndex)
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
|
private static Vector<float> EvalBestChildVectorized(float w, float W, Vector<float> C, Vector<float> scoreSums, Vector<float> visits, Vector<float> maxScores)
|
||||||
|
{
|
||||||
|
var exploitation = W * (scoreSums / visits) + w * maxScores;
|
||||||
|
var exploration = Vector.SquareRoot(C / visits);
|
||||||
|
return exploitation + exploration;
|
||||||
|
}
|
||||||
|
|
||||||
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
|
// Requires a multiple of Vector<float>.Count
|
||||||
|
private static float[] EvalBestChildMultiple(float parentVisits, float[] scoreSums, float[] visits, float[] maxScores)
|
||||||
|
{
|
||||||
|
var C = ExplorationConstant * MathF.Log(parentVisits);
|
||||||
|
var w = MaxScoreWeightingConstant;
|
||||||
|
var W = 1f - w;
|
||||||
|
var CVector = new Vector<float>(C);
|
||||||
|
|
||||||
|
var length = scoreSums.Length;
|
||||||
|
var result = new float[length];
|
||||||
|
|
||||||
|
for (var i = 0; i < length; i += Vector<float>.Count)
|
||||||
|
{
|
||||||
|
var scoreSumsVector = new Vector<float>(scoreSums, i);
|
||||||
|
var visitsVector = new Vector<float>(visits, i);
|
||||||
|
var maxScoresVector = new Vector<float>(maxScores, i);
|
||||||
|
var evalVector = EvalBestChildVectorized(w, W, CVector, scoreSumsVector, visitsVector, maxScoresVector);
|
||||||
|
evalVector.CopyTo(result, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private float[] EvalAllChildrenDbg(float parentVisits, List<int> children)
|
||||||
|
{
|
||||||
|
var length = children.Count;
|
||||||
|
var alignedLength = (length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1);
|
||||||
|
var scoreSums = new float[alignedLength];
|
||||||
|
var visits = new float[alignedLength];
|
||||||
|
var maxScores = new float[alignedLength];
|
||||||
|
|
||||||
|
|
||||||
|
for (var i = 0; i < length; ++i)
|
||||||
|
{
|
||||||
|
var node = Tree.Get(children[i]).State.Scores;
|
||||||
|
scoreSums[i] = node.ScoreSum;
|
||||||
|
visits[i] = node.Visits;
|
||||||
|
maxScores[i] = node.MaxScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
return EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores);
|
||||||
|
}
|
||||||
|
|
||||||
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
|
private int EvalBestChild(float parentVisits, List<int> children)
|
||||||
|
{
|
||||||
|
var length = children.Count;
|
||||||
|
var alignedLength = (length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1);
|
||||||
|
var scoreSums = new float[alignedLength];
|
||||||
|
var visits = new float[alignedLength];
|
||||||
|
var maxScores = new float[alignedLength];
|
||||||
|
|
||||||
|
|
||||||
|
for (var i = 0; i < length; ++i)
|
||||||
|
{
|
||||||
|
var node = Tree.Get(children[i]).State.Scores;
|
||||||
|
scoreSums[i] = node.ScoreSum;
|
||||||
|
visits[i] = node.Visits;
|
||||||
|
maxScores[i] = node.MaxScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
var evalScores = EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores);
|
||||||
|
var maxIdx = 0;
|
||||||
|
var max = evalScores[0];
|
||||||
|
for(var i = 1; i < length; ++i)
|
||||||
|
{
|
||||||
|
if (evalScores[i] >= max)
|
||||||
|
{
|
||||||
|
maxIdx = i;
|
||||||
|
max = evalScores[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return children[maxIdx];
|
||||||
|
}
|
||||||
|
|
||||||
|
private int EvalBestChildScalar(List<int> children, NodeScores parent)
|
||||||
|
{
|
||||||
|
Console.WriteLine(children.Count);
|
||||||
|
var C = ExplorationConstant * MathF.Log(parent.Visits);
|
||||||
|
var w = MaxScoreWeightingConstant;
|
||||||
|
var W = 1f - w;
|
||||||
|
|
||||||
|
var ret = -1;
|
||||||
|
var maxV = float.MinValue;
|
||||||
|
foreach (var childNode in children)
|
||||||
|
{
|
||||||
|
var child = Tree.Get(childNode).State.Scores;
|
||||||
|
var exploitation = (W * (child.ScoreSum / child.Visits)) + (w * child.MaxScore);
|
||||||
|
var exploration = MathF.Sqrt(C / child.Visits);
|
||||||
|
var score = exploitation + exploration;
|
||||||
|
if (score >= maxV)
|
||||||
|
{
|
||||||
|
ret = childNode;
|
||||||
|
maxV = score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int Select(int selectedIndex)
|
||||||
{
|
{
|
||||||
var selectedIndex = currentIndex;
|
|
||||||
while (true)
|
while (true)
|
||||||
{
|
{
|
||||||
var selectedNode = Tree.Get(selectedIndex);
|
var selectedNode = Tree.Get(selectedIndex);
|
||||||
@@ -110,13 +220,12 @@ public class Solver
|
|||||||
var expandable = selectedNode.State.AvailableActions.Count != 0;
|
var expandable = selectedNode.State.AvailableActions.Count != 0;
|
||||||
var likelyTerminal = selectedNode.Children.Count == 0;
|
var likelyTerminal = selectedNode.Children.Count == 0;
|
||||||
if (expandable || likelyTerminal) {
|
if (expandable || likelyTerminal) {
|
||||||
break;
|
return selectedIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
// select the node with the highest score
|
// select the node with the highest score
|
||||||
selectedIndex = RustMaxBy(selectedNode.Children, n => Eval(Tree.Get(n).State.Scores, selectedNode.State.Scores));
|
selectedIndex = EvalBestChild(selectedNode.State.Scores.Visits, selectedNode.Children);
|
||||||
}
|
}
|
||||||
return selectedIndex;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public (int Index, CompletionState State, float Score) ExpandAndRollout(int initialIndex)
|
public (int Index, CompletionState State, float Score) ExpandAndRollout(int initialIndex)
|
||||||
@@ -164,7 +273,8 @@ public class Solver
|
|||||||
var currentScores = currentNode.State.Scores;
|
var currentScores = currentNode.State.Scores;
|
||||||
currentScores.Visits++;
|
currentScores.Visits++;
|
||||||
currentScores.ScoreSum += score;
|
currentScores.ScoreSum += score;
|
||||||
currentScores.MaxScore = Math.Max(currentScores.MaxScore, score);
|
if (currentScores.MaxScore < score)
|
||||||
|
currentScores.MaxScore = score;
|
||||||
|
|
||||||
if (currentIndex == targetIndex)
|
if (currentIndex == targetIndex)
|
||||||
break;
|
break;
|
||||||
|
|||||||
Reference in New Issue
Block a user