Skip to content

Commit be634f6

Browse files
committed
feat: enhance multi-person fall tracking with better overlay and filtering
- add is_pose_complete filtering - unify per-person overlay - improve debug info and fall detection display
1 parent ad00521 commit be634f6

1 file changed

Lines changed: 93 additions & 67 deletions

File tree

fall_core.py

Lines changed: 93 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def __init__(self, window_size, fps, v_thresh, ar_thresh, dy_thresh):
2222
self.dy_thresh = dy_thresh
2323

2424
def add_pose(self, pose):
25-
self.pose_window.append(pose)
25+
if self.is_pose_complete(pose):
26+
self.pose_window.append(pose)
2627

2728
def is_ready(self):
2829
return len(self.pose_window) == self.window_size
@@ -64,13 +65,13 @@ def check_fall(self):
6465
v, dy = self.compute_velocity(p1, p2)
6566
ar_delta = self.compute_ar_delta(p1, p2)
6667

67-
cond_v = v > self.v_thresh and dy > self.dy_thresh
68-
cond_ar = dy > self.dy_thresh and ar_delta > self.ar_thresh
68+
ar_start = self._safe_aspect_ratio(p1)
69+
ar_end = self._safe_aspect_ratio(p2)
6970

7071
tag = []
71-
if cond_v:
72+
if v > self.v_thresh and dy > self.dy_thresh and ar_end > 0.1:
7273
tag.append("SpeedDrop")
73-
if cond_ar:
74+
if dy > self.dy_thresh and ar_delta > self.ar_thresh:
7475
tag.append("DownFlat")
7576

7677
debug = (
@@ -90,6 +91,23 @@ def check_fall(self):
9091

9192
return False, None, debug, ""
9293

94+
def _safe_aspect_ratio(self, p):
95+
length = len(p) - (len(p) % 3)
96+
x = [p[i] for i in range(0, length, 3)]
97+
y = [p[i + 1] for i in range(0, length, 3)]
98+
w, h = max(x) - min(x), max(y) - min(y)
99+
return w / h if h else 0
100+
101+
def is_pose_complete(self, pose, required_joints=(10, 11, 13, 14, 22, 23, 25, 26)):
102+
try:
103+
for idx in required_joints:
104+
x, y = pose[idx], pose[idx + 1]
105+
if x == 0 or y == 0:
106+
return False
107+
return True
108+
except IndexError:
109+
return False
110+
93111

94112
class FallDetectorMulti:
95113
def __init__(
@@ -110,6 +128,46 @@ def __init__(
110128
self.dy_thresh = dy_thresh
111129
self.next_id = 1
112130

131+
def draw_debug_overlay(self, image, results):
132+
for tid, pose, tag, debug, bbox, v, dy, ar in results:
133+
cx, cy = int(pose[2]), int(pose[3])
134+
135+
cv2.putText(
136+
image,
137+
f"ID: {tid}",
138+
(cx, cy - 20),
139+
cv2.FONT_HERSHEY_SIMPLEX,
140+
0.6,
141+
(0, 255, 255),
142+
2,
143+
)
144+
145+
debug_text = (
146+
(f"v={v:.1f}/{self.v_thresh:.1f}" if v is not None else "v=N/A")
147+
+ (
148+
f", dy={dy:.1f}/{self.dy_thresh:.1f}"
149+
if dy is not None
150+
else ", dy=N/A"
151+
)
152+
+ (
153+
f", ar={ar:.2f}/{self.ar_thresh:.2f}"
154+
if ar is not None
155+
else ", ar=N/A"
156+
)
157+
)
158+
159+
cv2.putText(
160+
image,
161+
debug_text,
162+
(cx, cy + 20),
163+
cv2.FONT_HERSHEY_SIMPLEX,
164+
0.45,
165+
(100, 255, 100),
166+
1,
167+
)
168+
169+
return image
170+
113171
def load_model(self, path):
114172
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
115173
weights = torch.load(path, map_location=device)
@@ -160,7 +218,7 @@ def center_and_height(p):
160218

161219
for tid, tracker in trackers.items():
162220
if tid in assigned_ids:
163-
continue # 🚫 本幀已被用過
221+
continue
164222

165223
if len(tracker.pose_window) == 0:
166224
continue
@@ -206,6 +264,8 @@ def process_video_file(self, path, out_dir="output_videos"):
206264
people, processed_frame = self.get_pose(frame)
207265
_image = cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR)
208266
assigned_ids = set()
267+
results = []
268+
209269
for pose in people:
210270
tid = self.match_pose_to_tracker(pose, self.trackers, assigned_ids)
211271
if tid is None:
@@ -221,20 +281,19 @@ def process_video_file(self, path, out_dir="output_videos"):
221281
self.trackers[tid].add_pose(pose)
222282
self.trackers[tid].last_update = time.time()
223283

284+
tag = debug = bbox = None
285+
v = dy = ar = None
286+
224287
if self.trackers[tid].is_ready():
225288
is_fall, bbox, debug, tag = self.trackers[tid].check_fall()
226-
if tag:
227-
print(f"[{tid}] {tag}{debug}")
228-
cv2.putText(
229-
_image,
230-
f"{tid} {tag}: {debug}",
231-
(10, 30),
232-
cv2.FONT_HERSHEY_SIMPLEX,
233-
0.6,
234-
(0, 255, 0),
235-
2,
236-
)
237-
if is_fall:
289+
p1, p2 = (
290+
self.trackers[tid].pose_window[0],
291+
self.trackers[tid].pose_window[-1],
292+
)
293+
v, dy = self.trackers[tid].compute_velocity(p1, p2)
294+
ar = self.trackers[tid].compute_ar_delta(p1, p2)
295+
296+
if is_fall and bbox:
238297
x1, y1, x2, y2 = map(int, bbox)
239298
cv2.rectangle(_image, (x1, y1), (x2, y2), (255, 0, 0), 4)
240299
cv2.putText(
@@ -246,16 +305,11 @@ def process_video_file(self, path, out_dir="output_videos"):
246305
(0, 0, 255),
247306
2,
248307
)
308+
249309
cx, cy = int(pose[2]), int(pose[3])
250-
cv2.putText(
251-
_image,
252-
f"ID: {tid}",
253-
(cx, cy - 20),
254-
cv2.FONT_HERSHEY_SIMPLEX,
255-
0.6,
256-
(0, 255, 255),
257-
2,
258-
)
310+
results.append((tid, pose, tag, debug, bbox, v, dy, ar))
311+
312+
_image = self.draw_debug_overlay(_image, results)
259313
writer.write(_image)
260314

261315
cap.release()
@@ -279,11 +333,11 @@ def draw_fps(self, frame, prev_time):
279333
)
280334
return frame, curr_time
281335

282-
def handle_frame(self, frame, prev_time=None):
336+
def handle_frame(self, frame, prev_time=None, writer=None):
283337
people, processed_frame = self.get_pose(frame)
284338
_image = cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR)
285339

286-
results = [] # for visualization (tid, pose, tag, debug, bbox)
340+
results = []
287341
for pose in people:
288342
tid = self.match_pose_to_tracker(pose, self.trackers)
289343
if tid is None:
@@ -292,6 +346,7 @@ def handle_frame(self, frame, prev_time=None):
292346
self.trackers[tid].append(pose)
293347

294348
tag = debug = bbox = None
349+
v = dy = ar_delta = None
295350
if len(self.trackers[tid]) == self.window_size:
296351
p1, p2 = self.trackers[tid][0], self.trackers[tid][-1]
297352
v, dy = self.compute_velocity(p1, p2)
@@ -314,46 +369,17 @@ def handle_frame(self, frame, prev_time=None):
314369
int(p2[2] + p2[4] / 2),
315370
int(p2[3] + p2[5] / 2),
316371
)
317-
results.append((tid, pose, tag, debug, bbox))
318372

319-
for tid, pose, tag, debug, bbox in results:
320-
cx, cy = int(pose[2]), int(pose[3])
321-
cv2.putText(
322-
_image,
323-
f"ID: {tid}",
324-
(cx, cy - 20),
325-
cv2.FONT_HERSHEY_SIMPLEX,
326-
0.6,
327-
(0, 255, 255),
328-
2,
329-
)
330-
cv2.putText(
331-
_image,
332-
debug,
333-
(cx, cy - 40),
334-
cv2.FONT_HERSHEY_SIMPLEX,
335-
0.5,
336-
(0, 255, 0),
337-
1,
338-
)
339-
if tag and bbox:
340-
x1, y1, x2, y2 = bbox
341-
print(f"[{tid}] {tag}{debug}")
342-
cv2.rectangle(_image, (x1, y1), (x2, y2), (255, 0, 0), 4)
343-
cv2.putText(
344-
_image, "FALL DETECTED", (x1, y1 - 10), 0, 0.8, (0, 0, 255), 2
345-
)
346-
cv2.putText(
347-
_image,
348-
f"{tid} {tag}",
349-
(x1, y1 - 30),
350-
cv2.FONT_HERSHEY_SIMPLEX,
351-
0.6,
352-
(0, 255, 0),
353-
2,
354-
)
373+
results.append((tid, pose, tag, debug, bbox, v, dy, ar_delta))
374+
375+
_image = self.draw_debug_overlay(_image, results)
355376

356377
if prev_time is not None:
357378
_image, new_time = self.draw_fps(_image, prev_time)
379+
if writer:
380+
writer.write(_image)
358381
return _image, new_time
359-
return _image
382+
else:
383+
if writer:
384+
writer.write(_image)
385+
return _image

0 commit comments

Comments
 (0)