242 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			242 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import pyray as RL
 | |
| from pyray import (Rectangle as Rect)
 | |
| import math
 | |
| import pdb
 | |
| import random
 | |
| from typing import Optional, Tuple, List
 | |
| from dataclasses import dataclass, field
 | |
| 
 | |
| screen_width = 1280
 | |
| screen_height = 1024
 | |
| 
 | |
| ball_r = 6
 | |
| ball_speed = 3.5
 | |
| num_balls = 1000
 | |
| qt_capacity = 4
 | |
| 
 | |
| @dataclass
 | |
| class Ball:
 | |
|     px: float
 | |
|     py: float
 | |
|     vx: float
 | |
|     vy: float
 | |
| 
 | |
| @dataclass
 | |
| class QNode:
 | |
|     aabb: Rect
 | |
|     points: List[Tuple[float, float]] = field(default_factory=list)
 | |
| 
 | |
| @dataclass
 | |
| class Quadtree:
 | |
|     node: QNode
 | |
|     subdivided = False
 | |
|     direction: str = ''
 | |
|     nw: Optional['Quadtree'] = None
 | |
|     ne: Optional['Quadtree'] = None
 | |
|     sw: Optional['Quadtree'] = None
 | |
|     se: Optional['Quadtree'] = None
 | |
|     parent: Optional['Quadtree'] = None
 | |
| 
 | |
| @dataclass
 | |
| class World:
 | |
|     balls = []
 | |
|     qt = {}
 | |
|     tick = 0
 | |
|     paused = False
 | |
|     mouse_clicks = []
 | |
| 
 | |
| w = World()
 | |
| 
 | |
| def qt_split(qt: Quadtree):
 | |
|     x, y, hw, hh = qt.node.aabb.x, qt.node.aabb.y, qt.node.aabb.width * 0.5, qt.node.aabb.height * 0.5
 | |
|     nw = Rect(x     , y     , hw, hh)
 | |
|     ne = Rect(x + hw, y     , hw, hh)
 | |
|     sw = Rect(x     , y + hh, hw, hh)
 | |
|     se = Rect(x + hw, y + hh, hw, hh)
 | |
|     qt.nw = Quadtree(QNode(nw), parent=qt, direction='NW')
 | |
|     qt.ne = Quadtree(QNode(ne), parent=qt, direction='NE')
 | |
|     qt.sw = Quadtree(QNode(sw), parent=qt, direction='SW')
 | |
|     qt.se = Quadtree(QNode(se), parent=qt, direction='SE')
 | |
|     qt.subdivided = True
 | |
|     
 | |
| def qt_insert(qt: Quadtree, p):
 | |
|     if not RL.check_collision_point_rec(p, qt.node.aabb):
 | |
|         return False
 | |
|     if qt.subdivided:
 | |
|         inserted = False
 | |
|         if not inserted: inserted = qt_insert(qt.nw, p)
 | |
|         if not inserted: inserted = qt_insert(qt.ne, p)
 | |
|         if not inserted: inserted = qt_insert(qt.sw, p)
 | |
|         if not inserted: inserted = qt_insert(qt.se, p)
 | |
|         return inserted
 | |
|     if len(qt.node.points) + 1 >= qt_capacity:
 | |
|         qt_split(qt)
 | |
|         qt.node.points.append(p)
 | |
|         inserted = False
 | |
|         for p in qt.node.points:
 | |
|             if   qt_insert(qt.nw, p): pass
 | |
|             elif qt_insert(qt.ne, p): pass
 | |
|             elif qt_insert(qt.sw, p): pass
 | |
|             elif qt_insert(qt.se, p): pass
 | |
|         qt.node.points.clear()
 | |
|         return True
 | |
|     else:
 | |
|         qt.node.points.append(p)
 | |
|         return True
 | |
| 
 | |
| def qt_find_nearest_point(qt: Quadtree, point) -> Tuple[float, float]:
 | |
|     closest_point = None
 | |
|     closest_dist = None
 | |
|     last_direction = None
 | |
|     containing_qt = qt
 | |
|     # Find the containing subnode
 | |
|     while containing_qt.subdivided:
 | |
|         if RL.check_collision_point_rec(point, qt.nw.node.aabb):
 | |
|             containing_qt = qt.nw
 | |
|         elif RL.check_collision_point_rec(point, qt.ne.node.aabb):
 | |
|             containing_qt = qt.ne
 | |
|         elif RL.check_collision_point_rec(point, qt.sw.node.aabb):
 | |
|             containing_qt = qt.sw
 | |
|         elif RL.check_collision_point_rec(point, qt.se.node.aabb):
 | |
|             containing_qt = qt.se
 | |
|     
 | |
|     while containing_qt.parent is not None:
 | |
|         # If it's greater than 1, then we  have a point inside we can compare to
 | |
|         if len(containing_qt.node.points) > 1:
 | |
|             for p in qt.node.points:
 | |
|                 if p == point:
 | |
|                     continue
 | |
|                 if closest_dist is None or RL.vector_2distance(Vec2(*point), Vec2(*p)) < closest_dist:
 | |
|                     closest_point = p
 | |
|             last_direction = containing_qt.direction
 | |
|             containing_qt = containing_qt.parent
 | |
|         else:
 | |
|             # If there aren't any other points in here, then we can't create a
 | |
|             # closest_point or a closest_dist. We would have to handle that later on
 | |
|             if not containing_qt.subdivided:
 | |
|                 last_direction = containing_qt.direction
 | |
|                 containing_qt = containing_qt.parent
 | |
|             else:
 | |
|                 # def search_for_nearest(child_qt: Quadtree):
 | |
|                 # We have to generalize this code, most likely, because it feels like
 | |
|                 # we have to do this recursively until we have exhausted all quadrants
 | |
|                     
 | |
|                 px, py = point
 | |
|                 # This is where we check the surrounding nodes and try to discard nodes
 | |
|                 if last_direction == 'NW':
 | |
|                     xse, yse = containing_qt.se.node.aabb.x, containing_qt.se.node.aabb.y
 | |
| 
 | |
|                     ne_dist = containing_qt.ne.node.aabb.x - px
 | |
|                     if ne_dist < closest_dist:
 | |
|                         closest_dist = True
 | |
|                         # Now we have to search inside, but we would have to do recursively
 | |
|                         pass
 | |
|                     sw_dist = containing_qt.sw.node.aabb.y - py
 | |
|                     se_dist = RL.vector_2distance(Vec2(*point), Vec2(xse, yse)) 
 | |
|                     assert se_dist >= 0, 'ITS LESS THAN 0!!!!'
 | |
|                 if last_direction == 'NE':
 | |
|                     xsw, ysw = containing_qt.sw.node.aabb.x, containing_qt.sw.node.aabb.y
 | |
| 
 | |
|                     nw_dist = px - containing_qt.nw.node.aabb.x
 | |
|                     sw_dist = RL.vector_2distance(Vec2(xsw, ysw), Vec2(*point))
 | |
|                     assert sw_dist >= 0, 'ITS LESS THAN 0!!!!'
 | |
|                     se_dist = containing_qt.se.node.aabb.y - py
 | |
|                 if last_direction == 'SW':
 | |
|                     xne, yne = containing_qt.ne.node.aabb.x, containing_qt.ne.node.aabb.y
 | |
| 
 | |
|                     nw_dist = px - containing_qt.nw.node.aabb.x
 | |
|                     ne_dist = RL.vector_2distance(Vec2(xne, yne), Vec2(*point))
 | |
|                     assert ne_dist >= 0, 'ITS LESS THAN 0!!!!'
 | |
|                     se_dist = containing_qt.se.node.aabb.x - px
 | |
|                 if last_direction == 'SE':
 | |
|                     xnw, ynw = containing_qt.nw.node.aabb.x, containing_qt.nw.node.aabb.y
 | |
| 
 | |
|                     nw_dist = RL.vector_2distance(Vec2(xnw, ynw), Vec2(*point))
 | |
|                     ne_dist = py - containing_qt.nw.node.aabb.y
 | |
|                     assert ne_dist >= 0, 'ITS LESS THAN 0!!!!'
 | |
|                     sw_dist = px - containing_qt.se.node.aabb.x
 | |
| 
 | |
|                 last_direction = containing_qt.direction
 | |
|                 containing_qt = containing_qt.parent
 | |
|                 
 | |
| 
 | |
| def construct_quadtree(points):
 | |
|     root_node = QNode(Rect(0, 0, screen_width, screen_height))
 | |
|     qt = Quadtree(root_node)
 | |
|     for p in points:
 | |
|         qt_insert(qt, p)
 | |
|     return qt
 | |
| 
 | |
| def rect_values(r: Rect):
 | |
|     return r.x, r.y, r.w, r.h
 | |
| 
 | |
| def init():
 | |
|     for n in range(num_balls):
 | |
|         px = random.randrange(ball_r, 50)
 | |
|         py = random.randrange(ball_r, 50)
 | |
|         # px = random.randrange(ball_r, screen_width  - ball_r)
 | |
|         # py = random.randrange(ball_r, screen_height - ball_r)
 | |
|         angle = random.uniform(0, 360)
 | |
|         vx = math.cos(angle) * ball_speed * random.uniform(1, 3)
 | |
|         vy = math.sin(angle) * ball_speed * random.uniform(1, 3)
 | |
|         w.balls.append(Ball(px, py, vx, vy))
 | |
| 
 | |
| def player_input():
 | |
|     if RL.is_key_pressed(RL.KEY_SPACE):
 | |
|         w.paused = not w.paused
 | |
|     if RL.is_mouse_button_pressed(0):
 | |
|         print(RL.get_mouse_position())
 | |
|         w.mouse_clicks.append(RL.get_mouse_position())
 | |
|     
 | |
| def update():
 | |
|     # Recontruct quadtree
 | |
|     if w.paused:
 | |
|         return
 | |
|     points = []
 | |
|     for b in w.balls:
 | |
|         points.append((b.px, b.py))
 | |
|     w.qt = construct_quadtree(points)
 | |
|     
 | |
|     for ball in w.balls:
 | |
|         ball.px += ball.vx
 | |
|         ball.py += ball.vy
 | |
|         if ball.px - ball_r <= 0 or ball.px + ball_r >= screen_width:
 | |
|             # Reset position to make sure it's clamped
 | |
|             ball.px = RL.clamp(ball.px, ball_r + 0.1, screen_width - ball_r - 0.1)
 | |
|             ball.vx *= -1
 | |
|         if ball.py - ball_r <= 0 or ball.py + ball_r > screen_height:
 | |
|             # Reset position to make sure it's clamped
 | |
|             ball.py = RL.clamp(ball.py, ball_r + 0.1, screen_height - ball_r - 0.1)
 | |
|             ball.vy *= -1
 | |
| 
 | |
| def draw_qt_dfs(qt: Quadtree):
 | |
|     if not qt:
 | |
|         return
 | |
|     draw_qt_dfs(qt.nw)
 | |
|     draw_qt_dfs(qt.ne)
 | |
|     draw_qt_dfs(qt.se)
 | |
|     draw_qt_dfs(qt.sw)
 | |
|     RL.draw_rectangle_lines_ex(qt.node.aabb, 0.5, RL.BLACK)
 | |
| 
 | |
| def draw():
 | |
|     RL.begin_drawing()
 | |
| 
 | |
|     RL.clear_background(RL.WHITE)
 | |
|     draw_qt_dfs(w.qt)
 | |
|     for ball in w.balls:
 | |
|         RL.draw_circle_lines_v((ball.px, ball.py), ball_r, RL.BLACK)
 | |
|     for mc in w.mouse_clicks:
 | |
|         RL.draw_circle_v(mc, 5, RL.RED)
 | |
| 
 | |
|     RL.end_drawing()
 | |
| 
 | |
| RL.init_window(screen_width, screen_height, "Quadtree");
 | |
| RL.set_target_fps(60)
 | |
| 
 | |
| init()
 | |
| while not RL.window_should_close():
 | |
|     player_input()
 | |
|     update()
 | |
|     draw()
 | |
|     
 |