Fixed ActionSet behavior

IndexOf might be killing the performance though..
This commit is contained in:
Asriel Camora
2023-06-19 08:52:25 -07:00
parent 5edec27977
commit 6d61e022b6
5 changed files with 175 additions and 246 deletions
-127
View File
@@ -1,127 +0,0 @@
using Craftimizer.Simulator.Actions;
using System.Diagnostics.CodeAnalysis;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics.X86;
namespace Craftimizer.Simulator;
public record ActionSet
{
public ulong Bits { get; set; }
public List<ActionType> SavedActions { get; set; } = new();
private bool HasFlagA(ActionType action) => (Bits & (1ul << ((int)action + 1))) != 0;
private bool HasFlagB(ActionType action) => SavedActions.Contains(action);
public bool HasFlag(ActionType action)
{
var a = HasFlagA(action);
var b = HasFlagB(action);
if (a != b)
throw new Exception($"Action {action} has different flags: {a} vs {b}");
return a;
}
private void SetFlagA(ActionType action) => Bits |= 1ul << ((int)action + 1);
private void SetFlagB(ActionType action)
{
if (!SavedActions.Contains(action))
SavedActions.Add(action);
}
public void SetFlag(ActionType action)
{
SetFlagA(action);
SetFlagB(action);
}
private void ClearFlagA(ActionType action) => Bits &= ~(1ul << ((int)action + 1));
private void ClearFlagB(ActionType action) => SavedActions.RemoveAll(a => a == action);
public void ClearFlag(ActionType action)
{
ClearFlagA(action);
ClearFlagB(action);
}
public IEnumerable<ActionType> Actions => GetActions();
private IEnumerable<ActionType> GetActions()
{
foreach (var action in Enum.GetValues<ActionType>())
if (HasFlag(action))
yield return action;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int NthBitSet(ulong value, int n)
{
if (Bmi2.X64.IsSupported)
return BitOperations.TrailingZeroCount(Bmi2.X64.ParallelBitDeposit(1ul << n, value));
ulong mask = 0x0000FFFFFFFFu;
var size = 32;
var _base = 0;
if (n++ >= BitOperations.PopCount(value))
return 64;
while (size > 0)
{
var count = BitOperations.PopCount(value & mask);
if (n > count)
{
_base += size;
size >>= 1;
mask |= mask << size;
}
else
{
size >>= 1;
mask >>= size;
}
}
return _base;
}
private ActionType ActionAtA(int index) => Actions.ElementAt(index);//(ActionType)(NthBitSet(Bits, index) - 1);
private ActionType ActionAtB(int index) => SavedActions.ElementAt(index);
public ActionType ActionAt(int index)
{
return ActionAtB(index);
var a = ActionAtA(index);
var a2 = (ActionType)(NthBitSet(Bits, index) - 1);
var b = ActionAtB(index);
if (a != a2)
throw new Exception($"A2: Action {index} has different flags: {a} vs {a2}");
if (a != b)
throw new Exception($"Action {index} has different flags: {a} vs {b}");
return a;
}
private int ActionCountA => BitOperations.PopCount(Bits);
private int ActionCountB => SavedActions.Count;
public int ActionCount { get
{
return ActionCountB;
var a = ActionCountA;
var b = ActionCountB;
if (a != b)
throw new Exception($"Action count has different flags: {a} vs {b}");
return a;
} }
private bool IsEmptyA => Bits == 0;
private bool IsEmptyB => SavedActions.Count == 0;
public bool IsEmpty
{
get
{
return IsEmptyB;
var a = IsEmptyA;
var b = IsEmptyB;
if (a != b)
throw new Exception($"IsEmpty has different flags: {a} vs {b}");
return a;
}
}
}
+59
View File
@@ -0,0 +1,59 @@
using Craftimizer.Simulator.Actions;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics.X86;
namespace Craftimizer.Solver.Crafty;
public sealed class ActionSet
{
private uint Bits { get; set; } = 0;
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int NthBitSet(uint value, int n)
{
if (Bmi2.IsSupported)
return BitOperations.TrailingZeroCount(Bmi2.ParallelBitDeposit(1u << n, value));
var mask = 0x0000FFFFu;
var size = 16;
var _base = 0;
if (n++ >= BitOperations.PopCount(value))
return 32;
while (size > 0)
{
var count = BitOperations.PopCount(value & mask);
if (n > count)
{
_base += size;
size >>= 1;
mask |= mask << size;
}
else
{
size >>= 1;
mask >>= size;
}
}
return _base;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int FromAction(ActionType action) => Array.IndexOf(Simulator.AcceptedActions, action);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static ActionType ToAction(int index) => Simulator.AcceptedActions[index];
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool HasAction(ActionType action) => (Bits & (1u << (FromAction(action) + 1))) != 0;
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void AddAction(ActionType action) => Bits |= 1u << (FromAction(action) + 1);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void RemoveAction(ActionType action) => Bits &= ~(1u << (FromAction(action) + 1));
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ActionType ElementAt(int index) => ToAction(NthBitSet(Bits, index) - 1);
public int Count => BitOperations.PopCount(Bits);
}
+1 -1
View File
@@ -10,7 +10,7 @@ public readonly record struct SimulationNode
public ActionSet AvailableActions { get; init; } public ActionSet AvailableActions { get; init; }
public CompletionState SimulationCompletionState { get; init; } public CompletionState SimulationCompletionState { get; init; }
public CompletionState CompletionState => public CompletionState CompletionState =>
AvailableActions.IsEmpty && SimulationCompletionState == CompletionState.Incomplete ? AvailableActions.Count == 0 && SimulationCompletionState == CompletionState.Incomplete ?
CompletionState.NoMoreActions : CompletionState.NoMoreActions :
SimulationCompletionState; SimulationCompletionState;
+15 -12
View File
@@ -20,7 +20,7 @@ public class Simulator : Sim
public override bool RollSuccessRaw(float successRate) => successRate == 1; public override bool RollSuccessRaw(float successRate) => successRate == 1;
public override void StepCondition() { } public override void StepCondition() { }
private static readonly ActionType[] AcceptedActions = new[] public static readonly ActionType[] AcceptedActions = new[]
{ {
ActionType.TrainedFinesse, ActionType.TrainedFinesse,
ActionType.PrudentSynthesis, ActionType.PrudentSynthesis,
@@ -49,14 +49,8 @@ public class Simulator : Sim
ActionType.BasicTouch, ActionType.BasicTouch,
}; };
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/craft_state.rs#L137 // https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/craft_state.rs#L146
public ActionSet AvailableActionsHeuristic(bool strict) private bool CanUseAction(ActionType action, bool strict)
{
if (IsComplete)
return new();
ActionUtils.SetSimulation(this);
var a = AcceptedActions.Where(action =>
{ {
var baseAction = action.WithUnsafe(); var baseAction = action.WithUnsafe();
@@ -152,10 +146,19 @@ public class Simulator : Sim
} }
return true; return true;
}); }
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/craft_state.rs#L137
public ActionSet AvailableActionsHeuristic(bool strict)
{
if (IsComplete)
return new();
ActionUtils.SetSimulation(this);
var ret = new ActionSet(); var ret = new ActionSet();
foreach (var ac in a) foreach (var action in AcceptedActions)
ret.SetFlag(ac); if (CanUseAction(action, strict))
ret.AddAction(action);
return ret; return ret;
} }
} }
+11 -17
View File
@@ -12,7 +12,7 @@ public class Solver
//public Random Random => Simulator.Input.Random; //public Random Random => Simulator.Input.Random;
public const int Iterations = 1000; public const int Iterations = 30000;
public const float ScoreStorageThreshold = 1f; public const float ScoreStorageThreshold = 1f;
public const float MaxScoreWeightingConstant = 0.1f; public const float MaxScoreWeightingConstant = 0.1f;
public const float ExplorationConstant = 4f; public const float ExplorationConstant = 4f;
@@ -57,9 +57,9 @@ public class Solver
if (node.IsComplete) if (node.IsComplete)
return (currentIndex, node.CompletionState); return (currentIndex, node.CompletionState);
if (!node.AvailableActions.HasFlag(action)) if (!node.AvailableActions.HasAction(action))
return (currentIndex, CompletionState.InvalidAction); return (currentIndex, CompletionState.InvalidAction);
node.AvailableActions.ClearFlag(action); node.AvailableActions.RemoveAction(action);
currentIndex = Tree.Insert(currentIndex, Execute(node.State, action, strict)); currentIndex = Tree.Insert(currentIndex, Execute(node.State, action, strict));
} }
@@ -106,10 +106,9 @@ public class Solver
{ {
var selectedNode = Tree.Get(selectedIndex); var selectedNode = Tree.Get(selectedIndex);
var expandable = !selectedNode.State.AvailableActions.IsEmpty; var expandable = selectedNode.State.AvailableActions.Count != 0;
var likelyTerminal = selectedNode.Children.Count == 0; var likelyTerminal = selectedNode.Children.Count == 0;
if (expandable || likelyTerminal) if (expandable || likelyTerminal) {
{
break; break;
} }
@@ -126,9 +125,8 @@ public class Solver
if (initialNode.IsComplete) if (initialNode.IsComplete)
return (initialIndex, initialNode.CompletionState, initialNode.CalculateScore() ?? 0); return (initialIndex, initialNode.CompletionState, initialNode.CalculateScore() ?? 0);
var randomIdx = 0;// random.Next(initialNode.AvailableActions.ActionCount); var randomAction = initialNode.AvailableActions.ElementAt(0);
var randomAction = initialNode.AvailableActions.ActionAt(randomIdx); initialNode.AvailableActions.RemoveAction(randomAction);
initialNode.AvailableActions.ClearFlag(randomAction);
var expandedState = Execute(initialNode.State, randomAction, true); var expandedState = Execute(initialNode.State, randomAction, true);
var expandedIndex = Tree.Insert(initialIndex, expandedState); var expandedIndex = Tree.Insert(initialIndex, expandedState);
@@ -139,8 +137,7 @@ public class Solver
{ {
if (currentState.IsComplete) if (currentState.IsComplete)
break; break;
randomIdx = 0;// random.Next(currentState.AvailableActions.ActionCount); randomAction = currentState.AvailableActions.ElementAt(0);
randomAction = currentState.AvailableActions.ActionAt(randomIdx);
currentState = Execute(currentState.State, randomAction, true); currentState = Execute(currentState.State, randomAction, true);
} }
@@ -190,8 +187,7 @@ public class Solver
{ {
var actions = new List<ActionType>(); var actions = new List<ActionType>();
var node = Tree.Get(0); var node = Tree.Get(0);
while (node.Children.Count != 0) while (node.Children.Count != 0) {
{
var next_index = RustMaxBy(node.Children, n => Tree.Get(n).State.Scores.MaxScore); var next_index = RustMaxBy(node.Children, n => Tree.Get(n).State.Scores.MaxScore);
node = Tree.Get(next_index); node = Tree.Get(next_index);
if (node.State.Action != null) if (node.State.Action != null)
@@ -213,8 +209,7 @@ public class Solver
public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SimulationInput input, List<ActionType> actions, Action<ActionType>? actionCallback) public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SimulationInput input, List<ActionType> actions, Action<ActionType>? actionCallback)
{ {
var (state, result) = Simulate(input, actions); var (state, result) = Simulate(input, actions);
if (result != CompletionState.Incomplete) if (result != CompletionState.Incomplete) {
{
return (actions, state); return (actions, state);
} }
@@ -224,8 +219,7 @@ public class Solver
solver.Search(0); solver.Search(0);
var (solution_actions, solution_node) = solver.Solution(); var (solution_actions, solution_node) = solver.Solution();
if (solution_node.Scores.MaxScore >= 1.0) if (solution_node.Scores.MaxScore >= 1.0) {
{
actions.AddRange(solution_actions); actions.AddRange(solution_actions);
return (actions, solution_node.State); return (actions, solution_node.State);
} }