Remove all concurrency code
Muddled the code too much, and only gave a marginal performance improvement in the grand scheme of things. Other ways to parallelize MCTS will be nicer to implement and could yield better results.
This commit is contained in:
@@ -7,6 +7,8 @@ namespace Craftimizer.Solver.Crafty;
|
||||
|
||||
public struct ActionSet
|
||||
{
|
||||
private const bool IsDeterministic = true;
|
||||
|
||||
private uint bits;
|
||||
|
||||
[Pure]
|
||||
@@ -19,24 +21,6 @@ public struct ActionSet
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static uint ToMask(ActionType action) => 1u << FromAction(action) + 1;
|
||||
|
||||
// Return true if action was newly added and not there before.
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public bool AddActionConcurrent(ActionType action)
|
||||
{
|
||||
var mask = ToMask(action);
|
||||
var old = Interlocked.Or(ref bits, mask);
|
||||
return (old & mask) == 0;
|
||||
}
|
||||
|
||||
// Return true if action was newly removed and not already gone.
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public bool RemoveActionConcurrent(ActionType action)
|
||||
{
|
||||
var mask = ToMask(action);
|
||||
var old = Interlocked.And(ref bits, ~mask);
|
||||
return (old & mask) != 0;
|
||||
}
|
||||
|
||||
// Return true if action was newly added and not there before.
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public bool AddAction(ActionType action)
|
||||
@@ -71,52 +55,17 @@ public struct ActionSet
|
||||
public readonly bool IsEmpty => bits == 0;
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public readonly ActionType SelectRandom(Random random) => ElementAt(random.Next(Count));
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public ActionType? PopRandomConcurrent(Random random)
|
||||
{
|
||||
uint snapshot;
|
||||
uint newValue;
|
||||
ActionType action;
|
||||
do
|
||||
{
|
||||
snapshot = bits;
|
||||
if (snapshot == 0)
|
||||
return null;
|
||||
|
||||
var count = BitOperations.PopCount(snapshot);
|
||||
var index = random.Next(count);
|
||||
|
||||
action = ToAction(Intrinsics.NthBitSet(snapshot, index) - 1);
|
||||
newValue = snapshot & ~ToMask(action);
|
||||
}
|
||||
while (Interlocked.CompareExchange(ref bits, newValue, snapshot) != snapshot);
|
||||
return action;
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public ActionType? PopFirstConcurrent()
|
||||
{
|
||||
uint snapshot;
|
||||
uint newValue;
|
||||
ActionType action;
|
||||
do
|
||||
{
|
||||
snapshot = bits;
|
||||
if (snapshot == 0)
|
||||
return null;
|
||||
|
||||
action = ToAction(Intrinsics.NthBitSet(snapshot, 0) - 1);
|
||||
newValue = snapshot & ~ToMask(action);
|
||||
}
|
||||
while (Interlocked.CompareExchange(ref bits, newValue, snapshot) != snapshot);
|
||||
return action;
|
||||
}
|
||||
public readonly ActionType SelectRandom(Random random) =>
|
||||
IsDeterministic ?
|
||||
First() :
|
||||
ElementAt(random.Next(Count));
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public ActionType PopRandom(Random random)
|
||||
{
|
||||
if (IsDeterministic)
|
||||
return PopFirst();
|
||||
|
||||
var action = ElementAt(random.Next(Count));
|
||||
RemoveAction(action);
|
||||
return action;
|
||||
|
||||
@@ -11,41 +11,20 @@ public struct ArenaBuffer<T> where T : struct
|
||||
// The benchmark reaches 20 at most, but here we have a little leeway just in case.
|
||||
private const int MaxSize = 24;
|
||||
|
||||
private static int BatchSize = Vector<float>.Count;
|
||||
private static int BatchSizeBits = int.Log2(BatchSize);
|
||||
private static int BatchSizeMask = BatchSize - 1;
|
||||
private static readonly int BatchSize = Vector<float>.Count;
|
||||
private static readonly int BatchSizeBits = int.Log2(BatchSize);
|
||||
private static readonly int BatchSizeMask = BatchSize - 1;
|
||||
|
||||
private static int BatchCount = MaxSize / BatchSize;
|
||||
private static readonly int BatchCount = MaxSize / BatchSize;
|
||||
|
||||
public ArenaNode<T>[][] Data;
|
||||
private int index; // Unused in single threaded workload
|
||||
private int count;
|
||||
|
||||
public readonly int Count => count;
|
||||
|
||||
public void AddConcurrent(ArenaNode<T> node)
|
||||
{
|
||||
if (Data == null)
|
||||
Interlocked.CompareExchange(ref Data, new ArenaNode<T>[BatchCount][], null);
|
||||
|
||||
var idx = Interlocked.Increment(ref index) - 1;
|
||||
|
||||
var (arrayIdx, subIdx) = GetArrayIndex(idx);
|
||||
|
||||
if (Data[arrayIdx] == null)
|
||||
Interlocked.CompareExchange(ref Data[arrayIdx], new ArenaNode<T>[BatchSize], null);
|
||||
|
||||
node.ChildIdx = (arrayIdx, subIdx);
|
||||
Data[arrayIdx][subIdx] = node;
|
||||
|
||||
Interlocked.Increment(ref count);
|
||||
}
|
||||
public int Count { get; private set; }
|
||||
|
||||
public void Add(ArenaNode<T> node)
|
||||
{
|
||||
Data ??= new ArenaNode<T>[BatchCount][];
|
||||
|
||||
var idx = count++;
|
||||
var idx = Count++;
|
||||
|
||||
var (arrayIdx, subIdx) = GetArrayIndex(idx);
|
||||
|
||||
|
||||
@@ -23,15 +23,6 @@ public sealed class ArenaNode<T> where T : struct
|
||||
public ArenaNode<T>? ChildAt((int arrayIdx, int subIdx) at) =>
|
||||
Children.Data?[at.arrayIdx]?[at.subIdx];
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public ArenaNode<T> AddConcurrent(T state)
|
||||
{
|
||||
var node = new ArenaNode<T>(state, this);
|
||||
ChildScores.AddConcurrent();
|
||||
Children.AddConcurrent(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public ArenaNode<T> Add(T state)
|
||||
{
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
|
||||
|
||||
namespace Craftimizer.Solver.Crafty;
|
||||
|
||||
public interface ISolver
|
||||
{
|
||||
abstract static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator);
|
||||
|
||||
abstract static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token);
|
||||
}
|
||||
@@ -124,28 +124,4 @@ internal static class Intrinsics
|
||||
result[i] = MathF.ReciprocalSqrtEstimate(data[i]);
|
||||
return new(result);
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static void CASMax(ref float location, float newValue)
|
||||
{
|
||||
float snapshot;
|
||||
do
|
||||
{
|
||||
snapshot = location;
|
||||
if (snapshot >= newValue) return;
|
||||
} while (Interlocked.CompareExchange(ref location, newValue, snapshot) != snapshot);
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static void CASAdd(ref float location, float value)
|
||||
{
|
||||
float snapshot;
|
||||
float newValue;
|
||||
do
|
||||
{
|
||||
snapshot = location;
|
||||
newValue = snapshot + value;
|
||||
}
|
||||
while (Interlocked.CompareExchange(ref location, newValue, snapshot) != snapshot);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
using System;
|
||||
using System.ComponentModel;
|
||||
using System.Diagnostics.Contracts;
|
||||
using System.Numerics;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace Craftimizer.Solver.Crafty;
|
||||
|
||||
@@ -28,51 +25,26 @@ public struct NodeScoresBuffer
|
||||
// The benchmark reaches 20 at most, but here we have a little leeway just in case.
|
||||
private const int MaxSize = 24;
|
||||
|
||||
private static int BatchSize = Vector<float>.Count;
|
||||
private static int BatchSizeBits = int.Log2(BatchSize);
|
||||
private static int BatchSizeMask = BatchSize - 1;
|
||||
private static readonly int BatchSize = Vector<float>.Count;
|
||||
private static readonly int BatchSizeBits = int.Log2(BatchSize);
|
||||
private static readonly int BatchSizeMask = BatchSize - 1;
|
||||
|
||||
private static int BatchCount = MaxSize / BatchSize;
|
||||
private static readonly int BatchCount = MaxSize / BatchSize;
|
||||
|
||||
public ScoresBatch[] Data;
|
||||
private int index;
|
||||
private int count;
|
||||
|
||||
public readonly int Count => count;
|
||||
|
||||
public void AddConcurrent()
|
||||
{
|
||||
if (Data == null)
|
||||
Interlocked.CompareExchange(ref Data, new ScoresBatch[BatchCount], null);
|
||||
|
||||
var idx = Interlocked.Increment(ref index) - 1;
|
||||
|
||||
var (arrayIdx, _) = GetArrayIndex(idx);
|
||||
|
||||
if (Data[arrayIdx] == null)
|
||||
Interlocked.CompareExchange(ref Data[arrayIdx], new ScoresBatch(), null);
|
||||
|
||||
Interlocked.Increment(ref count);
|
||||
}
|
||||
public int Count { get; private set; }
|
||||
|
||||
public void Add()
|
||||
{
|
||||
Data ??= new ScoresBatch[BatchCount];
|
||||
|
||||
var idx = count++;
|
||||
var idx = Count++;
|
||||
|
||||
var (arrayIdx, _) = GetArrayIndex(idx);
|
||||
|
||||
Data[arrayIdx] ??= new();
|
||||
}
|
||||
|
||||
public readonly void VisitConcurrent((int arrayIdx, int subIdx) at, float score)
|
||||
{
|
||||
Intrinsics.CASAdd(ref Data[at.arrayIdx].ScoreSum.Span[at.subIdx], score);
|
||||
Intrinsics.CASMax(ref Data[at.arrayIdx].MaxScore.Span[at.subIdx], score);
|
||||
Interlocked.Increment(ref Data[at.arrayIdx].Visits.Span[at.subIdx]);
|
||||
}
|
||||
|
||||
public readonly void Visit((int arrayIdx, int subIdx) at, float score)
|
||||
{
|
||||
Data[at.arrayIdx].ScoreSum.Span[at.subIdx] += score;
|
||||
|
||||
@@ -9,13 +9,6 @@ public sealed class RootScores
|
||||
public float MaxScore;
|
||||
public int Visits;
|
||||
|
||||
public void VisitConcurrent(float score)
|
||||
{
|
||||
Intrinsics.CASAdd(ref ScoreSum, score);
|
||||
Intrinsics.CASMax(ref MaxScore, score);
|
||||
Interlocked.Increment(ref Visits);
|
||||
}
|
||||
|
||||
public void Visit(float score)
|
||||
{
|
||||
ScoreSum += score;
|
||||
|
||||
@@ -1,14 +1,35 @@
|
||||
using Craftimizer.Simulator;
|
||||
using Craftimizer.Simulator.Actions;
|
||||
using Craftimizer.Simulator;
|
||||
using System.Diagnostics.Contracts;
|
||||
using System.Numerics;
|
||||
using System.Runtime.CompilerServices;
|
||||
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
|
||||
|
||||
namespace Craftimizer.Solver.Crafty;
|
||||
public static class SolverUtils
|
||||
|
||||
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
|
||||
public sealed class Solver
|
||||
{
|
||||
public static SimulationNode Execute(Simulator simulator, SimulationState state, ActionType action, bool strict)
|
||||
private SolverConfig config;
|
||||
private Node rootNode;
|
||||
private RootScores rootScores;
|
||||
|
||||
public float MaxScore => rootScores.MaxScore;
|
||||
|
||||
public Solver(SolverConfig config, SimulationState state, bool strict)
|
||||
{
|
||||
this.config = config;
|
||||
var sim = new Simulator(state, config.MaxStepCount);
|
||||
rootNode = new(new(
|
||||
state,
|
||||
null,
|
||||
sim.CompletionState,
|
||||
sim.AvailableActionsHeuristic(strict)
|
||||
));
|
||||
rootScores = new();
|
||||
}
|
||||
|
||||
private static SimulationNode Execute(Simulator simulator, SimulationState state, ActionType action, bool strict)
|
||||
{
|
||||
(_, var newState) = simulator.Execute(state, action);
|
||||
return new(
|
||||
@@ -19,27 +40,50 @@ public static class SolverUtils
|
||||
);
|
||||
}
|
||||
|
||||
public static (Node EndNode, CompletionState State) ExecuteActions(Simulator simulator, Node startNode, ReadOnlySpan<ActionType> actions, bool strict = false)
|
||||
private static Node ExecuteActions(Simulator simulator, Node startNode, ReadOnlySpan<ActionType> actions, bool strict)
|
||||
{
|
||||
foreach (var action in actions)
|
||||
{
|
||||
var state = startNode.State;
|
||||
if (state.IsComplete)
|
||||
return (startNode, state.CompletionState);
|
||||
return startNode;
|
||||
|
||||
if (!state.AvailableActions.HasAction(action))
|
||||
return (startNode, CompletionState.InvalidAction);
|
||||
return startNode;
|
||||
state.AvailableActions.RemoveAction(action);
|
||||
|
||||
startNode = startNode.Add(Execute(simulator, state.State, action, strict));
|
||||
}
|
||||
|
||||
return (startNode, startNode.State.CompletionState);
|
||||
return startNode;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
private (List<ActionType> Actions, SimulationNode Node) Solution()
|
||||
{
|
||||
var actions = new List<ActionType>();
|
||||
var node = rootNode;
|
||||
|
||||
while (node.Children.Count != 0)
|
||||
{
|
||||
node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!;
|
||||
|
||||
if (node.State.Action != null)
|
||||
actions.Add(node.State.Action.Value);
|
||||
}
|
||||
|
||||
var at = node.ChildIdx;
|
||||
ref var sum = ref node.ParentScores!.Value.Data[at.arrayIdx].ScoreSum.Span[at.subIdx];
|
||||
ref var max = ref node.ParentScores!.Value.Data[at.arrayIdx].MaxScore.Span[at.subIdx];
|
||||
ref var visits = ref node.ParentScores!.Value.Data[at.arrayIdx].Visits.Span[at.subIdx];
|
||||
//Console.WriteLine($"{sum} {max} {visits}");
|
||||
|
||||
return (actions, node.State);
|
||||
}
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores)
|
||||
private static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores)
|
||||
{
|
||||
var length = scores.Count;
|
||||
var vecLength = Vector<float>.Count;
|
||||
@@ -67,21 +111,6 @@ public static class SolverUtils
|
||||
return max;
|
||||
}
|
||||
|
||||
[Pure]
|
||||
public static (List<ActionType> Actions, SimulationNode Node) Solution(Node node)
|
||||
{
|
||||
var actions = new List<ActionType>();
|
||||
while (node.Children.Count != 0)
|
||||
{
|
||||
node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!;
|
||||
|
||||
if (node.State.Action != null)
|
||||
actions.Add(node.State.Action.Value);
|
||||
}
|
||||
|
||||
return (actions, node.State);
|
||||
}
|
||||
|
||||
// Calculates the best child node to explore next
|
||||
// Exploitation: ((1 - w) * (s / v)) + (w * m)
|
||||
// Exploration: sqrt(c * ln(V) / v)
|
||||
@@ -98,10 +127,9 @@ public static class SolverUtils
|
||||
// max score is tacked onto it
|
||||
// N_i = parent visits
|
||||
// c = exploration constant (but crafty places it inside the sqrt..?)
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
|
||||
public static (int arrayIdx, int subIdx) EvalBestChild<S>(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer scores) where S : ISolver
|
||||
private (int arrayIdx, int subIdx) EvalBestChild(int parentVisits, ref NodeScoresBuffer scores)
|
||||
{
|
||||
var length = scores.Count;
|
||||
var vecLength = Vector<float>.Count;
|
||||
@@ -143,9 +171,36 @@ public static class SolverUtils
|
||||
return max;
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static (Node ExpandedNode, float Score) Rollout(ref SolverConfig config, float maxScore, Node rootNode, Node expandedNode, Random random, Simulator simulator)
|
||||
[Pure]
|
||||
public Node Select()
|
||||
{
|
||||
var node = rootNode;
|
||||
var nodeVisits = rootScores.Visits;
|
||||
|
||||
while (true)
|
||||
{
|
||||
var expandable = !node.State.AvailableActions.IsEmpty;
|
||||
var likelyTerminal = node.Children.Count == 0;
|
||||
if (expandable || likelyTerminal)
|
||||
return node;
|
||||
|
||||
// select the node with the highest score
|
||||
var at = EvalBestChild(nodeVisits, ref node.ChildScores);
|
||||
nodeVisits = node.ChildScores.GetVisits(at);
|
||||
node = node.ChildAt(at)!;
|
||||
}
|
||||
}
|
||||
|
||||
public (Node ExpandedNode, float Score) ExpandAndRollout(Random random, Simulator simulator, Node initialNode)
|
||||
{
|
||||
ref var initialState = ref initialNode.State;
|
||||
// expand once
|
||||
if (initialState.IsComplete)
|
||||
return (initialNode, initialState.CalculateScore(config.MaxStepCount) ?? 0);
|
||||
|
||||
var poppedAction = initialState.AvailableActions.PopRandom(random);
|
||||
var expandedNode = initialNode.Add(Execute(simulator, initialState.State, poppedAction, true));
|
||||
|
||||
// playout to a terminal state
|
||||
var currentState = expandedNode.State.State;
|
||||
var currentCompletionState = expandedNode.State.SimulationCompletionState;
|
||||
@@ -168,56 +223,51 @@ public static class SolverUtils
|
||||
var score = SimulationNode.CalculateScoreForState(currentState, currentCompletionState, config.MaxStepCount) ?? 0;
|
||||
if (currentCompletionState == CompletionState.ProgressComplete)
|
||||
{
|
||||
if (score >= config.ScoreStorageThreshold && score >= maxScore)
|
||||
if (score >= config.ScoreStorageThreshold && score >= MaxScore)
|
||||
{
|
||||
(var terminalNode, _) = ExecuteActions(simulator, expandedNode, actions[..actionCount], true);
|
||||
var terminalNode = ExecuteActions(simulator, expandedNode, actions[..actionCount], true);
|
||||
return (terminalNode, score);
|
||||
}
|
||||
}
|
||||
return (expandedNode, score);
|
||||
}
|
||||
|
||||
public void Backpropagate(Node startNode, float score)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
if (startNode == rootNode)
|
||||
{
|
||||
rootScores.Visit(score);
|
||||
break;
|
||||
}
|
||||
startNode.ParentScores!.Value.Visit(startNode.ChildIdx, score);
|
||||
|
||||
startNode = startNode.Parent!;
|
||||
}
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static void Search<S>(ref SolverConfig config, int iterations, RootScores rootScores, Node rootNode, CancellationToken token) where S : ISolver
|
||||
private void Search(CancellationToken token)
|
||||
{
|
||||
Simulator simulator = new(rootNode.State.State, config.MaxStepCount);
|
||||
var random = rootNode.State.State.Input.Random;
|
||||
for (var i = 0; i < iterations; i++)
|
||||
for (var i = 0; i < config.Iterations; i++)
|
||||
{
|
||||
if (token.IsCancellationRequested)
|
||||
break;
|
||||
|
||||
if (!S.SearchIter(ref config, rootScores, rootNode, random, simulator))
|
||||
{
|
||||
// Retry, count this iteration as moot
|
||||
i--;
|
||||
continue;
|
||||
}
|
||||
var selectedNode = Select();
|
||||
var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode);
|
||||
|
||||
Backpropagate(endNode, score);
|
||||
}
|
||||
}
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static Node CreateRootNode(SolverConfig config, SimulationInput input, bool strict) =>
|
||||
CreateRootNode(config, new SimulationState(input), strict);
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseForked(SolverConfig config, int forkCount, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) =>
|
||||
SearchStepwiseForked(config, forkCount, new SimulationState(input), actionCallback, token);
|
||||
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static Node CreateRootNode(SolverConfig config, SimulationState state, bool strict)
|
||||
{
|
||||
var sim = new Simulator(state, config.MaxStepCount);
|
||||
return new(new(
|
||||
state,
|
||||
null,
|
||||
sim.CompletionState,
|
||||
sim.AvailableActionsHeuristic(strict)
|
||||
));
|
||||
}
|
||||
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseForked<S>(SolverConfig config, int forkCount, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver =>
|
||||
SearchStepwiseForked<S>(config, forkCount, new SimulationState(input), actionCallback, token);
|
||||
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseForked<S>(SolverConfig config, int forkCount, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseForked(SolverConfig config, int forkCount, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token = default)
|
||||
{
|
||||
var actions = new List<ActionType>();
|
||||
var sim = new Simulator(state, config.MaxStepCount);
|
||||
@@ -230,13 +280,11 @@ public static class SolverUtils
|
||||
for (var i = 0; i < forkCount; ++i)
|
||||
tasks[i] = Task.Run(() =>
|
||||
{
|
||||
var rootNode = CreateRootNode(config, state, true);
|
||||
RootScores rootScores = new();
|
||||
var solver = new Solver(config, state, true);
|
||||
solver.Search(token);
|
||||
var (solution_actions, solution_node) = solver.Solution();
|
||||
|
||||
S.Search(ref config, rootScores, rootNode, token);
|
||||
var (solution_actions, solution_node) = Solution(rootNode);
|
||||
|
||||
return (rootScores.MaxScore, solution_actions, solution_node.State);
|
||||
return (solver.MaxScore, solution_actions, solution_node.State);
|
||||
}, token);
|
||||
Task.WaitAll(tasks, CancellationToken.None);
|
||||
|
||||
@@ -258,24 +306,23 @@ public static class SolverUtils
|
||||
return (actions, state);
|
||||
}
|
||||
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchStepwise<S>(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver =>
|
||||
SearchStepwise<S>(config, new SimulationState(input), actionCallback, token);
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) =>
|
||||
SearchStepwise(config, new SimulationState(input), actionCallback, token);
|
||||
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchStepwise<S>(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token = default)
|
||||
{
|
||||
var actions = new List<ActionType>();
|
||||
var sim = new Simulator(state, config.MaxStepCount);
|
||||
var rootNode = CreateRootNode(config, state, true);
|
||||
RootScores rootScores = new();
|
||||
var solver = new Solver(config, state, true);
|
||||
while (!sim.IsComplete)
|
||||
{
|
||||
if (token.IsCancellationRequested)
|
||||
break;
|
||||
|
||||
S.Search(ref config, rootScores, rootNode, token);
|
||||
var (solution_actions, solution_node) = Solution(rootNode);
|
||||
solver.Search(token);
|
||||
var (solution_actions, solution_node) = solver.Solution();
|
||||
|
||||
if (rootScores.MaxScore >= 1.0)
|
||||
if (solver.MaxScore >= 1.0)
|
||||
{
|
||||
actions.AddRange(solution_actions);
|
||||
return (actions, solution_node.State);
|
||||
@@ -287,21 +334,20 @@ public static class SolverUtils
|
||||
|
||||
actionCallback?.Invoke(chosen_action);
|
||||
|
||||
rootNode = CreateRootNode(config, state, true);
|
||||
solver = new Solver(config, state, true);
|
||||
}
|
||||
|
||||
return (actions, state);
|
||||
}
|
||||
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchOneshot<S>(SolverConfig config, SimulationInput input, CancellationToken token = default) where S : ISolver =>
|
||||
SearchOneshot<S>(config, new SimulationState(input), token);
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationInput input, CancellationToken token = default) =>
|
||||
SearchOneshot(config, new SimulationState(input), token);
|
||||
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchOneshot<S>(SolverConfig config, SimulationState state, CancellationToken token = default) where S : ISolver
|
||||
public static (List<ActionType> Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationState state, CancellationToken token = default)
|
||||
{
|
||||
var rootNode = CreateRootNode(config, state, false);
|
||||
RootScores rootScores = new();
|
||||
S.Search(ref config, rootScores, rootNode, token);
|
||||
var (solution_actions, solution_node) = Solution(rootNode);
|
||||
var solver = new Solver(config, state, false);
|
||||
solver.Search(token);
|
||||
var (solution_actions, solution_node) = solver.Solution();
|
||||
return (solution_actions, solution_node.State);
|
||||
}
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
using System.Diagnostics.Contracts;
|
||||
using System.Runtime.CompilerServices;
|
||||
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
|
||||
|
||||
namespace Craftimizer.Solver.Crafty;
|
||||
|
||||
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
|
||||
public sealed class SolverConcurrent : ISolver
|
||||
{
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static (int arrayIdx, int subIdx)? EvalBestChild(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer children) =>
|
||||
parentVisits == 0 ?
|
||||
null :
|
||||
SolverUtils.EvalBestChild<SolverConcurrent>(ref config, parentVisits, ref children);
|
||||
|
||||
[Pure]
|
||||
public static Node Select(ref SolverConfig config, int rootNodeVisits, Node rootNode)
|
||||
{
|
||||
var node = rootNode;
|
||||
var nodeVisits = rootNodeVisits;
|
||||
while (true)
|
||||
{
|
||||
var expandable = !node.State.AvailableActions.IsEmpty;
|
||||
var likelyTerminal = node.Children.Count == 0;
|
||||
if (expandable || likelyTerminal)
|
||||
return node;
|
||||
|
||||
// select the node with the highest score
|
||||
// if null (current node is invalid & not backpropagated just yet), try again from root
|
||||
var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores);
|
||||
if (at.HasValue)
|
||||
{
|
||||
nodeVisits = node.ChildScores.GetVisits(at.Value);
|
||||
node = node.ChildAt(at.Value);
|
||||
if (node == null)
|
||||
{
|
||||
node = rootNode;
|
||||
nodeVisits = rootNodeVisits;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
node = rootNode;
|
||||
nodeVisits = rootNodeVisits;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static (Node ExpandedNode, float Score)? ExpandAndRollout(ref SolverConfig config, float maxScore, Node rootNode, Random random, Simulator simulator, Node initialNode)
|
||||
{
|
||||
ref var initialState = ref initialNode.State;
|
||||
// expand once
|
||||
if (initialState.IsComplete)
|
||||
return (initialNode, initialState.CalculateScore(config.MaxStepCount) ?? 0);
|
||||
|
||||
var poppedAction = initialState.AvailableActions.PopRandomConcurrent(random);
|
||||
if (!poppedAction.HasValue)
|
||||
return null;
|
||||
var expandedNode = initialNode.AddConcurrent(SolverUtils.Execute(simulator, initialState.State, poppedAction.Value, true));
|
||||
|
||||
return SolverUtils.Rollout(ref config, maxScore, rootNode, expandedNode, random, simulator);
|
||||
}
|
||||
|
||||
public static void Backpropagate(RootScores rootScores, Node rootNode, Node startNode, float score)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
if (startNode == rootNode)
|
||||
{
|
||||
rootScores.VisitConcurrent(score);
|
||||
break;
|
||||
}
|
||||
startNode.ParentScores!.Value.VisitConcurrent(startNode.ChildIdx, score);
|
||||
|
||||
startNode = startNode.Parent!;
|
||||
}
|
||||
}
|
||||
|
||||
public static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator)
|
||||
{
|
||||
var selectedNode = Select(ref config, rootScores.Visits, rootNode);
|
||||
var rolledOut = ExpandAndRollout(ref config, rootScores.MaxScore, rootNode, random, simulator, selectedNode);
|
||||
if (!rolledOut.HasValue)
|
||||
return false;
|
||||
|
||||
var (endNode, score) = rolledOut.Value;
|
||||
Backpropagate(rootScores, rootNode, endNode, score);
|
||||
return true;
|
||||
}
|
||||
|
||||
public static void SearchThread(SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token) =>
|
||||
SolverUtils.Search<SolverConcurrent>(ref config, config.Iterations / config.ThreadCount, rootScores, rootNode, token);
|
||||
|
||||
public static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token)
|
||||
{
|
||||
var configP = config;
|
||||
var tasks = new Task[config.ThreadCount];
|
||||
for (var i = 0; i < config.ThreadCount; ++i)
|
||||
tasks[i] = Task.Run(() => SearchThread(configP, rootScores, rootNode, token), token);
|
||||
Task.WaitAll(tasks, CancellationToken.None);
|
||||
}
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
using System.Diagnostics.Contracts;
|
||||
using System.Runtime.CompilerServices;
|
||||
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
|
||||
|
||||
namespace Craftimizer.Solver.Crafty;
|
||||
|
||||
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
|
||||
public sealed class SolverSingle : ISolver
|
||||
{
|
||||
[Pure]
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static (int arrayIdx, int subIdx) EvalBestChild(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer children) =>
|
||||
SolverUtils.EvalBestChild<SolverSingle>(ref config, parentVisits, ref children);
|
||||
|
||||
[Pure]
|
||||
public static Node Select(ref SolverConfig config, int nodeVisits, Node node)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
var expandable = !node.State.AvailableActions.IsEmpty;
|
||||
var likelyTerminal = node.Children.Count == 0;
|
||||
if (expandable || likelyTerminal)
|
||||
return node;
|
||||
|
||||
// select the node with the highest score
|
||||
var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores);
|
||||
nodeVisits = node.ChildScores.GetVisits(at);
|
||||
node = node.ChildAt(at)!;
|
||||
}
|
||||
}
|
||||
|
||||
public static (Node ExpandedNode, float Score) ExpandAndRollout(ref SolverConfig config, float maxScore, Node rootNode, Random random, Simulator simulator, Node initialNode)
|
||||
{
|
||||
ref var initialState = ref initialNode.State;
|
||||
// expand once
|
||||
if (initialState.IsComplete)
|
||||
return (initialNode, initialState.CalculateScore(config.MaxStepCount) ?? 0);
|
||||
|
||||
var poppedAction = initialState.AvailableActions.PopRandom(random);
|
||||
var expandedNode = initialNode.Add(SolverUtils.Execute(simulator, initialState.State, poppedAction, true));
|
||||
|
||||
return SolverUtils.Rollout(ref config, maxScore, rootNode, expandedNode, random, simulator);
|
||||
}
|
||||
|
||||
public static void Backpropagate(RootScores rootScores, Node rootNode, Node startNode, float score)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
if (startNode == rootNode)
|
||||
{
|
||||
rootScores.Visit(score);
|
||||
break;
|
||||
}
|
||||
startNode.ParentScores!.Value.Visit(startNode.ChildIdx, score);
|
||||
|
||||
startNode = startNode.Parent!;
|
||||
}
|
||||
}
|
||||
|
||||
public static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator)
|
||||
{
|
||||
var selectedNode = Select(ref config, rootScores.Visits, rootNode);
|
||||
var (endNode, score) = ExpandAndRollout(ref config, rootScores.MaxScore, rootNode, random, simulator, selectedNode);
|
||||
|
||||
Backpropagate(rootScores, rootNode, endNode, score);
|
||||
return true;
|
||||
}
|
||||
|
||||
public static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token) =>
|
||||
SolverUtils.Search<SolverSingle>(ref config, config.Iterations, rootScores, rootNode, token);
|
||||
}
|
||||
Reference in New Issue
Block a user