Improve simd code
This commit is contained in:
+60
-66
@@ -1,10 +1,8 @@
|
|||||||
using Craftimizer.Simulator;
|
using Craftimizer.Simulator;
|
||||||
using Craftimizer.Simulator.Actions;
|
using Craftimizer.Simulator.Actions;
|
||||||
using System.ComponentModel;
|
|
||||||
using System.Diagnostics;
|
using System.Diagnostics;
|
||||||
using System.Numerics;
|
using System.Numerics;
|
||||||
using System.Runtime.CompilerServices;
|
using System.Runtime.CompilerServices;
|
||||||
using System.Runtime.Intrinsics;
|
|
||||||
|
|
||||||
namespace Craftimizer.Solver.Crafty;
|
namespace Craftimizer.Solver.Crafty;
|
||||||
|
|
||||||
@@ -85,7 +83,7 @@ public class Solver
|
|||||||
|
|
||||||
return exploitation + exploration;
|
return exploitation + exploration;
|
||||||
}
|
}
|
||||||
|
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
private static int RustMaxBy(List<int> source, Func<int, float> into)
|
private static int RustMaxBy(List<int> source, Func<int, float> into)
|
||||||
{
|
{
|
||||||
@@ -111,9 +109,12 @@ public class Solver
|
|||||||
return exploitation + exploration;
|
return exploitation + exploration;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static int AlignToVectorLength(int length) =>
|
||||||
|
(length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1);
|
||||||
|
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
// Requires a multiple of Vector<float>.Count
|
// Requires a multiple of Vector<float>.Count
|
||||||
private static float[] EvalBestChildMultiple(float parentVisits, float[] scoreSums, float[] visits, float[] maxScores)
|
private static void EvalBestChildMultiple(float parentVisits, ReadOnlySpan<float> scoreSums, ReadOnlySpan<float> visits, ReadOnlySpan<float> maxScores, Span<float> evalScores)
|
||||||
{
|
{
|
||||||
var C = ExplorationConstant * MathF.Log(parentVisits);
|
var C = ExplorationConstant * MathF.Log(parentVisits);
|
||||||
var w = MaxScoreWeightingConstant;
|
var w = MaxScoreWeightingConstant;
|
||||||
@@ -121,28 +122,25 @@ public class Solver
|
|||||||
var CVector = new Vector<float>(C);
|
var CVector = new Vector<float>(C);
|
||||||
|
|
||||||
var length = scoreSums.Length;
|
var length = scoreSums.Length;
|
||||||
var result = new float[length];
|
|
||||||
|
|
||||||
for (var i = 0; i < length; i += Vector<float>.Count)
|
for (var i = 0; i < length; i += Vector<float>.Count)
|
||||||
{
|
{
|
||||||
var scoreSumsVector = new Vector<float>(scoreSums, i);
|
var scoreSumsVector = new Vector<float>(scoreSums[i..(i + Vector<float>.Count)]);
|
||||||
var visitsVector = new Vector<float>(visits, i);
|
var visitsVector = new Vector<float>(visits[i..(i + Vector<float>.Count)]);
|
||||||
var maxScoresVector = new Vector<float>(maxScores, i);
|
var maxScoresVector = new Vector<float>(maxScores[i..(i + Vector<float>.Count)]);
|
||||||
var evalVector = EvalBestChildVectorized(w, W, CVector, scoreSumsVector, visitsVector, maxScoresVector);
|
var evalVector = EvalBestChildVectorized(w, W, CVector, scoreSumsVector, visitsVector, maxScoresVector);
|
||||||
evalVector.CopyTo(result, i);
|
evalVector.CopyTo(evalScores[i..(i + Vector<float>.Count)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private float[] EvalAllChildrenDbg(float parentVisits, List<int> children)
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
|
private int EvalBestChildAlternative(float parentVisits, List<int> children)
|
||||||
{
|
{
|
||||||
var length = children.Count;
|
var length = children.Count;
|
||||||
var alignedLength = (length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1);
|
var alignedLength = AlignToVectorLength(length);
|
||||||
var scoreSums = new float[alignedLength];
|
Span<float> scoreSums = stackalloc float[alignedLength];
|
||||||
var visits = new float[alignedLength];
|
Span<float> visits = stackalloc float[alignedLength];
|
||||||
var maxScores = new float[alignedLength];
|
Span<float> maxScores = stackalloc float[alignedLength];
|
||||||
|
Span<float> evalScores = stackalloc float[alignedLength];
|
||||||
|
|
||||||
for (var i = 0; i < length; ++i)
|
for (var i = 0; i < length; ++i)
|
||||||
{
|
{
|
||||||
@@ -152,63 +150,56 @@ public class Solver
|
|||||||
maxScores[i] = node.MaxScore;
|
maxScores[i] = node.MaxScore;
|
||||||
}
|
}
|
||||||
|
|
||||||
return EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores);
|
EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores, evalScores);
|
||||||
|
var max = 0;
|
||||||
|
for (var i = 1; i < length; ++i)
|
||||||
|
if (evalScores[i] >= evalScores[max])
|
||||||
|
max = i;
|
||||||
|
return children[max];
|
||||||
}
|
}
|
||||||
|
|
||||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||||
private int EvalBestChild(float parentVisits, List<int> children)
|
private int EvalBestChild(float parentVisits, List<int> children)
|
||||||
{
|
{
|
||||||
var length = children.Count;
|
var length = children.Count;
|
||||||
var alignedLength = (length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1);
|
|
||||||
var scoreSums = new float[alignedLength];
|
|
||||||
var visits = new float[alignedLength];
|
|
||||||
var maxScores = new float[alignedLength];
|
|
||||||
|
|
||||||
|
var C = ExplorationConstant * MathF.Log(parentVisits);
|
||||||
for (var i = 0; i < length; ++i)
|
|
||||||
{
|
|
||||||
var node = Tree.Get(children[i]).State.Scores;
|
|
||||||
scoreSums[i] = node.ScoreSum;
|
|
||||||
visits[i] = node.Visits;
|
|
||||||
maxScores[i] = node.MaxScore;
|
|
||||||
}
|
|
||||||
|
|
||||||
var evalScores = EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores);
|
|
||||||
var maxIdx = 0;
|
|
||||||
var max = evalScores[0];
|
|
||||||
for(var i = 1; i < length; ++i)
|
|
||||||
{
|
|
||||||
if (evalScores[i] >= max)
|
|
||||||
{
|
|
||||||
maxIdx = i;
|
|
||||||
max = evalScores[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return children[maxIdx];
|
|
||||||
}
|
|
||||||
|
|
||||||
private int EvalBestChildScalar(List<int> children, NodeScores parent)
|
|
||||||
{
|
|
||||||
Console.WriteLine(children.Count);
|
|
||||||
var C = ExplorationConstant * MathF.Log(parent.Visits);
|
|
||||||
var w = MaxScoreWeightingConstant;
|
var w = MaxScoreWeightingConstant;
|
||||||
var W = 1f - w;
|
var W = 1f - w;
|
||||||
|
var CVector = new Vector<float>(C);
|
||||||
|
|
||||||
var ret = -1;
|
Span<float> scoreSums = stackalloc float[Vector<float>.Count];
|
||||||
var maxV = float.MinValue;
|
Span<float> visits = stackalloc float[Vector<float>.Count];
|
||||||
foreach (var childNode in children)
|
Span<float> maxScores = stackalloc float[Vector<float>.Count];
|
||||||
|
|
||||||
|
var max = 0;
|
||||||
|
var maxScore = 0f;
|
||||||
|
for (var i = 0; i < length; i += Vector<float>.Count)
|
||||||
{
|
{
|
||||||
var child = Tree.Get(childNode).State.Scores;
|
var iterCount = i + Vector<float>.Count > length ?
|
||||||
var exploitation = (W * (child.ScoreSum / child.Visits)) + (w * child.MaxScore);
|
length - i :
|
||||||
var exploration = MathF.Sqrt(C / child.Visits);
|
Vector<float>.Count;
|
||||||
var score = exploitation + exploration;
|
|
||||||
if (score >= maxV)
|
for (var j = 0; j < iterCount; ++j)
|
||||||
{
|
{
|
||||||
ret = childNode;
|
var node = Tree.Get(children[i + j]).State.Scores;
|
||||||
maxV = score;
|
scoreSums[j] = node.ScoreSum;
|
||||||
|
visits[j] = node.Visits;
|
||||||
|
maxScores[j] = node.MaxScore;
|
||||||
|
}
|
||||||
|
var evalScores = EvalBestChildVectorized(w, W, CVector, new(scoreSums), new(visits), new(maxScores));
|
||||||
|
|
||||||
|
for (var j = 0; j < iterCount; ++j)
|
||||||
|
{
|
||||||
|
if (evalScores[j] >= maxScore)
|
||||||
|
{
|
||||||
|
max = i + j;
|
||||||
|
maxScore = evalScores[j];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ret;
|
|
||||||
|
return children[max];
|
||||||
}
|
}
|
||||||
|
|
||||||
public int Select(int selectedIndex)
|
public int Select(int selectedIndex)
|
||||||
@@ -219,7 +210,8 @@ public class Solver
|
|||||||
|
|
||||||
var expandable = selectedNode.State.AvailableActions.Count != 0;
|
var expandable = selectedNode.State.AvailableActions.Count != 0;
|
||||||
var likelyTerminal = selectedNode.Children.Count == 0;
|
var likelyTerminal = selectedNode.Children.Count == 0;
|
||||||
if (expandable || likelyTerminal) {
|
if (expandable || likelyTerminal)
|
||||||
|
{
|
||||||
return selectedIndex;
|
return selectedIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -280,7 +272,7 @@ public class Solver
|
|||||||
if (currentIndex == targetIndex)
|
if (currentIndex == targetIndex)
|
||||||
break;
|
break;
|
||||||
|
|
||||||
currentIndex = currentNode.Parent!.Value;
|
currentIndex = currentNode.Parent;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -320,7 +312,8 @@ public class Solver
|
|||||||
public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SimulationInput input, List<ActionType> actions, Action<ActionType>? actionCallback)
|
public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SimulationInput input, List<ActionType> actions, Action<ActionType>? actionCallback)
|
||||||
{
|
{
|
||||||
var (state, result) = Simulate(input, actions);
|
var (state, result) = Simulate(input, actions);
|
||||||
if (result != CompletionState.Incomplete) {
|
if (result != CompletionState.Incomplete)
|
||||||
|
{
|
||||||
return (actions, state);
|
return (actions, state);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -331,7 +324,8 @@ public class Solver
|
|||||||
solver.Search(0);
|
solver.Search(0);
|
||||||
var (solution_actions, solution_node) = solver.Solution();
|
var (solution_actions, solution_node) = solver.Solution();
|
||||||
|
|
||||||
if (solution_node.Scores.MaxScore >= 1.0) {
|
if (solution_node.Scores.MaxScore >= 1.0)
|
||||||
|
{
|
||||||
actions.AddRange(solution_actions);
|
actions.AddRange(solution_actions);
|
||||||
return (actions, solution_node.State);
|
return (actions, solution_node.State);
|
||||||
}
|
}
|
||||||
@@ -344,7 +338,7 @@ public class Solver
|
|||||||
|
|
||||||
solver = new Solver(state, true);
|
solver = new Solver(state, true);
|
||||||
}
|
}
|
||||||
//Debugger.Break();
|
Debugger.Break();
|
||||||
|
|
||||||
return (actions, state);
|
return (actions, state);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user