Split solver into MCTS and algorithms
This commit is contained in:
@@ -158,6 +158,7 @@ dotnet_diagnostic.CA1841.severity = suggestion
|
||||
dotnet_diagnostic.CA1845.severity = suggestion
|
||||
dotnet_diagnostic.MA0011.severity = silent
|
||||
dotnet_diagnostic.MA0076.severity = silent
|
||||
dotnet_diagnostic.MA0046.severity = silent
|
||||
dotnet_diagnostic.MA0002.severity = silent
|
||||
csharp_style_prefer_switch_expression = true:suggestion
|
||||
csharp_style_prefer_pattern_matching = true:silent
|
||||
|
||||
@@ -77,11 +77,11 @@ internal static class Program
|
||||
|
||||
Console.WriteLine($"{state.Quality} {state.CP} {state.Progress} {state.Durability}");
|
||||
//return;
|
||||
var (_, s) = Solver.Solver.SearchStepwiseFurcated(config, state, a => Console.WriteLine(a), default);
|
||||
var (_, s) = config.Invoke(state, a => Console.WriteLine(a))!.Value;
|
||||
Console.WriteLine($"Qual: {s.Quality}/{s.Input.Recipe.MaxQuality}");
|
||||
return;
|
||||
|
||||
Solver.Solver.SearchStepwiseFurcated(config, new(input), null, default);
|
||||
config.Invoke(new(input));
|
||||
//Benchmark(() => );
|
||||
}
|
||||
|
||||
|
||||
@@ -15,25 +15,6 @@ public class Macro
|
||||
public List<ActionType> Actions { get; set; } = new();
|
||||
}
|
||||
|
||||
public static class AlgorithmUtils
|
||||
{
|
||||
public static void Invoke(this SolverConfig me, SimulationState state, Action<ActionType>? actionCallback = null, CancellationToken token = default)
|
||||
{
|
||||
try
|
||||
{
|
||||
Solver.Solver.Search(me, state, actionCallback, token);
|
||||
}
|
||||
catch (AggregateException e)
|
||||
{
|
||||
e.Handle(ex => ex is OperationCanceledException);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[Serializable]
|
||||
public class Configuration : IPluginConfiguration
|
||||
{
|
||||
|
||||
@@ -6,5 +6,7 @@ public enum CompletionState : byte
|
||||
ProgressComplete,
|
||||
NoMoreDurability,
|
||||
|
||||
Other
|
||||
InvalidAction,
|
||||
MaxActionCountReached,
|
||||
NoMoreActions
|
||||
}
|
||||
|
||||
+11
-11
@@ -21,8 +21,17 @@ public class Simulator
|
||||
|
||||
public bool IsFirstStep => State.StepCount == 0;
|
||||
|
||||
public CompletionState CompletionState => CalculateCompletionState(State);
|
||||
public virtual bool IsComplete => CompletionState != CompletionState.Incomplete;
|
||||
public virtual CompletionState CompletionState {
|
||||
get
|
||||
{
|
||||
if (Progress >= Input.Recipe.MaxProgress)
|
||||
return CompletionState.ProgressComplete;
|
||||
if (Durability <= 0)
|
||||
return CompletionState.NoMoreDurability;
|
||||
return CompletionState.Incomplete;
|
||||
}
|
||||
}
|
||||
public bool IsComplete => CompletionState != CompletionState.Incomplete;
|
||||
|
||||
public IEnumerable<ActionType> AvailableActions => ActionUtils.AvailableActions(this);
|
||||
|
||||
@@ -278,13 +287,4 @@ public class Simulator
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public void IncreaseQuality(float efficiency) =>
|
||||
IncreaseQualityRaw(CalculateQualityGain(efficiency, false));
|
||||
|
||||
public static CompletionState CalculateCompletionState(SimulationState state)
|
||||
{
|
||||
if (state.Progress >= state.Input.Recipe.MaxProgress)
|
||||
return CompletionState.ProgressComplete;
|
||||
if (state.Durability <= 0)
|
||||
return CompletionState.NoMoreDurability;
|
||||
return CompletionState.Incomplete;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
using CompState = Craftimizer.Simulator.CompletionState;
|
||||
|
||||
namespace Craftimizer.Solver;
|
||||
|
||||
public enum CompletionState : byte
|
||||
{
|
||||
Incomplete,
|
||||
ProgressComplete,
|
||||
NoMoreDurability,
|
||||
|
||||
InvalidAction,
|
||||
MaxActionCountReached,
|
||||
NoMoreActions
|
||||
}
|
||||
|
||||
internal static class CompletionStateUtils
|
||||
{
|
||||
public static CompState IntoBase(this CompletionState me) =>
|
||||
(CompState)me >= CompState.Other ? CompState.Other : (CompState)me;
|
||||
}
|
||||
+338
@@ -0,0 +1,338 @@
|
||||
using Craftimizer.Simulator.Actions;
|
||||
using Craftimizer.Simulator;
|
||||
using System.Diagnostics.Contracts;
|
||||
using System.Numerics;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text;
|
||||
using Node = Craftimizer.Solver.ArenaNode<Craftimizer.Solver.SimulationNode>;
|
||||
|
||||
namespace Craftimizer.Solver;
|
||||
|
||||
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
|
||||
public sealed class MCTS
|
||||
{
|
||||
private readonly MCTSConfig config;
|
||||
private readonly Node rootNode;
|
||||
private readonly RootScores rootScores;
|
||||
|
||||
public float MaxScore => rootScores.MaxScore;
|
||||
|
||||
public MCTS(MCTSConfig config, SimulationState state)
|
||||
{
|
||||
this.config = config;
|
||||
var sim = new Simulator(state, config.MaxStepCount);
|
||||
rootNode = new(new(
|
||||
state,
|
||||
null,
|
||||
sim.CompletionState,
|
||||
sim.AvailableActionsHeuristic(config.StrictActions)
|
||||
));
|
||||
rootScores = new();
|
||||
}
|
||||
|
||||
private static SimulationNode Execute(Simulator simulator, SimulationState state, ActionType action, bool strict)
|
||||
{
|
||||
(_, var newState) = simulator.Execute(state, action);
|
||||
return new(
|
||||
newState,
|
||||
action,
|
||||
simulator.CompletionState,
|
||||
simulator.AvailableActionsHeuristic(strict)
|
||||
);
|
||||
}
|
||||
|
||||
private static Node ExecuteActions(Simulator simulator, Node startNode, ReadOnlySpan<ActionType> actions, bool strict)
|
||||
{
|
||||
foreach (var action in actions)
|
||||
{
|
||||
var state = startNode.State;
|
||||
if (state.IsComplete)
|
||||
return startNode;
|
||||
|
||||
if (!state.AvailableActions.HasAction(action))
|
||||
return startNode;
|
||||
state.AvailableActions.RemoveAction(action);
|
||||
|
||||
startNode = startNode.Add(Execute(simulator, state.State, action, strict));
|
||||
}
|
||||
|
||||
return startNode;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores)
|
||||
{
|
||||
var length = scores.Count;
|
||||
var vecLength = Vector<float>.Count;
|
||||
|
||||
var max = (0, 0);
|
||||
var maxScore = 0f;
|
||||
for (var i = 0; length > 0; ++i)
|
||||
{
|
||||
var iterCount = Math.Min(vecLength, length);
|
||||
|
||||
ref var chunk = ref scores.Data[i];
|
||||
var m = new Vector<float>(chunk.MaxScore.Span);
|
||||
|
||||
var idx = Intrinsics.HMaxIndex(m, iterCount);
|
||||
|
||||
if (m[idx] >= maxScore)
|
||||
{
|
||||
max = (i, idx);
|
||||
maxScore = m[idx];
|
||||
}
|
||||
|
||||
length -= iterCount;
|
||||
}
|
||||
|
||||
return max;
|
||||
}
|
||||
|
||||
// Calculates the best child node to explore next
|
||||
// Exploitation: ((1 - w) * (s / v)) + (w * m)
|
||||
// Exploration: sqrt(c * ln(V) / v)
|
||||
// w = maxScoreWeightingConstant
|
||||
// s = score sum
|
||||
// m = max score
|
||||
// v = visits
|
||||
// V = parentVisits
|
||||
// c = explorationConstant
|
||||
|
||||
// Somewhat based off of https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
|
||||
// Here, w_i = (1-w)*score sum
|
||||
// n_i = visits
|
||||
// max score is tacked onto it
|
||||
// N_i = parent visits
|
||||
// c = exploration constant (but crafty places it inside the sqrt..?)
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
|
||||
private static (int arrayIdx, int subIdx) EvalBestChild(
|
||||
float explorationConstant,
|
||||
float maxScoreWeightingConstant,
|
||||
int parentVisits,
|
||||
ref NodeScoresBuffer scores)
|
||||
{
|
||||
var length = scores.Count;
|
||||
var vecLength = Vector<float>.Count;
|
||||
|
||||
var C = MathF.Sqrt(explorationConstant * MathF.Log(parentVisits));
|
||||
var w = maxScoreWeightingConstant;
|
||||
var W = 1f - w;
|
||||
var CVector = new Vector<float>(C);
|
||||
|
||||
var max = (0, 0);
|
||||
var maxScore = 0f;
|
||||
for (var i = 0; length > 0; ++i)
|
||||
{
|
||||
var iterCount = Math.Min(vecLength, length);
|
||||
|
||||
ref var chunk = ref scores.Data[i];
|
||||
var s = new Vector<float>(chunk.ScoreSum.Span);
|
||||
var vInt = new Vector<int>(chunk.Visits.Span);
|
||||
var m = new Vector<float>(chunk.MaxScore.Span);
|
||||
|
||||
vInt = Vector.Max(vInt, Vector<int>.One);
|
||||
var v = Vector.ConvertToSingle(vInt);
|
||||
|
||||
var exploitation = W * (s / v) + w * m;
|
||||
var exploration = CVector * Intrinsics.ReciprocalSqrt(v);
|
||||
var evalScores = exploitation + exploration;
|
||||
|
||||
var idx = Intrinsics.HMaxIndex(evalScores, iterCount);
|
||||
|
||||
if (evalScores[idx] >= maxScore)
|
||||
{
|
||||
max = (i, idx);
|
||||
maxScore = evalScores[idx];
|
||||
}
|
||||
|
||||
length -= iterCount;
|
||||
}
|
||||
|
||||
return max;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
private Node Select()
|
||||
{
|
||||
var node = rootNode;
|
||||
var nodeVisits = rootScores.Visits;
|
||||
|
||||
float explorationConstant = config.ExplorationConstant, maxScoreWeightingConstant = config.MaxScoreWeightingConstant;
|
||||
while (true)
|
||||
{
|
||||
var expandable = !node.State.AvailableActions.IsEmpty;
|
||||
var likelyTerminal = node.Children.Count == 0;
|
||||
if (expandable || likelyTerminal)
|
||||
return node;
|
||||
|
||||
// select the node with the highest score
|
||||
var at = EvalBestChild(explorationConstant, maxScoreWeightingConstant, nodeVisits, ref node.ChildScores);
|
||||
nodeVisits = node.ChildScores.GetVisits(at);
|
||||
node = node.ChildAt(at)!;
|
||||
}
|
||||
}
|
||||
|
||||
private (Node ExpandedNode, float Score) ExpandAndRollout(Random random, Simulator simulator, Node initialNode)
|
||||
{
|
||||
ref var initialState = ref initialNode.State;
|
||||
// expand once
|
||||
if (initialState.IsComplete)
|
||||
return (initialNode, initialState.CalculateScore(config) ?? 0);
|
||||
|
||||
var poppedAction = initialState.AvailableActions.PopRandom(random);
|
||||
var expandedNode = initialNode.Add(Execute(simulator, initialState.State, poppedAction, true));
|
||||
|
||||
// playout to a terminal state
|
||||
var currentState = expandedNode.State.State;
|
||||
var currentCompletionState = expandedNode.State.SimulationCompletionState;
|
||||
var currentActions = expandedNode.State.AvailableActions;
|
||||
|
||||
|
||||
byte actionCount = 0;
|
||||
Span<ActionType> actions = stackalloc ActionType[Math.Min(config.MaxStepCount - currentState.ActionCount, config.MaxRolloutStepCount)];
|
||||
while (SimulationNode.GetCompletionState(currentCompletionState, currentActions) == CompletionState.Incomplete &&
|
||||
actionCount < actions.Length)
|
||||
{
|
||||
var nextAction = currentActions.SelectRandom(random);
|
||||
actions[actionCount++] = nextAction;
|
||||
(_, currentState) = simulator.Execute(currentState, nextAction);
|
||||
currentCompletionState = simulator.CompletionState;
|
||||
if (currentCompletionState != CompletionState.Incomplete)
|
||||
break;
|
||||
currentActions = simulator.AvailableActionsHeuristic(true);
|
||||
}
|
||||
|
||||
// store the result if a max score was reached
|
||||
var score = SimulationNode.CalculateScoreForState(currentState, currentCompletionState, config) ?? 0;
|
||||
if (currentCompletionState == CompletionState.ProgressComplete)
|
||||
{
|
||||
if (score >= config.ScoreStorageThreshold && score >= MaxScore)
|
||||
{
|
||||
var terminalNode = ExecuteActions(simulator, expandedNode, actions[..actionCount], true);
|
||||
return (terminalNode, score);
|
||||
}
|
||||
}
|
||||
return (expandedNode, score);
|
||||
}
|
||||
|
||||
private void Backpropagate(Node startNode, float score)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
if (startNode == rootNode)
|
||||
{
|
||||
rootScores.Visit(score);
|
||||
break;
|
||||
}
|
||||
startNode.ParentScores!.Value.Visit(startNode.ChildIdx, score);
|
||||
|
||||
startNode = startNode.Parent!;
|
||||
}
|
||||
}
|
||||
|
||||
private void ShowAllNodes()
|
||||
{
|
||||
static void ShowNodes(StringBuilder b, Node node, Stack<Node> path)
|
||||
{
|
||||
path.Push(node);
|
||||
b.AppendLine($"{new string(' ', path.Count)}{node.State.Action}");
|
||||
{
|
||||
for (var i = 0; i < node.Children.Count; ++i)
|
||||
{
|
||||
var n = node.ChildAt((i >> 3, i & 7))!;
|
||||
ShowNodes(b, n, path);
|
||||
}
|
||||
path.Pop();
|
||||
}
|
||||
}
|
||||
var b = new StringBuilder();
|
||||
ShowNodes(b, rootNode, new());
|
||||
Console.WriteLine(b.ToString());
|
||||
}
|
||||
|
||||
private bool AllNodesComplete()
|
||||
{
|
||||
static bool NodesIncomplete(Node node, Stack<Node> path)
|
||||
{
|
||||
path.Push(node);
|
||||
if (node.Children.Count == 0)
|
||||
{
|
||||
if (!node.State.AvailableActions.IsEmpty)
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
for (var i = 0; i < node.Children.Count; ++i)
|
||||
{
|
||||
var n = node.ChildAt((i >> 3, i & 7))!;
|
||||
if (NodesIncomplete(n, path))
|
||||
return true;
|
||||
}
|
||||
path.Pop();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return !NodesIncomplete(rootNode, new());
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public void Search(int iterations, CancellationToken token)
|
||||
{
|
||||
Simulator simulator = new(rootNode.State.State, config.MaxStepCount);
|
||||
var random = rootNode.State.State.Input.Random;
|
||||
var n = 0;
|
||||
for (var i = 0; i < iterations || MaxScore == 0; i++)
|
||||
{
|
||||
if (token.IsCancellationRequested)
|
||||
break;
|
||||
|
||||
var selectedNode = Select();
|
||||
var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode);
|
||||
if (MaxScore == 0)
|
||||
{
|
||||
if (endNode == selectedNode)
|
||||
{
|
||||
if (n++ > 5000)
|
||||
{
|
||||
n = 0;
|
||||
if (AllNodesComplete())
|
||||
{
|
||||
//Console.WriteLine("All nodes solved for. Can't find a valid solution.");
|
||||
//ShowAllNodes();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
n = 0;
|
||||
}
|
||||
|
||||
Backpropagate(endNode, score);
|
||||
}
|
||||
}
|
||||
|
||||
[Pure]
|
||||
public SolverSolution Solution()
|
||||
{
|
||||
var actions = new List<ActionType>();
|
||||
var node = rootNode;
|
||||
|
||||
while (node.Children.Count != 0)
|
||||
{
|
||||
node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!;
|
||||
|
||||
if (node.State.Action != null)
|
||||
actions.Add(node.State.Action.Value);
|
||||
}
|
||||
|
||||
//var at = node.ChildIdx;
|
||||
//ref var sum = ref node.ParentScores!.Value.Data[at.arrayIdx].ScoreSum.Span[at.subIdx];
|
||||
//ref var max = ref node.ParentScores!.Value.Data[at.arrayIdx].MaxScore.Span[at.subIdx];
|
||||
//ref var visits = ref node.ParentScores!.Value.Data[at.arrayIdx].Visits.Span[at.subIdx];
|
||||
//Console.WriteLine($"{sum} {max} {visits}");
|
||||
|
||||
return new(actions, node.State.State);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace Craftimizer.Solver;
|
||||
|
||||
[StructLayout(LayoutKind.Auto)]
|
||||
public readonly record struct MCTSConfig
|
||||
{
|
||||
public int MaxStepCount { get; init; }
|
||||
public int MaxRolloutStepCount { get; init; }
|
||||
public bool StrictActions { get; init; }
|
||||
|
||||
public float MaxScoreWeightingConstant { get; init; }
|
||||
public float ExplorationConstant { get; init; }
|
||||
public float ScoreStorageThreshold { get; init; }
|
||||
|
||||
public float ScoreProgress { get; init; }
|
||||
public float ScoreQuality { get; init; }
|
||||
public float ScoreDurability { get; init; }
|
||||
public float ScoreCP { get; init; }
|
||||
public float ScoreSteps { get; init; }
|
||||
|
||||
public MCTSConfig(SolverConfig config)
|
||||
{
|
||||
MaxStepCount= config.MaxStepCount;
|
||||
MaxRolloutStepCount = config.MaxRolloutStepCount;
|
||||
StrictActions = config.StrictActions;
|
||||
|
||||
MaxScoreWeightingConstant = config.MaxScoreWeightingConstant;
|
||||
ExplorationConstant = config.ExplorationConstant;
|
||||
ScoreStorageThreshold = config.ScoreStorageThreshold;
|
||||
|
||||
ScoreProgress = config.ScoreProgressBonus;
|
||||
ScoreQuality = config.ScoreQualityBonus;
|
||||
ScoreDurability = config.ScoreDurabilityBonus;
|
||||
ScoreCP = config.ScoreCPBonus;
|
||||
ScoreSteps = config.ScoreFewerStepsBonus;
|
||||
}
|
||||
}
|
||||
@@ -30,7 +30,7 @@ public struct SimulationNode
|
||||
CompletionState.NoMoreActions :
|
||||
simCompletionState;
|
||||
|
||||
public readonly float? CalculateScore(SolverConfig config) =>
|
||||
public readonly float? CalculateScore(MCTSConfig config) =>
|
||||
CalculateScoreForState(State, SimulationCompletionState, config);
|
||||
|
||||
private static bool CanByregot(SimulationState state)
|
||||
@@ -41,7 +41,7 @@ public struct SimulationNode
|
||||
return BaseComboAction.VerifyDurability2(state, 10);
|
||||
}
|
||||
|
||||
public static float? CalculateScoreForState(SimulationState state, CompletionState completionState, SolverConfig config)
|
||||
public static float? CalculateScoreForState(SimulationState state, CompletionState completionState, MCTSConfig config)
|
||||
{
|
||||
if (completionState != CompletionState.ProgressComplete)
|
||||
return null;
|
||||
@@ -50,32 +50,32 @@ public struct SimulationNode
|
||||
bonus * Math.Min(1f, value / target);
|
||||
|
||||
var progressScore = Apply(
|
||||
config.ScoreProgressBonus,
|
||||
config.ScoreProgress,
|
||||
state.Progress,
|
||||
state.Input.Recipe.MaxProgress
|
||||
);
|
||||
|
||||
var byregotBonus = CanByregot(state) ? (state.ActiveEffects.InnerQuiet * .2f + 1) * state.Input.BaseQualityGain : 0;
|
||||
var qualityScore = Apply(
|
||||
config.ScoreQualityBonus,
|
||||
config.ScoreQuality,
|
||||
state.Quality + byregotBonus,
|
||||
state.Input.Recipe.MaxQuality
|
||||
);
|
||||
|
||||
var durabilityScore = Apply(
|
||||
config.ScoreDurabilityBonus,
|
||||
config.ScoreDurability,
|
||||
state.Durability,
|
||||
state.Input.Recipe.MaxDurability
|
||||
);
|
||||
|
||||
var cpScore = Apply(
|
||||
config.ScoreCPBonus,
|
||||
config.ScoreCP,
|
||||
state.CP,
|
||||
state.Input.Stats.CP
|
||||
);
|
||||
|
||||
var fewerStepsScore =
|
||||
config.ScoreFewerStepsBonus * (1f - (float)(state.ActionCount + 1) / config.MaxStepCount);
|
||||
config.ScoreSteps * (1f - (float)(state.ActionCount + 1) / config.MaxStepCount);
|
||||
|
||||
return progressScore + qualityScore + durabilityScore + cpScore + fewerStepsScore;
|
||||
}
|
||||
|
||||
+10
-6
@@ -9,8 +9,16 @@ public sealed class Simulator : SimulatorNoRandom
|
||||
{
|
||||
private readonly int maxStepCount;
|
||||
|
||||
public new CompletionState CompletionState => CalculateCompletionState(State, maxStepCount);
|
||||
public override bool IsComplete => CompletionState != CompletionState.Incomplete;
|
||||
public override CompletionState CompletionState
|
||||
{
|
||||
get
|
||||
{
|
||||
var b = base.CompletionState;
|
||||
if (b == CompletionState.Incomplete && (ActionCount + 1) >= maxStepCount)
|
||||
return CompletionState.MaxActionCountReached;
|
||||
return b;
|
||||
}
|
||||
}
|
||||
|
||||
public Simulator(SimulationState state, int maxStepCount) : base(state)
|
||||
{
|
||||
@@ -187,8 +195,4 @@ public sealed class Simulator : SimulatorNoRandom
|
||||
return ret;
|
||||
}
|
||||
|
||||
public static CompletionState CalculateCompletionState(SimulationState state, int maxStepCount) =>
|
||||
state.ActionCount + 1 >= maxStepCount ?
|
||||
CompletionState.MaxActionCountReached :
|
||||
(CompletionState)CalculateCompletionState(state);
|
||||
}
|
||||
|
||||
+14
-339
@@ -1,338 +1,12 @@
|
||||
using Craftimizer.Simulator;
|
||||
using Craftimizer.Simulator.Actions;
|
||||
using System.Diagnostics;
|
||||
using System.Diagnostics.Contracts;
|
||||
using System.Numerics;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text;
|
||||
using Node = Craftimizer.Solver.ArenaNode<Craftimizer.Solver.SimulationNode>;
|
||||
|
||||
namespace Craftimizer.Solver;
|
||||
|
||||
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
|
||||
public sealed class Solver
|
||||
public static class Solver
|
||||
{
|
||||
private SolverConfig config;
|
||||
private Node rootNode;
|
||||
private RootScores rootScores;
|
||||
|
||||
public float MaxScore => rootScores.MaxScore;
|
||||
|
||||
public Solver(SolverConfig config, SimulationState state)
|
||||
{
|
||||
this.config = config;
|
||||
var sim = new Simulator(state, config.MaxStepCount);
|
||||
rootNode = new(new(
|
||||
state,
|
||||
null,
|
||||
sim.CompletionState,
|
||||
sim.AvailableActionsHeuristic(config.StrictActions)
|
||||
));
|
||||
rootScores = new();
|
||||
}
|
||||
|
||||
private static SimulationNode Execute(Simulator simulator, SimulationState state, ActionType action, bool strict)
|
||||
{
|
||||
(_, var newState) = simulator.Execute(state, action);
|
||||
return new(
|
||||
newState,
|
||||
action,
|
||||
simulator.CompletionState,
|
||||
simulator.AvailableActionsHeuristic(strict)
|
||||
);
|
||||
}
|
||||
|
||||
private static Node ExecuteActions(Simulator simulator, Node startNode, ReadOnlySpan<ActionType> actions, bool strict)
|
||||
{
|
||||
foreach (var action in actions)
|
||||
{
|
||||
var state = startNode.State;
|
||||
if (state.IsComplete)
|
||||
return startNode;
|
||||
|
||||
if (!state.AvailableActions.HasAction(action))
|
||||
return startNode;
|
||||
state.AvailableActions.RemoveAction(action);
|
||||
|
||||
startNode = startNode.Add(Execute(simulator, state.State, action, strict));
|
||||
}
|
||||
|
||||
return startNode;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
private SolverSolution Solution()
|
||||
{
|
||||
var actions = new List<ActionType>();
|
||||
var node = rootNode;
|
||||
|
||||
while (node.Children.Count != 0)
|
||||
{
|
||||
node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!;
|
||||
|
||||
if (node.State.Action != null)
|
||||
actions.Add(node.State.Action.Value);
|
||||
}
|
||||
|
||||
//var at = node.ChildIdx;
|
||||
//ref var sum = ref node.ParentScores!.Value.Data[at.arrayIdx].ScoreSum.Span[at.subIdx];
|
||||
//ref var max = ref node.ParentScores!.Value.Data[at.arrayIdx].MaxScore.Span[at.subIdx];
|
||||
//ref var visits = ref node.ParentScores!.Value.Data[at.arrayIdx].Visits.Span[at.subIdx];
|
||||
//Console.WriteLine($"{sum} {max} {visits}");
|
||||
|
||||
return new(actions, node.State.State);
|
||||
}
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores)
|
||||
{
|
||||
var length = scores.Count;
|
||||
var vecLength = Vector<float>.Count;
|
||||
|
||||
var max = (0, 0);
|
||||
var maxScore = 0f;
|
||||
for (var i = 0; length > 0; ++i)
|
||||
{
|
||||
var iterCount = Math.Min(vecLength, length);
|
||||
|
||||
ref var chunk = ref scores.Data[i];
|
||||
var m = new Vector<float>(chunk.MaxScore.Span);
|
||||
|
||||
var idx = Intrinsics.HMaxIndex(m, iterCount);
|
||||
|
||||
if (m[idx] >= maxScore)
|
||||
{
|
||||
max = (i, idx);
|
||||
maxScore = m[idx];
|
||||
}
|
||||
|
||||
length -= iterCount;
|
||||
}
|
||||
|
||||
return max;
|
||||
}
|
||||
|
||||
// Calculates the best child node to explore next
|
||||
// Exploitation: ((1 - w) * (s / v)) + (w * m)
|
||||
// Exploration: sqrt(c * ln(V) / v)
|
||||
// w = maxScoreWeightingConstant
|
||||
// s = score sum
|
||||
// m = max score
|
||||
// v = visits
|
||||
// V = parentVisits
|
||||
// c = explorationConstant
|
||||
|
||||
// Somewhat based off of https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation
|
||||
// Here, w_i = (1-w)*score sum
|
||||
// n_i = visits
|
||||
// max score is tacked onto it
|
||||
// N_i = parent visits
|
||||
// c = exploration constant (but crafty places it inside the sqrt..?)
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
|
||||
private (int arrayIdx, int subIdx) EvalBestChild(int parentVisits, ref NodeScoresBuffer scores)
|
||||
{
|
||||
var length = scores.Count;
|
||||
var vecLength = Vector<float>.Count;
|
||||
|
||||
var C = MathF.Sqrt(config.ExplorationConstant * MathF.Log(parentVisits));
|
||||
var w = config.MaxScoreWeightingConstant;
|
||||
var W = 1f - w;
|
||||
var CVector = new Vector<float>(C);
|
||||
|
||||
var max = (0, 0);
|
||||
var maxScore = 0f;
|
||||
for (var i = 0; length > 0; ++i)
|
||||
{
|
||||
var iterCount = Math.Min(vecLength, length);
|
||||
|
||||
ref var chunk = ref scores.Data[i];
|
||||
var s = new Vector<float>(chunk.ScoreSum.Span);
|
||||
var vInt = new Vector<int>(chunk.Visits.Span);
|
||||
var m = new Vector<float>(chunk.MaxScore.Span);
|
||||
|
||||
vInt = Vector.Max(vInt, Vector<int>.One);
|
||||
var v = Vector.ConvertToSingle(vInt);
|
||||
|
||||
var exploitation = W * (s / v) + w * m;
|
||||
var exploration = CVector * Intrinsics.ReciprocalSqrt(v);
|
||||
var evalScores = exploitation + exploration;
|
||||
|
||||
var idx = Intrinsics.HMaxIndex(evalScores, iterCount);
|
||||
|
||||
if (evalScores[idx] >= maxScore)
|
||||
{
|
||||
max = (i, idx);
|
||||
maxScore = evalScores[idx];
|
||||
}
|
||||
|
||||
length -= iterCount;
|
||||
}
|
||||
|
||||
return max;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
public Node Select()
|
||||
{
|
||||
var node = rootNode;
|
||||
var nodeVisits = rootScores.Visits;
|
||||
|
||||
while (true)
|
||||
{
|
||||
var expandable = !node.State.AvailableActions.IsEmpty;
|
||||
var likelyTerminal = node.Children.Count == 0;
|
||||
if (expandable || likelyTerminal)
|
||||
return node;
|
||||
|
||||
// select the node with the highest score
|
||||
var at = EvalBestChild(nodeVisits, ref node.ChildScores);
|
||||
nodeVisits = node.ChildScores.GetVisits(at);
|
||||
node = node.ChildAt(at)!;
|
||||
}
|
||||
}
|
||||
|
||||
public (Node ExpandedNode, float Score) ExpandAndRollout(Random random, Simulator simulator, Node initialNode)
|
||||
{
|
||||
ref var initialState = ref initialNode.State;
|
||||
// expand once
|
||||
if (initialState.IsComplete)
|
||||
return (initialNode, initialState.CalculateScore(config) ?? 0);
|
||||
|
||||
var poppedAction = initialState.AvailableActions.PopRandom(random);
|
||||
var expandedNode = initialNode.Add(Execute(simulator, initialState.State, poppedAction, true));
|
||||
|
||||
// playout to a terminal state
|
||||
var currentState = expandedNode.State.State;
|
||||
var currentCompletionState = expandedNode.State.SimulationCompletionState;
|
||||
var currentActions = expandedNode.State.AvailableActions;
|
||||
|
||||
|
||||
byte actionCount = 0;
|
||||
Span<ActionType> actions = stackalloc ActionType[Math.Min(config.MaxStepCount - currentState.ActionCount, config.MaxRolloutStepCount)];
|
||||
while (SimulationNode.GetCompletionState(currentCompletionState, currentActions) == CompletionState.Incomplete &&
|
||||
actionCount < actions.Length)
|
||||
{
|
||||
var nextAction = currentActions.SelectRandom(random);
|
||||
actions[actionCount++] = nextAction;
|
||||
(_, currentState) = simulator.Execute(currentState, nextAction);
|
||||
currentCompletionState = simulator.CompletionState;
|
||||
if (currentCompletionState != CompletionState.Incomplete)
|
||||
break;
|
||||
currentActions = simulator.AvailableActionsHeuristic(true);
|
||||
}
|
||||
|
||||
// store the result if a max score was reached
|
||||
var score = SimulationNode.CalculateScoreForState(currentState, currentCompletionState, config) ?? 0;
|
||||
if (currentCompletionState == CompletionState.ProgressComplete)
|
||||
{
|
||||
if (score >= config.ScoreStorageThreshold && score >= MaxScore)
|
||||
{
|
||||
var terminalNode = ExecuteActions(simulator, expandedNode, actions[..actionCount], true);
|
||||
return (terminalNode, score);
|
||||
}
|
||||
}
|
||||
return (expandedNode, score);
|
||||
}
|
||||
|
||||
public void Backpropagate(Node startNode, float score)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
if (startNode == rootNode)
|
||||
{
|
||||
rootScores.Visit(score);
|
||||
break;
|
||||
}
|
||||
startNode.ParentScores!.Value.Visit(startNode.ChildIdx, score);
|
||||
|
||||
startNode = startNode.Parent!;
|
||||
}
|
||||
}
|
||||
|
||||
private void ShowAllNodes()
|
||||
{
|
||||
static void ShowNodes(StringBuilder b, Node node, Stack<Node> path)
|
||||
{
|
||||
path.Push(node);
|
||||
b.AppendLine($"{new string(' ', path.Count)}{node.State.Action}");
|
||||
{
|
||||
for (var i = 0; i < node.Children.Count; ++i)
|
||||
{
|
||||
var n = node.ChildAt((i >> 3, i & 7))!;
|
||||
ShowNodes(b, n, path);
|
||||
}
|
||||
path.Pop();
|
||||
}
|
||||
}
|
||||
var b = new StringBuilder();
|
||||
ShowNodes(b, rootNode, new());
|
||||
Console.WriteLine(b.ToString());
|
||||
}
|
||||
|
||||
private bool AllNodesComplete()
|
||||
{
|
||||
static bool NodesIncomplete(Node node, Stack<Node> path)
|
||||
{
|
||||
path.Push(node);
|
||||
if (node.Children.Count == 0)
|
||||
{
|
||||
if (!node.State.AvailableActions.IsEmpty)
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
for (var i = 0; i < node.Children.Count; ++i)
|
||||
{
|
||||
var n = node.ChildAt((i >> 3, i & 7))!;
|
||||
if (NodesIncomplete(n, path))
|
||||
return true;
|
||||
}
|
||||
path.Pop();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return !NodesIncomplete(rootNode, new());
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private void Search(int iterations, CancellationToken token)
|
||||
{
|
||||
Simulator simulator = new(rootNode.State.State, config.MaxStepCount);
|
||||
var random = rootNode.State.State.Input.Random;
|
||||
var n = 0;
|
||||
for (var i = 0; i < iterations || MaxScore == 0; i++)
|
||||
{
|
||||
if (token.IsCancellationRequested)
|
||||
break;
|
||||
|
||||
var selectedNode = Select();
|
||||
var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode);
|
||||
if (MaxScore == 0)
|
||||
{
|
||||
if (endNode == selectedNode)
|
||||
{
|
||||
if (n++ > 5000)
|
||||
{
|
||||
n = 0;
|
||||
if (AllNodesComplete())
|
||||
{
|
||||
//Console.WriteLine("All nodes solved for. Can't find a valid solution.");
|
||||
//ShowAllNodes();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
n = 0;
|
||||
}
|
||||
|
||||
Backpropagate(endNode, score);
|
||||
}
|
||||
}
|
||||
|
||||
public static SolverSolution SearchStepwiseFurcated(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token)
|
||||
private static SolverSolution SearchStepwiseFurcated(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token)
|
||||
{
|
||||
var definiteActionCount = 0;
|
||||
var bestSims = new List<(float Score, SolverSolution Result)>();
|
||||
@@ -354,7 +28,7 @@ public sealed class Solver
|
||||
var st = activeStates[stateIdx];
|
||||
tasks[i] = Task.Run(() =>
|
||||
{
|
||||
var solver = new Solver(config, activeStates[stateIdx].State);
|
||||
var solver = new MCTS(new(config), activeStates[stateIdx].State);
|
||||
solver.Search(config.Iterations / config.ForkCount, token);
|
||||
return (solver.MaxScore, stateIdx, solver.Solution());
|
||||
}, token);
|
||||
@@ -442,7 +116,7 @@ public sealed class Solver
|
||||
return result;
|
||||
}
|
||||
|
||||
public static SolverSolution SearchStepwiseForked(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token)
|
||||
private static SolverSolution SearchStepwiseForked(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token)
|
||||
{
|
||||
var actions = new List<ActionType>();
|
||||
var sim = new Simulator(state, config.MaxStepCount);
|
||||
@@ -460,7 +134,7 @@ public sealed class Solver
|
||||
for (var i = 0; i < config.ForkCount; ++i)
|
||||
tasks[i] = Task.Run(() =>
|
||||
{
|
||||
var solver = new Solver(config, state);
|
||||
var solver = new MCTS(new(config), state);
|
||||
solver.Search(config.Iterations / config.ForkCount, token);
|
||||
return (solver.MaxScore, solver.Solution());
|
||||
}, token);
|
||||
@@ -489,7 +163,7 @@ public sealed class Solver
|
||||
return new(actions, state);
|
||||
}
|
||||
|
||||
public static SolverSolution SearchStepwise(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token)
|
||||
private static SolverSolution SearchStepwise(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token)
|
||||
{
|
||||
var actions = new List<ActionType>();
|
||||
var sim = new Simulator(state, config.MaxStepCount);
|
||||
@@ -501,7 +175,7 @@ public sealed class Solver
|
||||
if (sim.IsComplete)
|
||||
break;
|
||||
|
||||
var solver = new Solver(config, state);
|
||||
var solver = new MCTS(new(config), state);
|
||||
|
||||
var s = Stopwatch.StartNew();
|
||||
solver.Search(config.Iterations, token);
|
||||
@@ -526,13 +200,13 @@ public sealed class Solver
|
||||
return new(actions, state);
|
||||
}
|
||||
|
||||
public static SolverSolution SearchOneshotForked(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token)
|
||||
private static SolverSolution SearchOneshotForked(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token)
|
||||
{
|
||||
var tasks = new Task<(float MaxScore, SolverSolution Solution)>[config.ForkCount];
|
||||
for (var i = 0; i < config.ForkCount; ++i)
|
||||
tasks[i] = Task.Run(() =>
|
||||
{
|
||||
var solver = new Solver(config, state);
|
||||
var solver = new MCTS(new(config), state);
|
||||
solver.Search(config.Iterations / config.ForkCount, token);
|
||||
return (solver.MaxScore, solver.Solution());
|
||||
}, token);
|
||||
@@ -545,9 +219,9 @@ public sealed class Solver
|
||||
return solution;
|
||||
}
|
||||
|
||||
public static SolverSolution SearchOneshot(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token)
|
||||
private static SolverSolution SearchOneshot(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token)
|
||||
{
|
||||
var solver = new Solver(config, state);
|
||||
var solver = new MCTS(new(config), state);
|
||||
solver.Search(config.Iterations, token);
|
||||
var solution = solver.Solution();
|
||||
foreach (var action in solution.Actions)
|
||||
@@ -556,7 +230,7 @@ public sealed class Solver
|
||||
return solution;
|
||||
}
|
||||
|
||||
public static SolverSolution Search(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback = null, CancellationToken token = default)
|
||||
public static SolverSolution Search(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token)
|
||||
{
|
||||
Func<SolverConfig, SimulationState, Action<ActionType>?, CancellationToken, SolverSolution> func = config.Algorithm switch
|
||||
{
|
||||
@@ -564,7 +238,8 @@ public sealed class Solver
|
||||
SolverAlgorithm.OneshotForked => SearchOneshotForked,
|
||||
SolverAlgorithm.Stepwise => SearchStepwise,
|
||||
SolverAlgorithm.StepwiseForked => SearchStepwiseForked,
|
||||
SolverAlgorithm.StepwiseFurcated or _ => SearchStepwiseFurcated,
|
||||
SolverAlgorithm.StepwiseFurcated => SearchStepwiseFurcated,
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(config), config, $"Invalid algorithm: {config.Algorithm}")
|
||||
};
|
||||
return func(config, state, actionCallback, token);
|
||||
}
|
||||
|
||||
+20
-1
@@ -1,3 +1,5 @@
|
||||
using Craftimizer.Simulator.Actions;
|
||||
using Craftimizer.Simulator;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace Craftimizer.Solver;
|
||||
@@ -65,4 +67,21 @@ public readonly record struct SolverConfig
|
||||
FurcatedActionCount = Environment.ProcessorCount / 2,
|
||||
Algorithm = SolverAlgorithm.StepwiseForked
|
||||
};
|
||||
}
|
||||
|
||||
public SolverSolution? Invoke(SimulationState state, Action<ActionType>? actionCallback = null, CancellationToken token = default)
|
||||
{
|
||||
try
|
||||
{
|
||||
return Solver.Search(this, state, actionCallback, token);
|
||||
}
|
||||
catch (AggregateException e)
|
||||
{
|
||||
e.Handle(ex => ex is OperationCanceledException);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user