Skip to content

Commit 956aed9

Browse files
authored
Merge pull request #12 from Y-B-Class-Projects/refactor-fall-function
Refactor fall function
2 parents 5fb2c8f + 215341f commit 956aed9

6 files changed

Lines changed: 169 additions & 41 deletions

File tree

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,23 @@ These videos demonstrate how the model successfully and accurately recognizes hu
6060
- **Short Term**: Optimize the if-else-based model using a time-sliding window
6161
- **Long Term**: Integrate a time-series model (e.g., LSTM) for more accurate detection
6262

63+
## Fall Detect Logic
64+
The fall detection logic uses the following parameters:
65+
66+
| Parameter | Description |
67+
|------------------------|-------------|
68+
| `FPS` | Frames per second. Used to calculate time and velocity. |
69+
| `WINDOW_SIZE` | Number of frames in each analysis window. Defines how many frames are used to evaluate pose changes. |
70+
| `V_THRESH` | Threshold for center-of-mass velocity. Movements faster than this are considered potentially abnormal. |
71+
| `DY_THRESH` | Threshold for vertical (Y-axis) displacement of the center of mass. |
72+
| `ASPECT_RATIO_THRESH` | Threshold for change in aspect ratio (width/height) of the body. Indicates whether the body has become horizontal. |
73+
74+
These values can be configured via a `.env` file.
75+
If not provided, default values defined in the code will be used.
76+
77+
> **Note:** This logic is still under development and subject to change as part of our ongoing implementation and validation process.
78+
79+
6380
## Possible Future Improvements
6481

6582
In order to alert human falls and save lives, the real-time system may be deployed and implemented in nursing homes, hospitals, and senior living facilities.

config.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# config.py
2+
import os
3+
from dotenv import load_dotenv
4+
5+
load_dotenv()
6+
7+
def get_env_int(var_name, default):
8+
try:
9+
return int(os.getenv(var_name, default))
10+
except ValueError:
11+
print(f"Warning: {var_name} is not a valid int, using default {default}")
12+
return default
13+
14+
def get_env_float(var_name, default):
15+
try:
16+
return float(os.getenv(var_name, default))
17+
except ValueError:
18+
print(f"Warning: {var_name} is not a valid float, using default {default}")
19+
return default
20+
21+
22+
FPS = get_env_int("FPS", 30)
23+
WINDOW_SIZE = get_env_int("WINDOW_SIZE", 30)
24+
V_THRESH = get_env_int("V_THRESH", 0)
25+
DY_THRESH = get_env_int("DY_THRESH", 0)
26+
ASPECT_RATIO_THRESH = get_env_float("ASPECT_RATIO_THRESH", 0)

fall_core.py

Lines changed: 82 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,54 @@
99
from utils.plots import output_to_keypoint
1010

1111

12+
def compute_center_of_mass(pose):
13+
left_shoulder = (pose[10], pose[11])
14+
right_shoulder = (pose[13], pose[14])
15+
left_hip = (pose[22], pose[23])
16+
right_hip = (pose[25], pose[26])
17+
cx = (left_shoulder[0] + right_shoulder[0] + left_hip[0] + right_hip[0]) / 4
18+
cy = (left_shoulder[1] + right_shoulder[1] + left_hip[1] + right_hip[1]) / 4
19+
return cx, cy
20+
21+
22+
def compute_center_velocity(pose_start, pose_end, fps, window_size):
23+
com_start = compute_center_of_mass(pose_start)
24+
com_end = compute_center_of_mass(pose_end)
25+
dx = com_end[0] - com_start[0]
26+
dy = com_end[1] - com_start[1]
27+
distance = math.sqrt(dx ** 2 + dy ** 2)
28+
time_elapsed = (window_size - 1) / fps
29+
velocity = distance / time_elapsed
30+
return min(velocity, 300.0), dy
31+
32+
33+
def compute_bbox_aspect_ratio(pose):
34+
x_vals = [pose[i] for i in range(0, len(pose), 3)]
35+
y_vals = [pose[i] for i in range(1, len(pose), 3)]
36+
width = max(x_vals) - min(x_vals)
37+
height = max(y_vals) - min(y_vals)
38+
return width / height if height != 0 else 0
39+
40+
41+
def compute_aspect_ratio_delta(pose_start, pose_end):
42+
ar_start = compute_bbox_aspect_ratio(pose_start)
43+
ar_end = compute_bbox_aspect_ratio(pose_end)
44+
return ar_end - ar_start, ar_start, ar_end
45+
46+
47+
def find_most_similar_pose(reference_pose, candidate_poses):
48+
ref_cx, ref_cy = compute_center_of_mass(reference_pose)
49+
min_dist = float("inf")
50+
best_pose = None
51+
for pose in candidate_poses:
52+
cx, cy = compute_center_of_mass(pose)
53+
dist = (ref_cx - cx)**2 + (ref_cy - cy)**2
54+
if dist < min_dist:
55+
min_dist = dist
56+
best_pose = pose
57+
return best_pose
58+
59+
1260
def get_pose_model():
1361
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1462
print("device: ", device)
@@ -49,7 +97,6 @@ def prepare_vid_out(video_path, vid_cap, output_dir):
4997
if not success:
5098
raise RuntimeError(f"Failed to read first frame for output setup: {video_path}")
5199

52-
# Reset video capture position to beginning
53100
vid_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
54101

55102
vid_write_image = letterbox(first_frame, 960, stride=64, auto=True)[0]
@@ -67,37 +114,40 @@ def prepare_vid_out(video_path, vid_cap, output_dir):
67114
return out
68115

69116

70-
def fall_detection(poses):
71-
for pose in poses:
72-
xmin, ymin = (pose[2] - pose[4] / 2), (pose[3] - pose[5] / 2)
73-
xmax, ymax = (pose[2] + pose[4] / 2), (pose[3] + pose[5] / 2)
74-
left_shoulder_y = pose[23]
75-
left_shoulder_x = pose[22]
76-
right_shoulder_y = pose[26]
77-
left_body_y = pose[41]
78-
left_body_x = pose[40]
79-
right_body_y = pose[44]
80-
len_factor = math.sqrt(
81-
(left_shoulder_y - left_body_y) ** 2 + (left_shoulder_x - left_body_x) ** 2
82-
)
83-
left_foot_y = pose[53]
84-
right_foot_y = pose[56]
85-
dx = int(xmax) - int(xmin)
86-
dy = int(ymax) - int(ymin)
87-
difference = dy - dx
88-
if (
89-
left_shoulder_y > left_foot_y - len_factor
90-
and left_body_y > left_foot_y - (len_factor / 2)
91-
and left_shoulder_y > left_body_y - (len_factor / 2)
92-
or (
93-
right_shoulder_y > right_foot_y - len_factor
94-
and right_body_y > right_foot_y - (len_factor / 2)
95-
and right_shoulder_y > right_body_y - (len_factor / 2)
96-
)
97-
or difference < 0
98-
):
99-
return True, (xmin, ymin, xmax, ymax)
100-
return False, None
117+
def fall_detection(pose_window, window_size, fps, v_thresh, aspect_ratio_thresh, dy_thresh):
118+
if len(pose_window) < window_size:
119+
return False, None, None, None
120+
121+
pose_start = pose_window[0][0]
122+
pose_end_all = pose_window[-1]
123+
pose_end = find_most_similar_pose(pose_start, pose_end_all)
124+
125+
v, dy = compute_center_velocity(pose_start, pose_end, fps, window_size)
126+
ar_delta, ar_start, ar_end = compute_aspect_ratio_delta(pose_start, pose_end)
127+
128+
debug_text = (
129+
f"v={v:.2f}/{v_thresh:.2f}px/s, "
130+
f"y={dy:.1f}/{dy_thresh:.1f}, "
131+
f"ar={ar_delta:.2f}/{aspect_ratio_thresh:.2f}"
132+
)
133+
print(f"[TRACE] {debug_text}")
134+
135+
cond_speed_drop = v > v_thresh and dy > dy_thresh
136+
cond_down_flat = dy > dy_thresh and ar_delta > aspect_ratio_thresh
137+
138+
if cond_speed_drop or cond_down_flat:
139+
tag = (
140+
("SpeedDrop " if cond_speed_drop else "") +
141+
("DownFlat " if cond_down_flat else "")
142+
).strip()
143+
144+
xmin = pose_end[2] - pose_end[4] / 2
145+
ymin = pose_end[3] - pose_end[5] / 2
146+
xmax = pose_end[2] + pose_end[4] / 2
147+
ymax = pose_end[3] + pose_end[5] / 2
148+
return True, (xmin, ymin, xmax, ymax), debug_text, tag
149+
150+
return False, None, debug_text, ""
101151

102152

103153
def falling_alarm(image, bbox):
@@ -121,7 +171,6 @@ def falling_alarm(image, bbox):
121171
lineType=cv2.LINE_AA,
122172
)
123173

124-
125174
def draw_fps(frame, prev_time):
126175
import time
127176
curr_time = time.time()

realtime.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
import cv2
66
import time
77
import torch
8+
from collections import deque
9+
from config import WINDOW_SIZE, FPS, V_THRESH, ASPECT_RATIO_THRESH, DY_THRESH
810

911

1012
def process_realtime_camera():
1113
model, device = get_pose_model()
1214
cap = cv2.VideoCapture(0)
1315
prev_time = time.time()
16+
pose_window = deque(maxlen=WINDOW_SIZE)
1417

1518
while cap.isOpened():
1619
ret, frame = cap.read()
@@ -20,10 +23,24 @@ def process_realtime_camera():
2023
frame, prev_time = draw_fps(frame, prev_time)
2124
image_tensor, output = get_pose(frame, model, device)
2225
_image = prepare_image(image_tensor)
23-
is_fall, bbox = fall_detection(output)
24-
25-
if is_fall:
26-
falling_alarm(_image, bbox)
26+
if len(output) > 0:
27+
pose_window.append(output)
28+
if len(pose_window) == WINDOW_SIZE:
29+
is_fall, bbox, debug_text, tag = fall_detection(
30+
pose_window,
31+
WINDOW_SIZE,
32+
FPS,
33+
V_THRESH,
34+
ASPECT_RATIO_THRESH,
35+
DY_THRESH
36+
)
37+
if debug_text:
38+
_image = cv2.putText(
39+
_image, f"{tag}: {debug_text}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
40+
0.7, (0, 255, 0), 2
41+
)
42+
if is_fall:
43+
falling_alarm(_image, bbox)
2744

2845
cv2.imshow("Real-Time Fall Detection", _image[:,:,::-1])
2946
if cv2.waitKey(1) & 0xFF == ord("q"):

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ torch>=1.7.0,!=1.12.0
1212
torchvision>=0.8.1,!=0.13.0
1313
tqdm>=4.41.0
1414
protobuf<4.21.3
15+
python-dotenv>=1.0.0
16+
1517

1618
# Logging -------------------------------------
1719
tensorboard>=2.4.1

video.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import cv2
66
import os
77
from tqdm import tqdm
8+
from collections import deque
9+
from config import WINDOW_SIZE, FPS, V_THRESH, ASPECT_RATIO_THRESH, DY_THRESH
810

911

1012
def process_video_file(video_path, output_dir):
@@ -15,6 +17,7 @@ def process_video_file(video_path, output_dir):
1517

1618
model, device = get_pose_model()
1719
vid_out = prepare_vid_out(video_path, vid_cap, output_dir)
20+
pose_window = deque(maxlen=WINDOW_SIZE)
1821

1922
success, frame = vid_cap.read()
2023
frames = []
@@ -25,10 +28,24 @@ def process_video_file(video_path, output_dir):
2528
for image in tqdm(frames, desc=f"Processing {os.path.basename(video_path)}"):
2629
image, output = get_pose(image, model, device)
2730
_image = prepare_image(image)
28-
is_fall, bbox = fall_detection(output)
29-
30-
if is_fall:
31-
falling_alarm(_image, bbox)
31+
if len(output) > 0:
32+
pose_window.append(output)
33+
if len(pose_window) == WINDOW_SIZE:
34+
is_fall, bbox, debug_text, tag= fall_detection(
35+
pose_window,
36+
WINDOW_SIZE,
37+
FPS,
38+
V_THRESH,
39+
ASPECT_RATIO_THRESH,
40+
DY_THRESH
41+
)
42+
# debug
43+
_image = cv2.putText(
44+
_image, f"{tag}: {debug_text}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
45+
0.7, (0, 255, 0), 2
46+
)
47+
if is_fall:
48+
falling_alarm(_image, bbox)
3249
vid_out.write(_image[:,:,::-1])
3350

3451
vid_out.release()

0 commit comments

Comments
 (0)