~athorp96/kattis-solutions

71b250002be43726978a9e5dea7ac77ec458b519 — Andrew Thorp 1 year, 2 months ago 0dd940a
IMplement union find
1 files changed, 89 insertions(+), 89 deletions(-)

M 10_kinds_of_people/10_kinds_of_people.py
M 10_kinds_of_people/10_kinds_of_people.py => 10_kinds_of_people/10_kinds_of_people.py +89 -89
@@ 9,8 9,13 @@ def log(msg):
        print(f"Debug:\t{msg}")
    return

def ten_kinds_of_people(map, queries):
    pass
def parse_query(input):
    parts = input.split()

    if len(parts) != 4:
        raise Exception(f"query input \"{input}\" if of incorrect length")

    return [int(parts[0]) - 1, int(parts[1]) - 1],[int(parts[2]) - 1, int(parts[3]) - 1]

def parse_dimensions(dim_string):
    dims = dim_string.split()


@@ 26,101 31,85 @@ class Map:
        self.n_cols = dims[1]
        self.rows = []

        size = self.n_rows * self.n_cols

        self.parent = [i for i in range(size)]
        self.rank = [0 for i in range(size)]

        # Load the map into memory
        for r in range(self.n_rows):
            row = sys.stdin.readline().strip()
            if len(row) != self.n_cols:
                raise Exception(f"too many columns on row {r}: {len(row)}")
                raise Exception(f"incorrect number of columns on row {r}: {len(row)}")
            self.rows.append(row)

    def get(self, coord):
        x = coord[0]
        y = coord[1]
        return self.rows[x][y]

    def get_adjacent_coordinates(self, c):
        expected_value = self.rows[c[0]][c[1]]
        # log(f"getting adjacent for {c}")
        coords = [
            (c[0], c[1] - 1),
            (c[0], c[1] + 1),
            (c[0] - 1, c[1]),
            (c[0] + 1, c[1])
        ]

        coords = {c for c in coords if c[0] >= 0 and c[0] < self.n_rows and c[1] >= 0 and c[1] < self.n_cols and self.rows[c[0]][c[1]] == expected_value}
        # log(f"adjacents: {coords}")
        return coords

def log_breadcrumbs(crumbs):
    for r in crumbs:
        row = ""
        for c in r:
            if c:
                row += "X"
            else:
                row += "."
        print(row)

def distance(p1, p2):
    return math.sqrt(math.pow(p2[0] - p1[0], 2) + math.pow(p2[1] - p1[1], 2))


def bfs(map, start, finish):
    # log(f"navigating from {start} to {finish}")
    if map.get(start) != map.get(finish):
        print("neither")
        return

    desired_value = map.get(start)

    breadcrumbs = [[False for j in range(map.n_cols)] for i in range(map.n_rows)]

    queue = PriorityQueue()
    queue.put((distance(start, finish), start))

    while True:
        # Use exception handling instead of `.empty()` to avoid numerous calls to `len()`
        try:
            current = queue.get(block=False)[1]
        except:
            print("neither")
            return
        # print(f"current: {current}")
        # log(f"testing {current}")
        if coordinates_equal(current, finish):
            if desired_value == "0":
                print("binary")
            else:
                print("decimal")
        # Scan over each node, connecting it to the node to the right and below, if applicable
        # Note: "connect" is not correct here, they're being unioned into a tree.
        # E.g.:
        # 0  0  1      0--0  1      0--0  1<     0--0  1      0--0  1
        # ^            |  ^         |     |      |     |      |     |
        # 0  1  1  ->  0  1  1  ->  0  1  1  ->  0  1  1  ->  0 >1--1
        #                                        ^               |
        # 1  1  0      1  1  0      1  1  0      1  1  0      1  1  0
        #
        #
        # 0--0  1      0--0  1      0--0  1      0--0  1      0--0  1
        # |     |      |     |      |     |      |     |      |     |
        # 0  1--1  ->  0  1--1  ->  0  1--1  ->  0  1--1  ->  0  1--1
        #    |  ^         |            |            |            |
        # 1  1  0      1--1  0      1--1  0      1--1  0      1--1  0
        #              ^               ^               ^        Done
        for i in range(self.n_rows):
            for j in range(self.n_cols):
                value = self.rows[i][j]

                # print(f"checking {i},{j}")

                # Check below
                if i < self.n_rows - 1:
                    if self.rows[i+1][j] == value:
                        self.union(i,j,i+1,j)

                # Check to the right
                if j < self.n_cols - 1:
                    if self.rows[i][j+1] == value:
                        self.union(i,j,i,j+1)

    def key(self, v1, v2):
        return v1 * self.n_cols + v2

    def find(self, x, y):
        n = self.key(x,y)
        return self._find(n)

    def _find(self, n):
        parent = self.parent[n]
        while parent != n:
            n = parent
            parent = self.parent[n]

        return parent


    def union(self, x1, y1, x2, y2):
        return self._union(self.key(x1, y1), self.key(x2, y2))

    def _union(self, n1, n2):
        root_1 = self._find(n1)
        root_2 = self._find(n2)

        # print(f"\tunion {n1} to {n2}")

        if root_1 == root_2:
            return

        if breadcrumbs[current[0]][current[1]]:
            continue
        if self.rank[root_1] >= self.rank[root_2]:
            self.parent[root_2] = root_1
            self.rank[root_1] = max(self.rank[root_2] + 1, self.rank[root_1])
        else:
            breadcrumbs[current[0]][current[1]] = True

        log_breadcrumbs(breadcrumbs)


        for c in map.get_adjacent_coordinates(current):
            if not breadcrumbs[c[0]][c[1]]:
                queue.put((distance(c, finish), c))

        # log("queue:")
        # for i in queue:
            # log(f"\t{i}")

            self.parent[root_1] = root_2
            self.rank[root_2] = max(self.rank[root_1] + 1, self.rank[root_2])

def coordinates_equal(c1, c2):
    return c1[0] == c2[0] and c1[1] == c2[1]

def parse_query(input):
    parts = input.split()

    if len(parts) != 4:
        raise Exception(f"query input \"{input}\" if of incorrect length")

    return [int(parts[0]) - 1, int(parts[1]) - 1],[int(parts[2]) - 1, int(parts[3]) - 1]

def main():
    dimensions = sys.stdin.readline()


@@ 133,13 122,24 @@ def main():
    # log(dimensions)

    map = Map(dimensions)
    # print(map.parent)
    # print(map.rank)
    # print(map.graph)
    # log(map.rows)

    num_queries = int(sys.stdin.readline().strip())

    for i in range(num_queries):
        query = parse_query(sys.stdin.readline().strip())
        bfs(map, query[0], query[1])

        if map.find(query[0][0],query[0][1]) == map.find(query[1][0],query[1][1]):
            if map.rows[query[0][0]][query[0][1]] == "0":
                print("binary")
            else:
                print("decimal")
        else:
            print("neither")


if __name__ == "__main__":
    main()