Minor solver refactoring

This commit is contained in:
Asriel Camora
2023-07-13 21:22:39 +02:00
parent df72386f19
commit 9a245aa9ad
3 changed files with 65 additions and 50 deletions
+7 -4
View File
@@ -55,10 +55,13 @@ public struct ActionSet
public readonly bool IsEmpty => bits == 0; public readonly bool IsEmpty => bits == 0;
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public readonly ActionType SelectRandom(Random random) => public readonly ActionType SelectRandom(Random random)
IsDeterministic ? {
First() : if (IsDeterministic)
ElementAt(random.Next(Count)); return First();
return ElementAt(random.Next(Count));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public ActionType PopRandom(Random random) public ActionType PopRandom(Random random)
+52 -46
View File
@@ -61,7 +61,7 @@ public sealed class Solver
} }
[Pure] [Pure]
private (List<ActionType> Actions, SimulationNode Node) Solution() private SolverSolution Solution()
{ {
var actions = new List<ActionType>(); var actions = new List<ActionType>();
var node = rootNode; var node = rootNode;
@@ -80,7 +80,7 @@ public sealed class Solver
//ref var visits = ref node.ParentScores!.Value.Data[at.arrayIdx].Visits.Span[at.subIdx]; //ref var visits = ref node.ParentScores!.Value.Data[at.arrayIdx].Visits.Span[at.subIdx];
//Console.WriteLine($"{sum} {max} {visits}"); //Console.WriteLine($"{sum} {max} {visits}");
return (actions, node.State); return new(actions, node.State.State);
} }
[Pure] [Pure]
@@ -332,17 +332,17 @@ public sealed class Solver
} }
} }
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseFurcated(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback = null, CancellationToken token = default) => public static SolverSolution SearchStepwiseFurcated(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback = null, CancellationToken token = default) =>
SearchStepwiseFurcated(config, new SimulationState(input), actionCallback, token); SearchStepwiseFurcated(config, new SimulationState(input), actionCallback, token);
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseFurcated(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback = null, CancellationToken token = default) public static SolverSolution SearchStepwiseFurcated(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback = null, CancellationToken token = default)
{ {
var definiteActionCount = 0; var definiteActionCount = 0;
var bestSims = new List<(float Score, (List<ActionType> Actions, SimulationState State) Result)>(); var bestSims = new List<(float Score, SolverSolution Result)>();
var sim = new Simulator(state, config.MaxStepCount); var sim = new Simulator(state, config.MaxStepCount);
var activeStates = new List<(List<ActionType> Actions, SimulationState State)>() { (new(), state) }; var activeStates = new List<SolverSolution>() { new(new(), state) };
while (activeStates.Count != 0) while (activeStates.Count != 0)
{ {
@@ -350,7 +350,7 @@ public sealed class Solver
break; break;
var s = Stopwatch.StartNew(); var s = Stopwatch.StartNew();
var tasks = new List<Task<(float MaxScore, int FurcatedActionIdx, (List<ActionType> Actions, SimulationNode Node) Solution)>>(config.ForkCount); var tasks = new List<Task<(float MaxScore, int FurcatedActionIdx, SolverSolution Solution)>>(config.ForkCount);
for (var i = 0; i < config.ForkCount; i++) for (var i = 0; i < config.ForkCount; i++)
{ {
var stateIdx = (int)((float)i / config.ForkCount * activeStates.Count); var stateIdx = (int)((float)i / config.ForkCount * activeStates.Count);
@@ -372,14 +372,14 @@ public sealed class Solver
var bestAction = bestActions[0]; var bestAction = bestActions[0];
if (bestAction.MaxScore >= config.ScoreStorageThreshold) if (bestAction.MaxScore >= config.ScoreStorageThreshold)
{ {
var (maxScore, furcatedActionIdx, (solutionActions, solutionNode)) = bestAction; var (maxScore, furcatedActionIdx, solution) = bestAction;
var (activeActions, activeState) = activeStates[furcatedActionIdx]; var (activeActions, activeState) = activeStates[furcatedActionIdx];
activeActions.AddRange(solutionActions); activeActions.AddRange(solution.Actions);
return (activeActions, solutionNode.State); return solution with { Actions = activeActions };
} }
var newStates = new List<(List<ActionType> Actions, SimulationState State)>(config.FurcatedActionCount); var newStates = new List<SolverSolution>(config.FurcatedActionCount);
for (var i = 0; i < bestActions.Length; ++i) for (var i = 0; i < bestActions.Length; ++i)
{ {
var (maxScore, furcatedActionIdx, (solutionActions, solutionNode)) = bestActions[i]; var (maxScore, furcatedActionIdx, (solutionActions, solutionNode)) = bestActions[i];
@@ -393,9 +393,9 @@ public sealed class Solver
var newActions = new List<ActionType>(activeActions) { chosenAction }; var newActions = new List<ActionType>(activeActions) { chosenAction };
var newState = sim.Execute(activeState, chosenAction).NewState; var newState = sim.Execute(activeState, chosenAction).NewState;
if (sim.IsComplete) if (sim.IsComplete)
bestSims.Add((maxScore, (newActions, newState))); bestSims.Add((maxScore, new(newActions, newState)));
else else
newStates.Add((newActions, newState)); newStates.Add(new(newActions, newState));
} }
if (bestSims.Count == 0 && newStates.Count != 0) if (bestSims.Count == 0 && newStates.Count != 0)
@@ -435,7 +435,7 @@ public sealed class Solver
} }
if (bestSims.Count == 0) if (bestSims.Count == 0)
return (new(), state); return new(new(), state);
var result = bestSims.MaxBy(s => s.Score).Result; var result = bestSims.MaxBy(s => s.Score).Result;
for (var i = definiteActionCount; i < result.Actions.Count; ++i) for (var i = definiteActionCount; i < result.Actions.Count; ++i)
@@ -444,10 +444,10 @@ public sealed class Solver
return result; return result;
} }
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseForked(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback = null, CancellationToken token = default) => public static SolverSolution SearchStepwiseForked(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback = null, CancellationToken token = default) =>
SearchStepwiseForked(config, new SimulationState(input), actionCallback, token); SearchStepwiseForked(config, new SimulationState(input), actionCallback, token);
public static (List<ActionType> Actions, SimulationState State) SearchStepwiseForked(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback = null, CancellationToken token = default) public static SolverSolution SearchStepwiseForked(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback = null, CancellationToken token = default)
{ {
var actions = new List<ActionType>(); var actions = new List<ActionType>();
var sim = new Simulator(state, config.MaxStepCount); var sim = new Simulator(state, config.MaxStepCount);
@@ -461,7 +461,7 @@ public sealed class Solver
var s = Stopwatch.StartNew(); var s = Stopwatch.StartNew();
var tasks = new Task<(float MaxScore, (List<ActionType> Actions, SimulationNode Node) Solution)>[config.ForkCount]; var tasks = new Task<(float MaxScore, SolverSolution Solution)>[config.ForkCount];
for (var i = 0; i < config.ForkCount; ++i) for (var i = 0; i < config.ForkCount; ++i)
tasks[i] = Task.Run(() => tasks[i] = Task.Run(() =>
{ {
@@ -472,29 +472,29 @@ public sealed class Solver
Task.WaitAll(tasks, CancellationToken.None); Task.WaitAll(tasks, CancellationToken.None);
s.Stop(); s.Stop();
var (maxScore, (solutionActions, solutionNode)) = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore); var (maxScore, solution) = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore);
if (maxScore >= config.ScoreStorageThreshold) if (maxScore >= config.ScoreStorageThreshold)
{ {
actions.AddRange(solutionActions); actions.AddRange(solution.Actions);
return (actions, solutionNode.State); return solution with { Actions = actions };
} }
var chosen_action = solutionActions[0]; var chosenAction = solution.Actions[0];
actionCallback?.Invoke(chosen_action); actionCallback?.Invoke(chosenAction);
Console.WriteLine($"{s.Elapsed.TotalMilliseconds:0.00}ms {config.Iterations / config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t"); Console.WriteLine($"{s.Elapsed.TotalMilliseconds:0.00}ms {config.Iterations / config.ForkCount / s.Elapsed.TotalSeconds / 1000:0.00} kI/s/t");
(_, state) = sim.Execute(state, chosen_action); (_, state) = sim.Execute(state, chosenAction);
actions.Add(chosen_action); actions.Add(chosenAction);
} }
return (actions, state); return new(actions, state);
} }
public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback = null, CancellationToken token = default) => public static SolverSolution SearchStepwise(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback = null, CancellationToken token = default) =>
SearchStepwise(config, new SimulationState(input), actionCallback, token); SearchStepwise(config, new SimulationState(input), actionCallback, token);
public static (List<ActionType> Actions, SimulationState State) SearchStepwise(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback = null, CancellationToken token = default) public static SolverSolution SearchStepwise(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback = null, CancellationToken token = default)
{ {
var actions = new List<ActionType>(); var actions = new List<ActionType>();
var sim = new Simulator(state, config.MaxStepCount); var sim = new Simulator(state, config.MaxStepCount);
@@ -512,31 +512,31 @@ public sealed class Solver
solver.Search(config.Iterations, token); solver.Search(config.Iterations, token);
s.Stop(); s.Stop();
var (solution_actions, solution_node) = solver.Solution(); var solution = solver.Solution();
if (solver.MaxScore >= config.ScoreStorageThreshold) if (solver.MaxScore >= config.ScoreStorageThreshold)
{ {
actions.AddRange(solution_actions); actions.AddRange(solution.Actions);
return (actions, solution_node.State); return solution with { Actions = actions };
} }
var chosen_action = solution_actions[0]; var chosenAction = solution.Actions[0];
actionCallback?.Invoke(chosen_action); actionCallback?.Invoke(chosenAction);
Console.WriteLine($"{s.Elapsed.TotalMilliseconds:0.00}ms {config.Iterations / s.Elapsed.TotalSeconds / 1000:0.00} kI/s"); Console.WriteLine($"{s.Elapsed.TotalMilliseconds:0.00}ms {config.Iterations / s.Elapsed.TotalSeconds / 1000:0.00} kI/s");
(_, state) = sim.Execute(state, chosen_action); (_, state) = sim.Execute(state, chosenAction);
actions.Add(chosen_action); actions.Add(chosenAction);
} }
return (actions, state); return new(actions, state);
} }
public static (List<ActionType> Actions, SimulationState State) SearchOneshotForked(SolverConfig config, SimulationInput input, CancellationToken token = default) => public static SolverSolution SearchOneshotForked(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback = null, CancellationToken token = default) =>
SearchOneshotForked(config, new SimulationState(input), token); SearchOneshotForked(config, new SimulationState(input), actionCallback, token);
public static (List<ActionType> Actions, SimulationState State) SearchOneshotForked(SolverConfig config, SimulationState state, CancellationToken token = default) public static SolverSolution SearchOneshotForked(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback = null, CancellationToken token = default)
{ {
var tasks = new Task<(float MaxScore, (List<ActionType> Actions, SimulationNode Node) Solution)>[config.ForkCount]; var tasks = new Task<(float MaxScore, SolverSolution Solution)>[config.ForkCount];
for (var i = 0; i < config.ForkCount; ++i) for (var i = 0; i < config.ForkCount; ++i)
tasks[i] = Task.Run(() => tasks[i] = Task.Run(() =>
{ {
@@ -546,18 +546,24 @@ public sealed class Solver
}, token); }, token);
Task.WaitAll(tasks, CancellationToken.None); Task.WaitAll(tasks, CancellationToken.None);
var (solutionActions, solutionNode) = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore).Solution; var solution = tasks.Select(t => t.Result).MaxBy(r => r.MaxScore).Solution;
return (solutionActions, solutionNode.State); foreach (var action in solution.Actions)
actionCallback?.Invoke(action);
return solution;
} }
public static (List<ActionType> Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationInput input, CancellationToken token = default) => public static SolverSolution SearchOneshot(SolverConfig config, SimulationInput input, Action<ActionType>? actionCallback = null, CancellationToken token = default) =>
SearchOneshot(config, new SimulationState(input), token); SearchOneshot(config, new SimulationState(input), actionCallback, token);
public static (List<ActionType> Actions, SimulationState State) SearchOneshot(SolverConfig config, SimulationState state, CancellationToken token = default) public static SolverSolution SearchOneshot(SolverConfig config, SimulationState state, Action<ActionType>? actionCallback = null, CancellationToken token = default)
{ {
var solver = new Solver(config, state); var solver = new Solver(config, state);
solver.Search(config.Iterations, token); solver.Search(config.Iterations, token);
var (solution_actions, solution_node) = solver.Solution(); var solution = solver.Solution();
return (solution_actions, solution_node.State); foreach (var action in solution.Actions)
actionCallback?.Invoke(action);
return solution;
} }
} }
+6
View File
@@ -0,0 +1,6 @@
using Craftimizer.Simulator;
using Craftimizer.Simulator.Actions;
namespace Craftimizer.Solver.Crafty;
public readonly record struct SolverSolution(List<ActionType> Actions, SimulationState State);