~bsprague/advent-of-code

bcd7bf3a7b97b5324fb05b87efe3e18bf2a3bb40 — Brandon Sprague 5 months ago 26d3ff0
reviving the A* solution for day17, even though it's slower
1 files changed, 60 insertions(+), 73 deletions(-)

M 2023/day17/day17.rs
M 2023/day17/day17.rs => 2023/day17/day17.rs +60 -73
@@ 41,18 41,18 @@ impl Node {
}

#[derive(Debug, Hash, PartialEq, Eq, Clone)]
pub struct NodeWithDist {
pub struct NodeWithScore {
    n: Node,
    dist: u64,
    f_score: u64,
}

impl Ord for NodeWithDist {
impl Ord for NodeWithScore {
    fn cmp(&self, other: &Self) -> Ordering {
        other.dist.cmp(&self.dist)
        other.f_score.cmp(&self.f_score)
    }
}

impl PartialOrd for NodeWithDist {
impl PartialOrd for NodeWithScore {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }


@@ 60,50 60,52 @@ impl PartialOrd for NodeWithDist {

impl Data {
    pub fn minimal_heat_loss(&self) -> u64 {
        // It's Dijkstra time, just follow the Wikipedia page per usual: https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm
        let mut q: BinaryHeap<NodeWithDist> = BinaryHeap::new();
        let mut prev: HashMap<Node, Node> = HashMap::new();
        let mut dist: HashMap<Node, u64> = HashMap::new();
        // It's A* time, just follow the Wikipedia page per usual: https://en.wikipedia.org/wiki/A%2A_search_algorithm
        let mut open_set: BinaryHeap<NodeWithScore> = BinaryHeap::new();
        let mut came_from: HashMap<Node, Node> = HashMap::new();
        let mut g_score: HashMap<Node, u64> = HashMap::new();
        let mut f_score: HashMap<Node, u64> = HashMap::new();

        let (width, height) = (self.0[0].len(), self.0.len());

        let start = Loc { x: 0, y: 0 };
        let start_h = self.heuristic_from(&start, width, height);
        let start_1 = Node {
            x: 1,
            x: 0,
            y: 0,
            dir: Dir::Right,
            run_len: 2,
            run_len: 1,
        };
        let start_2 = Node {
            x: 0,
            y: 1,
            y: 0,
            dir: Dir::Down,
            run_len: 2,
            run_len: 1,
        };
        q.push(NodeWithDist {
        open_set.push(NodeWithScore {
            n: start_1.clone(),
            dist: self.0[0][1],
            f_score: start_h,
        });
        q.push(NodeWithDist {
        open_set.push(NodeWithScore {
            n: start_2.clone(),
            dist: self.0[1][0],
            f_score: start_h,
        });
        let goal = Loc {
            x: width - 1,
            y: height - 1,
        };
        dist.insert(start_1.clone(), self.0[0][1]);
        dist.insert(start_2.clone(), self.0[1][0]);
        g_score.insert(start_1.clone(), 0);
        g_score.insert(start_2.clone(), 0);
        f_score.insert(start_1.clone(), start_h);
        f_score.insert(start_2.clone(), start_h);

        while let Some(n) = q.pop() {
        while let Some(n) = open_set.pop() {
            if n.n.run_len > 3 {
                continue;
            }

            if n.n.to_loc() == goal {
                return n.dist;
                // println!("{}", n.dist);
                // return self.calc_cost(&prev, &n.n);
                return self.calc_cost(&came_from, &n.n);
            }

            let mut neighbors = vec![];


@@ 157,32 159,23 @@ impl Data {
            }

            for neighbor in neighbors.into_iter() {
                let alt = dist.get(&n.n).unwrap() + self.0[neighbor.y][neighbor.x];
                if alt < *dist.get(&neighbor).unwrap_or(&u64::MAX) {
                    dist.insert(neighbor.clone(), alt);
                    prev.insert(neighbor.clone(), n.n.clone());
                    q.push(NodeWithDist {
                        n: neighbor,
                        dist: alt,
                    });
                let tentative_g = if let Some(g_sc) = g_score.get(&n.n) {
                    g_sc + self.0[neighbor.y][neighbor.x]
                } else {
                    u64::MAX
                };
                if tentative_g < *g_score.get(&neighbor).unwrap_or(&u64::MAX) {
                    came_from.insert(neighbor.clone(), n.n.clone());
                    g_score.insert(neighbor.clone(), tentative_g);
                    let f_sc = tentative_g + self.heuristic_from(&neighbor.to_loc(), width, height);
                    f_score.insert(neighbor.clone(), f_sc);
                    if !open_set.iter().any(|n| n.n == neighbor) {
                        open_set.push(NodeWithScore {
                            n: neighbor,
                            f_score: f_sc,
                        });
                    }
                }
                // let tentative_g = if let Some(g_sc) = g_score.get(&n.n) {
                //     g_sc + self.0[neighbor.y][neighbor.x]
                // } else {
                //     u64::MAX
                // };
                // if tentative_g < *g_score.get(&neighbor).unwrap_or(&u64::MAX) {
                //     prev.insert(neighbor.clone(), n.n.clone());
                //     g_score.insert(neighbor.clone(), tentative_g);
                //     let f_sc = tentative_g + self.heuristic_from(&neighbor.to_loc(), width, height);
                //     dist.insert(neighbor.clone(), f_sc);
                //     if !q.iter().any(|n| n.n == neighbor) {
                //         q.push(NodeWithDist {
                //             n: neighbor,
                //             f_score: f_sc,
                //         });
                //     }
                // }
            }
        }



@@ 198,13 191,13 @@ impl Data {
    fn calc_cost(&self, came_from: &HashMap<Node, Node>, goal: &Node) -> u64 {
        let start = Loc { x: 0, y: 0 };

        let mut path: HashMap<Loc, Node> = HashMap::new();
        // path.insert(start.clone());
        // let mut path = HashSet::new();
        // path.insert(&start);

        let mut cost = 0u64;
        let mut cur = goal;
        while cur.to_loc() != start {
            path.insert(cur.to_loc(), cur.clone());
            // path.insert(cur);
            cost += self.0[cur.y][cur.x];
            if let Some(v) = came_from.get(cur) {
                cur = v;


@@ 212,28 205,22 @@ impl Data {
                panic!("node {:?} didn't come from anywhere", cur);
            }
        }
        path.insert(cur.to_loc(), cur.clone());

        println!();
        for y in 0..self.0.len() {
            for x in 0..self.0[y].len() {
                print!(
                    "{}",
                    if let Some(n) = path.get(&Loc { x: x, y: y }) {
                        match n.dir {
                            Dir::Up => '^',
                            Dir::Down => 'v',
                            Dir::Left => '<',
                            Dir::Right => '>',
                        }
                    } else {
                        self.0[y][x].to_string().chars().next().unwrap()
                    }
                );
            }
            println!();
        }
        println!();

        // println!();
        // for y in 0..self.0.len() {
        //     for x in 0..self.0[y].len() {
        //         print!(
        //             "{}",
        //             if path.contains(&Node { x: x, y: y }) {
        //                 'x'
        //             } else {
        //                 self.0[y][x].to_string().chars().next().unwrap()
        //             }
        //         );
        //     }
        //     println!();
        // }
        // println!();
        cost
    }