242 lines
8.3 KiB
Python
242 lines
8.3 KiB
Python
import pyray as RL
|
|
from pyray import (Rectangle as Rect)
|
|
import math
|
|
import pdb
|
|
import random
|
|
from typing import Optional, Tuple, List
|
|
from dataclasses import dataclass, field
|
|
|
|
screen_width = 1280
|
|
screen_height = 1024
|
|
|
|
ball_r = 6
|
|
ball_speed = 3.5
|
|
num_balls = 1000
|
|
qt_capacity = 4
|
|
|
|
@dataclass
|
|
class Ball:
|
|
px: float
|
|
py: float
|
|
vx: float
|
|
vy: float
|
|
|
|
@dataclass
|
|
class QNode:
|
|
aabb: Rect
|
|
points: List[Tuple[float, float]] = field(default_factory=list)
|
|
|
|
@dataclass
|
|
class Quadtree:
|
|
node: QNode
|
|
subdivided = False
|
|
direction: str = ''
|
|
nw: Optional['Quadtree'] = None
|
|
ne: Optional['Quadtree'] = None
|
|
sw: Optional['Quadtree'] = None
|
|
se: Optional['Quadtree'] = None
|
|
parent: Optional['Quadtree'] = None
|
|
|
|
@dataclass
|
|
class World:
|
|
balls = []
|
|
qt = {}
|
|
tick = 0
|
|
paused = False
|
|
mouse_clicks = []
|
|
|
|
w = World()
|
|
|
|
def qt_split(qt: Quadtree):
|
|
x, y, hw, hh = qt.node.aabb.x, qt.node.aabb.y, qt.node.aabb.width * 0.5, qt.node.aabb.height * 0.5
|
|
nw = Rect(x , y , hw, hh)
|
|
ne = Rect(x + hw, y , hw, hh)
|
|
sw = Rect(x , y + hh, hw, hh)
|
|
se = Rect(x + hw, y + hh, hw, hh)
|
|
qt.nw = Quadtree(QNode(nw), parent=qt, direction='NW')
|
|
qt.ne = Quadtree(QNode(ne), parent=qt, direction='NE')
|
|
qt.sw = Quadtree(QNode(sw), parent=qt, direction='SW')
|
|
qt.se = Quadtree(QNode(se), parent=qt, direction='SE')
|
|
qt.subdivided = True
|
|
|
|
def qt_insert(qt: Quadtree, p):
|
|
if not RL.check_collision_point_rec(p, qt.node.aabb):
|
|
return False
|
|
if qt.subdivided:
|
|
inserted = False
|
|
if not inserted: inserted = qt_insert(qt.nw, p)
|
|
if not inserted: inserted = qt_insert(qt.ne, p)
|
|
if not inserted: inserted = qt_insert(qt.sw, p)
|
|
if not inserted: inserted = qt_insert(qt.se, p)
|
|
return inserted
|
|
if len(qt.node.points) + 1 >= qt_capacity:
|
|
qt_split(qt)
|
|
qt.node.points.append(p)
|
|
inserted = False
|
|
for p in qt.node.points:
|
|
if qt_insert(qt.nw, p): pass
|
|
elif qt_insert(qt.ne, p): pass
|
|
elif qt_insert(qt.sw, p): pass
|
|
elif qt_insert(qt.se, p): pass
|
|
qt.node.points.clear()
|
|
return True
|
|
else:
|
|
qt.node.points.append(p)
|
|
return True
|
|
|
|
def qt_find_nearest_point(qt: Quadtree, point) -> Tuple[float, float]:
|
|
closest_point = None
|
|
closest_dist = None
|
|
last_direction = None
|
|
containing_qt = qt
|
|
# Find the containing subnode
|
|
while containing_qt.subdivided:
|
|
if RL.check_collision_point_rec(point, qt.nw.node.aabb):
|
|
containing_qt = qt.nw
|
|
elif RL.check_collision_point_rec(point, qt.ne.node.aabb):
|
|
containing_qt = qt.ne
|
|
elif RL.check_collision_point_rec(point, qt.sw.node.aabb):
|
|
containing_qt = qt.sw
|
|
elif RL.check_collision_point_rec(point, qt.se.node.aabb):
|
|
containing_qt = qt.se
|
|
|
|
while containing_qt.parent is not None:
|
|
# If it's greater than 1, then we have a point inside we can compare to
|
|
if len(containing_qt.node.points) > 1:
|
|
for p in qt.node.points:
|
|
if p == point:
|
|
continue
|
|
if closest_dist is None or RL.vector_2distance(Vec2(*point), Vec2(*p)) < closest_dist:
|
|
closest_point = p
|
|
last_direction = containing_qt.direction
|
|
containing_qt = containing_qt.parent
|
|
else:
|
|
# If there aren't any other points in here, then we can't create a
|
|
# closest_point or a closest_dist. We would have to handle that later on
|
|
if not containing_qt.subdivided:
|
|
last_direction = containing_qt.direction
|
|
containing_qt = containing_qt.parent
|
|
else:
|
|
# def search_for_nearest(child_qt: Quadtree):
|
|
# We have to generalize this code, most likely, because it feels like
|
|
# we have to do this recursively until we have exhausted all quadrants
|
|
|
|
px, py = point
|
|
# This is where we check the surrounding nodes and try to discard nodes
|
|
if last_direction == 'NW':
|
|
xse, yse = containing_qt.se.node.aabb.x, containing_qt.se.node.aabb.y
|
|
|
|
ne_dist = containing_qt.ne.node.aabb.x - px
|
|
if ne_dist < closest_dist:
|
|
closest_dist = True
|
|
# Now we have to search inside, but we would have to do recursively
|
|
pass
|
|
sw_dist = containing_qt.sw.node.aabb.y - py
|
|
se_dist = RL.vector_2distance(Vec2(*point), Vec2(xse, yse))
|
|
assert se_dist >= 0, 'ITS LESS THAN 0!!!!'
|
|
if last_direction == 'NE':
|
|
xsw, ysw = containing_qt.sw.node.aabb.x, containing_qt.sw.node.aabb.y
|
|
|
|
nw_dist = px - containing_qt.nw.node.aabb.x
|
|
sw_dist = RL.vector_2distance(Vec2(xsw, ysw), Vec2(*point))
|
|
assert sw_dist >= 0, 'ITS LESS THAN 0!!!!'
|
|
se_dist = containing_qt.se.node.aabb.y - py
|
|
if last_direction == 'SW':
|
|
xne, yne = containing_qt.ne.node.aabb.x, containing_qt.ne.node.aabb.y
|
|
|
|
nw_dist = px - containing_qt.nw.node.aabb.x
|
|
ne_dist = RL.vector_2distance(Vec2(xne, yne), Vec2(*point))
|
|
assert ne_dist >= 0, 'ITS LESS THAN 0!!!!'
|
|
se_dist = containing_qt.se.node.aabb.x - px
|
|
if last_direction == 'SE':
|
|
xnw, ynw = containing_qt.nw.node.aabb.x, containing_qt.nw.node.aabb.y
|
|
|
|
nw_dist = RL.vector_2distance(Vec2(xnw, ynw), Vec2(*point))
|
|
ne_dist = py - containing_qt.nw.node.aabb.y
|
|
assert ne_dist >= 0, 'ITS LESS THAN 0!!!!'
|
|
sw_dist = px - containing_qt.se.node.aabb.x
|
|
|
|
last_direction = containing_qt.direction
|
|
containing_qt = containing_qt.parent
|
|
|
|
|
|
def construct_quadtree(points):
|
|
root_node = QNode(Rect(0, 0, screen_width, screen_height))
|
|
qt = Quadtree(root_node)
|
|
for p in points:
|
|
qt_insert(qt, p)
|
|
return qt
|
|
|
|
def rect_values(r: Rect):
|
|
return r.x, r.y, r.w, r.h
|
|
|
|
def init():
|
|
for n in range(num_balls):
|
|
px = random.randrange(ball_r, 50)
|
|
py = random.randrange(ball_r, 50)
|
|
# px = random.randrange(ball_r, screen_width - ball_r)
|
|
# py = random.randrange(ball_r, screen_height - ball_r)
|
|
angle = random.uniform(0, 360)
|
|
vx = math.cos(angle) * ball_speed * random.uniform(1, 3)
|
|
vy = math.sin(angle) * ball_speed * random.uniform(1, 3)
|
|
w.balls.append(Ball(px, py, vx, vy))
|
|
|
|
def player_input():
|
|
if RL.is_key_pressed(RL.KEY_SPACE):
|
|
w.paused = not w.paused
|
|
if RL.is_mouse_button_pressed(0):
|
|
print(RL.get_mouse_position())
|
|
w.mouse_clicks.append(RL.get_mouse_position())
|
|
|
|
def update():
|
|
# Recontruct quadtree
|
|
if w.paused:
|
|
return
|
|
points = []
|
|
for b in w.balls:
|
|
points.append((b.px, b.py))
|
|
w.qt = construct_quadtree(points)
|
|
|
|
for ball in w.balls:
|
|
ball.px += ball.vx
|
|
ball.py += ball.vy
|
|
if ball.px - ball_r <= 0 or ball.px + ball_r >= screen_width:
|
|
# Reset position to make sure it's clamped
|
|
ball.px = RL.clamp(ball.px, ball_r + 0.1, screen_width - ball_r - 0.1)
|
|
ball.vx *= -1
|
|
if ball.py - ball_r <= 0 or ball.py + ball_r > screen_height:
|
|
# Reset position to make sure it's clamped
|
|
ball.py = RL.clamp(ball.py, ball_r + 0.1, screen_height - ball_r - 0.1)
|
|
ball.vy *= -1
|
|
|
|
def draw_qt_dfs(qt: Quadtree):
|
|
if not qt:
|
|
return
|
|
draw_qt_dfs(qt.nw)
|
|
draw_qt_dfs(qt.ne)
|
|
draw_qt_dfs(qt.se)
|
|
draw_qt_dfs(qt.sw)
|
|
RL.draw_rectangle_lines_ex(qt.node.aabb, 0.5, RL.BLACK)
|
|
|
|
def draw():
|
|
RL.begin_drawing()
|
|
|
|
RL.clear_background(RL.WHITE)
|
|
draw_qt_dfs(w.qt)
|
|
for ball in w.balls:
|
|
RL.draw_circle_lines_v((ball.px, ball.py), ball_r, RL.BLACK)
|
|
for mc in w.mouse_clicks:
|
|
RL.draw_circle_v(mc, 5, RL.RED)
|
|
|
|
RL.end_drawing()
|
|
|
|
RL.init_window(screen_width, screen_height, "Quadtree");
|
|
RL.set_target_fps(60)
|
|
|
|
init()
|
|
while not RL.window_should_close():
|
|
player_input()
|
|
update()
|
|
draw()
|
|
|