simplify day 7

This commit is contained in:
Joseph Montanaro 2021-12-09 13:18:32 -08:00
parent 19e219bcc3
commit 2133290565

View File

@ -1,59 +1,54 @@
use color_eyre::eyre;
#[derive(Default)]
struct Map {
counts: Vec<usize>, // index is position, value is how many at that position
index: Vec<usize>, // values are indices of `counts`, sorted from most to least
fn distance_between(a: usize, b: usize) -> usize {
// can't do the fancypants (a - b).abs() thing here because these are unsigned
if a > b {
a - b
}
else {
b - a
}
}
impl Map {
fn from(nums: &[usize]) -> Map {
let hi = nums.iter().max().unwrap_or(&0);
let mut counts = vec![0; hi + 1];
for n in nums {
counts[*n] += 1;
}
let mut index: Vec<usize> = (0..counts.len()).collect();
index.sort_by(|&a, &b| counts[b].cmp(&counts[a])); // flip the comparison function to sort in reverse
Map {counts, index}
}
fn cost_simple(counts: &[usize], pos: usize) -> usize {
counts.iter()
.enumerate()
.fold(0, |total, (i, count)| {
total + (distance_between(pos, i) * count) // yes, these parens are unnecessary
})
}
fn cost_simple(&self, pos: usize) -> usize {
let mut total = 0;
for (i, count) in self.counts.iter().enumerate() {
let distance = if pos > i {pos - i} else {i - pos};
total += distance * count;
}
total
}
fn cost_complex(&self, pos: usize) -> usize {
let mut total = 0;
for (i, count) in self.counts.iter().enumerate() {
let distance = if pos > i {pos -i} else {i - pos};
let cost_single = distance * (distance + 1) / 2; // Gauss' method for summing numbers 0 to n
total += cost_single * count;
}
total
}
fn cost_complex(counts: &[usize], pos: usize) -> usize {
counts.iter()
.enumerate()
.fold(0, |total, (i, count)| {
let dx = distance_between(pos, i);
let cost_single = dx * (dx + 1) / 2; // Gauss' method for summing numbers from 1 to n
total + (cost_single * count)
})
}
pub fn run(data: &str) -> eyre::Result<(usize, usize)> {
let nums = data.trim().split(',')
.map(|s| s.parse::<usize>())
.collect::<Result<Vec<_>, _>>()?;
let map = Map::from(&nums);
let mut counts = Vec::new();
for s in data.trim().split(',') {
let n = s.parse::<usize>()?;
if n >= counts.len() {
counts.resize(n + 1, 0);
}
counts[n] += 1;
}
let mut min_simple = usize::MAX;
let mut min_complex = usize::MAX;
for i in 0..map.counts.len() {
let simple = map.cost_simple(i);
for i in 0..counts.len() {
let simple = cost_simple(&counts, i);
if simple < min_simple {min_simple = simple;}
let complex = map.cost_complex(i);
let complex = cost_complex(&counts, i);
if complex < min_complex {min_complex = complex;}
}