From 3ab50d389e60eab64b0f240761cfc6dd4aa6fc0f Mon Sep 17 00:00:00 2001 From: Asriel Camora Date: Fri, 7 Jul 2023 18:13:27 +0200 Subject: [PATCH] Fix concurrency, add forked simulations --- Benchmark/Program.cs | 8 ++-- Solver/Crafty/ArenaNode.cs | 8 ++-- Solver/Crafty/SolverConcurrent.cs | 5 +++ Solver/Crafty/SolverConfig.cs | 4 +- Solver/Crafty/SolverSingle.cs | 2 +- Solver/Crafty/SolverUtils.cs | 71 ++++++++++++++++++++++++++++--- 6 files changed, 84 insertions(+), 14 deletions(-) diff --git a/Benchmark/Program.cs b/Benchmark/Program.cs index ef2ee2d..69f0147 100644 --- a/Benchmark/Program.cs +++ b/Benchmark/Program.cs @@ -46,14 +46,16 @@ internal static class Program var config = new SolverConfig() { - Iterations = 1_000_000, + Iterations = 100_000, ThreadCount = 8, }; Debugger.Break(); var s = Stopwatch.StartNew(); - if (true) - _ = SolverUtils.SearchStepwise(config, input, a => Console.WriteLine(a)); + if (true) { + (_, var state) = SolverUtils.SearchStepwise(config, input, a => Console.WriteLine(a)); + Console.WriteLine($"Qual: {state.Quality}/{state.Input.Recipe.MaxQuality}"); + } else { (var actions, _) = SolverUtils.SearchOneshot(config, input); diff --git a/Solver/Crafty/ArenaNode.cs b/Solver/Crafty/ArenaNode.cs index 90d5b6c..89fba01 100644 --- a/Solver/Crafty/ArenaNode.cs +++ b/Solver/Crafty/ArenaNode.cs @@ -20,15 +20,15 @@ public sealed class ArenaNode where T : struct Parent = parent; } - public ArenaNode ChildAt((int arrayIdx, int subIdx) at) => - Children.Data[at.arrayIdx][at.subIdx]; + public ArenaNode? ChildAt((int arrayIdx, int subIdx) at) => + Children.Data?[at.arrayIdx]?[at.subIdx]; [MethodImpl(MethodImplOptions.AggressiveInlining)] public ArenaNode AddConcurrent(T state) { var node = new ArenaNode(state, this); - Children.AddConcurrent(node); ChildScores.AddConcurrent(); + Children.AddConcurrent(node); return node; } @@ -36,8 +36,8 @@ public sealed class ArenaNode where T : struct public ArenaNode Add(T state) { var node = new ArenaNode(state, this); - Children.Add(node); ChildScores.Add(); + Children.Add(node); return node; } } diff --git a/Solver/Crafty/SolverConcurrent.cs b/Solver/Crafty/SolverConcurrent.cs index 91080a0..3bb6032 100644 --- a/Solver/Crafty/SolverConcurrent.cs +++ b/Solver/Crafty/SolverConcurrent.cs @@ -33,6 +33,11 @@ public sealed class SolverConcurrent : ISolver { nodeVisits = node.ChildScores.GetVisits(at.Value); node = node.ChildAt(at.Value); + if (node == null) + { + node = rootNode; + nodeVisits = rootNodeVisits; + } } else { diff --git a/Solver/Crafty/SolverConfig.cs b/Solver/Crafty/SolverConfig.cs index 46c7515..4ffd308 100644 --- a/Solver/Crafty/SolverConfig.cs +++ b/Solver/Crafty/SolverConfig.cs @@ -10,6 +10,7 @@ public readonly record struct SolverConfig public float MaxScoreWeightingConstant { get; init; } public float ExplorationConstant { get; init; } public int MaxStepCount { get; init; } + public int MaxRolloutStepCount { get; init; } public int ThreadCount { get; init; } public SolverConfig() @@ -17,8 +18,9 @@ public readonly record struct SolverConfig Iterations = 300000; ScoreStorageThreshold = 1f; MaxScoreWeightingConstant = 0.1f; - ExplorationConstant = 4f; + ExplorationConstant = 2; MaxStepCount = 25; + MaxRolloutStepCount = 99; ThreadCount = Environment.ProcessorCount; } } diff --git a/Solver/Crafty/SolverSingle.cs b/Solver/Crafty/SolverSingle.cs index b53fe08..0ecc287 100644 --- a/Solver/Crafty/SolverSingle.cs +++ b/Solver/Crafty/SolverSingle.cs @@ -25,7 +25,7 @@ public sealed class SolverSingle : ISolver // select the node with the highest score var at = EvalBestChild(ref config, nodeVisits, ref node.ChildScores); nodeVisits = node.ChildScores.GetVisits(at); - node = node.ChildAt(at); + node = node.ChildAt(at)!; } } diff --git a/Solver/Crafty/SolverUtils.cs b/Solver/Crafty/SolverUtils.cs index fbf92a0..959a8c8 100644 --- a/Solver/Crafty/SolverUtils.cs +++ b/Solver/Crafty/SolverUtils.cs @@ -1,9 +1,9 @@ -using Craftimizer.Simulator.Actions; using Craftimizer.Simulator; -using Node = Craftimizer.Solver.Crafty.ArenaNode; +using Craftimizer.Simulator.Actions; using System.Diagnostics.Contracts; using System.Numerics; using System.Runtime.CompilerServices; +using Node = Craftimizer.Solver.Crafty.ArenaNode; namespace Craftimizer.Solver.Crafty; public static class SolverUtils @@ -73,7 +73,7 @@ public static class SolverUtils var actions = new List(); while (node.Children.Count != 0) { - node = node.ChildAt(ChildMaxScore(ref node.ChildScores)); + node = node.ChildAt(ChildMaxScore(ref node.ChildScores))!; if (node.State.Action != null) actions.Add(node.State.Action.Value); @@ -82,6 +82,23 @@ public static class SolverUtils return (actions, node.State); } + // Calculates the best child node to explore next + // Exploitation: ((1 - w) * (s / v)) + (w * m) + // Exploration: sqrt(c * ln(V) / v) + // w = maxScoreWeightingConstant + // s = score sum + // m = max score + // v = visits + // V = parentVisits + // c = explorationConstant + + // Somewhat based off of https://en.wikipedia.org/wiki/Monte_Carlo_tree_search#Exploration_and_exploitation + // Here, w_i = (1-w)*score sum + // n_i = visits + // max score is tacked onto it + // N_i = parent visits + // c = exploration constant (but crafty places it inside the sqrt..?) + [Pure] [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] public static (int arrayIdx, int subIdx) EvalBestChild(ref SolverConfig config, int parentVisits, ref NodeScoresBuffer scores) where S : ISolver @@ -135,8 +152,8 @@ public static class SolverUtils var currentActions = expandedNode.State.AvailableActions; byte actionCount = 0; - Span actions = stackalloc ActionType[config.MaxStepCount - currentState.ActionCount]; - while (true) + Span actions = stackalloc ActionType[Math.Min(config.MaxStepCount - currentState.ActionCount, config.MaxRolloutStepCount)]; + while (actionCount < actions.Length) { if (SimulationNode.GetCompletionState(currentCompletionState, currentActions) != CompletionState.Incomplete) break; @@ -197,6 +214,50 @@ public static class SolverUtils )); } + public static (List Actions, SimulationState State) SearchStepwiseForked(SolverConfig config, int forkCount, SimulationInput input, Action? actionCallback, CancellationToken token = default) where S : ISolver => + SearchStepwiseForked(config, forkCount, new SimulationState(input), actionCallback, token); + + public static (List Actions, SimulationState State) SearchStepwiseForked(SolverConfig config, int forkCount, SimulationState state, Action? actionCallback, CancellationToken token = default) where S : ISolver + { + var actions = new List(); + var sim = new Simulator(state, config.MaxStepCount); + while (!sim.IsComplete) + { + if (token.IsCancellationRequested) + break; + + var tasks = new Task<(float score, List actions, SimulationState state)>[forkCount]; + for (var i = 0; i < forkCount; ++i) + tasks[i] = Task.Run(() => + { + var rootNode = CreateRootNode(config, state, true); + RootScores rootScores = new(); + + S.Search(ref config, rootScores, rootNode, token); + var (solution_actions, solution_node) = Solution(rootNode); + + return (rootScores.MaxScore, solution_actions, solution_node.State); + }, token); + Task.WaitAll(tasks, CancellationToken.None); + + var (score, solution_actions, solution_state) = tasks.Select(t => t.Result).MaxBy(r => r.score); + + if (score >= 1.0) + { + actions.AddRange(solution_actions); + return (actions, solution_state); + } + + var chosen_action = solution_actions[0]; + (_, state) = sim.Execute(state, chosen_action); + actions.Add(chosen_action); + + actionCallback?.Invoke(chosen_action); + } + + return (actions, state); + } + public static (List Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationInput input, Action? actionCallback, CancellationToken token = default) where S : ISolver => SearchStepwise(config, new SimulationState(input), actionCallback, token);