WIP: Find nearest neighbor algorithm

This commit is contained in:
Joseph Ferano 2024-08-01 12:16:56 +07:00
parent db3f5d9407
commit ccd46d07a8

View File

@ -24,16 +24,18 @@ class Ball:
@dataclass @dataclass
class QNode: class QNode:
aabb: Rect aabb: Rect
points: List = field(default_factory=list) points: List[Tuple[float, float]] = field(default_factory=list)
@dataclass @dataclass
class Quadtree: class Quadtree:
node: QNode node: QNode
subdivided = False subdivided = False
direction: str = ''
nw: Optional['Quadtree'] = None nw: Optional['Quadtree'] = None
ne: Optional['Quadtree'] = None ne: Optional['Quadtree'] = None
sw: Optional['Quadtree'] = None sw: Optional['Quadtree'] = None
se: Optional['Quadtree'] = None se: Optional['Quadtree'] = None
parent: Optional['Quadtree'] = None
@dataclass @dataclass
class World: class World:
@ -51,10 +53,10 @@ def qt_split(qt: Quadtree):
ne = Rect(x + hw, y , hw, hh) ne = Rect(x + hw, y , hw, hh)
sw = Rect(x , y + hh, hw, hh) sw = Rect(x , y + hh, hw, hh)
se = Rect(x + hw, y + hh, hw, hh) se = Rect(x + hw, y + hh, hw, hh)
qt.nw = Quadtree(QNode(nw)) qt.nw = Quadtree(QNode(nw), parent=qt, direction='NW')
qt.ne = Quadtree(QNode(ne)) qt.ne = Quadtree(QNode(ne), parent=qt, direction='NE')
qt.sw = Quadtree(QNode(sw)) qt.sw = Quadtree(QNode(sw), parent=qt, direction='SW')
qt.se = Quadtree(QNode(se)) qt.se = Quadtree(QNode(se), parent=qt, direction='SE')
qt.subdivided = True qt.subdivided = True
def qt_insert(qt: Quadtree, p): def qt_insert(qt: Quadtree, p):
@ -82,6 +84,82 @@ def qt_insert(qt: Quadtree, p):
qt.node.points.append(p) qt.node.points.append(p)
return True return True
def qt_find_nearest_point(qt: Quadtree, point) -> Tuple[float, float]:
closest_point = None
closest_dist = None
last_direction = None
containing_qt = qt
# Find the containing subnode
while containing_qt.subdivided:
if RL.check_collision_point_rec(point, qt.nw.node.aabb):
containing_qt = qt.nw
elif RL.check_collision_point_rec(point, qt.ne.node.aabb):
containing_qt = qt.ne
elif RL.check_collision_point_rec(point, qt.sw.node.aabb):
containing_qt = qt.sw
elif RL.check_collision_point_rec(point, qt.se.node.aabb):
containing_qt = qt.se
while containing_qt.parent is not None:
# If it's greater than 1, then we have a point inside we can compare to
if len(containing_qt.node.points) > 1:
for p in qt.node.points:
if p == point:
continue
if closest_dist is None or RL.vector_2distance(Vec2(*point), Vec2(*p)) < closest_dist:
closest_point = p
last_direction = containing_qt.direction
containing_qt = containing_qt.parent
else:
# If there aren't any other points in here, then we can't create a
# closest_point or a closest_dist. We would have to handle that later on
if not containing_qt.subdivided:
last_direction = containing_qt.direction
containing_qt = containing_qt.parent
else:
# def search_for_nearest(child_qt: Quadtree):
# We have to generalize this code, most likely, because it feels like
# we have to do this recursively until we have exhausted all quadrants
px, py = point
# This is where we check the surrounding nodes and try to discard nodes
if last_direction == 'NW':
xse, yse = containing_qt.se.node.aabb.x, containing_qt.se.node.aabb.y
ne_dist = containing_qt.ne.node.aabb.x - px
if ne_dist < closest_dist:
closest_dist = True
# Now we have to search inside, but we would have to do recursively
pass
sw_dist = containing_qt.sw.node.aabb.y - py
se_dist = RL.vector_2distance(Vec2(*point), Vec2(xse, yse))
assert se_dist >= 0, 'ITS LESS THAN 0!!!!'
if last_direction == 'NE':
xsw, ysw = containing_qt.sw.node.aabb.x, containing_qt.sw.node.aabb.y
nw_dist = px - containing_qt.nw.node.aabb.x
sw_dist = RL.vector_2distance(Vec2(xsw, ysw), Vec2(*point))
assert sw_dist >= 0, 'ITS LESS THAN 0!!!!'
se_dist = containing_qt.se.node.aabb.y - py
if last_direction == 'SW':
xne, yne = containing_qt.ne.node.aabb.x, containing_qt.ne.node.aabb.y
nw_dist = px - containing_qt.nw.node.aabb.x
ne_dist = RL.vector_2distance(Vec2(xne, yne), Vec2(*point))
assert ne_dist >= 0, 'ITS LESS THAN 0!!!!'
se_dist = containing_qt.se.node.aabb.x - px
if last_direction == 'SE':
xnw, ynw = containing_qt.nw.node.aabb.x, containing_qt.nw.node.aabb.y
nw_dist = RL.vector_2distance(Vec2(xnw, ynw), Vec2(*point))
ne_dist = py - containing_qt.nw.node.aabb.y
assert ne_dist >= 0, 'ITS LESS THAN 0!!!!'
sw_dist = px - containing_qt.se.node.aabb.x
last_direction = containing_qt.direction
containing_qt = containing_qt.parent
def construct_quadtree(points): def construct_quadtree(points):
root_node = QNode(Rect(0, 0, screen_width, screen_height)) root_node = QNode(Rect(0, 0, screen_width, screen_height))
qt = Quadtree(root_node) qt = Quadtree(root_node)