## ~athorp96/kattis-solutions

71b250002be43726978a9e5dea7ac77ec458b519 — Andrew Thorp 1 year, 11 months ago
```IMplement union find
```
```1 files changed, 89 insertions(+), 89 deletions(-)

M 10_kinds_of_people/10_kinds_of_people.py
```
`M 10_kinds_of_people/10_kinds_of_people.py => 10_kinds_of_people/10_kinds_of_people.py +89 -89`
```@@ 9,8 9,13 @@ def log(msg):
print(f"Debug:\t{msg}")
return

-def ten_kinds_of_people(map, queries):
-    pass
+def parse_query(input):
+    parts = input.split()
+
+    if len(parts) != 4:
+        raise Exception(f"query input \"{input}\" if of incorrect length")
+
+    return [int(parts[0]) - 1, int(parts[1]) - 1],[int(parts[2]) - 1, int(parts[3]) - 1]

def parse_dimensions(dim_string):
dims = dim_string.split()

@@ 26,101 31,85 @@ class Map:
self.n_cols = dims[1]
self.rows = []

+        size = self.n_rows * self.n_cols
+
+        self.parent = [i for i in range(size)]
+        self.rank = [0 for i in range(size)]
+
+        # Load the map into memory
for r in range(self.n_rows):
if len(row) != self.n_cols:
-                raise Exception(f"too many columns on row {r}: {len(row)}")
+                raise Exception(f"incorrect number of columns on row {r}: {len(row)}")
self.rows.append(row)

-    def get(self, coord):
-        x = coord[0]
-        y = coord[1]
-        return self.rows[x][y]
-
-    def get_adjacent_coordinates(self, c):
-        expected_value = self.rows[c[0]][c[1]]
-        # log(f"getting adjacent for {c}")
-        coords = [
-            (c[0], c[1] - 1),
-            (c[0], c[1] + 1),
-            (c[0] - 1, c[1]),
-            (c[0] + 1, c[1])
-        ]
-
-        coords = {c for c in coords if c[0] >= 0 and c[0] < self.n_rows and c[1] >= 0 and c[1] < self.n_cols and self.rows[c[0]][c[1]] == expected_value}
-        # log(f"adjacents: {coords}")
-        return coords
-
-    for r in crumbs:
-        row = ""
-        for c in r:
-            if c:
-                row += "X"
-            else:
-                row += "."
-        print(row)
-
-def distance(p1, p2):
-    return math.sqrt(math.pow(p2[0] - p1[0], 2) + math.pow(p2[1] - p1[1], 2))
-
-
-def bfs(map, start, finish):
-    # log(f"navigating from {start} to {finish}")
-    if map.get(start) != map.get(finish):
-        print("neither")
-        return
-
-    desired_value = map.get(start)
-
-    breadcrumbs = [[False for j in range(map.n_cols)] for i in range(map.n_rows)]
-
-    queue = PriorityQueue()
-    queue.put((distance(start, finish), start))
-
-    while True:
-        # Use exception handling instead of `.empty()` to avoid numerous calls to `len()`
-        try:
-            current = queue.get(block=False)[1]
-        except:
-            print("neither")
-            return
-        # print(f"current: {current}")
-        # log(f"testing {current}")
-        if coordinates_equal(current, finish):
-            if desired_value == "0":
-                print("binary")
-            else:
-                print("decimal")
+        # Scan over each node, connecting it to the node to the right and below, if applicable
+        # Note: "connect" is not correct here, they're being unioned into a tree.
+        # E.g.:
+        # 0  0  1      0--0  1      0--0  1<     0--0  1      0--0  1
+        # ^            |  ^         |     |      |     |      |     |
+        # 0  1  1  ->  0  1  1  ->  0  1  1  ->  0  1  1  ->  0 >1--1
+        #                                        ^               |
+        # 1  1  0      1  1  0      1  1  0      1  1  0      1  1  0
+        #
+        #
+        # 0--0  1      0--0  1      0--0  1      0--0  1      0--0  1
+        # |     |      |     |      |     |      |     |      |     |
+        # 0  1--1  ->  0  1--1  ->  0  1--1  ->  0  1--1  ->  0  1--1
+        #    |  ^         |            |            |            |
+        # 1  1  0      1--1  0      1--1  0      1--1  0      1--1  0
+        #              ^               ^               ^        Done
+        for i in range(self.n_rows):
+            for j in range(self.n_cols):
+                value = self.rows[i][j]
+
+                # print(f"checking {i},{j}")
+
+                # Check below
+                if i < self.n_rows - 1:
+                    if self.rows[i+1][j] == value:
+                        self.union(i,j,i+1,j)
+
+                # Check to the right
+                if j < self.n_cols - 1:
+                    if self.rows[i][j+1] == value:
+                        self.union(i,j,i,j+1)
+
+    def key(self, v1, v2):
+        return v1 * self.n_cols + v2
+
+    def find(self, x, y):
+        n = self.key(x,y)
+        return self._find(n)
+
+    def _find(self, n):
+        parent = self.parent[n]
+        while parent != n:
+            n = parent
+            parent = self.parent[n]
+
+        return parent
+
+
+    def union(self, x1, y1, x2, y2):
+        return self._union(self.key(x1, y1), self.key(x2, y2))
+
+    def _union(self, n1, n2):
+        root_1 = self._find(n1)
+        root_2 = self._find(n2)
+
+        # print(f"\tunion {n1} to {n2}")
+
+        if root_1 == root_2:
return

-            continue
+        if self.rank[root_1] >= self.rank[root_2]:
+            self.parent[root_2] = root_1
+            self.rank[root_1] = max(self.rank[root_2] + 1, self.rank[root_1])
else:
-            breadcrumbs[current[0]][current[1]] = True
-
-
-
-        for c in map.get_adjacent_coordinates(current):
-            if not breadcrumbs[c[0]][c[1]]:
-                queue.put((distance(c, finish), c))
-
-        # log("queue:")
-        # for i in queue:
-            # log(f"\t{i}")
-
+            self.parent[root_1] = root_2
+            self.rank[root_2] = max(self.rank[root_1] + 1, self.rank[root_2])

-def coordinates_equal(c1, c2):
-    return c1[0] == c2[0] and c1[1] == c2[1]
-
-def parse_query(input):
-    parts = input.split()
-
-    if len(parts) != 4:
-        raise Exception(f"query input \"{input}\" if of incorrect length")
-
-    return [int(parts[0]) - 1, int(parts[1]) - 1],[int(parts[2]) - 1, int(parts[3]) - 1]

def main():

@@ 133,13 122,24 @@ def main():
# log(dimensions)

map = Map(dimensions)
+    # print(map.parent)
+    # print(map.rank)
+    # print(map.graph)
# log(map.rows)

for i in range(num_queries):
-        bfs(map, query[0], query[1])
+
+        if map.find(query[0][0],query[0][1]) == map.find(query[1][0],query[1][1]):
+            if map.rows[query[0][0]][query[0][1]] == "0":
+                print("binary")
+            else:
+                print("decimal")
+        else:
+            print("neither")
+

if __name__ == "__main__":
main()

```