Remove all concurrency code

Muddled the code too much, and only gave a marginal performance improvement in the grand scheme of things. Other ways to parallelize MCTS will be nicer to implement and could yield better results.
This commit is contained in:
Asriel Camora
2023-07-07 20:17:35 +02:00
parent 3ab50d389e
commit 636501ab86
11 changed files with 153 additions and 431 deletions
+9 -60
View File
@@ -7,6 +7,8 @@ namespace Craftimizer.Solver.Crafty;
public struct ActionSet
{
private const bool IsDeterministic = true;
private uint bits;
[Pure]
@@ -19,24 +21,6 @@ public struct ActionSet
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static uint ToMask(ActionType action) => 1u << FromAction(action) + 1;
// Return true if action was newly added and not there before.
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool AddActionConcurrent(ActionType action)
{
var mask = ToMask(action);
var old = Interlocked.Or(ref bits, mask);
return (old & mask) == 0;
}
// Return true if action was newly removed and not already gone.
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool RemoveActionConcurrent(ActionType action)
{
var mask = ToMask(action);
var old = Interlocked.And(ref bits, ~mask);
return (old & mask) != 0;
}
// Return true if action was newly added and not there before.
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool AddAction(ActionType action)
@@ -71,52 +55,17 @@ public struct ActionSet
public readonly bool IsEmpty => bits == 0;
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public readonly ActionType SelectRandom(Random random) => ElementAt(random.Next(Count));
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ActionType? PopRandomConcurrent(Random random)
{
uint snapshot;
uint newValue;
ActionType action;
do
{
snapshot = bits;
if (snapshot == 0)
return null;
var count = BitOperations.PopCount(snapshot);
var index = random.Next(count);
action = ToAction(Intrinsics.NthBitSet(snapshot, index) - 1);
newValue = snapshot & ~ToMask(action);
}
while (Interlocked.CompareExchange(ref bits, newValue, snapshot) != snapshot);
return action;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ActionType? PopFirstConcurrent()
{
uint snapshot;
uint newValue;
ActionType action;
do
{
snapshot = bits;
if (snapshot == 0)
return null;
action = ToAction(Intrinsics.NthBitSet(snapshot, 0) - 1);
newValue = snapshot & ~ToMask(action);
}
while (Interlocked.CompareExchange(ref bits, newValue, snapshot) != snapshot);
return action;
}
public readonly ActionType SelectRandom(Random random) =>
IsDeterministic ?
First() :
ElementAt(random.Next(Count));
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ActionType PopRandom(Random random)
{
if (IsDeterministic)
return PopFirst();
var action = ElementAt(random.Next(Count));
RemoveAction(action);
return action;
+6 -27
View File
@@ -11,41 +11,20 @@ public struct ArenaBuffer<T> where T : struct
// The benchmark reaches 20 at most, but here we have a little leeway just in case.
private const int MaxSize = 24;
private static int BatchSize = Vector<float>.Count;
private static int BatchSizeBits = int.Log2(BatchSize);
private static int BatchSizeMask = BatchSize - 1;
private static readonly int BatchSize = Vector<float>.Count;
private static readonly int BatchSizeBits = int.Log2(BatchSize);
private static readonly int BatchSizeMask = BatchSize - 1;
private static int BatchCount = MaxSize / BatchSize;
private static readonly int BatchCount = MaxSize / BatchSize;
public ArenaNode<T>[][] Data;
private int index; // Unused in single threaded workload
private int count;
public readonly int Count => count;
public void AddConcurrent(ArenaNode<T> node)
{
if (Data == null)
Interlocked.CompareExchange(ref Data, new ArenaNode<T>[BatchCount][], null);
var idx = Interlocked.Increment(ref index) - 1;
var (arrayIdx, subIdx) = GetArrayIndex(idx);
if (Data[arrayIdx] == null)
Interlocked.CompareExchange(ref Data[arrayIdx], new ArenaNode<T>[BatchSize], null);
node.ChildIdx = (arrayIdx, subIdx);
Data[arrayIdx][subIdx] = node;
Interlocked.Increment(ref count);
}
public int Count { get; private set; }
public void Add(ArenaNode<T> node)
{
Data ??= new ArenaNode<T>[BatchCount][];
var idx = count++;
var idx = Count++;
var (arrayIdx, subIdx) = GetArrayIndex(idx);
-9
View File
@@ -23,15 +23,6 @@ public sealed class ArenaNode<T> where T : struct
public ArenaNode<T>? ChildAt((int arrayIdx, int subIdx) at) =>
Children.Data?[at.arrayIdx]?[at.subIdx];
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ArenaNode<T> AddConcurrent(T state)
{
var node = new ArenaNode<T>(state, this);
ChildScores.AddConcurrent();
Children.AddConcurrent(node);
return node;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ArenaNode<T> Add(T state)
{
-10
View File
@@ -1,10 +0,0 @@
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
namespace Craftimizer.Solver.Crafty;
public interface ISolver
{
abstract static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator);
abstract static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token);
}
-24
View File
@@ -124,28 +124,4 @@ internal static class Intrinsics
result[i] = MathF.ReciprocalSqrtEstimate(data[i]);
return new(result);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void CASMax(ref float location, float newValue)
{
float snapshot;
do
{
snapshot = location;
if (snapshot >= newValue) return;
} while (Interlocked.CompareExchange(ref location, newValue, snapshot) != snapshot);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void CASAdd(ref float location, float value)
{
float snapshot;
float newValue;
do
{
snapshot = location;
newValue = snapshot + value;
}
while (Interlocked.CompareExchange(ref location, newValue, snapshot) != snapshot);
}
}
+6 -34
View File
@@ -1,9 +1,6 @@
using System;
using System.ComponentModel;
using System.Diagnostics.Contracts;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
namespace Craftimizer.Solver.Crafty;
@@ -28,51 +25,26 @@ public struct NodeScoresBuffer
// The benchmark reaches 20 at most, but here we have a little leeway just in case.
private const int MaxSize = 24;
private static int BatchSize = Vector<float>.Count;
private static int BatchSizeBits = int.Log2(BatchSize);
private static int BatchSizeMask = BatchSize - 1;
private static readonly int BatchSize = Vector<float>.Count;
private static readonly int BatchSizeBits = int.Log2(BatchSize);
private static readonly int BatchSizeMask = BatchSize - 1;
private static int BatchCount = MaxSize / BatchSize;
private static readonly int BatchCount = MaxSize / BatchSize;
public ScoresBatch[] Data;
private int index;
private int count;
public readonly int Count => count;
public void AddConcurrent()
{
if (Data == null)
Interlocked.CompareExchange(ref Data, new ScoresBatch[BatchCount], null);
var idx = Interlocked.Increment(ref index) - 1;
var (arrayIdx, _) = GetArrayIndex(idx);
if (Data[arrayIdx] == null)
Interlocked.CompareExchange(ref Data[arrayIdx], new ScoresBatch(), null);
Interlocked.Increment(ref count);
}
public int Count { get; private set; }
public void Add()
{
Data ??= new ScoresBatch[BatchCount];
var idx = count++;
var idx = Count++;
var (arrayIdx, _) = GetArrayIndex(idx);
Data[arrayIdx] ??= new();
}
public readonly void VisitConcurrent((int arrayIdx, int subIdx) at, float score)
{
Intrinsics.CASAdd(ref Data[at.arrayIdx].ScoreSum.Span[at.subIdx], score);
Intrinsics.CASMax(ref Data[at.arrayIdx].MaxScore.Span[at.subIdx], score);
Interlocked.Increment(ref Data[at.arrayIdx].Visits.Span[at.subIdx]);
}
public readonly void Visit((int arrayIdx, int subIdx) at, float score)
{
Data[at.arrayIdx].ScoreSum.Span[at.subIdx] += score;
-7
View File
@@ -9,13 +9,6 @@ public sealed class RootScores
public float MaxScore;
public int Visits;
public void VisitConcurrent(float score)
{
Intrinsics.CASAdd(ref ScoreSum, score);
Intrinsics.CASMax(ref MaxScore, score);
Interlocked.Increment(ref Visits);
}
public void Visit(float score)
{
ScoreSum += score;
@@ -1,14 +1,35 @@
using Craftimizer.Simulator;
using Craftimizer.Simulator.Actions;
using Craftimizer.Simulator;
using System.Diagnostics.Contracts;
using System.Numerics;
using System.Runtime.CompilerServices;
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
namespace Craftimizer.Solver.Crafty;
public static class SolverUtils
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
public sealed class Solver
{
public static SimulationNode Execute(Simulator simulator, SimulationState state, ActionType action, bool strict)
private SolverConfig config;
private Node rootNode;
private RootScores rootScores;
public float MaxScore => rootScores.MaxScore;
public Solver(SolverConfig config, SimulationState state, bool strict)
{
this.config = config;
var sim = new Simulator(state, config.MaxStepCount);
rootNode = new(new(
state,
null,
sim.CompletionState,
sim.AvailableActionsHeuristic(strict)
));
rootScores = new();
}
private static SimulationNode Execute(Simulator simulator, SimulationState state, ActionType action, bool strict)
{
(_, var newState) = simulator.Execute(state, action);
return new(
@@ -19,27 +40,50 @@ public static class SolverUtils
);
}
public static (Node EndNode, CompletionState State) ExecuteActions(Simulator simulator, Node startNode, ReadOnlySpan<ActionType> actions, bool strict = false)
private static Node ExecuteActions(Simulator simulator, Node startNode, ReadOnlySpan<ActionType> actions, bool strict)
{
foreach (var action in actions)
{
var state = startNode.State;
if (state.IsComplete)
return (startNode, state.CompletionState);
return startNode;
if (!state.AvailableActions.HasAction(action))
return (startNode, CompletionState.InvalidAction);
return startNode;
state.AvailableActions.RemoveAction(action);
startNode = startNode.Add(Execute(simulator, state.State, action, strict));
}
return (startNode, startNode.State.CompletionState);
return startNode;
}
[Pure]
private (List<ActionType> Actions, SimulationNode Node) Solution()
{
var actions = new List<ActionType>();
var node = rootNode;
while (node.Children.Count != 0)
{
node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!;
if (node.State.Action != null)
actions.Add(node.State.Action.Value);
}
var at = node.ChildIdx;
ref var sum = ref node.ParentScores!.Value.Data[at.arrayIdx].ScoreSum.Span[at.subIdx];
ref var max = ref node.ParentScores!.Value.Data[at.arrayIdx].MaxScore.Span[at.subIdx];
ref var visits = ref node.ParentScores!.Value.Data[at.arrayIdx].Visits.Span[at.subIdx];
//Console.WriteLine($"{sum} {max} {visits}");
return (actions, node.State);
}
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores)
private static (int arrayIdx, int subIdx) ChildMaxScore(ref NodeScoresBuffer scores)
{
var length = scores.Count;
var vecLength = Vector<float>.Count;
@@ -67,21 +111,6 @@ public static class SolverUtils
return max;
}
[Pure]
public static (List<ActionType> Actions, SimulationNode Node) Solution(Node node)
{
var actions = new List<ActionType>();
while (node.Children.Count != 0)
{
node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!;
if (node.State.Action != null)
actions.Add(node.State.Action.Value);
}
return (actions, node.State);
}
// Calculates the best child node to explore next
// Exploitation: ((1 - w) * (s / v)) + (w * m)
// Exploration: sqrt(c * ln(V) / v)
@@ -98,10 +127,9 @@ public static class SolverUtils
// max score is tacked onto it
// N_i = parent visits
// c = exploration constant (but crafty places it inside the sqrt..?)
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
public static (int arrayIdx, int subIdx) EvalBestChild<S>(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer scores) where S : ISolver
private (int arrayIdx, int subIdx) EvalBestChild(int parentVisits, ref NodeScoresBuffer scores)
{
var length = scores.Count;
var vecLength = Vector<float>.Count;
@@ -143,9 +171,36 @@ public static class SolverUtils
return max;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static (Node ExpandedNode, float Score) Rollout(ref SolverConfig config, float maxScore, Node rootNode, Node expandedNode, Random random, Simulator simulator)
[Pure]
public Node Select()
{
var node = rootNode;
var nodeVisits = rootScores.Visits;
while (true)
{
var expandable = !node.State.AvailableActions.IsEmpty;
var likelyTerminal = node.Children.Count == 0;
if (expandable || likelyTerminal)
return node;
// select the node with the highest score
var at = EvalBestChild(nodeVisits, ref node.ChildScores);
nodeVisits = node.ChildScores.GetVisits(at);
node = node.ChildAt(at)!;
}
}
public (Node ExpandedNode, float Score) ExpandAndRollout(Random random, Simulator simulator, Node initialNode)
{
ref var initialState = ref initialNode.State;
// expand once
if (initialState.IsComplete)
return (initialNode, initialState.CalculateScore(config.MaxStepCount) ?? 0);
var poppedAction = initialState.AvailableActions.PopRandom(random);
var expandedNode = initialNode.Add(Execute(simulator, initialState.State, poppedAction, true));
// playout to a terminal state
var currentState = expandedNode.State.State;
var currentCompletionState = expandedNode.State.SimulationCompletionState;
@@ -168,56 +223,51 @@ public static class SolverUtils
var score = SimulationNode.CalculateScoreForState(currentState, currentCompletionState, config.MaxStepCount) ?? 0;
if (currentCompletionState == CompletionState.ProgressComplete)
{
if (score >= config.ScoreStorageThreshold && score >= maxScore)
if (score >= config.ScoreStorageThreshold && score >= MaxScore)
{
(var terminalNode, _) = ExecuteActions(simulator, expandedNode, actions[..actionCount], true);
var terminalNode = ExecuteActions(simulator, expandedNode, actions[..actionCount], true);
return (terminalNode, score);
}
}
return (expandedNode, score);
}
public void Backpropagate(Node startNode, float score)
{
while (true)
{
if (startNode == rootNode)
{
rootScores.Visit(score);
break;
}
startNode.ParentScores!.Value.Visit(startNode.ChildIdx, score);
startNode = startNode.Parent!;
}
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Search<S>(ref SolverConfig config, int iterations, RootScores rootScores, Node rootNode, CancellationToken token) where S : ISolver
private void Search(CancellationToken token)
{
Simulator simulator = new(rootNode.State.State, config.MaxStepCount);
var random = rootNode.State.State.Input.Random;
for (var i = 0; i < iterations; i++)
for (var i = 0; i < config.Iterations; i++)
{
if (token.IsCancellationRequested)
break;
if (!S.SearchIter(ref config, rootScores, rootNode, random, simulator))
{
// Retry, count this iteration as moot
i--;
continue;
}
var selectedNode = Select();
var (endNode, score) = ExpandAndRollout(random, simulator, selectedNode);
Backpropagate(endNode, score);
}
}
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Node CreateRootNode(SolverConfig config, SimulationInput input, bool strict) =>
CreateRootNode(config, new SimulationState(input), strict);
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseForked(SolverConfig config, int forkCount, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) =>
SearchStepwiseForked(config, forkCount, new SimulationState(input), actionCallback, token);
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Node CreateRootNode(SolverConfig config, SimulationState state, bool strict)
{
var sim = new Simulator(state, config.MaxStepCount);
return new(new(
state,
null,
sim.CompletionState,
sim.AvailableActionsHeuristic(strict)
));
}
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseForked<S>(SolverConfig config, int forkCount, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver =>
SearchStepwiseForked<S>(config, forkCount, new SimulationState(input), actionCallback, token);
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseForked<S>(SolverConfig config, int forkCount, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseForked(SolverConfig config, int forkCount, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token = default)
{
var actions = new List<ActionType>();
var sim = new Simulator(state, config.MaxStepCount);
@@ -230,13 +280,11 @@ public static class SolverUtils
for (var i = 0; i < forkCount; ++i)
tasks[i] = Task.Run(() =>
{
var rootNode = CreateRootNode(config, state, true);
RootScores rootScores = new();
var solver = new Solver(config, state, true);
solver.Search(token);
var (solution_actions, solution_node) = solver.Solution();
S.Search(ref config, rootScores, rootNode, token);
var (solution_actions, solution_node) = Solution(rootNode);
return (rootScores.MaxScore, solution_actions, solution_node.State);
return (solver.MaxScore, solution_actions, solution_node.State);
}, token);
Task.WaitAll(tasks, CancellationToken.None);
@@ -258,24 +306,23 @@ public static class SolverUtils
return (actions, state);
}
public static (List<ActionType> Actions, SimulationState State) SearchStepwise<S>(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver =>
SearchStepwise<S>(config, new SimulationState(input), actionCallback, token);
public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback, CancellationToken token = default) =>
SearchStepwise(config, new SimulationState(input), actionCallback, token);
public static (List<ActionType> Actions, SimulationState State) SearchStepwise<S>(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token = default) where S : ISolver
public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback, CancellationToken token = default)
{
var actions = new List<ActionType>();
var sim = new Simulator(state, config.MaxStepCount);
var rootNode = CreateRootNode(config, state, true);
RootScores rootScores = new();
var solver = new Solver(config, state, true);
while (!sim.IsComplete)
{
if (token.IsCancellationRequested)
break;
S.Search(ref config, rootScores, rootNode, token);
var (solution_actions, solution_node) = Solution(rootNode);
solver.Search(token);
var (solution_actions, solution_node) = solver.Solution();
if (rootScores.MaxScore >= 1.0)
if (solver.MaxScore >= 1.0)
{
actions.AddRange(solution_actions);
return (actions, solution_node.State);
@@ -287,21 +334,20 @@ public static class SolverUtils
actionCallback?.Invoke(chosen_action);
rootNode = CreateRootNode(config, state, true);
solver = new Solver(config, state, true);
}
return (actions, state);
}
public static (List<ActionType> Actions, SimulationState State) SearchOneshot<S>(SolverConfig config, SimulationInput input, CancellationToken token = default) where S : ISolver =>
SearchOneshot<S>(config, new SimulationState(input), token);
public static (List<ActionType> Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationInput input, CancellationToken token = default) =>
SearchOneshot(config, new SimulationState(input), token);
public static (List<ActionType> Actions, SimulationState State) SearchOneshot<S>(SolverConfig config, SimulationState state, CancellationToken token = default) where S : ISolver
public static (List<ActionType> Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationState state, CancellationToken token = default)
{
var rootNode = CreateRootNode(config, state, false);
RootScores rootScores = new();
S.Search(ref config, rootScores, rootNode, token);
var (solution_actions, solution_node) = Solution(rootNode);
var solver = new Solver(config, state, false);
solver.Search(token);
var (solution_actions, solution_node) = solver.Solution();
return (solution_actions, solution_node.State);
}
}
-103
View File
@@ -1,103 +0,0 @@
using System.Diagnostics.Contracts;
using System.Runtime.CompilerServices;
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
namespace Craftimizer.Solver.Crafty;
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
public sealed class SolverConcurrent : ISolver
{
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static (int arrayIdx, int subIdx)? EvalBestChild(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer children) =>
parentVisits == 0 ?
null :
SolverUtils.EvalBestChild<SolverConcurrent>(ref config, parentVisits, ref children);
[Pure]
public static Node Select(ref SolverConfig config, int rootNodeVisits, Node rootNode)
{
var node = rootNode;
var nodeVisits = rootNodeVisits;
while (true)
{
var expandable = !node.State.AvailableActions.IsEmpty;
var likelyTerminal = node.Children.Count == 0;
if (expandable || likelyTerminal)
return node;
// select the node with the highest score
// if null (current node is invalid & not backpropagated just yet), try again from root
var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores);
if (at.HasValue)
{
nodeVisits = node.ChildScores.GetVisits(at.Value);
node = node.ChildAt(at.Value);
if (node == null)
{
node = rootNode;
nodeVisits = rootNodeVisits;
}
}
else
{
node = rootNode;
nodeVisits = rootNodeVisits;
}
}
}
public static (Node ExpandedNode, float Score)? ExpandAndRollout(ref SolverConfig config, float maxScore, Node rootNode, Random random, Simulator simulator, Node initialNode)
{
ref var initialState = ref initialNode.State;
// expand once
if (initialState.IsComplete)
return (initialNode, initialState.CalculateScore(config.MaxStepCount) ?? 0);
var poppedAction = initialState.AvailableActions.PopRandomConcurrent(random);
if (!poppedAction.HasValue)
return null;
var expandedNode = initialNode.AddConcurrent(SolverUtils.Execute(simulator, initialState.State, poppedAction.Value, true));
return SolverUtils.Rollout(ref config, maxScore, rootNode, expandedNode, random, simulator);
}
public static void Backpropagate(RootScores rootScores, Node rootNode, Node startNode, float score)
{
while (true)
{
if (startNode == rootNode)
{
rootScores.VisitConcurrent(score);
break;
}
startNode.ParentScores!.Value.VisitConcurrent(startNode.ChildIdx, score);
startNode = startNode.Parent!;
}
}
public static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator)
{
var selectedNode = Select(ref config, rootScores.Visits, rootNode);
var rolledOut = ExpandAndRollout(ref config, rootScores.MaxScore, rootNode, random, simulator, selectedNode);
if (!rolledOut.HasValue)
return false;
var (endNode, score) = rolledOut.Value;
Backpropagate(rootScores, rootNode, endNode, score);
return true;
}
public static void SearchThread(SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token) =>
SolverUtils.Search<SolverConcurrent>(ref config, config.Iterations / config.ThreadCount, rootScores, rootNode, token);
public static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token)
{
var configP = config;
var tasks = new Task[config.ThreadCount];
for (var i = 0; i < config.ThreadCount; ++i)
tasks[i] = Task.Run(() => SearchThread(configP, rootScores, rootNode, token), token);
Task.WaitAll(tasks, CancellationToken.None);
}
}
-71
View File
@@ -1,71 +0,0 @@
using System.Diagnostics.Contracts;
using System.Runtime.CompilerServices;
using Node = Craftimizer.Solver.Crafty.ArenaNode<Craftimizer.Solver.Crafty.SimulationNode>;
namespace Craftimizer.Solver.Crafty;
// https://github.com/alostsock/crafty/blob/cffbd0cad8bab3cef9f52a3e3d5da4f5e3781842/crafty/src/simulator.rs
public sealed class SolverSingle : ISolver
{
[Pure]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static (int arrayIdx, int subIdx) EvalBestChild(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer children) =>
SolverUtils.EvalBestChild<SolverSingle>(ref config, parentVisits, ref children);
[Pure]
public static Node Select(ref SolverConfig config, int nodeVisits, Node node)
{
while (true)
{
var expandable = !node.State.AvailableActions.IsEmpty;
var likelyTerminal = node.Children.Count == 0;
if (expandable || likelyTerminal)
return node;
// select the node with the highest score
var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores);
nodeVisits = node.ChildScores.GetVisits(at);
node = node.ChildAt(at)!;
}
}
public static (Node ExpandedNode, float Score) ExpandAndRollout(ref SolverConfig config, float maxScore, Node rootNode, Random random, Simulator simulator, Node initialNode)
{
ref var initialState = ref initialNode.State;
// expand once
if (initialState.IsComplete)
return (initialNode, initialState.CalculateScore(config.MaxStepCount) ?? 0);
var poppedAction = initialState.AvailableActions.PopRandom(random);
var expandedNode = initialNode.Add(SolverUtils.Execute(simulator, initialState.State, poppedAction, true));
return SolverUtils.Rollout(ref config, maxScore, rootNode, expandedNode, random, simulator);
}
public static void Backpropagate(RootScores rootScores, Node rootNode, Node startNode, float score)
{
while (true)
{
if (startNode == rootNode)
{
rootScores.Visit(score);
break;
}
startNode.ParentScores!.Value.Visit(startNode.ChildIdx, score);
startNode = startNode.Parent!;
}
}
public static bool SearchIter(ref SolverConfig config, RootScores rootScores, Node rootNode, Random random, Simulator simulator)
{
var selectedNode = Select(ref config, rootScores.Visits, rootNode);
var (endNode, score) = ExpandAndRollout(ref config, rootScores.MaxScore, rootNode, random, simulator, selectedNode);
Backpropagate(rootScores, rootNode, endNode, score);
return true;
}
public static void Search(ref SolverConfig config, RootScores rootScores, Node rootNode, CancellationToken token) =>
SolverUtils.Search<SolverSingle>(ref config, config.Iterations, rootScores, rootNode, token);
}