Pass large structs byref instead
This commit is contained in:
+6
-6
@@ -17,7 +17,7 @@ public sealed class MCTS
|
||||
|
||||
public float MaxScore => rootScores.MaxScore;
|
||||
|
||||
public MCTS(MCTSConfig config, SimulationState state)
|
||||
public MCTS(in MCTSConfig config, in SimulationState state)
|
||||
{
|
||||
this.config = config;
|
||||
var sim = new Simulator(state, config.MaxStepCount);
|
||||
@@ -30,7 +30,7 @@ public sealed class MCTS
|
||||
rootScores = new();
|
||||
}
|
||||
|
||||
private static SimulationNode Execute(Simulator simulator, SimulationState state, ActionType action, bool strict)
|
||||
private static SimulationNode Execute(Simulator simulator, in SimulationState state, ActionType action, bool strict)
|
||||
{
|
||||
(_, var newState) = simulator.Execute(state, action);
|
||||
return new(
|
||||
@@ -61,7 +61,7 @@ public sealed class MCTS
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores)
|
||||
private static (int arrayIdx, int subIdx) ChildMaxScore(in NodeScoresBuffer scores)
|
||||
{
|
||||
var length = scores.Count;
|
||||
var vecLength = Vector<float>.Count;
|
||||
@@ -111,7 +111,7 @@ public sealed class MCTS
|
||||
float explorationConstant,
|
||||
float maxScoreWeightingConstant,
|
||||
int parentVisits,
|
||||
ref NodeScoresBuffer scores)
|
||||
in NodeScoresBuffer scores)
|
||||
{
|
||||
var length = scores.Count;
|
||||
var vecLength = Vector<float>.Count;
|
||||
@@ -168,7 +168,7 @@ public sealed class MCTS
|
||||
return node;
|
||||
|
||||
// select the node with the highest score
|
||||
var at = EvalBestChild(explorationConstant, maxScoreWeightingConstant, nodeVisits, ref node.ChildScores);
|
||||
var at = EvalBestChild(explorationConstant, maxScoreWeightingConstant, nodeVisits, in node.ChildScores);
|
||||
nodeVisits = node.ChildScores.GetVisits(at);
|
||||
node = node.ChildAt(at)!;
|
||||
}
|
||||
@@ -320,7 +320,7 @@ public sealed class MCTS
|
||||
|
||||
while (node.Children.Count != 0)
|
||||
{
|
||||
node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!;
|
||||
node = node.ChildAt(ChildMaxScore(in node.ChildScores))!;
|
||||
|
||||
if (node.State.Action != null)
|
||||
actions.Add(node.State.Action.Value);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace Craftimizer.Solver;
|
||||
|
||||
@@ -21,7 +21,7 @@ public readonly record struct MCTSConfig
|
||||
public float ScoreCP { get; init; }
|
||||
public float ScoreSteps { get; init; }
|
||||
|
||||
public MCTSConfig(SolverConfig config)
|
||||
public MCTSConfig(in SolverConfig config)
|
||||
{
|
||||
MaxStepCount = config.MaxStepCount;
|
||||
MaxRolloutStepCount = config.MaxRolloutStepCount;
|
||||
|
||||
@@ -19,7 +19,7 @@ public struct SimulationNode
|
||||
|
||||
public readonly bool IsComplete => CompletionState != CompletionState.Incomplete;
|
||||
|
||||
public SimulationNode(SimulationState state, ActionType? action, CompletionState completionState, ActionSet actions)
|
||||
public SimulationNode(in SimulationState state, ActionType? action, CompletionState completionState, ActionSet actions)
|
||||
{
|
||||
State = state;
|
||||
Action = action;
|
||||
@@ -32,10 +32,10 @@ public struct SimulationNode
|
||||
CompletionState.NoMoreActions :
|
||||
simCompletionState;
|
||||
|
||||
public readonly float? CalculateScore(MCTSConfig config) =>
|
||||
public readonly float? CalculateScore(in MCTSConfig config) =>
|
||||
CalculateScoreForState(State, SimulationCompletionState, config);
|
||||
|
||||
public static float? CalculateScoreForState(SimulationState state, CompletionState completionState, MCTSConfig config)
|
||||
public static float? CalculateScoreForState(in SimulationState state, CompletionState completionState, MCTSConfig config)
|
||||
{
|
||||
if (completionState != CompletionState.ProgressComplete)
|
||||
return null;
|
||||
|
||||
+1
-1
@@ -20,7 +20,7 @@ internal sealed class Simulator : SimulatorNoRandom
|
||||
}
|
||||
}
|
||||
|
||||
public Simulator(SimulationState state, int maxStepCount) : base(state)
|
||||
public Simulator(in SimulationState state, int maxStepCount) : base(state)
|
||||
{
|
||||
this.maxStepCount = maxStepCount;
|
||||
}
|
||||
|
||||
+1
-1
@@ -35,7 +35,7 @@ public sealed class Solver : IDisposable
|
||||
// Always called when a new step is generated.
|
||||
public event NewActionDelegate? OnNewAction;
|
||||
|
||||
public Solver(SolverConfig config, SimulationState state)
|
||||
public Solver(in SolverConfig config, in SimulationState state)
|
||||
{
|
||||
Config = config;
|
||||
State = state;
|
||||
|
||||
@@ -9,7 +9,7 @@ public readonly record struct SolverSolution {
|
||||
public readonly IEnumerable<ActionType> ActionEnumerable { init => actions = SanitizeCombos(value).ToList(); }
|
||||
public readonly SimulationState State { get; init; }
|
||||
|
||||
public SolverSolution(IEnumerable<ActionType> actions, SimulationState state)
|
||||
public SolverSolution(IEnumerable<ActionType> actions, in SimulationState state)
|
||||
{
|
||||
ActionEnumerable = actions;
|
||||
State = state;
|
||||
|
||||
Reference in New Issue
Block a user