~noelle/aoc-2021

aa2de6916825dad8245c52ab4cbbefea3702e490 — Noelle Leigh 2 years ago 1b78c1c
09_2
1 files changed, 102 insertions(+), 0 deletions(-)

A 09/puzzle_2.py
A 09/puzzle_2.py => 09/puzzle_2.py +102 -0
@@ 0,0 1,102 @@
"""
Solution for AoC 2021 09 Puzzle 2

cat input.txt | python puzzle_2.py
"""
import sys
from functools import reduce
from operator import mul

def move_uphill(heightmap: list, row_index: int, col_index: int, seen_points: set):
    """
    Recursively crawl out of the basin and return its size.
    """
    point = heightmap[row_index][col_index]
    seen_points.add((row_index, col_index))
    left_neighbor = heightmap[row_index][col_index - 1] if col_index > 0 else None
    right_neighbor = (
        heightmap[row_index][col_index + 1] if col_index < num_cols - 1 else None
    )
    top_neighbor = heightmap[row_index - 1][col_index] if row_index > 0 else None
    bottom_neighbor = (
        heightmap[row_index + 1][col_index] if row_index < num_rows - 1 else None
    )

    total = 1

    if (
        left_neighbor is not None
        and left_neighbor < 9
        and left_neighbor > point
        and (row_index, col_index - 1) not in seen_points
    ):
        total += move_uphill(heightmap, row_index, col_index - 1, seen_points)

    if (
        right_neighbor is not None
        and right_neighbor < 9
        and right_neighbor > point
        and (row_index, col_index + 1) not in seen_points
    ):
        total += move_uphill(heightmap, row_index, col_index + 1, seen_points)

    if (
        top_neighbor is not None
        and top_neighbor < 9
        and top_neighbor > point
        and (row_index - 1, col_index) not in seen_points
    ):
        total += move_uphill(heightmap, row_index - 1, col_index, seen_points)

    if (
        bottom_neighbor is not None
        and bottom_neighbor < 9
        and bottom_neighbor > point
        and (row_index + 1, col_index) not in seen_points
    ):
        total += move_uphill(heightmap, row_index + 1, col_index, seen_points)

    return total


if __name__ == "__main__":
    heightmap = list(
        map(
            lambda line: list(map(int, line)),
            sys.stdin.read().splitlines(),
        )
    )
    num_rows = len(heightmap)
    num_cols = len(heightmap[0])
    basin_sizes = []
    for row_index, row in enumerate(heightmap):
        for col_index, point in enumerate(row):
            left_neighbor = (
                heightmap[row_index][col_index - 1] if col_index > 0 else None
            )
            right_neighbor = (
                heightmap[row_index][col_index + 1]
                if col_index < num_cols - 1
                else None
            )
            top_neighbor = (
                heightmap[row_index - 1][col_index] if row_index > 0 else None
            )
            bottom_neighbor = (
                heightmap[row_index + 1][col_index]
                if row_index < num_rows - 1
                else None
            )
            neighbors = filter(
                lambda val: val is not None,
                [left_neighbor, right_neighbor, top_neighbor, bottom_neighbor],
            )
            is_lowest = all(map(lambda val: val > point, neighbors))
            if is_lowest:
                seen_points = set()
                basin_size = move_uphill(heightmap, row_index, col_index, seen_points)
                basin_sizes.append(basin_size)

    basin_size_product = reduce(mul, sorted(basin_sizes, reverse=True)[:3], 1)

    sys.stdout.write(str(basin_size_product))