Improve simd code

This commit is contained in:
Asriel Camora
2023-06-20 22:38:26 -07:00
parent 47c9339d56
commit b7393b5c65
+59 -65
View File
@@ -1,10 +1,8 @@
using Craftimizer.Simulator; using Craftimizer.Simulator;
using Craftimizer.Simulator.Actions; using Craftimizer.Simulator.Actions;
using System.ComponentModel;
using System.Diagnostics; using System.Diagnostics;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
namespace Craftimizer.Solver.Crafty; namespace Craftimizer.Solver.Crafty;
@@ -111,9 +109,12 @@ public class Solver
return exploitation + exploration; return exploitation + exploration;
} }
private static int AlignToVectorLength(int length) =>
(length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1);
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
// Requires a multiple of Vector<float>.Count // Requires a multiple of Vector<float>.Count
private static float[] EvalBestChildMultiple(float parentVisits, float[] scoreSums, float[] visits, float[] maxScores) private static void EvalBestChildMultiple(float parentVisits, ReadOnlySpan<float> scoreSums, ReadOnlySpan<float> visits, ReadOnlySpan<float> maxScores, Span<float> evalScores)
{ {
var C = ExplorationConstant * MathF.Log(parentVisits); var C = ExplorationConstant * MathF.Log(parentVisits);
var w = MaxScoreWeightingConstant; var w = MaxScoreWeightingConstant;
@@ -121,28 +122,25 @@ public class Solver
var CVector = new Vector<float>(C); var CVector = new Vector<float>(C);
var length = scoreSums.Length; var length = scoreSums.Length;
var result = new float[length];
for (var i = 0; i < length; i += Vector<float>.Count) for (var i = 0; i < length; i += Vector<float>.Count)
{ {
var scoreSumsVector = new Vector<float>(scoreSums, i); var scoreSumsVector = new Vector<float>(scoreSums[i..(i + Vector<float>.Count)]);
var visitsVector = new Vector<float>(visits, i); var visitsVector = new Vector<float>(visits[i..(i + Vector<float>.Count)]);
var maxScoresVector = new Vector<float>(maxScores, i); var maxScoresVector = new Vector<float>(maxScores[i..(i + Vector<float>.Count)]);
var evalVector = EvalBestChildVectorized(w, W, CVector, scoreSumsVector, visitsVector, maxScoresVector); var evalVector = EvalBestChildVectorized(w, W, CVector, scoreSumsVector, visitsVector, maxScoresVector);
evalVector.CopyTo(result, i); evalVector.CopyTo(evalScores[i..(i + Vector<float>.Count)]);
} }
return result;
} }
private float[] EvalAllChildrenDbg(float parentVisits, List<int> children) [MethodImpl(MethodImplOptions.AggressiveInlining)]
private int EvalBestChildAlternative(float parentVisits, List<int> children)
{ {
var length = children.Count; var length = children.Count;
var alignedLength = (length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1); var alignedLength = AlignToVectorLength(length);
var scoreSums = new float[alignedLength]; Span<float> scoreSums = stackalloc float[alignedLength];
var visits = new float[alignedLength]; Span<float> visits = stackalloc float[alignedLength];
var maxScores = new float[alignedLength]; Span<float> maxScores = stackalloc float[alignedLength];
Span<float> evalScores = stackalloc float[alignedLength];
for (var i = 0; i < length; ++i) for (var i = 0; i < length; ++i)
{ {
@@ -152,63 +150,56 @@ public class Solver
maxScores[i] = node.MaxScore; maxScores[i] = node.MaxScore;
} }
return EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores); EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores, evalScores);
var max = 0;
for (var i = 1; i < length; ++i)
if (evalScores[i] >= evalScores[max])
max = i;
return children[max];
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private int EvalBestChild(float parentVisits, List<int> children) private int EvalBestChild(float parentVisits, List<int> children)
{ {
var length = children.Count; var length = children.Count;
var alignedLength = (length + (Vector<float>.Count - 1)) & ~(Vector<float>.Count - 1);
var scoreSums = new float[alignedLength];
var visits = new float[alignedLength];
var maxScores = new float[alignedLength];
var C = ExplorationConstant * MathF.Log(parentVisits);
for (var i = 0; i < length; ++i)
{
var node = Tree.Get(children[i]).State.Scores;
scoreSums[i] = node.ScoreSum;
visits[i] = node.Visits;
maxScores[i] = node.MaxScore;
}
var evalScores = EvalBestChildMultiple(parentVisits, scoreSums, visits, maxScores);
var maxIdx = 0;
var max = evalScores[0];
for(var i = 1; i < length; ++i)
{
if (evalScores[i] >= max)
{
maxIdx = i;
max = evalScores[i];
}
}
return children[maxIdx];
}
private int EvalBestChildScalar(List<int> children, NodeScores parent)
{
Console.WriteLine(children.Count);
var C = ExplorationConstant * MathF.Log(parent.Visits);
var w = MaxScoreWeightingConstant; var w = MaxScoreWeightingConstant;
var W = 1f - w; var W = 1f - w;
var CVector = new Vector<float>(C);
var ret = -1; Span<float> scoreSums = stackalloc float[Vector<float>.Count];
var maxV = float.MinValue; Span<float> visits = stackalloc float[Vector<float>.Count];
foreach (var childNode in children) Span<float> maxScores = stackalloc float[Vector<float>.Count];
var max = 0;
var maxScore = 0f;
for (var i = 0; i < length; i += Vector<float>.Count)
{ {
var child = Tree.Get(childNode).State.Scores; var iterCount = i + Vector<float>.Count > length ?
var exploitation = (W * (child.ScoreSum / child.Visits)) + (w * child.MaxScore); length - i :
var exploration = MathF.Sqrt(C / child.Visits); Vector<float>.Count;
var score = exploitation + exploration;
if (score >= maxV) for (var j = 0; j < iterCount; ++j)
{ {
ret = childNode; var node = Tree.Get(children[i + j]).State.Scores;
maxV = score; scoreSums[j] = node.ScoreSum;
visits[j] = node.Visits;
maxScores[j] = node.MaxScore;
}
var evalScores = EvalBestChildVectorized(w, W, CVector, new(scoreSums), new(visits), new(maxScores));
for (var j = 0; j < iterCount; ++j)
{
if (evalScores[j] >= maxScore)
{
max = i + j;
maxScore = evalScores[j];
}
} }
} }
return ret;
return children[max];
} }
public int Select(int selectedIndex) public int Select(int selectedIndex)
@@ -219,7 +210,8 @@ public class Solver
var expandable = selectedNode.State.AvailableActions.Count != 0; var expandable = selectedNode.State.AvailableActions.Count != 0;
var likelyTerminal = selectedNode.Children.Count == 0; var likelyTerminal = selectedNode.Children.Count == 0;
if (expandable || likelyTerminal) { if (expandable || likelyTerminal)
{
return selectedIndex; return selectedIndex;
} }
@@ -280,7 +272,7 @@ public class Solver
if (currentIndex == targetIndex) if (currentIndex == targetIndex)
break; break;
currentIndex = currentNode.Parent!.Value; currentIndex = currentNode.Parent;
} }
} }
@@ -320,7 +312,8 @@ public class Solver
public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SimulationInput input, List<ActionType> actions, Action<ActionType>? actionCallback) public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SimulationInput input, List<ActionType> actions, Action<ActionType>? actionCallback)
{ {
var (state, result) = Simulate(input, actions); var (state, result) = Simulate(input, actions);
if (result != CompletionState.Incomplete) { if (result != CompletionState.Incomplete)
{
return (actions, state); return (actions, state);
} }
@@ -331,7 +324,8 @@ public class Solver
solver.Search(0); solver.Search(0);
var (solution_actions, solution_node) = solver.Solution(); var (solution_actions, solution_node) = solver.Solution();
if (solution_node.Scores.MaxScore >= 1.0) { if (solution_node.Scores.MaxScore >= 1.0)
{
actions.AddRange(solution_actions); actions.AddRange(solution_actions);
return (actions, solution_node.State); return (actions, solution_node.State);
} }
@@ -344,7 +338,7 @@ public class Solver
solver = new Solver(state, true); solver = new Solver(state, true);
} }
//Debugger.Break(); Debugger.Break();
return (actions, state); return (actions, state);
} }