5 files changed, 35 insertions(+), 14 deletions(-)
R spadmon/__init__.py => pyquad/__init__.py
R spadmon/base_quadtree.py => pyquad/base_quadtree.py
R spadmon/geometry_objects.py => pyquad/geometry_objects.py
R spadmon/point_quadtree.py => pyquad/point_quadtree.py
R spadmon/region_quadtree.py => pyquad/region_quadtree.py
@@ 173,7 173,7 @@ class BoundingBox:
self,
ax: MplAxes,
c: str = "k",
- lw: int | float = 0.5,
+ lw: int | float = 0.1,
**kwargs: Dict[Any, Any],
) -> None:
"""
@@ 1,8 1,8 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Callable
+from typing import Any, Callable, Dict, TYPE_CHECKING, Union
-import numpy as np
from matplotlib.axes._axes import Axes as MplAxes
+import numpy as np
if TYPE_CHECKING or __package__:
from .geometry_objects import BoundingBox
@@ 35,9 35,11 @@ class RegionNode:
self.depth = depth
self.children = 0
- self.split_criteria = split_func(array)
+ self.val = self.split_criteria = split_func(array.flatten())
self.split_func = split_func
+ self.data: Union[None, TArray2D] = array
+
self._divided = False
self._leaf = True
@@ 69,8 71,13 @@ class RegionNode:
self._divided = True
self._leaf = False
+ self.data = None
- def draw(self, ax: MplAxes) -> None:
+ def draw(
+ self,
+ ax: MplAxes,
+ **kwargs: Dict[Any, Any],
+ ) -> None:
"""
Helper method to plot tree nodes on a matplotlib axis
@@ 84,12 91,12 @@ class RegionNode:
None
"""
- self.bounding_box.draw(ax)
+ self.bounding_box.draw(ax, **kwargs)
if self._divided:
- self.nw.draw(ax) if self.nw else None
- self.ne.draw(ax) if self.ne else None
- self.se.draw(ax) if self.se else None
- self.sw.draw(ax) if self.sw else None
+ self.nw.draw(ax, **kwargs) if self.nw else None
+ self.ne.draw(ax, **kwargs) if self.ne else None
+ self.se.draw(ax, **kwargs) if self.se else None
+ self.sw.draw(ax, **kwargs) if self.sw else None
def __str__(self) -> str:
"""
@@ 100,10 107,15 @@ class RegionNode:
str
"""
sp = " " * self.depth * 2
- s = f"depth={self.depth} var={self.split_criteria} {self.bounding_box}\n"
+ s = (
+ f"depth={self.depth}"
+ f"\ndecomp={self.split_criteria}"
+ f"\ndata={self.data.shape if isinstance(self.data, np.ndarray) else None}"
+ f"{self.bounding_box}\n"
+ )
if not self._divided:
return s
- return f"{s} \n".join(
+ return f"{s if self.data != None else None} \n".join(
[
sp + "nw: " + str(self.nw),
sp + "ne: " + str(self.ne),
@@ 181,6 193,10 @@ class RegionQuadTree:
if not node:
return
+ # Ensure root is split
+ if node.depth == 0:
+ node.split(array)
+
if (
node.depth >= self.max_depth
or node.split_criteria <= self.split_thresh
@@ 190,6 206,7 @@ class RegionQuadTree:
# assign quadrant to leaf and stop recursing
node.leaf = True
+
return
# split quadrant if there is too much detail
@@ 198,7 215,11 @@ class RegionQuadTree:
for children in CHILDREN_NAMES:
self.build(node.__dict__[children], array)
- def draw(self, ax: MplAxes) -> None:
+ def draw(
+ self,
+ ax: MplAxes,
+ **kwargs: Dict[Any, Any],
+ ) -> None:
"""
Visualize quadtree
@@ 212,4 233,4 @@ class RegionQuadTree:
None
"""
- self.root.draw(ax)
+ self.root.draw(ax, **kwargs)