Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 100 additions & 10 deletions App/routes/explain.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
import io
from pydantic import Field
from fastapi import APIRouter, status, Depends
from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from sqlalchemy import Connection
from db.database import context_get_conn
from schemas.explain_schema import ExplainImageRequest
from services import session_svc, explain_svc, image_svc
from schemas.explain_schema import ExplainImageRequest, ExplainFrameRequest
from services import session_svc, explain_svc, image_svc, video_svc
from celery_app import celery_app
from celery.result import AsyncResult

Expand Down Expand Up @@ -47,7 +45,11 @@ async def explain_image(
# Celery Task 호출(Redis Broker 활용)
task = explain_svc.process_explain_image_task.delay(
user_email = session_user["email"],
result_dict = result.model_dump(mode='json'),
version_type = result.version_type,
domain_type = result.domain_type,
image_loc = result.image_loc,
image_id = result.image_id,
category = 1 if result.label == "FAKE" else 0,
explain_req_dict = explain_req.model_dump())
return {
"message": "딥페이크 이미지 위조 흔적 시각화 접수 완료. 시각화 분석 시작 ...",
Expand All @@ -66,10 +68,10 @@ async def get_explain_image_result(
task = AsyncResult(task_id, app=celery_app)

if task.state in ("PENDING", "STARTED", "RETRY"):
return {
"status": "PENDING",
"message": "딥페이크 이미지 위조 흔적 시각화 분석 중 ..."
}
return JSONResponse(
status_code = status.HTTP_202_ACCEPTED,
content = {"message": "딥페이크 이미지 위조 흔적 시각화 분석 중 ..."}
)

if task.state == "FAILURE":
raise HTTPException(
Expand All @@ -89,6 +91,94 @@ async def get_explain_image_result(
"message": result["message"],
"cam_loc": result["cam_loc"],
}

@router.post("/video/{video_id}/frame/{frame_index}", status_code=status.HTTP_202_ACCEPTED,
response_class=JSONResponse, summary="딥페이크 비디오 프레임 위조 흔적 시각화 비동기 접수")
async def explain_frame(
video_id: int,
frame_index: int,
explain_req: ExplainFrameRequest,
conn: Connection = Depends(context_get_conn),
session_user = Depends(session_svc.get_session_user_prt), # 로그인 필수
):
# 딥페이크 비디오 추론 결과 가져오기
result = await video_svc.get_video_result(conn, video_id)

# 딥페이크 비디오 추론 성공 여부 확인하기
if result.status != "SUCCESS":
raise HTTPException(
status_code = status.HTTP_400_BAD_REQUEST,
detail = "비디오 프레임 위조 흔적 분석은 추론이 성공한 비디오에서만 가능합니다"
)

# 비디오 파일 저장 경로 가져오기
video_path = "." + result.video_loc
if not os.path.exists(video_path):
raise HTTPException(
status_code = status.HTTP_404_NOT_FOUND,
detail = f"요청하신 비디오 파일을 찾을 수 없습니다. 삭제하였는지 다시 확인해주세요."
)

# 딥페이크 비디오 프레임 위조 흔적 분석 (pro model는 aug_smooth 사용 불가, 연산이 너무 많아짐)
if result.model_type == "pro" and explain_req.aug_smooth:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Pro 모델은 aug_smooth 기능을 지원하지 않습니다",
)

# 비디오 내 해당 frame이 몇초에 위치한 frame인지 확인
frame_time = video_svc.get_video_frame_by_index(conn, video_id, frame_index)


# Celery Task 호출(Redis Broker 활용)
task = explain_svc.process_explain_frame_task.delay(
user_email = session_user["email"],
version_type = result.version_type,
domain_type = result.domain_type,
video_loc = result.video_loc,
video_id = video_id,
category = 1 if result.label == "FAKE" else 0,
frame_time = frame_time,
explain_req_dict = explain_req.model_dump())
return {
"message": "딥페이크 비디오 프레임 위조 흔적 시각화 접수 완료. 시각화 분석 시작 ...",
"task_id": task.id,
}

@router.get("/frame/result/{task_id}", status_code=status.HTTP_200_OK,
response_class=JSONResponse, summary="딥페이크 비디오 프레임 위조 흔적 시각화 결과 가져오기")
async def get_explain_frame_result(
task_id: str,
session_user = Depends(session_svc.get_session_user_prt), # 로그인 필수
):

# Redis Broker에서 Task ID에 해당하는 비동기 작업 상태 가져오기
task = AsyncResult(task_id, app=celery_app)

# 비동기 작업 진행 상태 Check
if task.state in ("PENDING", "STARTED", "RETRY"):
return JSONResponse(
status_code = status.HTTP_202_ACCEPTED,
content = {"message": "딥페이크 비디오 프레임 위조 흔적 시각화 분석 중 ..."}
)

# 비동기 작업 실패
if task.state == "FAILURE":
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="딥페이크 비디오 프레임 위조 흔적 시각화 중 알 수 없는 오류가 발생하였습니다")

# Celery Task 결과 가져오기
result = task.result

# 딥페이크 비디오 프레임 위조 흔적 시각화 생성 또는 파일 저장 도중 오류 발생
if result["status"] == "FAILED":
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=result["message"])

return {
"status": result["status"],
"message": result["message"],
"cam_loc": result["cam_loc"],
}

2 changes: 1 addition & 1 deletion App/routes/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def predict_video(
}

@router.get("/video/{video_id}", status_code=status.HTTP_200_OK,
response_model=VideoDetailData, summary="딥페이크 비디오 비디오 추론 결과값 가져오기")
response_model=VideoDetailData, summary="딥페이크 비디오 추론 결과값 가져오기")
async def get_video_result(
video_id: int,
conn: Connection = Depends(context_get_conn),
Expand Down
19 changes: 11 additions & 8 deletions App/schemas/explain_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,22 @@


class ExplainImageRequest(BaseModel):
model_type: Literal["fast", "pro"] = Field("fast",
description="추론 모드 (fast: 속도 우선, pro: 정확도 우선)")

branch_level: Literal["low","high"] = Field("high",
description="브랜치 레벨\nlow: 국소 위조 흔적 포착\nhigh: 전역적 위조 흔적 포착")
description="브랜치 레벨\nlow: 국소 위조 흔적 포착\nhigh: 전역 위조 흔적 포착")

explainer_type: str = Field("eigengradcam",
description = ("선택 가능한 XAI 기법. low: [hirescam, gradcamelementwise, layercam], ""high: [eigengradcam, gradcamplusplus, xgradcam]"))

display_type: Literal["heatmap", "bbox", "heatmap_bbox"] = Field("heatmap",
display_type: Literal["heatmap", "bbox", "heatmap_bbox"] = Field("heatmap_bbox",
description="시각화 형태. heatmap: 전체 분포, bbox: 위조 의심 영역 사각형, heatmap_bbox: 위조 의심 사각형 내부에 블러 처리된 히트맵을 중첩")

category: Literal[0, 1] = Field(1,
description="판단 클래스 인덱스 (0: Real / 1: Fake)")

overlay_ratio: float = Field(0.5, ge=0.0, le=1.0,
overlay_ratio: float = Field(0.7, ge=0.5, le=1.0,
description = "Heatmap 투명도 (0: 히트맵만 강조, 1: 원본 이미지 위주)")

threshold: float = Field(0.5, ge=0.5, le=1.0,
threshold: float = Field(0.9, ge=0.5, le=1.0,
description="contour/bbox 이진화 임계값 (0.0~1.0)")

aug_smooth: bool = Field(False,
Expand Down Expand Up @@ -66,4 +66,7 @@ def validate_explainer_for_eigen_smooth(self) -> "ExplainImageRequest":
f"{sorted(_EIGEN_ALLOWED)} "
f"(입력값: '{self.explainer_type}')"
)
return self
return self

class ExplainFrameRequest(ExplainImageRequest):
pass
118 changes: 103 additions & 15 deletions App/services/explain_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@

_explainer_cache: dict = {}

# 비디오 내 특정 Frame 추출 + Face Cropping
def _extract_face_from_frame(video_path: str, frame_time: float, explainer: CAMExplainer):
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_MSEC, frame_time * 1000)
ret, frame = cap.read()
cap.release()

frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
bbox = explainer._get_face_bbox(frame_rgb)

return explainer._crop_face(frame_rgb, bbox[:4])

# 캐시된 CAMExplainer 객체 반환하거나 새로 생성
def _get_or_create_explainer(model_name: str, dataset: str, explain_req_dict: dict):
cache_key = (model_name, dataset, explain_req_dict["explainer_type"], explain_req_dict["branch_level"])
Expand All @@ -40,37 +52,49 @@ def _get_or_create_explainer(model_name: str, dataset: str, explain_req_dict: di
return _explainer_cache[cache_key]

# 시각화 이미지 생성 (heatmap, contour, bbox 선택)
def _run_visualization(explainer: CAMExplainer, image_path: str, explain_req_dict: dict) -> np.ndarray:
def _run_visualization(explainer: CAMExplainer, image_path: str, category: int, explain_req_dict: dict) -> np.ndarray:
if explain_req_dict["display_type"] == "heatmap":
return explainer.display_heatmap_on_image(image_path, image_weight=explain_req_dict["overlay_ratio"], threshold=explain_req_dict["threshold"],
category=explain_req_dict["category"], aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"])
category=category, aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"])
elif explain_req_dict["display_type"] == "bbox":
return explainer.display_bbox_on_image(image_path, threshold=explain_req_dict["threshold"],
category=explain_req_dict["category"], aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"])
category=category, aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"])
else: # display_type == "heatmap_bbox"
return explainer.display_heatmap_bbox_on_image(image_path, image_weight=explain_req_dict["overlay_ratio"], threshold=explain_req_dict["threshold"],
category=category, aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"])

# 비디오 프레임 시각화 생성 (heatmap, contour, bbox 선택)
def _run_visualization_from_array(explainer: CAMExplainer, face: str, category: int, explain_req_dict: dict) -> np.ndarray:
if explain_req_dict["display_type"] == "heatmap":
return explainer.display_heatmap_from_array(face, image_weight=explain_req_dict["overlay_ratio"], threshold=explain_req_dict["threshold"],
category=category, aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"])
elif explain_req_dict["display_type"] == "bbox":
return explainer.display_bbox_from_array(face, threshold=explain_req_dict["threshold"],
category=category, aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"])
else: # display_type == "heatmap_bbox"
return explainer.display_heatmap_bbox_on_image(image_path, threshold=explain_req_dict["threshold"],
category=explain_req_dict["category"], aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"])
return explainer.display_heatmap_bbox_from_array(face, image_weight=explain_req_dict["overlay_ratio"], threshold=explain_req_dict["threshold"],
category=category, aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"])

# 딥페이크 이미지 위조 흔적 시각화 처리
@celery_app.task(name="process_explain_image_task")
def process_explain_image_task(user_email: str,
result_dict: dict,
version_type: str,
domain_type: str,
image_loc: str,
image_id: int,
category: int,
explain_req_dict: dict):
async def run_explain():
cam_loc = None
try:
# 추론에 사용된 모델 및 데이터셋 설정 로드
version_type = result_dict["version_type"]
model_type = result_dict["model_type"]
domain_type = result_dict["domain_type"]

model_name, dataset = inference_svc.MODEL_CONFIG[version_type][model_type][domain_type]
model_name, dataset = inference_svc.MODEL_CONFIG[version_type][explain_req_dict["model_type"]][domain_type]

explainer = _get_or_create_explainer(model_name, dataset, explain_req_dict)
image_path = "." + result_dict["image_loc"]
image_path = "." + image_loc

# 시각화 이미지 생성 시작
try:
image = _run_visualization(explainer, image_path, explain_req_dict)
image = _run_visualization(explainer, image_path, category, explain_req_dict)
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Expand All @@ -79,7 +103,7 @@ async def run_explain():

# 생성된 시각화 이미지 파일 저장
try:
cam_loc = await image_svc.upload_image_cam(user_email, result_dict["image_id"], image)
cam_loc = await image_svc.upload_image_cam(user_email, image_id, image)
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Expand Down Expand Up @@ -110,3 +134,67 @@ async def run_explain():
return loop.run_until_complete(run_explain())
finally:
loop.close()

# 딥페이크 비디오 프레임 위조 흔적 시각화 처리
@celery_app.task(name="process_explain_frame_task")
def process_explain_frame_task(user_email: str,
version_type: str,
domain_type: str,
video_loc: str,
video_id: int,
category: int,
frame_time: float,
explain_req_dict: dict):
async def run_explain():
cam_loc = None
try:
model_name, dataset = inference_svc.MODEL_CONFIG[version_type][explain_req_dict["model_type"]][domain_type]
video_path = "." + video_loc

explainer = _get_or_create_explainer(model_name, dataset, explain_req_dict)

face = _extract_face_from_frame(video_path, frame_time, explainer)

# 비디오 프레임 시각화 생성 시작
try:
image = _run_visualization_from_array(explainer, face, category, explain_req_dict)
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="딥페이크 비디오 프레임 위조 흔적을 생성하는 중 오류가 발생하였습니다"
)

# 비디오 프레임 시각화 파일 저장
try:
cam_loc = await image_svc.upload_frame_cam(user_email, video_id, frame_time, image)
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="딥페이크 비디오 프레임 위조 흔적 파일을 저장하는 중 오류가 발생했습니다."
)

return {"status": "SUCCESS",
"message": "딥페이크 비디오 프레임 위조 흔적 시각화가 성공적으로 이루어졌습니다",
"cam_loc": cam_loc}

except HTTPException as e:
print(e.detail)
return {"status": "FAILED", "message": str(e.detail)}

except Exception as e:
print(str(e))
return {"status": "FAILED", "message": str(e)}

finally:
# 임시 저장된 비디오 프레임 시각화 파일 삭제
if cam_loc:
image_svc.cleanup_image_cam.apply_async(args=[cam_loc], countdown=60)

# 동기식 Celery 워커 내 비동기 이벤트 루프 구동
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(run_explain())
finally:
loop.close()

19 changes: 19 additions & 0 deletions App/services/image_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,25 @@ async def upload_image_cam(user_email: str, image_id: int, image: np.ndarray) ->

return cam_loc[1:].replace("\\", "/")

# 딥페이크 비디오 프레임 위조 흔적 시각화 파일 서버 내 저장 (회원 전용)
async def upload_frame_cam(user_email: str, video_id: int, frame_time: float, image: np.ndarray) -> str:
user_dir = os.path.join(EXPLAIN_UPLOAD_DIR, user_email)
os.makedirs(user_dir, exist_ok=True)

cam_filename = f"v{video_id}_t{frame_time}_{int(time.time())}.png"
cam_loc = os.path.join(user_dir, cam_filename)

image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
_, buf = cv2.imencode(".png", image_bgr)

try:
async with aio.open(cam_loc, "wb") as outfile:
await outfile.write(buf.tobytes())
except Exception as e:
raise e

return cam_loc[1:].replace("\\", "/")

# 사용자 업로드 이미지 서버 내 삭제
# 호출 : image.py : history 삭제 할 때 db 와 실제 파일 삭제
# 호출 : inference.py : 추론 FAIL일 때 delete_video and delete_video_db 실행
Expand Down
5 changes: 2 additions & 3 deletions inference/video_predictor_prt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ def _extract_frames(self, video_path: str) -> Dict[int, np.ndarray]:

frame_cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
duration_sec = int(frame_cnt / fps)
num_frames = max(10, min(duration_sec, 60))

num_frames = int(frame_cnt / fps)

frame_indices = [min(int(sec * fps), frame_cnt - 1) for sec in range(num_frames)]
frame_indices = sorted(set(frame_indices))
frames = {}
Expand Down
Loading