import pyray as RL from pyray import (Rectangle as Rect) import math import pdb import random from typing import Optional, Tuple, List from dataclasses import dataclass, field screen_width = 1280 screen_height = 1024 ball_r = 6 ball_speed = 3.5 num_balls = 1000 qt_capacity = 4 @dataclass class Ball: px: float py: float vx: float vy: float @dataclass class QNode: aabb: Rect points: List[Tuple[float, float]] = field(default_factory=list) @dataclass class Quadtree: node: QNode subdivided = False direction: str = '' nw: Optional['Quadtree'] = None ne: Optional['Quadtree'] = None sw: Optional['Quadtree'] = None se: Optional['Quadtree'] = None parent: Optional['Quadtree'] = None @dataclass class World: balls = [] qt = {} tick = 0 paused = False mouse_clicks = [] w = World() def qt_split(qt: Quadtree): x, y, hw, hh = qt.node.aabb.x, qt.node.aabb.y, qt.node.aabb.width * 0.5, qt.node.aabb.height * 0.5 nw = Rect(x , y , hw, hh) ne = Rect(x + hw, y , hw, hh) sw = Rect(x , y + hh, hw, hh) se = Rect(x + hw, y + hh, hw, hh) qt.nw = Quadtree(QNode(nw), parent=qt, direction='NW') qt.ne = Quadtree(QNode(ne), parent=qt, direction='NE') qt.sw = Quadtree(QNode(sw), parent=qt, direction='SW') qt.se = Quadtree(QNode(se), parent=qt, direction='SE') qt.subdivided = True def qt_insert(qt: Quadtree, p): if not RL.check_collision_point_rec(p, qt.node.aabb): return False if qt.subdivided: inserted = False if not inserted: inserted = qt_insert(qt.nw, p) if not inserted: inserted = qt_insert(qt.ne, p) if not inserted: inserted = qt_insert(qt.sw, p) if not inserted: inserted = qt_insert(qt.se, p) return inserted if len(qt.node.points) + 1 >= qt_capacity: qt_split(qt) qt.node.points.append(p) inserted = False for p in qt.node.points: if qt_insert(qt.nw, p): pass elif qt_insert(qt.ne, p): pass elif qt_insert(qt.sw, p): pass elif qt_insert(qt.se, p): pass qt.node.points.clear() return True else: qt.node.points.append(p) return True def qt_find_nearest_point(qt: Quadtree, point) -> Tuple[float, float]: closest_point = None closest_dist = None last_direction = None containing_qt = qt # Find the containing subnode while containing_qt.subdivided: if RL.check_collision_point_rec(point, qt.nw.node.aabb): containing_qt = qt.nw elif RL.check_collision_point_rec(point, qt.ne.node.aabb): containing_qt = qt.ne elif RL.check_collision_point_rec(point, qt.sw.node.aabb): containing_qt = qt.sw elif RL.check_collision_point_rec(point, qt.se.node.aabb): containing_qt = qt.se while containing_qt.parent is not None: # If it's greater than 1, then we have a point inside we can compare to if len(containing_qt.node.points) > 1: for p in qt.node.points: if p == point: continue if closest_dist is None or RL.vector_2distance(Vec2(*point), Vec2(*p)) < closest_dist: closest_point = p last_direction = containing_qt.direction containing_qt = containing_qt.parent else: # If there aren't any other points in here, then we can't create a # closest_point or a closest_dist. We would have to handle that later on if not containing_qt.subdivided: last_direction = containing_qt.direction containing_qt = containing_qt.parent else: # def search_for_nearest(child_qt: Quadtree): # We have to generalize this code, most likely, because it feels like # we have to do this recursively until we have exhausted all quadrants px, py = point # This is where we check the surrounding nodes and try to discard nodes if last_direction == 'NW': xse, yse = containing_qt.se.node.aabb.x, containing_qt.se.node.aabb.y ne_dist = containing_qt.ne.node.aabb.x - px if ne_dist < closest_dist: closest_dist = True # Now we have to search inside, but we would have to do recursively pass sw_dist = containing_qt.sw.node.aabb.y - py se_dist = RL.vector_2distance(Vec2(*point), Vec2(xse, yse)) assert se_dist >= 0, 'ITS LESS THAN 0!!!!' if last_direction == 'NE': xsw, ysw = containing_qt.sw.node.aabb.x, containing_qt.sw.node.aabb.y nw_dist = px - containing_qt.nw.node.aabb.x sw_dist = RL.vector_2distance(Vec2(xsw, ysw), Vec2(*point)) assert sw_dist >= 0, 'ITS LESS THAN 0!!!!' se_dist = containing_qt.se.node.aabb.y - py if last_direction == 'SW': xne, yne = containing_qt.ne.node.aabb.x, containing_qt.ne.node.aabb.y nw_dist = px - containing_qt.nw.node.aabb.x ne_dist = RL.vector_2distance(Vec2(xne, yne), Vec2(*point)) assert ne_dist >= 0, 'ITS LESS THAN 0!!!!' se_dist = containing_qt.se.node.aabb.x - px if last_direction == 'SE': xnw, ynw = containing_qt.nw.node.aabb.x, containing_qt.nw.node.aabb.y nw_dist = RL.vector_2distance(Vec2(xnw, ynw), Vec2(*point)) ne_dist = py - containing_qt.nw.node.aabb.y assert ne_dist >= 0, 'ITS LESS THAN 0!!!!' sw_dist = px - containing_qt.se.node.aabb.x last_direction = containing_qt.direction containing_qt = containing_qt.parent def construct_quadtree(points): root_node = QNode(Rect(0, 0, screen_width, screen_height)) qt = Quadtree(root_node) for p in points: qt_insert(qt, p) return qt def rect_values(r: Rect): return r.x, r.y, r.w, r.h def init(): for n in range(num_balls): px = random.randrange(ball_r, 50) py = random.randrange(ball_r, 50) # px = random.randrange(ball_r, screen_width - ball_r) # py = random.randrange(ball_r, screen_height - ball_r) angle = random.uniform(0, 360) vx = math.cos(angle) * ball_speed * random.uniform(1, 3) vy = math.sin(angle) * ball_speed * random.uniform(1, 3) w.balls.append(Ball(px, py, vx, vy)) def player_input(): if RL.is_key_pressed(RL.KEY_SPACE): w.paused = not w.paused if RL.is_mouse_button_pressed(0): print(RL.get_mouse_position()) w.mouse_clicks.append(RL.get_mouse_position()) def update(): # Recontruct quadtree if w.paused: return points = [] for b in w.balls: points.append((b.px, b.py)) w.qt = construct_quadtree(points) for ball in w.balls: ball.px += ball.vx ball.py += ball.vy if ball.px - ball_r <= 0 or ball.px + ball_r >= screen_width: # Reset position to make sure it's clamped ball.px = RL.clamp(ball.px, ball_r + 0.1, screen_width - ball_r - 0.1) ball.vx *= -1 if ball.py - ball_r <= 0 or ball.py + ball_r > screen_height: # Reset position to make sure it's clamped ball.py = RL.clamp(ball.py, ball_r + 0.1, screen_height - ball_r - 0.1) ball.vy *= -1 def draw_qt_dfs(qt: Quadtree): if not qt: return draw_qt_dfs(qt.nw) draw_qt_dfs(qt.ne) draw_qt_dfs(qt.se) draw_qt_dfs(qt.sw) RL.draw_rectangle_lines_ex(qt.node.aabb, 0.5, RL.BLACK) def draw(): RL.begin_drawing() RL.clear_background(RL.WHITE) draw_qt_dfs(w.qt) for ball in w.balls: RL.draw_circle_lines_v((ball.px, ball.py), ball_r, RL.BLACK) for mc in w.mouse_clicks: RL.draw_circle_v(mc, 5, RL.RED) RL.end_drawing() RL.init_window(screen_width, screen_height, "Quadtree"); RL.set_target_fps(60) init() while not RL.window_should_close(): player_input() update() draw()