Generalized branch-and-bound implementation

This commit is contained in:
jazzpi 2022-12-21 19:32:31 +01:00
parent 47cb2fefa5
commit 0279edb249
2 changed files with 88 additions and 67 deletions

View File

@ -1,7 +1,7 @@
use std::collections::HashSet;
use regex::Regex;
use crate::util::{self, BnBState};
#[derive(Debug, Clone)]
pub struct Blueprint {
pub ore_ore_cost: usize,
@ -77,29 +77,6 @@ impl State {
}
}
pub fn finished(&self) -> bool {
self.time_remaining == 0
}
pub fn possible_actions(&self, blueprint: &Blueprint) -> Vec<State> {
assert!(!self.finished());
let mut result = Vec::new();
if self.time_remaining > 1 {
self.produce_ore_next(blueprint).map(|a| result.push(a));
self.produce_clay_next(blueprint).map(|a| result.push(a));
self.produce_obsidian_next(blueprint)
.map(|a| result.push(a));
self.produce_geode_next(blueprint).map(|a| result.push(a));
}
let mut do_nothing = self.clone();
do_nothing.run_steps(self.time_remaining);
result.push(do_nothing);
result
}
fn produce<F: Fn(&mut State)>(&self, time_for_ore_prod: isize, produce: F) -> Option<State> {
let time_until_robot_ready = (time_for_ore_prod.max(0) as usize) + 1;
// For this to make sense, we also need at least one minute
@ -180,16 +157,6 @@ impl State {
})
}
pub fn upper_bound(&self) -> usize {
// Build one geode robot each remaining turn
// \sum_{k=1}^n {k - 1} = \sum_{k=0}^{n-1} {k} = 1/2 (n-1) n
self.lower_bound() + ((self.time_remaining - 1) * self.time_remaining) / 2
}
pub fn lower_bound(&self) -> usize {
self.geodes + self.geode_robots * self.time_remaining
}
fn run_steps(&mut self, n: usize) {
assert!(self.time_remaining >= n);
@ -201,38 +168,43 @@ impl State {
}
}
pub fn max_geodes(minutes: usize, blueprint: &Blueprint) -> usize {
let initial = State::new(minutes);
let initial_upper = initial.upper_bound();
let mut next = Vec::new();
next.push((initial.clone(), initial_upper));
let mut visited = HashSet::new();
visited.insert(initial);
let mut lower_bound = 0;
while let Some(n) = next.pop() {
if n.1 < lower_bound {
// Between pushing this state and popping it, we've found a better lower bound
continue;
}
let state = n.0;
for action in state.possible_actions(blueprint) {
if action.finished() {
let action_lower = action.lower_bound();
if action_lower > lower_bound {
lower_bound = action_lower;
}
} else {
let action_upper = action.upper_bound();
if action_upper > lower_bound && !visited.contains(&action) {
next.push((action.clone(), action_upper));
visited.insert(action);
}
}
}
impl util::BnBState<Blueprint> for State {
fn finished(&self) -> bool {
self.time_remaining == 0
}
lower_bound
fn possible_actions(&self, blueprint: &Blueprint) -> Vec<State> {
assert!(!self.finished());
let mut result = Vec::new();
if self.time_remaining > 1 {
self.produce_ore_next(blueprint).map(|a| result.push(a));
self.produce_clay_next(blueprint).map(|a| result.push(a));
self.produce_obsidian_next(blueprint)
.map(|a| result.push(a));
self.produce_geode_next(blueprint).map(|a| result.push(a));
}
let mut do_nothing = self.clone();
do_nothing.run_steps(self.time_remaining);
result.push(do_nothing);
result
}
fn upper_bound(&self, b: &Blueprint) -> usize {
// Build one geode robot each remaining turn
// \sum_{k=1}^n {k - 1} = \sum_{k=0}^{n-1} {k} = 1/2 (n-1) n
self.lower_bound(b) + ((self.time_remaining - 1) * self.time_remaining) / 2
}
fn lower_bound(&self, _: &Blueprint) -> usize {
self.geodes + self.geode_robots * self.time_remaining
}
}
pub fn max_geodes(minutes: usize, blueprint: &Blueprint) -> usize {
let initial = State::new(minutes);
util::maximize(&initial, blueprint).lower_bound(blueprint)
}

View File

@ -1,5 +1,7 @@
use std::collections::HashSet;
use std::env;
use std::fs;
use std::hash::Hash;
pub fn parse_input() -> String {
let args: Vec<String> = env::args().collect();
@ -73,3 +75,50 @@ impl Coordinate for SignedCoord {
(self.0 as usize, self.1 as usize)
}
}
pub trait BnBState<T> {
fn finished(&self) -> bool;
fn lower_bound(&self, extra: &T) -> usize;
fn upper_bound(&self, extra: &T) -> usize;
fn possible_actions(&self, extra: &T) -> Vec<Self>
where
Self: Sized;
}
pub fn maximize<E, S>(initial_state: &S, extra: &E) -> S
where
S: BnBState<E> + Clone + Hash + Eq,
{
let mut lower_bound = initial_state.lower_bound(extra);
let mut best = initial_state.clone();
let mut next = vec![(initial_state.clone(), initial_state.upper_bound(extra))];
let mut visited = HashSet::new();
visited.insert(initial_state.clone());
while let Some(n) = next.pop() {
if n.1 < lower_bound {
// Between pushing this state and popping it, we've found a better solution
continue;
}
let state = n.0;
for action in state.possible_actions(extra) {
if action.finished() {
let action_lower = action.lower_bound(extra);
if action_lower > lower_bound {
lower_bound = action_lower;
best = action;
}
} else {
let action_upper = action.upper_bound(extra);
if action_upper > lower_bound && !visited.contains(&action) {
next.push((action.clone(), action_upper));
visited.insert(action);
}
}
}
}
best
}