Quadtree: Finish nearest neighbor search alg

This commit is contained in:
Joseph Ferano 2024-08-06 12:29:05 +07:00
parent ccd46d07a8
commit 2d7c38365d
2 changed files with 97 additions and 91 deletions

View File

@ -1,7 +1,9 @@
# -*- mode: Org; eval: (olivetti-mode 0) -*- #
* The topics * The topics
#+begin_src syntree #+begin_src syntree
("Game Programming Fundamentals" ("Game Programming Fundamentals"
("Math" "_:Data Structures" "_:Algorithms" "_:Linear Algebra") ("Math" "_:Data Structures" "_:Algorithms" "_:Linear Algebra" "_:Geometry")
("Engineering" ("Engineering"
("Machine Architecture" "_:CPU Design" "_:Memory Hierarchy" "_:Processes" "_:Concurrency") ("Machine Architecture" "_:CPU Design" "_:Memory Hierarchy" "_:Processes" "_:Concurrency")
"_:Networking" "_:Networking"
@ -12,16 +14,26 @@
#+RESULTS: #+RESULTS:
#+begin_example #+begin_example
Game Programming Fundamentals Game Programming Fundamentals
._________________________________________|_________________________________________. .___________________________________________|___________________________________________.
| | | |
Math Engineering Math Engineering
.______________|_____________. .________________________________._|________._________________________. .______________._____|_______.____________. .________________________________._|________._________________________.
| | | | | | | | | | | | | | |
.______|______. .____|___. .______|_____. Machine Architecture .____|___. .___|___. Operating Systems .______|______. .____|___. .______|_____. .___|__. Machine Architecture .____|___. .___|___. Operating Systems
|_____________| |________| |____________| .______________._____|_______.___________. |________| |_______| .____________.|_____________. |_____________| |________| |____________| |______| .______________._____|_______.___________. |________| |_______| .____________.|_____________.
Data Structures Algorithms Linear Algebra | | | | Networking Compilers | | | Data Structures Algorithms Linear Algebra Geometry | | | | Networking Compilers | | |
.____|___. ._______|______. .___|___. .____|____. .____|___. ._____|____. .______|_____. .____|___. ._______|______. .___|___. .____|____. .____|___. ._____|____. .______|_____.
|________| |______________| |_______| |_________| |________| |__________| |____________| |________| |______________| |_______| |_________| |________| |__________| |____________|
CPU Design Memory Hierarchy Processes Concurrency Scheduling File Systems Virtual Memory CPU Design Memory Hierarchy Processes Concurrency Scheduling File Systems Virtual Memory
#+end_example #+end_example
#+begin_src syntree
("Graphics"
("Math" "_:Data Structures" "_:Algorithms" "_:Linear Algebra")
("Engineering"
("Machine Architecture" "_:CPU Design" "_:Memory Hierarchy" "_:Processes" "_:Concurrency")
"_:Networking"
"_:Compilers"
("Operating Systems" "_:Scheduling" "_:File Systems" "_:Virtual Memory")))
#+end_src

View File

@ -1,13 +1,13 @@
import pyray as RL import pyray as RL
from pyray import (Rectangle as Rect) from pyray import (Rectangle as Rect, Vector2 as Vec2)
import math import math
import pdb import pdb
import random import random
from typing import Optional, Tuple, List from typing import Optional, Tuple, List
from dataclasses import dataclass, field from dataclasses import dataclass, field
screen_width = 1280 screen_width = 1200
screen_height = 1024 screen_height = 960
ball_r = 6 ball_r = 6
ball_speed = 3.5 ball_speed = 3.5
@ -44,6 +44,10 @@ class World:
tick = 0 tick = 0
paused = False paused = False
mouse_clicks = [] mouse_clicks = []
nearest_points = []
visited_quadrants = []
nearest_pairs = []
current_visited = 0
w = World() w = World()
@ -85,79 +89,55 @@ def qt_insert(qt: Quadtree, p):
return True return True
def qt_find_nearest_point(qt: Quadtree, point) -> Tuple[float, float]: def qt_find_nearest_point(qt: Quadtree, point) -> Tuple[float, float]:
closest_point = None
closest_dist = None
last_direction = None
containing_qt = qt containing_qt = qt
# Find the containing subnode
while containing_qt.subdivided: while containing_qt.subdivided:
if RL.check_collision_point_rec(point, qt.nw.node.aabb): if RL.check_collision_point_rec(point, containing_qt.nw.node.aabb):
containing_qt = qt.nw containing_qt = containing_qt.nw
elif RL.check_collision_point_rec(point, qt.ne.node.aabb): elif RL.check_collision_point_rec(point, containing_qt.ne.node.aabb):
containing_qt = qt.ne containing_qt = containing_qt.ne
elif RL.check_collision_point_rec(point, qt.sw.node.aabb): elif RL.check_collision_point_rec(point, containing_qt.sw.node.aabb):
containing_qt = qt.sw containing_qt = containing_qt.sw
elif RL.check_collision_point_rec(point, qt.se.node.aabb): elif RL.check_collision_point_rec(point, containing_qt.se.node.aabb):
containing_qt = qt.se containing_qt = containing_qt.se
while containing_qt.parent is not None: def search_for_nearest(qt: Quadtree, direction = ''):
# If it's greater than 1, then we have a point inside we can compare to nonlocal closest_point, closest_dist
if len(containing_qt.node.points) > 1: contains_point = RL.check_collision_point_rec(point, qt.node.aabb)
if not contains_point:
px, py = point.x, point.y
dx, dy = 0,0
if px < qt.node.aabb.x:
dx = qt.node.aabb.x - px
elif px > qt.node.aabb.x + qt.node.aabb.width:
dx = px - (qt.node.aabb.x + qt.node.aabb.width)
if py < qt.node.aabb.y:
dy = qt.node.aabb.y - py
elif py > qt.node.aabb.y + qt.node.aabb.height:
dy = py - (qt.node.aabb.y + qt.node.aabb.height)
dist = RL.vector2_length(Vec2(dx, dy))
if dist >= closest_dist:
return
if qt.subdivided:
if direction != 'NW': search_for_nearest(qt.nw)
if direction != 'NE': search_for_nearest(qt.ne)
if direction != 'SW': search_for_nearest(qt.sw)
if direction != 'SE': search_for_nearest(qt.se)
w.visited_quadrants.append(qt)
for p in qt.node.points: for p in qt.node.points:
if p == point: d = RL.vector_2distance(point, Vec2(p[0], p[1]))
continue if d < closest_dist:
if closest_dist is None or RL.vector_2distance(Vec2(*point), Vec2(*p)) < closest_dist:
closest_point = p closest_point = p
last_direction = containing_qt.direction closest_dist = d
closest_point = None
closest_dist = float('inf')
previous_direction = ''
while containing_qt is not None:
search_for_nearest(containing_qt, previous_direction)
previous_direction = containing_qt.direction
containing_qt = containing_qt.parent containing_qt = containing_qt.parent
else:
# If there aren't any other points in here, then we can't create a
# closest_point or a closest_dist. We would have to handle that later on
if not containing_qt.subdivided:
last_direction = containing_qt.direction
containing_qt = containing_qt.parent
else:
# def search_for_nearest(child_qt: Quadtree):
# We have to generalize this code, most likely, because it feels like
# we have to do this recursively until we have exhausted all quadrants
px, py = point return closest_point
# This is where we check the surrounding nodes and try to discard nodes
if last_direction == 'NW':
xse, yse = containing_qt.se.node.aabb.x, containing_qt.se.node.aabb.y
ne_dist = containing_qt.ne.node.aabb.x - px
if ne_dist < closest_dist:
closest_dist = True
# Now we have to search inside, but we would have to do recursively
pass
sw_dist = containing_qt.sw.node.aabb.y - py
se_dist = RL.vector_2distance(Vec2(*point), Vec2(xse, yse))
assert se_dist >= 0, 'ITS LESS THAN 0!!!!'
if last_direction == 'NE':
xsw, ysw = containing_qt.sw.node.aabb.x, containing_qt.sw.node.aabb.y
nw_dist = px - containing_qt.nw.node.aabb.x
sw_dist = RL.vector_2distance(Vec2(xsw, ysw), Vec2(*point))
assert sw_dist >= 0, 'ITS LESS THAN 0!!!!'
se_dist = containing_qt.se.node.aabb.y - py
if last_direction == 'SW':
xne, yne = containing_qt.ne.node.aabb.x, containing_qt.ne.node.aabb.y
nw_dist = px - containing_qt.nw.node.aabb.x
ne_dist = RL.vector_2distance(Vec2(xne, yne), Vec2(*point))
assert ne_dist >= 0, 'ITS LESS THAN 0!!!!'
se_dist = containing_qt.se.node.aabb.x - px
if last_direction == 'SE':
xnw, ynw = containing_qt.nw.node.aabb.x, containing_qt.nw.node.aabb.y
nw_dist = RL.vector_2distance(Vec2(xnw, ynw), Vec2(*point))
ne_dist = py - containing_qt.nw.node.aabb.y
assert ne_dist >= 0, 'ITS LESS THAN 0!!!!'
sw_dist = px - containing_qt.se.node.aabb.x
last_direction = containing_qt.direction
containing_qt = containing_qt.parent
def construct_quadtree(points): def construct_quadtree(points):
@ -185,17 +165,17 @@ def player_input():
if RL.is_key_pressed(RL.KEY_SPACE): if RL.is_key_pressed(RL.KEY_SPACE):
w.paused = not w.paused w.paused = not w.paused
if RL.is_mouse_button_pressed(0): if RL.is_mouse_button_pressed(0):
print(RL.get_mouse_position()) mouse_pos = RL.get_mouse_position()
w.mouse_clicks.append(RL.get_mouse_position()) nearest = qt_find_nearest_point(w.qt, mouse_pos)
w.nearest_pairs.append((mouse_pos, nearest))
w.paused = True
if RL.is_key_pressed(RL.KEY_ENTER):
w.current_visited += 1
def update(): def update():
# Recontruct quadtree # Recontruct quadtree
if w.paused: if w.paused:
return return
points = []
for b in w.balls:
points.append((b.px, b.py))
w.qt = construct_quadtree(points)
for ball in w.balls: for ball in w.balls:
ball.px += ball.vx ball.px += ball.vx
@ -209,6 +189,11 @@ def update():
ball.py = RL.clamp(ball.py, ball_r + 0.1, screen_height - ball_r - 0.1) ball.py = RL.clamp(ball.py, ball_r + 0.1, screen_height - ball_r - 0.1)
ball.vy *= -1 ball.vy *= -1
points = []
for b in w.balls:
points.append((b.px, b.py))
w.qt = construct_quadtree(points)
def draw_qt_dfs(qt: Quadtree): def draw_qt_dfs(qt: Quadtree):
if not qt: if not qt:
return return
@ -223,10 +208,19 @@ def draw():
RL.clear_background(RL.WHITE) RL.clear_background(RL.WHITE)
draw_qt_dfs(w.qt) draw_qt_dfs(w.qt)
for i in range(w.current_visited):
RL.draw_rectangle_rec(w.visited_quadrants[i].node.aabb, RL.LIGHTGRAY)
for ball in w.balls: for ball in w.balls:
RL.draw_circle_lines_v((ball.px, ball.py), ball_r, RL.BLACK) RL.draw_circle_lines_v((ball.px, ball.py), ball_r, RL.BLACK)
for mc in w.mouse_clicks: for mc in w.mouse_clicks:
RL.draw_circle_v(mc, 5, RL.RED) RL.draw_circle_v(mc, 5, RL.RED)
for np in w.nearest_points:
RL.draw_circle_lines_v(Vec2(np[0], np[1]), ball_r, RL.GREEN)
for mc,(px,py) in w.nearest_pairs:
pos = Vec2(px, py)
RL.draw_circle_v(mc, 5, RL.RED)
RL.draw_circle_lines_v(pos, ball_r, RL.GREEN)
RL.draw_line_v(mc, pos, RL.BLUE)
RL.end_drawing() RL.end_drawing()