diff --git a/App/routes/explain.py b/App/routes/explain.py index 674edcb..be238da 100644 --- a/App/routes/explain.py +++ b/App/routes/explain.py @@ -1,13 +1,11 @@ import os -import io -from pydantic import Field from fastapi import APIRouter, status, Depends from fastapi.exceptions import HTTPException from fastapi.responses import JSONResponse from sqlalchemy import Connection from db.database import context_get_conn -from schemas.explain_schema import ExplainImageRequest -from services import session_svc, explain_svc, image_svc +from schemas.explain_schema import ExplainImageRequest, ExplainFrameRequest +from services import session_svc, explain_svc, image_svc, video_svc from celery_app import celery_app from celery.result import AsyncResult @@ -47,7 +45,11 @@ async def explain_image( # Celery Task 호출(Redis Broker 활용) task = explain_svc.process_explain_image_task.delay( user_email = session_user["email"], - result_dict = result.model_dump(mode='json'), + version_type = result.version_type, + domain_type = result.domain_type, + image_loc = result.image_loc, + image_id = result.image_id, + category = 1 if result.label == "FAKE" else 0, explain_req_dict = explain_req.model_dump()) return { "message": "딥페이크 이미지 위조 흔적 시각화 접수 완료. 시각화 분석 시작 ...", @@ -66,10 +68,10 @@ async def get_explain_image_result( task = AsyncResult(task_id, app=celery_app) if task.state in ("PENDING", "STARTED", "RETRY"): - return { - "status": "PENDING", - "message": "딥페이크 이미지 위조 흔적 시각화 분석 중 ..." - } + return JSONResponse( + status_code = status.HTTP_202_ACCEPTED, + content = {"message": "딥페이크 이미지 위조 흔적 시각화 분석 중 ..."} + ) if task.state == "FAILURE": raise HTTPException( @@ -89,6 +91,94 @@ async def get_explain_image_result( "message": result["message"], "cam_loc": result["cam_loc"], } + +@router.post("/video/{video_id}/frame/{frame_index}", status_code=status.HTTP_202_ACCEPTED, + response_class=JSONResponse, summary="딥페이크 비디오 프레임 위조 흔적 시각화 비동기 접수") +async def explain_frame( + video_id: int, + frame_index: int, + explain_req: ExplainFrameRequest, + conn: Connection = Depends(context_get_conn), + session_user = Depends(session_svc.get_session_user_prt), # 로그인 필수 +): + # 딥페이크 비디오 추론 결과 가져오기 + result = await video_svc.get_video_result(conn, video_id) + + # 딥페이크 비디오 추론 성공 여부 확인하기 + if result.status != "SUCCESS": + raise HTTPException( + status_code = status.HTTP_400_BAD_REQUEST, + detail = "비디오 프레임 위조 흔적 분석은 추론이 성공한 비디오에서만 가능합니다" + ) + + # 비디오 파일 저장 경로 가져오기 + video_path = "." + result.video_loc + if not os.path.exists(video_path): + raise HTTPException( + status_code = status.HTTP_404_NOT_FOUND, + detail = f"요청하신 비디오 파일을 찾을 수 없습니다. 삭제하였는지 다시 확인해주세요." + ) + + # 딥페이크 비디오 프레임 위조 흔적 분석 (pro model는 aug_smooth 사용 불가, 연산이 너무 많아짐) + if result.model_type == "pro" and explain_req.aug_smooth: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Pro 모델은 aug_smooth 기능을 지원하지 않습니다", + ) + # 비디오 내 해당 frame이 몇초에 위치한 frame인지 확인 + frame_time = video_svc.get_video_frame_by_index(conn, video_id, frame_index) - \ No newline at end of file + # Celery Task 호출(Redis Broker 활용) + task = explain_svc.process_explain_frame_task.delay( + user_email = session_user["email"], + version_type = result.version_type, + domain_type = result.domain_type, + video_loc = result.video_loc, + video_id = video_id, + category = 1 if result.label == "FAKE" else 0, + frame_time = frame_time, + explain_req_dict = explain_req.model_dump()) + return { + "message": "딥페이크 비디오 프레임 위조 흔적 시각화 접수 완료. 시각화 분석 시작 ...", + "task_id": task.id, + } + +@router.get("/frame/result/{task_id}", status_code=status.HTTP_200_OK, + response_class=JSONResponse, summary="딥페이크 비디오 프레임 위조 흔적 시각화 결과 가져오기") +async def get_explain_frame_result( + task_id: str, + session_user = Depends(session_svc.get_session_user_prt), # 로그인 필수 + ): + + # Redis Broker에서 Task ID에 해당하는 비동기 작업 상태 가져오기 + task = AsyncResult(task_id, app=celery_app) + + # 비동기 작업 진행 상태 Check + if task.state in ("PENDING", "STARTED", "RETRY"): + return JSONResponse( + status_code = status.HTTP_202_ACCEPTED, + content = {"message": "딥페이크 비디오 프레임 위조 흔적 시각화 분석 중 ..."} + ) + + # 비동기 작업 실패 + if task.state == "FAILURE": + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="딥페이크 비디오 프레임 위조 흔적 시각화 중 알 수 없는 오류가 발생하였습니다") + + # Celery Task 결과 가져오기 + result = task.result + + # 딥페이크 비디오 프레임 위조 흔적 시각화 생성 또는 파일 저장 도중 오류 발생 + if result["status"] == "FAILED": + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=result["message"]) + + return { + "status": result["status"], + "message": result["message"], + "cam_loc": result["cam_loc"], + } + \ No newline at end of file diff --git a/App/routes/inference.py b/App/routes/inference.py index 377c6da..9e6010b 100644 --- a/App/routes/inference.py +++ b/App/routes/inference.py @@ -99,7 +99,7 @@ async def predict_video( } @router.get("/video/{video_id}", status_code=status.HTTP_200_OK, - response_model=VideoDetailData, summary="딥페이크 비디오 비디오 추론 결과값 가져오기") + response_model=VideoDetailData, summary="딥페이크 비디오 추론 결과값 가져오기") async def get_video_result( video_id: int, conn: Connection = Depends(context_get_conn), diff --git a/App/schemas/explain_schema.py b/App/schemas/explain_schema.py index 276f806..f815a95 100644 --- a/App/schemas/explain_schema.py +++ b/App/schemas/explain_schema.py @@ -23,22 +23,22 @@ class ExplainImageRequest(BaseModel): + model_type: Literal["fast", "pro"] = Field("fast", + description="추론 모드 (fast: 속도 우선, pro: 정확도 우선)") + branch_level: Literal["low","high"] = Field("high", - description="브랜치 레벨\nlow: 국소 위조 흔적 포착\nhigh: 전역적 위조 흔적 포착") + description="브랜치 레벨\nlow: 국소 위조 흔적 포착\nhigh: 전역 위조 흔적 포착") explainer_type: str = Field("eigengradcam", description = ("선택 가능한 XAI 기법. low: [hirescam, gradcamelementwise, layercam], ""high: [eigengradcam, gradcamplusplus, xgradcam]")) - display_type: Literal["heatmap", "bbox", "heatmap_bbox"] = Field("heatmap", + display_type: Literal["heatmap", "bbox", "heatmap_bbox"] = Field("heatmap_bbox", description="시각화 형태. heatmap: 전체 분포, bbox: 위조 의심 영역 사각형, heatmap_bbox: 위조 의심 사각형 내부에 블러 처리된 히트맵을 중첩") - category: Literal[0, 1] = Field(1, - description="판단 클래스 인덱스 (0: Real / 1: Fake)") - - overlay_ratio: float = Field(0.5, ge=0.0, le=1.0, + overlay_ratio: float = Field(0.7, ge=0.5, le=1.0, description = "Heatmap 투명도 (0: 히트맵만 강조, 1: 원본 이미지 위주)") - threshold: float = Field(0.5, ge=0.5, le=1.0, + threshold: float = Field(0.9, ge=0.5, le=1.0, description="contour/bbox 이진화 임계값 (0.0~1.0)") aug_smooth: bool = Field(False, @@ -66,4 +66,7 @@ def validate_explainer_for_eigen_smooth(self) -> "ExplainImageRequest": f"{sorted(_EIGEN_ALLOWED)} " f"(입력값: '{self.explainer_type}')" ) - return self \ No newline at end of file + return self + +class ExplainFrameRequest(ExplainImageRequest): + pass \ No newline at end of file diff --git a/App/services/explain_svc.py b/App/services/explain_svc.py index 89c2f14..1a725ef 100644 --- a/App/services/explain_svc.py +++ b/App/services/explain_svc.py @@ -27,6 +27,18 @@ _explainer_cache: dict = {} +# 비디오 내 특정 Frame 추출 + Face Cropping +def _extract_face_from_frame(video_path: str, frame_time: float, explainer: CAMExplainer): + cap = cv2.VideoCapture(video_path) + cap.set(cv2.CAP_PROP_POS_MSEC, frame_time * 1000) + ret, frame = cap.read() + cap.release() + + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + bbox = explainer._get_face_bbox(frame_rgb) + + return explainer._crop_face(frame_rgb, bbox[:4]) + # 캐시된 CAMExplainer 객체 반환하거나 새로 생성 def _get_or_create_explainer(model_name: str, dataset: str, explain_req_dict: dict): cache_key = (model_name, dataset, explain_req_dict["explainer_type"], explain_req_dict["branch_level"]) @@ -40,37 +52,49 @@ def _get_or_create_explainer(model_name: str, dataset: str, explain_req_dict: di return _explainer_cache[cache_key] # 시각화 이미지 생성 (heatmap, contour, bbox 선택) -def _run_visualization(explainer: CAMExplainer, image_path: str, explain_req_dict: dict) -> np.ndarray: +def _run_visualization(explainer: CAMExplainer, image_path: str, category: int, explain_req_dict: dict) -> np.ndarray: if explain_req_dict["display_type"] == "heatmap": return explainer.display_heatmap_on_image(image_path, image_weight=explain_req_dict["overlay_ratio"], threshold=explain_req_dict["threshold"], - category=explain_req_dict["category"], aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"]) + category=category, aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"]) elif explain_req_dict["display_type"] == "bbox": return explainer.display_bbox_on_image(image_path, threshold=explain_req_dict["threshold"], - category=explain_req_dict["category"], aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"]) + category=category, aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"]) + else: # display_type == "heatmap_bbox" + return explainer.display_heatmap_bbox_on_image(image_path, image_weight=explain_req_dict["overlay_ratio"], threshold=explain_req_dict["threshold"], + category=category, aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"]) + +# 비디오 프레임 시각화 생성 (heatmap, contour, bbox 선택) +def _run_visualization_from_array(explainer: CAMExplainer, face: str, category: int, explain_req_dict: dict) -> np.ndarray: + if explain_req_dict["display_type"] == "heatmap": + return explainer.display_heatmap_from_array(face, image_weight=explain_req_dict["overlay_ratio"], threshold=explain_req_dict["threshold"], + category=category, aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"]) + elif explain_req_dict["display_type"] == "bbox": + return explainer.display_bbox_from_array(face, threshold=explain_req_dict["threshold"], + category=category, aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"]) else: # display_type == "heatmap_bbox" - return explainer.display_heatmap_bbox_on_image(image_path, threshold=explain_req_dict["threshold"], - category=explain_req_dict["category"], aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"]) + return explainer.display_heatmap_bbox_from_array(face, image_weight=explain_req_dict["overlay_ratio"], threshold=explain_req_dict["threshold"], + category=category, aug_smooth=explain_req_dict["aug_smooth"], eigen_smooth=explain_req_dict["eigen_smooth"]) + # 딥페이크 이미지 위조 흔적 시각화 처리 @celery_app.task(name="process_explain_image_task") def process_explain_image_task(user_email: str, - result_dict: dict, + version_type: str, + domain_type: str, + image_loc: str, + image_id: int, + category: int, explain_req_dict: dict): async def run_explain(): cam_loc = None try: - # 추론에 사용된 모델 및 데이터셋 설정 로드 - version_type = result_dict["version_type"] - model_type = result_dict["model_type"] - domain_type = result_dict["domain_type"] - - model_name, dataset = inference_svc.MODEL_CONFIG[version_type][model_type][domain_type] + model_name, dataset = inference_svc.MODEL_CONFIG[version_type][explain_req_dict["model_type"]][domain_type] explainer = _get_or_create_explainer(model_name, dataset, explain_req_dict) - image_path = "." + result_dict["image_loc"] + image_path = "." + image_loc # 시각화 이미지 생성 시작 try: - image = _run_visualization(explainer, image_path, explain_req_dict) + image = _run_visualization(explainer, image_path, category, explain_req_dict) except Exception: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -79,7 +103,7 @@ async def run_explain(): # 생성된 시각화 이미지 파일 저장 try: - cam_loc = await image_svc.upload_image_cam(user_email, result_dict["image_id"], image) + cam_loc = await image_svc.upload_image_cam(user_email, image_id, image) except Exception: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -110,3 +134,67 @@ async def run_explain(): return loop.run_until_complete(run_explain()) finally: loop.close() + +# 딥페이크 비디오 프레임 위조 흔적 시각화 처리 +@celery_app.task(name="process_explain_frame_task") +def process_explain_frame_task(user_email: str, + version_type: str, + domain_type: str, + video_loc: str, + video_id: int, + category: int, + frame_time: float, + explain_req_dict: dict): + async def run_explain(): + cam_loc = None + try: + model_name, dataset = inference_svc.MODEL_CONFIG[version_type][explain_req_dict["model_type"]][domain_type] + video_path = "." + video_loc + + explainer = _get_or_create_explainer(model_name, dataset, explain_req_dict) + + face = _extract_face_from_frame(video_path, frame_time, explainer) + + # 비디오 프레임 시각화 생성 시작 + try: + image = _run_visualization_from_array(explainer, face, category, explain_req_dict) + except Exception: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="딥페이크 비디오 프레임 위조 흔적을 생성하는 중 오류가 발생하였습니다" + ) + + # 비디오 프레임 시각화 파일 저장 + try: + cam_loc = await image_svc.upload_frame_cam(user_email, video_id, frame_time, image) + except Exception: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="딥페이크 비디오 프레임 위조 흔적 파일을 저장하는 중 오류가 발생했습니다." + ) + + return {"status": "SUCCESS", + "message": "딥페이크 비디오 프레임 위조 흔적 시각화가 성공적으로 이루어졌습니다", + "cam_loc": cam_loc} + + except HTTPException as e: + print(e.detail) + return {"status": "FAILED", "message": str(e.detail)} + + except Exception as e: + print(str(e)) + return {"status": "FAILED", "message": str(e)} + + finally: + # 임시 저장된 비디오 프레임 시각화 파일 삭제 + if cam_loc: + image_svc.cleanup_image_cam.apply_async(args=[cam_loc], countdown=60) + + # 동기식 Celery 워커 내 비동기 이벤트 루프 구동 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(run_explain()) + finally: + loop.close() + diff --git a/App/services/image_svc.py b/App/services/image_svc.py index 4088006..14f0f53 100644 --- a/App/services/image_svc.py +++ b/App/services/image_svc.py @@ -83,6 +83,25 @@ async def upload_image_cam(user_email: str, image_id: int, image: np.ndarray) -> return cam_loc[1:].replace("\\", "/") +# 딥페이크 비디오 프레임 위조 흔적 시각화 파일 서버 내 저장 (회원 전용) +async def upload_frame_cam(user_email: str, video_id: int, frame_time: float, image: np.ndarray) -> str: + user_dir = os.path.join(EXPLAIN_UPLOAD_DIR, user_email) + os.makedirs(user_dir, exist_ok=True) + + cam_filename = f"v{video_id}_t{frame_time}_{int(time.time())}.png" + cam_loc = os.path.join(user_dir, cam_filename) + + image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + _, buf = cv2.imencode(".png", image_bgr) + + try: + async with aio.open(cam_loc, "wb") as outfile: + await outfile.write(buf.tobytes()) + except Exception as e: + raise e + + return cam_loc[1:].replace("\\", "/") + # 사용자 업로드 이미지 서버 내 삭제 # 호출 : image.py : history 삭제 할 때 db 와 실제 파일 삭제 # 호출 : inference.py : 추론 FAIL일 때 delete_video and delete_video_db 실행 diff --git a/inference/video_predictor_prt.py b/inference/video_predictor_prt.py index d854bc8..d6e6e9a 100644 --- a/inference/video_predictor_prt.py +++ b/inference/video_predictor_prt.py @@ -47,9 +47,8 @@ def _extract_frames(self, video_path: str) -> Dict[int, np.ndarray]: frame_cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 - duration_sec = int(frame_cnt / fps) - num_frames = max(10, min(duration_sec, 60)) - + num_frames = int(frame_cnt / fps) + frame_indices = [min(int(sec * fps), frame_cnt - 1) for sec in range(num_frames)] frame_indices = sorted(set(frame_indices)) frames = {}