Pass large structs byref instead

This commit is contained in:
Asriel Camora
2023-11-10 18:41:58 -08:00
parent 3edb156d97
commit 036cbb2fb4
9 changed files with 21 additions and 27 deletions
+2 -8
View File
@@ -10,7 +10,7 @@ public abstract class BaseComboAction : BaseAction
protected bool BaseCanUse(Simulator s) => protected bool BaseCanUse(Simulator s) =>
base.CanUse(s); base.CanUse(s);
private static bool VerifyDurability2(int durabilityA, int durability, Effects effects) private static bool VerifyDurability2(int durabilityA, int durability, in Effects effects)
{ {
var wasteNots = effects.HasEffect(EffectType.WasteNot) || effects.HasEffect(EffectType.WasteNot2); var wasteNots = effects.HasEffect(EffectType.WasteNot) || effects.HasEffect(EffectType.WasteNot2);
// -A // -A
@@ -23,13 +23,10 @@ public abstract class BaseComboAction : BaseAction
return true; return true;
} }
public static bool VerifyDurability2(SimulationState s, int durabilityA) =>
VerifyDurability2(durabilityA, s.Durability, s.ActiveEffects);
public static bool VerifyDurability2(Simulator s, int durabilityA) => public static bool VerifyDurability2(Simulator s, int durabilityA) =>
VerifyDurability2(durabilityA, s.Durability, s.ActiveEffects); VerifyDurability2(durabilityA, s.Durability, s.ActiveEffects);
public static bool VerifyDurability3(int durabilityA, int durabilityB, int durability, Effects effects) public static bool VerifyDurability3(int durabilityA, int durabilityB, int durability, in Effects effects)
{ {
var wasteNots = Math.Max(effects.GetDuration(EffectType.WasteNot), effects.GetDuration(EffectType.WasteNot2)); var wasteNots = Math.Max(effects.GetDuration(EffectType.WasteNot), effects.GetDuration(EffectType.WasteNot2));
var manips = effects.HasEffect(EffectType.Manipulation); var manips = effects.HasEffect(EffectType.Manipulation);
@@ -56,7 +53,4 @@ public abstract class BaseComboAction : BaseAction
public static bool VerifyDurability3(Simulator s, int durabilityA, int durabilityB) => public static bool VerifyDurability3(Simulator s, int durabilityA, int durabilityB) =>
VerifyDurability3(durabilityA, durabilityB, s.Durability, s.ActiveEffects); VerifyDurability3(durabilityA, durabilityB, s.Durability, s.ActiveEffects);
public static bool VerifyDurability3(SimulationState s, int durabilityA, int durabilityB) =>
VerifyDurability3(durabilityA, durabilityB, s.Durability, s.ActiveEffects);
} }
+4 -4
View File
@@ -35,17 +35,17 @@ public class Simulator
public IEnumerable<ActionType> AvailableActions => ActionUtils.AvailableActions(this); public IEnumerable<ActionType> AvailableActions => ActionUtils.AvailableActions(this);
public Simulator(SimulationState state) public Simulator(in SimulationState state)
{ {
State = state; State = state;
} }
public void SetState(SimulationState state) public void SetState(in SimulationState state)
{ {
State = state; State = state;
} }
public (ActionResponse Response, SimulationState NewState) Execute(SimulationState state, ActionType action) public (ActionResponse Response, SimulationState NewState) Execute(in SimulationState state, ActionType action)
{ {
State = state; State = state;
return (Execute(action), State); return (Execute(action), State);
@@ -75,7 +75,7 @@ public class Simulator
return ActionResponse.UsedAction; return ActionResponse.UsedAction;
} }
public (ActionResponse Response, SimulationState NewState, int FailedActionIdx) ExecuteMultiple(SimulationState state, IEnumerable<ActionType> actions) public (ActionResponse Response, SimulationState NewState, int FailedActionIdx) ExecuteMultiple(in SimulationState state, IEnumerable<ActionType> actions)
{ {
State = state; State = state;
var i = 0; var i = 0;
+1 -1
View File
@@ -2,7 +2,7 @@ namespace Craftimizer.Simulator;
public class SimulatorNoRandom : Simulator public class SimulatorNoRandom : Simulator
{ {
public SimulatorNoRandom(SimulationState state) : base(state) public SimulatorNoRandom(in SimulationState state) : base(state)
{ {
} }
+6 -6
View File
@@ -17,7 +17,7 @@ public sealed class MCTS
public float MaxScore => rootScores.MaxScore; public float MaxScore => rootScores.MaxScore;
public MCTS(MCTSConfig config, SimulationState state) public MCTS(in MCTSConfig config, in SimulationState state)
{ {
this.config = config; this.config = config;
var sim = new Simulator(state, config.MaxStepCount); var sim = new Simulator(state, config.MaxStepCount);
@@ -30,7 +30,7 @@ public sealed class MCTS
rootScores = new(); rootScores = new();
} }
private static SimulationNode Execute(Simulator simulator, SimulationState state, ActionType action, bool strict) private static SimulationNode Execute(Simulator simulator, in SimulationState state, ActionType action, bool strict)
{ {
(_, var newState) = simulator.Execute(state, action); (_, var newState) = simulator.Execute(state, action);
return new( return new(
@@ -61,7 +61,7 @@ public sealed class MCTS
[Pure] [Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores) private static (int arrayIdx, int subIdx) ChildMaxScore(in NodeScoresBuffer scores)
{ {
var length = scores.Count; var length = scores.Count;
var vecLength = Vector<float>.Count; var vecLength = Vector<float>.Count;
@@ -111,7 +111,7 @@ public sealed class MCTS
float explorationConstant, float explorationConstant,
float maxScoreWeightingConstant, float maxScoreWeightingConstant,
int parentVisits, int parentVisits,
ref NodeScoresBuffer scores) in NodeScoresBuffer scores)
{ {
var length = scores.Count; var length = scores.Count;
var vecLength = Vector<float>.Count; var vecLength = Vector<float>.Count;
@@ -168,7 +168,7 @@ public sealed class MCTS
return node; return node;
// select the node with the highest score // select the node with the highest score
var at = EvalBestChild(explorationConstant, maxScoreWeightingConstant, nodeVisits, ref node.ChildScores); var at = EvalBestChild(explorationConstant, maxScoreWeightingConstant, nodeVisits, in node.ChildScores);
nodeVisits = node.ChildScores.GetVisits(at); nodeVisits = node.ChildScores.GetVisits(at);
node = node.ChildAt(at)!; node = node.ChildAt(at)!;
} }
@@ -320,7 +320,7 @@ public sealed class MCTS
while (node.Children.Count != 0) while (node.Children.Count != 0)
{ {
node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!; node = node.ChildAt(ChildMaxScore(in node.ChildScores))!;
if (node.State.Action != null) if (node.State.Action != null)
actions.Add(node.State.Action.Value); actions.Add(node.State.Action.Value);
+2 -2
View File
@@ -1,4 +1,4 @@
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace Craftimizer.Solver; namespace Craftimizer.Solver;
@@ -21,7 +21,7 @@ public readonly record struct MCTSConfig
public float ScoreCP { get; init; } public float ScoreCP { get; init; }
public float ScoreSteps { get; init; } public float ScoreSteps { get; init; }
public MCTSConfig(SolverConfig config) public MCTSConfig(in SolverConfig config)
{ {
MaxStepCount = config.MaxStepCount; MaxStepCount = config.MaxStepCount;
MaxRolloutStepCount = config.MaxRolloutStepCount; MaxRolloutStepCount = config.MaxRolloutStepCount;
+3 -3
View File
@@ -19,7 +19,7 @@ public struct SimulationNode
public readonly bool IsComplete => CompletionState != CompletionState.Incomplete; public readonly bool IsComplete => CompletionState != CompletionState.Incomplete;
public SimulationNode(SimulationState state, ActionType? action, CompletionState completionState, ActionSet actions) public SimulationNode(in SimulationState state, ActionType? action, CompletionState completionState, ActionSet actions)
{ {
State = state; State = state;
Action = action; Action = action;
@@ -32,10 +32,10 @@ public struct SimulationNode
CompletionState.NoMoreActions : CompletionState.NoMoreActions :
simCompletionState; simCompletionState;
public readonly float? CalculateScore(MCTSConfig config) => public readonly float? CalculateScore(in MCTSConfig config) =>
CalculateScoreForState(State, SimulationCompletionState, config); CalculateScoreForState(State, SimulationCompletionState, config);
public static float? CalculateScoreForState(SimulationState state, CompletionState completionState, MCTSConfig config) public static float? CalculateScoreForState(in SimulationState state, CompletionState completionState, MCTSConfig config)
{ {
if (completionState != CompletionState.ProgressComplete) if (completionState != CompletionState.ProgressComplete)
return null; return null;
+1 -1
View File
@@ -20,7 +20,7 @@ internal sealed class Simulator : SimulatorNoRandom
} }
} }
public Simulator(SimulationState state, int maxStepCount) : base(state) public Simulator(in SimulationState state, int maxStepCount) : base(state)
{ {
this.maxStepCount = maxStepCount; this.maxStepCount = maxStepCount;
} }
+1 -1
View File
@@ -35,7 +35,7 @@ public sealed class Solver : IDisposable
// Always called when a new step is generated. // Always called when a new step is generated.
public event NewActionDelegate? OnNewAction; public event NewActionDelegate? OnNewAction;
public Solver(SolverConfig config, SimulationState state) public Solver(in SolverConfig config, in SimulationState state)
{ {
Config = config; Config = config;
State = state; State = state;
+1 -1
View File
@@ -9,7 +9,7 @@ public readonly record struct SolverSolution {
public readonly IEnumerable<ActionType> ActionEnumerable { init => actions = SanitizeCombos(value).ToList(); } public readonly IEnumerable<ActionType> ActionEnumerable { init => actions = SanitizeCombos(value).ToList(); }
public readonly SimulationState State { get; init; } public readonly SimulationState State { get; init; }
public SolverSolution(IEnumerable<ActionType> actions, SimulationState state) public SolverSolution(IEnumerable<ActionType> actions, in SimulationState state)
{ {
ActionEnumerable = actions; ActionEnumerable = actions;
State = state; State = state;