Skip to content

Commit 7584ee1

Browse files
committed
refactor: unify video and realtime frame handling via shared process_frame()
1 parent 6413461 commit 7584ee1

2 files changed

Lines changed: 64 additions & 107 deletions

File tree

fall_core.py

Lines changed: 63 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -248,83 +248,6 @@ def center_and_height(p):
248248
assigned_ids.add(best_tid)
249249
return best_tid
250250

251-
def process_video_file(self, path, out_dir="output_videos"):
252-
cap = cv2.VideoCapture(path)
253-
if not cap.isOpened():
254-
print("[ERR] can't open video:", path)
255-
return
256-
257-
os.makedirs(out_dir, exist_ok=True)
258-
success, first = cap.read()
259-
vid_shape = letterbox(first, 960, stride=64, auto=True)[0].shape
260-
out_path = os.path.join(
261-
out_dir, os.path.basename(path).split(".")[0] + "_output.mp4"
262-
)
263-
writer = cv2.VideoWriter(
264-
out_path, cv2.VideoWriter_fourcc(*"mp4v"), 30, (vid_shape[1], vid_shape[0])
265-
)
266-
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
267-
268-
while True:
269-
success, frame = cap.read()
270-
if not success:
271-
break
272-
273-
people, processed_frame = self.get_pose(frame)
274-
_image = cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR)
275-
assigned_ids = set()
276-
results = []
277-
278-
for pose in people:
279-
tid = self.match_pose_to_tracker(pose, self.trackers, assigned_ids)
280-
if tid is None:
281-
tid = str(self.next_id)
282-
self.next_id += 1
283-
self.trackers[tid] = PersonFallTracker(
284-
self.window_size,
285-
self.fps,
286-
self.v_thresh,
287-
self.ar_thresh,
288-
self.dy_thresh,
289-
)
290-
self.trackers[tid].add_pose(pose)
291-
self.trackers[tid].last_update = time.time()
292-
293-
tag = debug = bbox = None
294-
v = dy = ar = None
295-
296-
if self.trackers[tid].is_ready():
297-
is_fall, bbox, debug, tag = self.trackers[tid].check_fall()
298-
p1, p2 = (
299-
self.trackers[tid].pose_window[0],
300-
self.trackers[tid].pose_window[-1],
301-
)
302-
v, dy = self.trackers[tid].compute_velocity(p1, p2)
303-
ar = self.trackers[tid].compute_ar_delta(p1, p2)
304-
305-
if is_fall and bbox:
306-
x1, y1, x2, y2 = map(int, bbox)
307-
cv2.rectangle(_image, (x1, y1), (x2, y2), (255, 0, 0), 4)
308-
cv2.putText(
309-
_image,
310-
"FALL DETECTED",
311-
(x1, y1 - 10),
312-
0,
313-
0.8,
314-
(0, 0, 255),
315-
2,
316-
)
317-
318-
cx, cy = int(pose[2]), int(pose[3])
319-
results.append((tid, pose, tag, debug, bbox, v, dy, ar))
320-
321-
_image = self.draw_debug_overlay(_image, results)
322-
writer.write(_image)
323-
324-
cap.release()
325-
writer.release()
326-
print(f"[DONE] Saved to {out_path}")
327-
328251
def draw_fps(self, frame, prev_time):
329252
import time
330253

@@ -342,44 +265,48 @@ def draw_fps(self, frame, prev_time):
342265
)
343266
return frame, curr_time
344267

345-
def handle_frame(self, frame, prev_time=None, writer=None):
268+
def process_frame(self, frame, prev_time=None, writer=None):
346269
people, processed_frame = self.get_pose(frame)
347270
_image = cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR)
348-
271+
assigned_ids = set()
349272
results = []
273+
350274
for pose in people:
351-
tid = self.match_pose_to_tracker(pose, self.trackers)
275+
tid = self.match_pose_to_tracker(pose, self.trackers, assigned_ids)
352276
if tid is None:
353-
tid = str(uuid4())[:8]
354-
self.trackers[tid] = deque(maxlen=self.window_size)
355-
self.trackers[tid].append(pose)
277+
tid = str(self.next_id)
278+
self.next_id += 1
279+
self.trackers[tid] = PersonFallTracker(
280+
self.window_size,
281+
self.fps,
282+
self.v_thresh,
283+
self.ar_thresh,
284+
self.dy_thresh,
285+
)
286+
self.trackers[tid].add_pose(pose)
287+
self.trackers[tid].last_update = time.time()
356288

357289
tag = debug = bbox = None
358-
v = dy = ar_delta = None
359-
if len(self.trackers[tid]) == self.window_size:
360-
p1, p2 = self.trackers[tid][0], self.trackers[tid][-1]
361-
v, dy = self.compute_velocity(p1, p2)
362-
ar_delta = self.compute_ar_delta(p1, p2)
363-
364-
cond_v = v > self.v_thresh and dy > self.dy_thresh
365-
cond_ar = dy > self.dy_thresh and ar_delta > self.ar_thresh
366-
tag_list = []
367-
if cond_v:
368-
tag_list.append("SpeedDrop")
369-
if cond_ar:
370-
tag_list.append("DownFlat")
371-
tag = " ".join(tag_list)
372-
debug = f"v={v:.1f}, dy={dy:.1f}, arΔ={ar_delta:.2f}"
373-
374-
if tag:
375-
bbox = (
376-
int(p2[2] - p2[4] / 2),
377-
int(p2[3] - p2[5] / 2),
378-
int(p2[2] + p2[4] / 2),
379-
int(p2[3] + p2[5] / 2),
290+
v = dy = ar = None
291+
292+
if self.trackers[tid].is_ready():
293+
is_fall, bbox, debug, tag = self.trackers[tid].check_fall()
294+
p1, p2 = (
295+
self.trackers[tid].pose_window[0],
296+
self.trackers[tid].pose_window[-1],
297+
)
298+
v, dy = self.trackers[tid].compute_velocity(p1, p2)
299+
ar = self.trackers[tid].compute_ar_delta(p1, p2)
300+
301+
if is_fall and bbox:
302+
x1, y1, x2, y2 = map(int, bbox)
303+
cv2.rectangle(_image, (x1, y1), (x2, y2), (255, 0, 0), 4)
304+
cv2.putText(
305+
_image, "FALL DETECTED", (x1, y1 - 10), 0, 0.8, (0, 0, 255), 2
380306
)
381307

382-
results.append((tid, pose, tag, debug, bbox, v, dy, ar_delta))
308+
cx, cy = int(pose[2]), int(pose[3])
309+
results.append((tid, pose, tag, debug, bbox, v, dy, ar))
383310

384311
_image = self.draw_debug_overlay(_image, results)
385312

@@ -392,3 +319,33 @@ def handle_frame(self, frame, prev_time=None, writer=None):
392319
if writer:
393320
writer.write(_image)
394321
return _image
322+
323+
def process_video_file(self, path, out_dir="output_videos"):
324+
cap = cv2.VideoCapture(path)
325+
if not cap.isOpened():
326+
print("[ERR] can't open video:", path)
327+
return
328+
329+
os.makedirs(out_dir, exist_ok=True)
330+
success, first = cap.read()
331+
vid_shape = letterbox(first, 960, stride=64, auto=True)[0].shape
332+
out_path = os.path.join(
333+
out_dir, os.path.basename(path).split(".")[0] + "_output.mp4"
334+
)
335+
writer = cv2.VideoWriter(
336+
out_path, cv2.VideoWriter_fourcc(*"mp4v"), 30, (vid_shape[1], vid_shape[0])
337+
)
338+
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
339+
340+
while True:
341+
success, frame = cap.read()
342+
if not success:
343+
break
344+
self.process_frame(frame, writer=writer)
345+
346+
cap.release()
347+
writer.release()
348+
print(f"[DONE] Saved to {out_path}")
349+
350+
def handle_frame(self, frame, prev_time=None, writer=None):
351+
return self.process_frame(frame, prev_time, writer)

video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def process_video():
88
videos_path = "fall_dataset/ci_videos"
99
print("[CI MODE] Only running on CI test videos...")
1010
else:
11-
videos_path = "fall_dataset/test_videos"
11+
videos_path = "fall_dataset/videos"
1212

1313
output_dir = "output_videos"
1414
os.makedirs(output_dir, exist_ok=True)

0 commit comments

Comments
 (0)