gaming-pads/quadtree.py

242 lines
8.3 KiB
Python

import pyray as RL
from pyray import (Rectangle as Rect)
import math
import pdb
import random
from typing import Optional, Tuple, List
from dataclasses import dataclass, field
screen_width = 1280
screen_height = 1024
ball_r = 6
ball_speed = 3.5
num_balls = 1000
qt_capacity = 4
@dataclass
class Ball:
px: float
py: float
vx: float
vy: float
@dataclass
class QNode:
aabb: Rect
points: List[Tuple[float, float]] = field(default_factory=list)
@dataclass
class Quadtree:
node: QNode
subdivided = False
direction: str = ''
nw: Optional['Quadtree'] = None
ne: Optional['Quadtree'] = None
sw: Optional['Quadtree'] = None
se: Optional['Quadtree'] = None
parent: Optional['Quadtree'] = None
@dataclass
class World:
balls = []
qt = {}
tick = 0
paused = False
mouse_clicks = []
w = World()
def qt_split(qt: Quadtree):
x, y, hw, hh = qt.node.aabb.x, qt.node.aabb.y, qt.node.aabb.width * 0.5, qt.node.aabb.height * 0.5
nw = Rect(x , y , hw, hh)
ne = Rect(x + hw, y , hw, hh)
sw = Rect(x , y + hh, hw, hh)
se = Rect(x + hw, y + hh, hw, hh)
qt.nw = Quadtree(QNode(nw), parent=qt, direction='NW')
qt.ne = Quadtree(QNode(ne), parent=qt, direction='NE')
qt.sw = Quadtree(QNode(sw), parent=qt, direction='SW')
qt.se = Quadtree(QNode(se), parent=qt, direction='SE')
qt.subdivided = True
def qt_insert(qt: Quadtree, p):
if not RL.check_collision_point_rec(p, qt.node.aabb):
return False
if qt.subdivided:
inserted = False
if not inserted: inserted = qt_insert(qt.nw, p)
if not inserted: inserted = qt_insert(qt.ne, p)
if not inserted: inserted = qt_insert(qt.sw, p)
if not inserted: inserted = qt_insert(qt.se, p)
return inserted
if len(qt.node.points) + 1 >= qt_capacity:
qt_split(qt)
qt.node.points.append(p)
inserted = False
for p in qt.node.points:
if qt_insert(qt.nw, p): pass
elif qt_insert(qt.ne, p): pass
elif qt_insert(qt.sw, p): pass
elif qt_insert(qt.se, p): pass
qt.node.points.clear()
return True
else:
qt.node.points.append(p)
return True
def qt_find_nearest_point(qt: Quadtree, point) -> Tuple[float, float]:
closest_point = None
closest_dist = None
last_direction = None
containing_qt = qt
# Find the containing subnode
while containing_qt.subdivided:
if RL.check_collision_point_rec(point, qt.nw.node.aabb):
containing_qt = qt.nw
elif RL.check_collision_point_rec(point, qt.ne.node.aabb):
containing_qt = qt.ne
elif RL.check_collision_point_rec(point, qt.sw.node.aabb):
containing_qt = qt.sw
elif RL.check_collision_point_rec(point, qt.se.node.aabb):
containing_qt = qt.se
while containing_qt.parent is not None:
# If it's greater than 1, then we have a point inside we can compare to
if len(containing_qt.node.points) > 1:
for p in qt.node.points:
if p == point:
continue
if closest_dist is None or RL.vector_2distance(Vec2(*point), Vec2(*p)) < closest_dist:
closest_point = p
last_direction = containing_qt.direction
containing_qt = containing_qt.parent
else:
# If there aren't any other points in here, then we can't create a
# closest_point or a closest_dist. We would have to handle that later on
if not containing_qt.subdivided:
last_direction = containing_qt.direction
containing_qt = containing_qt.parent
else:
# def search_for_nearest(child_qt: Quadtree):
# We have to generalize this code, most likely, because it feels like
# we have to do this recursively until we have exhausted all quadrants
px, py = point
# This is where we check the surrounding nodes and try to discard nodes
if last_direction == 'NW':
xse, yse = containing_qt.se.node.aabb.x, containing_qt.se.node.aabb.y
ne_dist = containing_qt.ne.node.aabb.x - px
if ne_dist < closest_dist:
closest_dist = True
# Now we have to search inside, but we would have to do recursively
pass
sw_dist = containing_qt.sw.node.aabb.y - py
se_dist = RL.vector_2distance(Vec2(*point), Vec2(xse, yse))
assert se_dist >= 0, 'ITS LESS THAN 0!!!!'
if last_direction == 'NE':
xsw, ysw = containing_qt.sw.node.aabb.x, containing_qt.sw.node.aabb.y
nw_dist = px - containing_qt.nw.node.aabb.x
sw_dist = RL.vector_2distance(Vec2(xsw, ysw), Vec2(*point))
assert sw_dist >= 0, 'ITS LESS THAN 0!!!!'
se_dist = containing_qt.se.node.aabb.y - py
if last_direction == 'SW':
xne, yne = containing_qt.ne.node.aabb.x, containing_qt.ne.node.aabb.y
nw_dist = px - containing_qt.nw.node.aabb.x
ne_dist = RL.vector_2distance(Vec2(xne, yne), Vec2(*point))
assert ne_dist >= 0, 'ITS LESS THAN 0!!!!'
se_dist = containing_qt.se.node.aabb.x - px
if last_direction == 'SE':
xnw, ynw = containing_qt.nw.node.aabb.x, containing_qt.nw.node.aabb.y
nw_dist = RL.vector_2distance(Vec2(xnw, ynw), Vec2(*point))
ne_dist = py - containing_qt.nw.node.aabb.y
assert ne_dist >= 0, 'ITS LESS THAN 0!!!!'
sw_dist = px - containing_qt.se.node.aabb.x
last_direction = containing_qt.direction
containing_qt = containing_qt.parent
def construct_quadtree(points):
root_node = QNode(Rect(0, 0, screen_width, screen_height))
qt = Quadtree(root_node)
for p in points:
qt_insert(qt, p)
return qt
def rect_values(r: Rect):
return r.x, r.y, r.w, r.h
def init():
for n in range(num_balls):
px = random.randrange(ball_r, 50)
py = random.randrange(ball_r, 50)
# px = random.randrange(ball_r, screen_width - ball_r)
# py = random.randrange(ball_r, screen_height - ball_r)
angle = random.uniform(0, 360)
vx = math.cos(angle) * ball_speed * random.uniform(1, 3)
vy = math.sin(angle) * ball_speed * random.uniform(1, 3)
w.balls.append(Ball(px, py, vx, vy))
def player_input():
if RL.is_key_pressed(RL.KEY_SPACE):
w.paused = not w.paused
if RL.is_mouse_button_pressed(0):
print(RL.get_mouse_position())
w.mouse_clicks.append(RL.get_mouse_position())
def update():
# Recontruct quadtree
if w.paused:
return
points = []
for b in w.balls:
points.append((b.px, b.py))
w.qt = construct_quadtree(points)
for ball in w.balls:
ball.px += ball.vx
ball.py += ball.vy
if ball.px - ball_r <= 0 or ball.px + ball_r >= screen_width:
# Reset position to make sure it's clamped
ball.px = RL.clamp(ball.px, ball_r + 0.1, screen_width - ball_r - 0.1)
ball.vx *= -1
if ball.py - ball_r <= 0 or ball.py + ball_r > screen_height:
# Reset position to make sure it's clamped
ball.py = RL.clamp(ball.py, ball_r + 0.1, screen_height - ball_r - 0.1)
ball.vy *= -1
def draw_qt_dfs(qt: Quadtree):
if not qt:
return
draw_qt_dfs(qt.nw)
draw_qt_dfs(qt.ne)
draw_qt_dfs(qt.se)
draw_qt_dfs(qt.sw)
RL.draw_rectangle_lines_ex(qt.node.aabb, 0.5, RL.BLACK)
def draw():
RL.begin_drawing()
RL.clear_background(RL.WHITE)
draw_qt_dfs(w.qt)
for ball in w.balls:
RL.draw_circle_lines_v((ball.px, ball.py), ball_r, RL.BLACK)
for mc in w.mouse_clicks:
RL.draw_circle_v(mc, 5, RL.RED)
RL.end_drawing()
RL.init_window(screen_width, screen_height, "Quadtree");
RL.set_target_fps(60)
init()
while not RL.window_should_close():
player_input()
update()
draw()