From 855927dc6653a81593640891307aea0facec8bb0 Mon Sep 17 00:00:00 2001 From: Lungsangg Date: Mon, 22 Jun 2026 13:38:24 +0530 Subject: [PATCH 1/7] done_audio_generate_endpoint --- worker_api/app.py | 4 +- worker_api/audio/__init__.py | 0 worker_api/audio/audio_views.py | 20 ++ worker_api/audio/enums.py | 37 +++ worker_api/audio/models/__init__.py | 0 .../audio/models/plan_item_audio_models.py | 41 +++ worker_api/audio/models/plan_items_models.py | 31 +++ .../audio/models/plan_sub_tasks_models.py | 45 ++++ worker_api/audio/models/plan_tasks_models.py | 35 +++ .../models/sub_task_timestamps_models.py | 39 +++ worker_api/audio/repositories/__init__.py | 0 .../plan_item_audio_repository.py | 30 +++ .../repositories/plan_items_repository.py | 25 ++ .../repositories/plan_sub_tasks_repository.py | 8 + .../sub_task_timestamps_repository.py | 42 +++ worker_api/audio/schemas.py | 21 ++ worker_api/audio/services/__init__.py | 0 .../audio/services/audio_generate_service.py | 250 ++++++++++++++++++ worker_api/audio/services/audio_prompt.py | 75 ++++++ .../audio/services/monlam_tts_service.py | 49 ++++ worker_api/audio/services/tts_service.py | 133 ++++++++++ worker_api/config.py | 8 + worker_api/uploads/S3_utils.py | 31 +++ 23 files changed, 922 insertions(+), 2 deletions(-) create mode 100644 worker_api/audio/__init__.py create mode 100644 worker_api/audio/audio_views.py create mode 100644 worker_api/audio/enums.py create mode 100644 worker_api/audio/models/__init__.py create mode 100644 worker_api/audio/models/plan_item_audio_models.py create mode 100644 worker_api/audio/models/plan_items_models.py create mode 100644 worker_api/audio/models/plan_sub_tasks_models.py create mode 100644 worker_api/audio/models/plan_tasks_models.py create mode 100644 worker_api/audio/models/sub_task_timestamps_models.py create mode 100644 worker_api/audio/repositories/__init__.py create mode 100644 worker_api/audio/repositories/plan_item_audio_repository.py create mode 100644 worker_api/audio/repositories/plan_items_repository.py create mode 100644 worker_api/audio/repositories/plan_sub_tasks_repository.py create mode 100644 worker_api/audio/repositories/sub_task_timestamps_repository.py create mode 100644 worker_api/audio/schemas.py create mode 100644 worker_api/audio/services/__init__.py create mode 100644 worker_api/audio/services/audio_generate_service.py create mode 100644 worker_api/audio/services/audio_prompt.py create mode 100644 worker_api/audio/services/monlam_tts_service.py create mode 100644 worker_api/audio/services/tts_service.py diff --git a/worker_api/app.py b/worker_api/app.py index 6be4cb8..6f62569 100644 --- a/worker_api/app.py +++ b/worker_api/app.py @@ -4,6 +4,7 @@ from worker_api.middleware.request_observability import RequestObservabilityMiddleware from worker_api.db.mongo_database import lifespan +from worker_api.audio.audio_views import audio_router import uvicorn @@ -15,8 +16,7 @@ lifespan=lifespan ) -# Add routers here as endpoints are transferred -# api.include_router(example_router) +api.include_router(audio_router) api.add_middleware( CORSMiddleware, diff --git a/worker_api/audio/__init__.py b/worker_api/audio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/worker_api/audio/audio_views.py b/worker_api/audio/audio_views.py new file mode 100644 index 0000000..acfb9d0 --- /dev/null +++ b/worker_api/audio/audio_views.py @@ -0,0 +1,20 @@ +from fastapi import APIRouter +from starlette import status + +from worker_api.audio.schemas import GeneratePlanAudioRequest +from worker_api.audio.services.audio_generate_service import generate_plan_audio_service + +audio_router = APIRouter(prefix="/audio", tags=["Audio"]) + + +@audio_router.post("/generate", status_code=status.HTTP_200_OK) +async def generate_plan_audio( + request: GeneratePlanAudioRequest, +): + return await generate_plan_audio_service( + day_id=request.day_id, + sub_task_id=request.sub_task_id, + language=request.language, + audio_type=request.type, + voice_name=request.voice_name, + ) diff --git a/worker_api/audio/enums.py b/worker_api/audio/enums.py new file mode 100644 index 0000000..b0d435a --- /dev/null +++ b/worker_api/audio/enums.py @@ -0,0 +1,37 @@ +import enum +from sqlalchemy import Enum + + +class ContentType(enum.Enum): + TEXT = "TEXT" + AUDIO = "AUDIO" + VIDEO = "VIDEO" + IMAGE = "IMAGE" + SOURCE_REFERENCE = "SOURCE_REFERENCE" + + +class PlanAudioType(enum.Enum): + RECITATION = "RECITATION" + INSTRUCTION = "INSTRUCTION" + TEXT_READING = "TEXT_READING" + + +class MonlamVoiceName(str, enum.Enum): + DOLKAR_LHASA_FEMALE = "dolkar_lhasa_female" + YANGCHEN_LHASA_FEMALE = "yangchen_lhasa_female" + DARJEEYALPHEL_LHASA_MALE = "darjeeyalphel_lhasa_male" + HISTRY_LHASA_MALE = "histry_lhasa_male" + SONAMTSERING_LHASA_MALE = "sonamtsering_lhasa_male" + DOLMA_AMDO_FEMALE = "dolma_amdo_female" + KID_AMDO_FEMALE = "kid_amdo_female" + BUDDHAHISTORY_AMDO_MALE = "buddhahistory_amdo_male" + HISTORY_AMDO_MALE = "history_amdo_male" + KALSANG_GYATSO_AMDO_MALE = "kalsang_gyatso_amdo_male" + KOTHEKE_KHAM_MALE = "kotheke_kham_male" + TIBET_TONGUE_KHAM_MALE = "tibet_tongue_kham_male" + TSERING_WANGMO_KHAM_FEMALE = "tsering_wangmo_kham_female" + WANGDONTSO_KHAM_FEMALE = "wangdontso_kham_female" + + +ContentTypeEnum = Enum(ContentType) +PlanAudioTypeEnum = Enum(PlanAudioType) diff --git a/worker_api/audio/models/__init__.py b/worker_api/audio/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/worker_api/audio/models/plan_item_audio_models.py b/worker_api/audio/models/plan_item_audio_models.py new file mode 100644 index 0000000..b048f91 --- /dev/null +++ b/worker_api/audio/models/plan_item_audio_models.py @@ -0,0 +1,41 @@ +from sqlalchemy import Column, Integer, DateTime, ForeignKey, BigInteger, Index, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship +from uuid import uuid4 +from worker_api.db.database import Base +from _datetime import datetime +import _datetime + + +class PlanItemAudio(Base): + __tablename__ = "plan_item_audio" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4) + plan_item_id = Column( + UUID(as_uuid=True), + ForeignKey("items.id", ondelete="CASCADE"), + nullable=False, + unique=True, + ) + audio_key = Column(String(1000), nullable=False) + duration_ms = Column(Integer, nullable=True) + mime_type = Column(String(64), nullable=True) + file_size_bytes = Column(BigInteger, nullable=True) + + created_at = Column( + DateTime(timezone=True), + default=datetime.now(_datetime.timezone.utc), + nullable=False, + ) + created_by = Column(String(255), nullable=False) + updated_at = Column( + DateTime(timezone=True), + default=datetime.now(_datetime.timezone.utc), + ) + updated_by = Column(String(255)) + + plan_item = relationship("PlanItem", back_populates="audio") + + __table_args__ = ( + Index("idx_plan_item_audio_plan_item_id", "plan_item_id"), + ) diff --git a/worker_api/audio/models/plan_items_models.py b/worker_api/audio/models/plan_items_models.py new file mode 100644 index 0000000..2623688 --- /dev/null +++ b/worker_api/audio/models/plan_items_models.py @@ -0,0 +1,31 @@ +from sqlalchemy import Column, Integer, DateTime, ForeignKey, Index, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship +from uuid import uuid4 +from worker_api.db.database import Base +from _datetime import datetime +import _datetime + + +class PlanItem(Base): + __tablename__ = "items" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4) + plan_id = Column(UUID(as_uuid=True), ForeignKey('plans.id', ondelete='CASCADE'), nullable=False) + day_number = Column(Integer, nullable=False) + + created_at = Column(DateTime(timezone=True), default=datetime.now(_datetime.timezone.utc), nullable=False) + created_by = Column(String(255), nullable=False) + updated_at = Column(DateTime(timezone=True), default=datetime.now(_datetime.timezone.utc)) + updated_by = Column(String(255)) + + audio = relationship( + "PlanItemAudio", + back_populates="plan_item", + uselist=False, + cascade="all, delete-orphan", + ) + + __table_args__ = ( + Index("idx_plan_items_plan_day", "plan_id", "day_number"), + ) diff --git a/worker_api/audio/models/plan_sub_tasks_models.py b/worker_api/audio/models/plan_sub_tasks_models.py new file mode 100644 index 0000000..7ce0e4c --- /dev/null +++ b/worker_api/audio/models/plan_sub_tasks_models.py @@ -0,0 +1,45 @@ +from sqlalchemy import Column, Integer, DateTime, Text, ForeignKey, Index, String +from sqlalchemy.dialects.postgresql import UUID, ARRAY +from sqlalchemy.orm import relationship +from uuid import uuid4 +from worker_api.db.database import Base +from worker_api.audio.enums import ContentTypeEnum +from _datetime import datetime +import _datetime + + +class PlanSubTask(Base): + __tablename__ = "sub_tasks" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4) + task_id = Column(UUID(as_uuid=True), ForeignKey('tasks.id', ondelete='CASCADE'), nullable=False) + audio_url = Column(String(255), nullable=True) + content_type = Column(ContentTypeEnum, nullable=False) + content = Column(Text, nullable=True) + duration = Column(String(255), nullable=True) + source_text_id = Column(UUID(as_uuid=True), nullable=True) + pecha_segment_id = Column(String(255), nullable=True) + segment_ids = Column(ARRAY(UUID(as_uuid=True)), nullable=True) + + display_order = Column(Integer, nullable=False) + + created_at = Column(DateTime(timezone=True), default=datetime.now(_datetime.timezone.utc), nullable=False) + created_by = Column(String(255), nullable=False) + updated_at = Column(DateTime(timezone=True), default=datetime.now(_datetime.timezone.utc)) + updated_by = Column(String(255)) + + deleted_at = Column(DateTime(timezone=True)) + deleted_by = Column(String(255)) + + task = relationship("PlanTask", back_populates="sub_tasks") + timestamp = relationship( + "SubTaskTimestamp", + back_populates="sub_task", + uselist=False, + cascade="all, delete-orphan", + ) + + __table_args__ = ( + Index("idx_sub_tasks_task_order", "task_id", "display_order"), + Index("idx_sub_tasks_content_type", "content_type"), + ) diff --git a/worker_api/audio/models/plan_tasks_models.py b/worker_api/audio/models/plan_tasks_models.py new file mode 100644 index 0000000..c43d648 --- /dev/null +++ b/worker_api/audio/models/plan_tasks_models.py @@ -0,0 +1,35 @@ +from sqlalchemy import Column, Integer, DateTime, Boolean, Text, ForeignKey, Index, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship +from uuid import uuid4 +from worker_api.db.database import Base +from _datetime import datetime +import _datetime + + +class PlanTask(Base): + __tablename__ = "tasks" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4) + plan_item_id = Column(UUID(as_uuid=True), ForeignKey('items.id', ondelete='CASCADE'), nullable=False) + + title = Column(Text, nullable=True) + + display_order = Column(Integer, nullable=False) + estimated_time = Column(Integer, nullable=True) + is_required = Column(Boolean, default=True) + + created_at = Column(DateTime(timezone=True), default=datetime.now(_datetime.timezone.utc), nullable=False) + created_by = Column(String(255), nullable=False) + updated_at = Column(DateTime(timezone=True), default=datetime.now(_datetime.timezone.utc)) + updated_by = Column(String(255)) + + deleted_at = Column(DateTime(timezone=True)) + deleted_by = Column(String(255)) + + plan_item = relationship("PlanItem", backref="tasks") + sub_tasks = relationship("PlanSubTask", back_populates="task", cascade="all, delete-orphan") + + __table_args__ = ( + Index("idx_tasks_plan_item_order", "plan_item_id", "display_order"), + ) diff --git a/worker_api/audio/models/sub_task_timestamps_models.py b/worker_api/audio/models/sub_task_timestamps_models.py new file mode 100644 index 0000000..cc82fec --- /dev/null +++ b/worker_api/audio/models/sub_task_timestamps_models.py @@ -0,0 +1,39 @@ +from sqlalchemy import Column, Integer, DateTime, ForeignKey, Index, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship +from uuid import uuid4 +from worker_api.db.database import Base +from _datetime import datetime +import _datetime + + +class SubTaskTimestamp(Base): + __tablename__ = "sub_task_timestamps" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4) + sub_task_id = Column( + UUID(as_uuid=True), + ForeignKey("sub_tasks.id", ondelete="CASCADE"), + nullable=False, + unique=True, + ) + start_ms = Column(Integer, nullable=False) + end_ms = Column(Integer, nullable=False) + + created_at = Column( + DateTime(timezone=True), + default=datetime.now(_datetime.timezone.utc), + nullable=False, + ) + created_by = Column(String(255), nullable=False) + updated_at = Column( + DateTime(timezone=True), + default=datetime.now(_datetime.timezone.utc), + ) + updated_by = Column(String(255)) + + sub_task = relationship("PlanSubTask", back_populates="timestamp") + + __table_args__ = ( + Index("idx_sub_task_timestamps_sub_task_id", "sub_task_id"), + ) diff --git a/worker_api/audio/repositories/__init__.py b/worker_api/audio/repositories/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/worker_api/audio/repositories/plan_item_audio_repository.py b/worker_api/audio/repositories/plan_item_audio_repository.py new file mode 100644 index 0000000..4d34b5a --- /dev/null +++ b/worker_api/audio/repositories/plan_item_audio_repository.py @@ -0,0 +1,30 @@ +from typing import Optional +from uuid import UUID +from sqlalchemy.orm import Session + +from worker_api.audio.models.plan_item_audio_models import PlanItemAudio + + +def get_plan_item_audio_by_plan_item_id(db: Session, plan_item_id: UUID) -> Optional[PlanItemAudio]: + return ( + db.query(PlanItemAudio) + .filter(PlanItemAudio.plan_item_id == plan_item_id) + .first() + ) + + +def upsert_plan_item_audio(db: Session, plan_item_audio: PlanItemAudio) -> PlanItemAudio: + existing = get_plan_item_audio_by_plan_item_id(db=db, plan_item_id=plan_item_audio.plan_item_id) + if existing: + existing.audio_key = plan_item_audio.audio_key + existing.duration_ms = plan_item_audio.duration_ms + existing.mime_type = plan_item_audio.mime_type + existing.file_size_bytes = plan_item_audio.file_size_bytes + existing.updated_by = plan_item_audio.updated_by + db.commit() + db.refresh(existing) + return existing + db.add(plan_item_audio) + db.commit() + db.refresh(plan_item_audio) + return plan_item_audio diff --git a/worker_api/audio/repositories/plan_items_repository.py b/worker_api/audio/repositories/plan_items_repository.py new file mode 100644 index 0000000..6561227 --- /dev/null +++ b/worker_api/audio/repositories/plan_items_repository.py @@ -0,0 +1,25 @@ +from uuid import UUID +from fastapi import HTTPException +from sqlalchemy.orm import Session, joinedload +from starlette import status + +from worker_api.audio.models.plan_items_models import PlanItem +from worker_api.audio.models.plan_tasks_models import PlanTask +from worker_api.audio.models.plan_sub_tasks_models import PlanSubTask + + +def get_plan_day_by_id_any_plan(db: Session, day_id: UUID) -> PlanItem: + plan_item = ( + db.query(PlanItem) + .options( + joinedload(PlanItem.tasks).joinedload(PlanTask.sub_tasks).joinedload(PlanSubTask.timestamp), + ) + .filter(PlanItem.id == day_id) + .first() + ) + if not plan_item: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": "BAD_REQUEST", "message": "Plan day not found"}, + ) + return plan_item diff --git a/worker_api/audio/repositories/plan_sub_tasks_repository.py b/worker_api/audio/repositories/plan_sub_tasks_repository.py new file mode 100644 index 0000000..9bb4faa --- /dev/null +++ b/worker_api/audio/repositories/plan_sub_tasks_repository.py @@ -0,0 +1,8 @@ +from uuid import UUID +from sqlalchemy.orm import Session + +from worker_api.audio.models.plan_sub_tasks_models import PlanSubTask + + +def get_sub_task_by_subtask_id(db: Session, id: UUID) -> PlanSubTask: + return db.query(PlanSubTask).filter(PlanSubTask.id == id).first() diff --git a/worker_api/audio/repositories/sub_task_timestamps_repository.py b/worker_api/audio/repositories/sub_task_timestamps_repository.py new file mode 100644 index 0000000..6295d89 --- /dev/null +++ b/worker_api/audio/repositories/sub_task_timestamps_repository.py @@ -0,0 +1,42 @@ +from typing import Optional +from uuid import UUID +from sqlalchemy.orm import Session + +from worker_api.audio.models.sub_task_timestamps_models import SubTaskTimestamp + + +def get_sub_task_timestamp_by_sub_task_id( + db: Session, sub_task_id: UUID +) -> Optional[SubTaskTimestamp]: + return ( + db.query(SubTaskTimestamp) + .filter(SubTaskTimestamp.sub_task_id == sub_task_id) + .first() + ) + + +def upsert_sub_task_timestamp( + db: Session, + sub_task_id: UUID, + start_ms: int, + end_ms: int, + created_by: str, +) -> SubTaskTimestamp: + existing = get_sub_task_timestamp_by_sub_task_id(db=db, sub_task_id=sub_task_id) + if existing: + existing.start_ms = start_ms + existing.end_ms = end_ms + existing.updated_by = created_by + db.commit() + db.refresh(existing) + return existing + row = SubTaskTimestamp( + sub_task_id=sub_task_id, + start_ms=start_ms, + end_ms=end_ms, + created_by=created_by, + ) + db.add(row) + db.commit() + db.refresh(row) + return row diff --git a/worker_api/audio/schemas.py b/worker_api/audio/schemas.py new file mode 100644 index 0000000..68c125b --- /dev/null +++ b/worker_api/audio/schemas.py @@ -0,0 +1,21 @@ +from typing import Optional +from uuid import UUID +from pydantic import BaseModel, model_validator + +from worker_api.audio.enums import PlanAudioType, MonlamVoiceName + + +class GeneratePlanAudioRequest(BaseModel): + day_id: Optional[UUID] = None + sub_task_id: Optional[UUID] = None + language: str + type: Optional[PlanAudioType] = PlanAudioType.TEXT_READING + voice_name: MonlamVoiceName = MonlamVoiceName.DOLKAR_LHASA_FEMALE + + @model_validator(mode="after") + def validate_either_day_or_subtask(self): + if not self.day_id and not self.sub_task_id: + raise ValueError("Either day_id or sub_task_id must be provided") + if self.day_id and self.sub_task_id: + raise ValueError("Provide either day_id or sub_task_id, not both") + return self diff --git a/worker_api/audio/services/__init__.py b/worker_api/audio/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/worker_api/audio/services/audio_generate_service.py b/worker_api/audio/services/audio_generate_service.py new file mode 100644 index 0000000..0edfd90 --- /dev/null +++ b/worker_api/audio/services/audio_generate_service.py @@ -0,0 +1,250 @@ +import struct +from io import BytesIO +from typing import Optional, List +from uuid import UUID, uuid4 + +from fastapi import HTTPException +from sqlalchemy.orm import Session +from starlette import status + +from worker_api.audio.enums import ContentType, PlanAudioType, MonlamVoiceName +from worker_api.audio.models.plan_item_audio_models import PlanItemAudio +from worker_api.audio.models.plan_items_models import PlanItem +from worker_api.audio.models.plan_sub_tasks_models import PlanSubTask +from worker_api.audio.repositories.plan_items_repository import get_plan_day_by_id_any_plan +from worker_api.audio.repositories.plan_sub_tasks_repository import get_sub_task_by_subtask_id +from worker_api.audio.repositories.plan_item_audio_repository import upsert_plan_item_audio +from worker_api.audio.repositories.sub_task_timestamps_repository import upsert_sub_task_timestamp +from worker_api.audio.services.tts_service import generate_tts_audio +from worker_api.config import get +from worker_api.db.database import SessionLocal +from worker_api.uploads.S3_utils import upload_bytes, download_bytes, generate_presigned_access_url + +WAV_CONTENT_TYPE = "audio/wav" + + +def _generate_audio_segments( + tasks, + audio_type: PlanAudioType, + language: str, + voice_name: Optional[str] = None, +) -> tuple[List[bytes], list]: + wav_header_size = 44 + audio_segments: List[bytes] = [] + subtask_refs = [] + allowed_types = {ContentType.TEXT, ContentType.SOURCE_REFERENCE} + for task in tasks: + subtask = task.sub_tasks[0] if task.sub_tasks else None + if not subtask: + continue + if subtask.content_type not in allowed_types: + continue + + if subtask.audio_url: + existing_wav = download_bytes( + key=subtask.audio_url, + ) + raw_pcm = existing_wav[wav_header_size:] + else: + wav_bytes = generate_tts_audio( + subtask.content, audio_type, language, voice_name=voice_name + ) + raw_pcm = wav_bytes[wav_header_size:] + + audio_segments.append(raw_pcm) + subtask_refs.append(subtask) + return audio_segments, subtask_refs + + +def _update_subtask_timestamps( + db: Session, + audio_segments: List[bytes], + subtask_refs: list, + sample_rate: int, + bytes_per_sample: int, +) -> int: + current_offset_ms = 0 + for i, raw_pcm in enumerate(audio_segments): + segment_samples = len(raw_pcm) // bytes_per_sample + segment_duration_ms = int((segment_samples / sample_rate) * 1000) + upsert_sub_task_timestamp( + db=db, + sub_task_id=subtask_refs[i].id, + start_ms=current_offset_ms, + end_ms=current_offset_ms + segment_duration_ms, + created_by="system", + ) + current_offset_ms += segment_duration_ms + return current_offset_ms + + +def _build_combined_wav(audio_segments: List[bytes]) -> tuple[bytes, int]: + sample_rate = 24000 + bits_per_sample = 16 + num_channels = 1 + bytes_per_sample = bits_per_sample // 8 + + combined_pcm = b"".join(audio_segments) + block_align = num_channels * bytes_per_sample + byte_rate = sample_rate * block_align + data_size = len(combined_pcm) + chunk_size = 36 + data_size + + wav_header = struct.pack( + "<4sI4s4sIHHIIHH4sI", + b"RIFF", chunk_size, b"WAVE", + b"fmt ", 16, 1, num_channels, + sample_rate, byte_rate, block_align, bits_per_sample, + b"data", data_size, + ) + return wav_header + combined_pcm, data_size + + +def _upload_and_persist_audio( + db: Session, + combined_wav: bytes, + duration_ms: int, + plan_id: UUID, + plan_item_id: UUID, +) -> PlanItemAudio: + s3_key = f"audio/plan_days/{plan_id}/{plan_item_id}/{uuid4()}.wav" + upload_bytes( + file_bytes=BytesIO(combined_wav), + key=s3_key, + content_type=WAV_CONTENT_TYPE, + ) + return upsert_plan_item_audio( + db=db, + plan_item_audio=PlanItemAudio( + plan_item_id=plan_item_id, + audio_key=s3_key, + duration_ms=duration_ms, + mime_type=WAV_CONTENT_TYPE, + file_size_bytes=len(combined_wav), + created_by="system", + ), + ) + + +async def generate_plan_audio_service( + language: str, + day_id: Optional[UUID] = None, + sub_task_id: Optional[UUID] = None, + audio_type: PlanAudioType = PlanAudioType.TEXT_READING, + voice_name: MonlamVoiceName = MonlamVoiceName.DOLKAR_LHASA_FEMALE, +): + if sub_task_id: + return await _generate_subtask_audio( + sub_task_id=sub_task_id, + audio_type=audio_type, + language=language, + voice_name=voice_name, + ) + + SAMPLE_RATE = 24000 + BYTES_PER_SAMPLE = 2 + + with SessionLocal() as db: + plan_item: PlanItem = get_plan_day_by_id_any_plan(db=db, day_id=day_id) + + audio_segments, subtask_refs = _generate_audio_segments( + plan_item.tasks, audio_type, language, voice_name + ) + if not audio_segments: + return [] + + duration_ms = _update_subtask_timestamps( + db=db, + audio_segments=audio_segments, + subtask_refs=subtask_refs, + sample_rate=SAMPLE_RATE, + bytes_per_sample=BYTES_PER_SAMPLE, + ) + + combined_wav, _ = _build_combined_wav(audio_segments) + + audio_row = _upload_and_persist_audio( + db=db, + combined_wav=combined_wav, + duration_ms=duration_ms, + plan_id=plan_item.plan_id, + plan_item_id=plan_item.id, + ) + + audio_url = generate_presigned_access_url( + key=audio_row.audio_key, + ) + + return { + "audio_url": audio_url, + "audio_duration_ms": audio_row.duration_ms, + "s3_key": audio_row.audio_key, + } + + +async def _generate_subtask_audio( + sub_task_id: UUID, + audio_type: PlanAudioType, + language: str, + voice_name: Optional[str] = None, +): + SAMPLE_RATE = 24000 + BYTES_PER_SAMPLE = 2 + WAV_HEADER_SIZE = 44 + + with SessionLocal() as db: + subtask: PlanSubTask = get_sub_task_by_subtask_id(db=db, id=sub_task_id) + if not subtask: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": "BAD_REQUEST", "message": "Sub task not found"}, + ) + + allowed_types = {ContentType.TEXT, ContentType.SOURCE_REFERENCE} + if subtask.content_type not in allowed_types: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": "BAD_REQUEST", + "message": "Sub task content type must be TEXT or SOURCE_REFERENCE for audio generation", + }, + ) + + wav_bytes = generate_tts_audio( + subtask.content, audio_type, language, voice_name=voice_name + ) + raw_pcm = wav_bytes[WAV_HEADER_SIZE:] + + segment_samples = len(raw_pcm) // BYTES_PER_SAMPLE + duration_ms = int((segment_samples / SAMPLE_RATE) * 1000) + + combined_wav, _ = _build_combined_wav([raw_pcm]) + + s3_key = f"audio/plan_subtasks/{subtask.task_id}/{sub_task_id}/{uuid4()}.wav" + upload_bytes( + file_bytes=BytesIO(combined_wav), + key=s3_key, + content_type=WAV_CONTENT_TYPE, + ) + + subtask.audio_url = s3_key + subtask.duration = str(duration_ms) + db.commit() + + upsert_sub_task_timestamp( + db=db, + sub_task_id=sub_task_id, + start_ms=0, + end_ms=duration_ms, + created_by="system", + ) + + audio_url = generate_presigned_access_url( + key=s3_key, + ) + + return { + "audio_url": audio_url, + "audio_duration_ms": duration_ms, + "s3_key": s3_key, + } diff --git a/worker_api/audio/services/audio_prompt.py b/worker_api/audio/services/audio_prompt.py new file mode 100644 index 0000000..4c2f3bd --- /dev/null +++ b/worker_api/audio/services/audio_prompt.py @@ -0,0 +1,75 @@ +from worker_api.audio.enums import PlanAudioType + +DEFAULT_VOICE_NAME = "Algenib" +DEFAULT_ACCENT = "Neutral" + +RECITATION_SCENE = """## Scene: +""" + +RECITATION_SAMPLE_CONTEXT = """## Sample Context: +""" + +INSTRUCTION_SCENE = """## Scene: + The Corporate Studio. +""" + +INSTRUCTION_SAMPLE_CONTEXT = """## Sample Context: +InstructionAal E-learning. Measured pacAing with clear pauses for clarity. Tone is authoritative, accessible, and articulate. +.""" + +TEXT_READING_SCENE = """## Scene: +It is dawn inside a vast, silent meditation hall nestled deep within a quiet forest monastery. A senior monk sits perfectly upright, speaking to a small circle of practitioners. The room possesses a subtle, warm, and spacious acoustic resonance. The atmosphere demands an organic, deeply grounded, and entirely unhurried delivery, where every word is heavy with presence. +""" + +TEXT_READING_SAMPLE_CONTEXT = """## Sample Context: +This voice profile is the definitive standard for serene scriptural readings, timeless philosophical translations, and monastic audiobooks where the listener requires immense breathing room to contemplate profound truths +""" + +AUDIO_TYPE_CONFIGS = { + PlanAudioType.RECITATION: { + "style": "Warm, understanding, soft tone with gentle inflections.", + "pace": "Natural conversational pace.", + "scene": RECITATION_SCENE, + "sample_context": RECITATION_SAMPLE_CONTEXT, + }, + PlanAudioType.INSTRUCTION: { + "style": "Authoritative, accessible, and articulate", + "pace": "Natural conversational pace with clear pauses for clarity", + "scene": INSTRUCTION_SCENE, + "sample_context": INSTRUCTION_SAMPLE_CONTEXT, + }, + PlanAudioType.TEXT_READING: { + "style": "Warm, understanding, soft tone with gentle inflections. Pace: Natural conversational pace.", + "pace": "Natural comfortable reading pace", + "scene": TEXT_READING_SCENE, + "sample_context": TEXT_READING_SAMPLE_CONTEXT, + }, +} + + +def build_tts_prompt( + transcript: str, + audio_type: PlanAudioType, +) -> str: + type_config = AUDIO_TYPE_CONFIGS[audio_type] + + director_note = ( + f"# Director's note\n" + f"Style: {type_config['style']}. " + f"Pace: {type_config['pace']}. " + f"Accent: {DEFAULT_ACCENT}." + ) + + parts = [ + "Read the following transcript based on the director's note.", + "", + director_note, + "", + type_config["scene"], + "", + type_config["sample_context"], + "", + f"## Transcript:\n{transcript}", + ] + + return "\n".join(parts) diff --git a/worker_api/audio/services/monlam_tts_service.py b/worker_api/audio/services/monlam_tts_service.py new file mode 100644 index 0000000..fe8929f --- /dev/null +++ b/worker_api/audio/services/monlam_tts_service.py @@ -0,0 +1,49 @@ +import httpx + +from worker_api.config import get +from worker_api.audio.enums import MonlamVoiceName + +DEFAULT_MONLAM_VOICE_NAME = MonlamVoiceName.DOLKAR_LHASA_FEMALE.value + + +def generate_monlam_tts_audio(content: str, voice_name: str | None = None) -> bytes: + if not content.strip(): + raise ValueError("Content cannot be empty") + + base_url = get("MONLAM_BASE_URL").rstrip("/") + api_key = get("MONLAM_API_KEY") + if not api_key: + raise RuntimeError("MONLAM_API_KEY is not configured") + + payload = { + "text": content, + "provider": get("MONLAM_TTS_PROVIDER"), + "model_name": get("MONLAM_TTS_MODEL_NAME"), + } + resolved_voice_name = voice_name or get("MONLAM_TTS_VOICE_NAME") or DEFAULT_MONLAM_VOICE_NAME + payload["voice_name"] = resolved_voice_name + + try: + response = httpx.post( + f"{base_url}/api/v1/text-to-speech/stream", + headers={ + "X-API-Key": api_key, + "Content-Type": "application/json", + }, + json=payload, + timeout=300.0, + ) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + detail = exc.response.text + raise RuntimeError( + f"Monlam TTS request failed with status {exc.response.status_code}: {detail}" + ) from exc + except httpx.RequestError as exc: + raise RuntimeError(f"Monlam TTS request failed: {exc}") from exc + + wav_bytes = response.content + if not wav_bytes or wav_bytes[:4] != b"RIFF": + raise RuntimeError("Monlam TTS generation returned invalid audio data") + + return wav_bytes diff --git a/worker_api/audio/services/tts_service.py b/worker_api/audio/services/tts_service.py new file mode 100644 index 0000000..ad310ce --- /dev/null +++ b/worker_api/audio/services/tts_service.py @@ -0,0 +1,133 @@ +import struct + +from worker_api.config import get +from worker_api.audio.services.audio_prompt import build_tts_prompt, DEFAULT_VOICE_NAME +from worker_api.audio.services.monlam_tts_service import generate_monlam_tts_audio +from worker_api.audio.enums import PlanAudioType + +SUPPORTED_TTS_LANGUAGES = {"en", "bo"} + + +def _normalize_language(language: str) -> str: + return (language or "en").strip().lower() + + +def generate_tts_audio( + content: str, + audio_type: PlanAudioType, + language: str = "en", + voice_name: str | None = None, +) -> bytes: + if not content.strip(): + raise ValueError("Content cannot be empty") + + normalized_language = _normalize_language(language) + if normalized_language not in SUPPORTED_TTS_LANGUAGES: + raise ValueError( + f"Unsupported language for TTS: {language}. Supported: {', '.join(sorted(SUPPORTED_TTS_LANGUAGES))}" + ) + + if normalized_language == "bo": + return generate_monlam_tts_audio(content, voice_name=voice_name) + + return _generate_gemini_tts_audio(content=content, audio_type=audio_type) + + +def _generate_gemini_tts_audio( + content: str, + audio_type: PlanAudioType, +) -> bytes: + prompt = build_tts_prompt(transcript=content, audio_type=audio_type) + + from google import genai + from google.genai import types + + api_key = get("GEMINI_API_KEY") + if not api_key: + raise RuntimeError("GEMINI_API_KEY is not configured") + + client = genai.Client(api_key=api_key) + + response = client.models.generate_content( + model="gemini-2.5-flash-preview-tts", + contents=[ + types.Content( + role="user", + parts=[types.Part.from_text(text=prompt)], + ), + ], + config=types.GenerateContentConfig( + temperature=1, + response_modalities=["audio"], + speech_config=types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name=DEFAULT_VOICE_NAME + ) + ) + ), + ), + ) + + if not response.candidates or not response.candidates[0].content.parts: + raise RuntimeError("TTS generation returned no audio data") + + part = response.candidates[0].content.parts[0] + if not part.inline_data or not part.inline_data.data: + raise RuntimeError("TTS generation returned no audio data") + + audio_data = part.inline_data.data + mime_type = part.inline_data.mime_type or "audio/L16;rate=24000" + + return _convert_to_wav(audio_data, mime_type) + + +def _convert_to_wav(audio_data: bytes, mime_type: str) -> bytes: + params = _parse_audio_mime_type(mime_type) + bits_per_sample = params["bits_per_sample"] + sample_rate = params["rate"] + num_channels = 1 + data_size = len(audio_data) + bytes_per_sample = bits_per_sample // 8 + block_align = num_channels * bytes_per_sample + byte_rate = sample_rate * block_align + chunk_size = 36 + data_size + + header = struct.pack( + "<4sI4s4sIHHIIHH4sI", + b"RIFF", + chunk_size, + b"WAVE", + b"fmt ", + 16, + 1, + num_channels, + sample_rate, + byte_rate, + block_align, + bits_per_sample, + b"data", + data_size, + ) + return header + audio_data + + +def _parse_audio_mime_type(mime_type: str) -> dict: + bits_per_sample = 16 + rate = 24000 + + parts = mime_type.split(";") + for param in parts: + param = param.strip() + if param.lower().startswith("rate="): + try: + rate = int(param.split("=", 1)[1]) + except (ValueError, IndexError): + pass + elif param.startswith("audio/L"): + try: + bits_per_sample = int(param.split("L", 1)[1]) + except (ValueError, IndexError): + pass + + return {"bits_per_sample": bits_per_sample, "rate": rate} diff --git a/worker_api/config.py b/worker_api/config.py index 835a6f2..8085f4f 100644 --- a/worker_api/config.py +++ b/worker_api/config.py @@ -92,6 +92,14 @@ REQUEST_OBSERVABILITY_MEMORY_WARN_MB=50, REQUEST_OBSERVABILITY_SKIP_PATHS="/health", + # TTS Configuration + GEMINI_API_KEY="", + MONLAM_BASE_URL="", + MONLAM_API_KEY="", + MONLAM_TTS_PROVIDER="", + MONLAM_TTS_MODEL_NAME="", + MONLAM_TTS_VOICE_NAME="", + ) TIME_FORMAT_PATTERN = re.compile(r"^([01]\d|2[0-3]):[0-5]\d$") diff --git a/worker_api/uploads/S3_utils.py b/worker_api/uploads/S3_utils.py index 4b1e152..30fbd75 100644 --- a/worker_api/uploads/S3_utils.py +++ b/worker_api/uploads/S3_utils.py @@ -127,6 +127,37 @@ def generate_presigned_access_url(key: str, expiration: int = None) -> str: ) +def download_bytes(key: str) -> bytes: + """ + Download file bytes from S3 bucket. + + Args: + key: S3 object key (path in bucket) + + Returns: + File content as bytes + + Raises: + HTTPException: If download fails + """ + try: + bucket_name = get("AWS_BUCKET_NAME") + response = s3_client.get_object(Bucket=bucket_name, Key=key) + return response["Body"].read() + except ClientError as e: + logger.error(f"Failed to download file from S3: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Failed to download file: {str(e)}", + ) + except Exception as e: + logger.error(f"Unexpected error downloading file from S3: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Unexpected error: {str(e)}", + ) + + def delete_file(key: str): """ Delete a file from S3 bucket. From 8dcb729e225847571a5ba78a5657675e448f33c1 Mon Sep 17 00:00:00 2001 From: Lungsangg Date: Mon, 22 Jun 2026 15:32:29 +0530 Subject: [PATCH 2/7] test case update --- tests/audio/__init__.py | 0 tests/audio/test_audio_generate_service.py | 348 +++++++++++++++++++++ tests/audio/test_audio_views.py | 195 ++++++++++++ 3 files changed, 543 insertions(+) create mode 100644 tests/audio/__init__.py create mode 100644 tests/audio/test_audio_generate_service.py create mode 100644 tests/audio/test_audio_views.py diff --git a/tests/audio/__init__.py b/tests/audio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/audio/test_audio_generate_service.py b/tests/audio/test_audio_generate_service.py new file mode 100644 index 0000000..796ca52 --- /dev/null +++ b/tests/audio/test_audio_generate_service.py @@ -0,0 +1,348 @@ +""" +Tests for audio generation service. +""" +import pytest +from uuid import uuid4 +from unittest.mock import patch, MagicMock, AsyncMock +from io import BytesIO + +from worker_api.audio.enums import ContentType, PlanAudioType, MonlamVoiceName +from worker_api.audio.services.audio_generate_service import ( + generate_plan_audio_service, + _generate_audio_segments, + _update_subtask_timestamps, + _build_combined_wav, + _upload_and_persist_audio, +) + + +class TestGeneratePlanAudioService: + """Tests for generate_plan_audio_service function.""" + + @pytest.mark.asyncio + @patch("worker_api.audio.services.audio_generate_service.SessionLocal") + @patch("worker_api.audio.services.audio_generate_service.get_plan_day_by_id_any_plan") + @patch("worker_api.audio.services.audio_generate_service._generate_audio_segments") + @patch("worker_api.audio.services.audio_generate_service._update_subtask_timestamps") + @patch("worker_api.audio.services.audio_generate_service._build_combined_wav") + @patch("worker_api.audio.services.audio_generate_service._upload_and_persist_audio") + @patch("worker_api.audio.services.audio_generate_service.generate_presigned_access_url") + async def test_generate_day_audio_success( + self, + mock_presigned_url, + mock_upload, + mock_build_wav, + mock_update_timestamps, + mock_generate_segments, + mock_get_day, + mock_session_local, + ): + """Test successful audio generation for a plan day.""" + day_id = uuid4() + plan_id = uuid4() + + # Mock database session + mock_db = MagicMock() + mock_session_local.return_value.__enter__.return_value = mock_db + + # Mock plan item + mock_plan_item = MagicMock() + mock_plan_item.id = day_id + mock_plan_item.plan_id = plan_id + mock_plan_item.tasks = [] + mock_get_day.return_value = mock_plan_item + + # Mock audio segments + mock_generate_segments.return_value = ([b"audio_data"], [MagicMock()]) + mock_update_timestamps.return_value = 45000 + mock_build_wav.return_value = (b"wav_data", 1000) + + # Mock audio row + mock_audio_row = MagicMock() + mock_audio_row.audio_key = "audio/test.wav" + mock_audio_row.duration_ms = 45000 + mock_upload.return_value = mock_audio_row + + mock_presigned_url.return_value = "https://s3.example.com/audio.wav" + + # Execute + result = await generate_plan_audio_service( + language="en", + day_id=day_id, + audio_type=PlanAudioType.TEXT_READING, + ) + + # Assert + assert result["audio_url"] == "https://s3.example.com/audio.wav" + assert result["audio_duration_ms"] == 45000 + assert result["s3_key"] == "audio/test.wav" + + mock_get_day.assert_called_once_with(db=mock_db, day_id=day_id) + mock_generate_segments.assert_called_once() + mock_upload.assert_called_once() + + @pytest.mark.asyncio + @patch("worker_api.audio.services.audio_generate_service.SessionLocal") + @patch("worker_api.audio.services.audio_generate_service.get_plan_day_by_id_any_plan") + @patch("worker_api.audio.services.audio_generate_service._generate_audio_segments") + async def test_generate_day_audio_no_segments( + self, + mock_generate_segments, + mock_get_day, + mock_session_local, + ): + """Test audio generation returns empty when no segments are generated.""" + day_id = uuid4() + + mock_db = MagicMock() + mock_session_local.return_value.__enter__.return_value = mock_db + + mock_plan_item = MagicMock() + mock_plan_item.tasks = [] + mock_get_day.return_value = mock_plan_item + + # No audio segments + mock_generate_segments.return_value = ([], []) + + result = await generate_plan_audio_service( + language="en", + day_id=day_id, + ) + + assert result == [] + + @pytest.mark.asyncio + @patch("worker_api.audio.services.audio_generate_service.SessionLocal") + @patch("worker_api.audio.services.audio_generate_service.get_sub_task_by_subtask_id") + @patch("worker_api.audio.services.audio_generate_service.generate_tts_audio") + @patch("worker_api.audio.services.audio_generate_service.upload_bytes") + @patch("worker_api.audio.services.audio_generate_service.upsert_sub_task_timestamp") + @patch("worker_api.audio.services.audio_generate_service.generate_presigned_access_url") + async def test_generate_subtask_audio_success( + self, + mock_presigned_url, + mock_upsert_timestamp, + mock_upload, + mock_tts, + mock_get_subtask, + mock_session_local, + ): + """Test successful audio generation for a single subtask.""" + sub_task_id = uuid4() + task_id = uuid4() + + mock_db = MagicMock() + mock_session_local.return_value.__enter__.return_value = mock_db + + # Mock subtask + mock_subtask = MagicMock() + mock_subtask.id = sub_task_id + mock_subtask.task_id = task_id + mock_subtask.content = "Test content" + mock_subtask.content_type = ContentType.TEXT + mock_get_subtask.return_value = mock_subtask + + # Mock TTS audio (44 byte header + audio data) + wav_header = b"RIFF" + b"\x00" * 40 + audio_data = b"\x00" * 1000 + mock_tts.return_value = wav_header + audio_data + + mock_presigned_url.return_value = "https://s3.example.com/audio.wav" + + # Execute + result = await generate_plan_audio_service( + language="en", + sub_task_id=sub_task_id, + audio_type=PlanAudioType.TEXT_READING, + ) + + # Assert + assert "audio_url" in result + assert "audio_duration_ms" in result + assert "s3_key" in result + + mock_get_subtask.assert_called_once_with(db=mock_db, id=sub_task_id) + mock_tts.assert_called_once() + mock_upload.assert_called_once() + mock_upsert_timestamp.assert_called_once() + + @pytest.mark.asyncio + @patch("worker_api.audio.services.audio_generate_service.SessionLocal") + @patch("worker_api.audio.services.audio_generate_service.get_sub_task_by_subtask_id") + async def test_generate_subtask_audio_not_found( + self, + mock_get_subtask, + mock_session_local, + ): + """Test error when subtask is not found.""" + sub_task_id = uuid4() + + mock_db = MagicMock() + mock_session_local.return_value.__enter__.return_value = mock_db + + mock_get_subtask.return_value = None + + with pytest.raises(Exception) as exc_info: + await generate_plan_audio_service( + language="en", + sub_task_id=sub_task_id, + ) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + @patch("worker_api.audio.services.audio_generate_service.SessionLocal") + @patch("worker_api.audio.services.audio_generate_service.get_sub_task_by_subtask_id") + async def test_generate_subtask_audio_invalid_content_type( + self, + mock_get_subtask, + mock_session_local, + ): + """Test error when subtask has invalid content type.""" + sub_task_id = uuid4() + + mock_db = MagicMock() + mock_session_local.return_value.__enter__.return_value = mock_db + + mock_subtask = MagicMock() + mock_subtask.content_type = ContentType.VIDEO # Invalid for audio + mock_get_subtask.return_value = mock_subtask + + with pytest.raises(Exception) as exc_info: + await generate_plan_audio_service( + language="en", + sub_task_id=sub_task_id, + ) + + assert exc_info.value.status_code == 400 + + +class TestGenerateAudioSegments: + """Tests for _generate_audio_segments helper function.""" + + @patch("worker_api.audio.services.audio_generate_service.generate_tts_audio") + def test_generate_segments_with_text_content(self, mock_tts): + """Test generating audio segments from text content.""" + wav_data = b"RIFF" + b"\x00" * 40 + b"audio_data" + mock_tts.return_value = wav_data + + mock_subtask = MagicMock() + mock_subtask.content = "Test content" + mock_subtask.content_type = ContentType.TEXT + mock_subtask.audio_url = None + + mock_task = MagicMock() + mock_task.sub_tasks = [mock_subtask] + + segments, refs = _generate_audio_segments( + [mock_task], + PlanAudioType.TEXT_READING, + "en", + ) + + assert len(segments) == 1 + assert len(refs) == 1 + assert refs[0] == mock_subtask + mock_tts.assert_called_once() + + @patch("worker_api.audio.services.audio_generate_service.download_bytes") + def test_generate_segments_with_existing_audio(self, mock_download): + """Test reusing existing audio from subtask.""" + wav_data = b"RIFF" + b"\x00" * 40 + b"existing_audio" + mock_download.return_value = wav_data + + mock_subtask = MagicMock() + mock_subtask.content_type = ContentType.TEXT + mock_subtask.audio_url = "audio/existing.wav" + + mock_task = MagicMock() + mock_task.sub_tasks = [mock_subtask] + + segments, refs = _generate_audio_segments( + [mock_task], + PlanAudioType.TEXT_READING, + "en", + ) + + assert len(segments) == 1 + mock_download.assert_called_once_with(key="audio/existing.wav") + + def test_generate_segments_skips_invalid_content_types(self): + """Test that non-text/source_reference subtasks are skipped.""" + mock_subtask = MagicMock() + mock_subtask.content_type = ContentType.VIDEO + + mock_task = MagicMock() + mock_task.sub_tasks = [mock_subtask] + + segments, refs = _generate_audio_segments( + [mock_task], + PlanAudioType.TEXT_READING, + "en", + ) + + assert len(segments) == 0 + assert len(refs) == 0 + + +class TestBuildCombinedWav: + """Tests for _build_combined_wav helper function.""" + + def test_build_wav_single_segment(self): + """Test building WAV file from single audio segment.""" + audio_data = b"\x00" * 1000 + + wav, size = _build_combined_wav([audio_data]) + + assert len(wav) > len(audio_data) # Header + data + assert wav[:4] == b"RIFF" + assert b"WAVE" in wav + assert size == len(audio_data) + + def test_build_wav_multiple_segments(self): + """Test building WAV file from multiple audio segments.""" + segment1 = b"\x00" * 500 + segment2 = b"\x01" * 500 + + wav, size = _build_combined_wav([segment1, segment2]) + + assert wav[:4] == b"RIFF" + assert size == len(segment1) + len(segment2) + + def test_build_wav_empty_segments(self): + """Test building WAV file with no segments.""" + wav, size = _build_combined_wav([]) + + assert wav[:4] == b"RIFF" + assert size == 0 + + +class TestUpdateSubtaskTimestamps: + """Tests for _update_subtask_timestamps helper function.""" + + @patch("worker_api.audio.services.audio_generate_service.upsert_sub_task_timestamp") + def test_update_timestamps(self, mock_upsert): + """Test updating subtask timestamps.""" + mock_db = MagicMock() + + # 1000 bytes at 2 bytes per sample = 500 samples + # 500 samples / 24000 Hz = 0.020833 seconds = 20.833 ms + audio_segment = b"\x00" * 1000 + + mock_subtask = MagicMock() + mock_subtask.id = uuid4() + + duration = _update_subtask_timestamps( + mock_db, + [audio_segment], + [mock_subtask], + 24000, # sample_rate + 2, # bytes_per_sample + ) + + assert duration > 0 + mock_upsert.assert_called_once() + call_kwargs = mock_upsert.call_args.kwargs + assert call_kwargs["sub_task_id"] == mock_subtask.id + assert call_kwargs["start_ms"] == 0 + assert call_kwargs["end_ms"] > 0 diff --git a/tests/audio/test_audio_views.py b/tests/audio/test_audio_views.py new file mode 100644 index 0000000..57eb737 --- /dev/null +++ b/tests/audio/test_audio_views.py @@ -0,0 +1,195 @@ +""" +Tests for audio generation API endpoints. +""" +import pytest +from uuid import uuid4 +from unittest.mock import patch, AsyncMock + +from worker_api.audio.enums import PlanAudioType, MonlamVoiceName + + +class TestGeneratePlanAudio: + """Tests for POST /audio/generate endpoint.""" + + @pytest.mark.asyncio + @patch("worker_api.audio.audio_views.generate_plan_audio_service") + async def test_generate_audio_with_day_id(self, mock_service, client): + """Test generating audio for a plan day.""" + day_id = uuid4() + mock_service.return_value = { + "audio_url": "https://s3.example.com/audio.wav", + "audio_duration_ms": 45000, + "s3_key": "audio/plan_days/test.wav" + } + + response = client.post( + "/api/v1/audio/generate", + json={ + "day_id": str(day_id), + "language": "en", + "type": "TEXT_READING" + } + ) + + assert response.status_code == 200 + data = response.json() + assert "audio_url" in data + assert "audio_duration_ms" in data + assert "s3_key" in data + assert data["audio_duration_ms"] == 45000 + + mock_service.assert_called_once() + call_kwargs = mock_service.call_args.kwargs + assert call_kwargs["day_id"] == day_id + assert call_kwargs["language"] == "en" + assert call_kwargs["audio_type"] == PlanAudioType.TEXT_READING + + @pytest.mark.asyncio + @patch("worker_api.audio.audio_views.generate_plan_audio_service") + async def test_generate_audio_with_sub_task_id(self, mock_service, client): + """Test generating audio for a single subtask.""" + sub_task_id = uuid4() + mock_service.return_value = { + "audio_url": "https://s3.example.com/audio.wav", + "audio_duration_ms": 12000, + "s3_key": "audio/plan_subtasks/test.wav" + } + + response = client.post( + "/api/v1/audio/generate", + json={ + "sub_task_id": str(sub_task_id), + "language": "bo", + "type": "RECITATION", + "voice_name": "dolkar_lhasa_female" + } + ) + + assert response.status_code == 200 + data = response.json() + assert "audio_url" in data + assert data["audio_duration_ms"] == 12000 + + mock_service.assert_called_once() + call_kwargs = mock_service.call_args.kwargs + assert call_kwargs["sub_task_id"] == sub_task_id + assert call_kwargs["language"] == "bo" + assert call_kwargs["audio_type"] == PlanAudioType.RECITATION + assert call_kwargs["voice_name"] == MonlamVoiceName.DOLKAR_LHASA_FEMALE + + def test_generate_audio_missing_both_ids(self, client): + """Test validation error when both day_id and sub_task_id are missing.""" + response = client.post( + "/api/v1/audio/generate", + json={ + "language": "en", + "type": "TEXT_READING" + } + ) + + assert response.status_code == 422 + data = response.json() + assert "detail" in data + + def test_generate_audio_with_both_ids(self, client): + """Test validation error when both day_id and sub_task_id are provided.""" + response = client.post( + "/api/v1/audio/generate", + json={ + "day_id": str(uuid4()), + "sub_task_id": str(uuid4()), + "language": "en", + "type": "TEXT_READING" + } + ) + + assert response.status_code == 422 + data = response.json() + assert "detail" in data + + def test_generate_audio_missing_language(self, client): + """Test validation error when language is missing.""" + response = client.post( + "/api/v1/audio/generate", + json={ + "day_id": str(uuid4()), + "type": "TEXT_READING" + } + ) + + assert response.status_code == 422 + + @pytest.mark.asyncio + @patch("worker_api.audio.audio_views.generate_plan_audio_service") + async def test_generate_audio_with_instruction_type(self, mock_service, client): + """Test generating instruction audio.""" + day_id = uuid4() + mock_service.return_value = { + "audio_url": "https://s3.example.com/audio.wav", + "audio_duration_ms": 30000, + "s3_key": "audio/plan_days/test.wav" + } + + response = client.post( + "/api/v1/audio/generate", + json={ + "day_id": str(day_id), + "language": "en", + "type": "INSTRUCTION" + } + ) + + assert response.status_code == 200 + mock_service.assert_called_once() + call_kwargs = mock_service.call_args.kwargs + assert call_kwargs["audio_type"] == PlanAudioType.INSTRUCTION + + @pytest.mark.asyncio + @patch("worker_api.audio.audio_views.generate_plan_audio_service") + async def test_generate_audio_default_type(self, mock_service, client): + """Test that TEXT_READING is the default audio type.""" + day_id = uuid4() + mock_service.return_value = { + "audio_url": "https://s3.example.com/audio.wav", + "audio_duration_ms": 30000, + "s3_key": "audio/plan_days/test.wav" + } + + response = client.post( + "/api/v1/audio/generate", + json={ + "day_id": str(day_id), + "language": "en" + } + ) + + assert response.status_code == 200 + mock_service.assert_called_once() + call_kwargs = mock_service.call_args.kwargs + assert call_kwargs["audio_type"] == PlanAudioType.TEXT_READING + + @pytest.mark.asyncio + @patch("worker_api.audio.audio_views.generate_plan_audio_service") + async def test_generate_audio_tibetan_with_voice(self, mock_service, client): + """Test generating Tibetan audio with specific voice.""" + sub_task_id = uuid4() + mock_service.return_value = { + "audio_url": "https://s3.example.com/audio.wav", + "audio_duration_ms": 15000, + "s3_key": "audio/plan_subtasks/test.wav" + } + + response = client.post( + "/api/v1/audio/generate", + json={ + "sub_task_id": str(sub_task_id), + "language": "bo", + "type": "TEXT_READING", + "voice_name": "sonamtsering_lhasa_male" + } + ) + + assert response.status_code == 200 + mock_service.assert_called_once() + call_kwargs = mock_service.call_args.kwargs + assert call_kwargs["voice_name"] == MonlamVoiceName.SONAMTSERING_LHASA_MALE From 52f498ff4f6103a3f2e93d93b56f41739d5c7967 Mon Sep 17 00:00:00 2001 From: Lungsangg Date: Mon, 22 Jun 2026 16:19:07 +0530 Subject: [PATCH 3/7] llm_endpoint_update --- worker_api/app.py | 2 + worker_api/config.py | 2 +- worker_api/llm/__init__.py | 0 worker_api/llm/llm_response_models.py | 13 +++++ worker_api/llm/llm_service.py | 78 +++++++++++++++++++++++++++ worker_api/llm/llm_views.py | 17 ++++++ 6 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 worker_api/llm/__init__.py create mode 100644 worker_api/llm/llm_response_models.py create mode 100644 worker_api/llm/llm_service.py create mode 100644 worker_api/llm/llm_views.py diff --git a/worker_api/app.py b/worker_api/app.py index 6f62569..3787d99 100644 --- a/worker_api/app.py +++ b/worker_api/app.py @@ -5,6 +5,7 @@ from worker_api.middleware.request_observability import RequestObservabilityMiddleware from worker_api.db.mongo_database import lifespan from worker_api.audio.audio_views import audio_router +from worker_api.llm.llm_views import llm_router import uvicorn @@ -17,6 +18,7 @@ ) api.include_router(audio_router) +api.include_router(llm_router) api.add_middleware( CORSMiddleware, diff --git a/worker_api/config.py b/worker_api/config.py index 8085f4f..08eb3db 100644 --- a/worker_api/config.py +++ b/worker_api/config.py @@ -85,7 +85,7 @@ SQS_TIMEOUT=1800, GROUP_INVITE_EXPIRY_MINUTES=30, - WEBUDDHIST_EMAIL_LOGO_URL="https://studio.webuddhist.com/assets/pecha_icon-DkKJLXuA.png", + WEBUDDHIST_EMAIL_LOGO_URL="", # Request observability (per-endpoint memory and latency logging) REQUEST_OBSERVABILITY_ENABLED="true", diff --git a/worker_api/llm/__init__.py b/worker_api/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/worker_api/llm/llm_response_models.py b/worker_api/llm/llm_response_models.py new file mode 100644 index 0000000..a081376 --- /dev/null +++ b/worker_api/llm/llm_response_models.py @@ -0,0 +1,13 @@ +from typing import Optional +from pydantic import BaseModel, Field + + +class LLMChatRequest(BaseModel): + prompt: str = Field(..., description="User prompt to send to Gemini") + system_prompt: Optional[str] = Field(None, description="Optional system instruction for Gemini") + model: Optional[str] = Field(None, description="Optional Gemini model name (defaults to gemini-2.5-flash)") + + +class LLMChatResponse(BaseModel): + response: str = Field(..., description="Gemini's response text") + model: str = Field(..., description="Model used for generation") diff --git a/worker_api/llm/llm_service.py b/worker_api/llm/llm_service.py new file mode 100644 index 0000000..ccda716 --- /dev/null +++ b/worker_api/llm/llm_service.py @@ -0,0 +1,78 @@ +import asyncio +from typing import Optional +from google import genai +from google.genai import types +from fastapi import HTTPException, status + +from worker_api.config import get + +DEFAULT_MODEL = "gemini-2.5-flash" + + +def _chat_with_gemini_sync( + prompt: str, + system_prompt: Optional[str] = None, + model: Optional[str] = None +) -> dict: + api_key = get("GEMINI_API_KEY") + if not api_key: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="GEMINI_API_KEY is not configured" + ) + + client = genai.Client(api_key=api_key) + + model_name = model or DEFAULT_MODEL + + config = types.GenerateContentConfig( + temperature=1.0, + ) + + if system_prompt: + config.system_instruction = system_prompt + + contents = [ + types.Content( + role="user", + parts=[types.Part.from_text(text=prompt)], + ), + ] + + try: + response = client.models.generate_content( + model=model_name, + contents=contents, + config=config, + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Gemini API error: {str(e)}" + ) + + if not response.candidates or not response.candidates[0].content.parts: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Gemini returned no response" + ) + + response_text = response.candidates[0].content.parts[0].text + + return { + "response": response_text, + "model": model_name + } + + +async def chat_with_gemini( + prompt: str, + system_prompt: Optional[str] = None, + model: Optional[str] = None +) -> dict: + return await asyncio.to_thread( + _chat_with_gemini_sync, + prompt, + system_prompt, + model + ) diff --git a/worker_api/llm/llm_views.py b/worker_api/llm/llm_views.py new file mode 100644 index 0000000..6858e16 --- /dev/null +++ b/worker_api/llm/llm_views.py @@ -0,0 +1,17 @@ +from fastapi import APIRouter +from starlette import status + +from worker_api.llm.llm_response_models import LLMChatRequest, LLMChatResponse +from worker_api.llm.llm_service import chat_with_gemini + +llm_router = APIRouter(prefix="/llm", tags=["LLM"]) + + +@llm_router.post("/chat", status_code=status.HTTP_200_OK) +async def chat(request: LLMChatRequest) -> LLMChatResponse: + result = await chat_with_gemini( + prompt=request.prompt, + system_prompt=request.system_prompt, + model=request.model + ) + return LLMChatResponse(**result) From 134d3d8b4c643771a958d1f09a1758bcef6391ff Mon Sep 17 00:00:00 2001 From: Lungsangg Date: Mon, 22 Jun 2026 16:25:21 +0530 Subject: [PATCH 4/7] test_case_update --- tests/llm/__init__.py | 0 tests/llm/test_llm_service.py | 235 ++++++++++++++++++++++++++++++++++ tests/llm/test_llm_views.py | 161 +++++++++++++++++++++++ 3 files changed, 396 insertions(+) create mode 100644 tests/llm/__init__.py create mode 100644 tests/llm/test_llm_service.py create mode 100644 tests/llm/test_llm_views.py diff --git a/tests/llm/__init__.py b/tests/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/llm/test_llm_service.py b/tests/llm/test_llm_service.py new file mode 100644 index 0000000..c879807 --- /dev/null +++ b/tests/llm/test_llm_service.py @@ -0,0 +1,235 @@ +""" +Tests for LLM service. +""" +import pytest +from unittest.mock import patch, MagicMock +from fastapi import HTTPException + +from worker_api.llm.llm_service import chat_with_gemini, _chat_with_gemini_sync, DEFAULT_MODEL + + +class TestChatWithGemini: + """Tests for chat_with_gemini async function.""" + + @pytest.mark.asyncio + @patch("worker_api.llm.llm_service._chat_with_gemini_sync") + async def test_chat_calls_sync_function(self, mock_sync): + """Test that async wrapper calls sync function via asyncio.to_thread.""" + mock_sync.return_value = { + "response": "Test response", + "model": "gemini-2.5-flash" + } + + result = await chat_with_gemini( + prompt="Test prompt", + system_prompt="Test system", + model="gemini-2.5-flash" + ) + + assert result["response"] == "Test response" + assert result["model"] == "gemini-2.5-flash" + mock_sync.assert_called_once_with("Test prompt", "Test system", "gemini-2.5-flash") + + +class TestChatWithGeminiSync: + """Tests for _chat_with_gemini_sync function.""" + + @patch("worker_api.llm.llm_service.get") + @patch("worker_api.llm.llm_service.genai.Client") + def test_successful_chat(self, mock_client_class, mock_get): + """Test successful chat with Gemini.""" + # Mock config + mock_get.return_value = "test-api-key" + + # Mock Gemini client and response + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_part = MagicMock() + mock_part.text = "Paris is the capital of France." + + mock_candidate = MagicMock() + mock_candidate.content.parts = [mock_part] + + mock_response = MagicMock() + mock_response.candidates = [mock_candidate] + + mock_client.models.generate_content.return_value = mock_response + + # Execute + result = _chat_with_gemini_sync( + prompt="What is the capital of France?", + system_prompt=None, + model=None + ) + + # Assert + assert result["response"] == "Paris is the capital of France." + assert result["model"] == DEFAULT_MODEL + + mock_client_class.assert_called_once_with(api_key="test-api-key") + mock_client.models.generate_content.assert_called_once() + + @patch("worker_api.llm.llm_service.get") + def test_missing_api_key(self, mock_get): + """Test error when GEMINI_API_KEY is not configured.""" + mock_get.return_value = "" + + with pytest.raises(HTTPException) as exc_info: + _chat_with_gemini_sync(prompt="Test", system_prompt=None, model=None) + + assert exc_info.value.status_code == 500 + assert "GEMINI_API_KEY is not configured" in exc_info.value.detail + + @patch("worker_api.llm.llm_service.get") + @patch("worker_api.llm.llm_service.genai.Client") + def test_with_system_prompt(self, mock_client_class, mock_get): + """Test chat with system prompt.""" + mock_get.return_value = "test-api-key" + + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_part = MagicMock() + mock_part.text = "Response with system prompt" + + mock_candidate = MagicMock() + mock_candidate.content.parts = [mock_part] + + mock_response = MagicMock() + mock_response.candidates = [mock_candidate] + + mock_client.models.generate_content.return_value = mock_response + + result = _chat_with_gemini_sync( + prompt="Test", + system_prompt="You are a helpful assistant", + model=None + ) + + assert result["response"] == "Response with system prompt" + + # Verify generate_content was called + call_args = mock_client.models.generate_content.call_args + assert call_args is not None + + @patch("worker_api.llm.llm_service.get") + @patch("worker_api.llm.llm_service.genai.Client") + def test_with_custom_model(self, mock_client_class, mock_get): + """Test chat with custom model.""" + mock_get.return_value = "test-api-key" + + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_part = MagicMock() + mock_part.text = "Response from custom model" + + mock_candidate = MagicMock() + mock_candidate.content.parts = [mock_part] + + mock_response = MagicMock() + mock_response.candidates = [mock_candidate] + + mock_client.models.generate_content.return_value = mock_response + + result = _chat_with_gemini_sync( + prompt="Test", + system_prompt=None, + model="gemini-1.5-pro" + ) + + assert result["response"] == "Response from custom model" + assert result["model"] == "gemini-1.5-pro" + + # Verify model was passed correctly + call_args = mock_client.models.generate_content.call_args + assert call_args.kwargs["model"] == "gemini-1.5-pro" + + @patch("worker_api.llm.llm_service.get") + @patch("worker_api.llm.llm_service.genai.Client") + def test_gemini_api_error(self, mock_client_class, mock_get): + """Test error handling when Gemini API fails.""" + mock_get.return_value = "test-api-key" + + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_client.models.generate_content.side_effect = Exception("API connection failed") + + with pytest.raises(HTTPException) as exc_info: + _chat_with_gemini_sync(prompt="Test", system_prompt=None, model=None) + + assert exc_info.value.status_code == 502 + assert "Gemini API error" in exc_info.value.detail + assert "API connection failed" in exc_info.value.detail + + @patch("worker_api.llm.llm_service.get") + @patch("worker_api.llm.llm_service.genai.Client") + def test_no_response_candidates(self, mock_client_class, mock_get): + """Test error when Gemini returns no candidates.""" + mock_get.return_value = "test-api-key" + + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_response = MagicMock() + mock_response.candidates = [] + + mock_client.models.generate_content.return_value = mock_response + + with pytest.raises(HTTPException) as exc_info: + _chat_with_gemini_sync(prompt="Test", system_prompt=None, model=None) + + assert exc_info.value.status_code == 502 + assert "Gemini returned no response" in exc_info.value.detail + + @patch("worker_api.llm.llm_service.get") + @patch("worker_api.llm.llm_service.genai.Client") + def test_no_response_parts(self, mock_client_class, mock_get): + """Test error when Gemini candidate has no parts.""" + mock_get.return_value = "test-api-key" + + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_candidate = MagicMock() + mock_candidate.content.parts = [] + + mock_response = MagicMock() + mock_response.candidates = [mock_candidate] + + mock_client.models.generate_content.return_value = mock_response + + with pytest.raises(HTTPException) as exc_info: + _chat_with_gemini_sync(prompt="Test", system_prompt=None, model=None) + + assert exc_info.value.status_code == 502 + assert "Gemini returned no response" in exc_info.value.detail + + @patch("worker_api.llm.llm_service.get") + @patch("worker_api.llm.llm_service.genai.Client") + def test_default_model_used(self, mock_client_class, mock_get): + """Test that default model is used when none specified.""" + mock_get.return_value = "test-api-key" + + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_part = MagicMock() + mock_part.text = "Response" + + mock_candidate = MagicMock() + mock_candidate.content.parts = [mock_part] + + mock_response = MagicMock() + mock_response.candidates = [mock_candidate] + + mock_client.models.generate_content.return_value = mock_response + + result = _chat_with_gemini_sync(prompt="Test", system_prompt=None, model=None) + + assert result["model"] == DEFAULT_MODEL + + call_args = mock_client.models.generate_content.call_args + assert call_args.kwargs["model"] == DEFAULT_MODEL diff --git a/tests/llm/test_llm_views.py b/tests/llm/test_llm_views.py new file mode 100644 index 0000000..28279a7 --- /dev/null +++ b/tests/llm/test_llm_views.py @@ -0,0 +1,161 @@ +""" +Tests for LLM chat API endpoints. +""" +import pytest +from unittest.mock import patch, AsyncMock +from fastapi import HTTPException + + +class TestLLMChat: + """Tests for POST /llm/chat endpoint.""" + + @pytest.mark.asyncio + @patch("worker_api.llm.llm_views.chat_with_gemini") + async def test_chat_with_prompt_only(self, mock_chat, client): + """Test chat with only prompt provided.""" + mock_chat.return_value = { + "response": "Paris is the capital of France.", + "model": "gemini-2.5-flash" + } + + response = client.post( + "/api/v1/llm/chat", + json={ + "prompt": "What is the capital of France?" + } + ) + + assert response.status_code == 200 + data = response.json() + assert data["response"] == "Paris is the capital of France." + assert data["model"] == "gemini-2.5-flash" + + mock_chat.assert_called_once() + call_kwargs = mock_chat.call_args.kwargs + assert call_kwargs["prompt"] == "What is the capital of France?" + assert call_kwargs["system_prompt"] is None + assert call_kwargs["model"] is None + + @pytest.mark.asyncio + @patch("worker_api.llm.llm_views.chat_with_gemini") + async def test_chat_with_system_prompt(self, mock_chat, client): + """Test chat with system prompt provided.""" + mock_chat.return_value = { + "response": "As a geography expert, Paris is the capital of France.", + "model": "gemini-2.5-flash" + } + + response = client.post( + "/api/v1/llm/chat", + json={ + "prompt": "What is the capital of France?", + "system_prompt": "You are a helpful geography assistant." + } + ) + + assert response.status_code == 200 + data = response.json() + assert "Paris" in data["response"] + + mock_chat.assert_called_once() + call_kwargs = mock_chat.call_args.kwargs + assert call_kwargs["system_prompt"] == "You are a helpful geography assistant." + + @pytest.mark.asyncio + @patch("worker_api.llm.llm_views.chat_with_gemini") + async def test_chat_with_custom_model(self, mock_chat, client): + """Test chat with custom model specified.""" + mock_chat.return_value = { + "response": "Test response", + "model": "gemini-1.5-pro" + } + + response = client.post( + "/api/v1/llm/chat", + json={ + "prompt": "Hello", + "model": "gemini-1.5-pro" + } + ) + + assert response.status_code == 200 + data = response.json() + assert data["model"] == "gemini-1.5-pro" + + mock_chat.assert_called_once() + call_kwargs = mock_chat.call_args.kwargs + assert call_kwargs["model"] == "gemini-1.5-pro" + + @pytest.mark.asyncio + async def test_chat_missing_prompt(self, client): + """Test chat with missing prompt returns validation error.""" + response = client.post( + "/api/v1/llm/chat", + json={} + ) + + assert response.status_code == 422 + data = response.json() + assert "detail" in data + + @pytest.mark.asyncio + async def test_chat_empty_prompt(self, client): + """Test chat with empty prompt.""" + response = client.post( + "/api/v1/llm/chat", + json={ + "prompt": "" + } + ) + + # Empty string is valid for Pydantic, but will likely fail at service level + assert response.status_code in [200, 422, 500, 502] + + @pytest.mark.asyncio + @patch("worker_api.llm.llm_views.chat_with_gemini") + async def test_chat_service_error(self, mock_chat, client): + """Test chat when service raises HTTPException.""" + mock_chat.side_effect = HTTPException( + status_code=502, + detail="Gemini API error: Connection timeout" + ) + + response = client.post( + "/api/v1/llm/chat", + json={ + "prompt": "Test prompt" + } + ) + + assert response.status_code == 502 + data = response.json() + assert "Gemini API error" in data["detail"] + + @pytest.mark.asyncio + @patch("worker_api.llm.llm_views.chat_with_gemini") + async def test_chat_with_all_parameters(self, mock_chat, client): + """Test chat with all parameters provided.""" + mock_chat.return_value = { + "response": "Complete response", + "model": "gemini-1.5-pro" + } + + response = client.post( + "/api/v1/llm/chat", + json={ + "prompt": "Tell me about AI", + "system_prompt": "You are an AI expert", + "model": "gemini-1.5-pro" + } + ) + + assert response.status_code == 200 + data = response.json() + assert data["response"] == "Complete response" + assert data["model"] == "gemini-1.5-pro" + + mock_chat.assert_called_once() + call_kwargs = mock_chat.call_args.kwargs + assert call_kwargs["prompt"] == "Tell me about AI" + assert call_kwargs["system_prompt"] == "You are an AI expert" + assert call_kwargs["model"] == "gemini-1.5-pro" From e427acebd0e1cc7c1104a9edc67e7e78db7152ac Mon Sep 17 00:00:00 2001 From: Lungsangg Date: Mon, 22 Jun 2026 16:28:55 +0530 Subject: [PATCH 5/7] sonarqube_setup --- .github/workflows/github-ci.yml | 37 +++++++++++++++++++++++++++++++++ sonar-project.properties | 12 +++++++++++ 2 files changed, 49 insertions(+) create mode 100644 sonar-project.properties diff --git a/.github/workflows/github-ci.yml b/.github/workflows/github-ci.yml index d6da52b..51a1729 100644 --- a/.github/workflows/github-ci.yml +++ b/.github/workflows/github-ci.yml @@ -97,3 +97,40 @@ jobs: - name: Run tests with coverage run: | poetry run pytest --cov=worker_api --cov-report=xml + + sonarqube: + name: SonarQube Scan + runs-on: ubuntu-latest + needs: [ buildDockerImage ] + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + version: 1.7.1 + virtualenvs-create: true + virtualenvs-in-project: true + + - name: Install dependencies + run: | + poetry config virtualenvs.create false --local + poetry install --no-interaction --with dev + + - name: Run tests with coverage + run: | + poetry run pytest --cov=worker_api --cov-report=xml + + - name: SonarQube Scan + uses: SonarSource/sonarqube-scan-action@v5 + env: + SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} + diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 0000000..140df82 --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,12 @@ +sonar.projectKey=webuddhist_worker +sonar.organization=webuddhist +sonar.sources=worker_api +sonar.tests=tests +sonar.python.version=3.12 +sonar.python.coverage.reportPaths=coverage.xml + +# Exclude model files, repository files, and generated code from code coverage +sonar.coverage.exclusions=**/*_models.py,**/*_response_models.py,**/*_repository.py,**/external_clients/** + +# Exclude generated code from analysis entirely +sonar.exclusions=**/external_clients/** From e9258497bd66350a9f389bd563aab7ecc93a008f Mon Sep 17 00:00:00 2001 From: Lungsangg Date: Mon, 22 Jun 2026 17:27:59 +0530 Subject: [PATCH 6/7] test_coverage_update --- tests/audio/test_monlam_tts_service.py | 240 ++++++++++++++++++++++ tests/audio/test_tts_service.py | 234 +++++++++++++++++++++ tests/test_constants.py | 56 +++++ tests/test_error_constants.py | 44 ++++ tests/test_s3_utils.py | 270 +++++++++++++++++++++++++ 5 files changed, 844 insertions(+) create mode 100644 tests/audio/test_monlam_tts_service.py create mode 100644 tests/audio/test_tts_service.py create mode 100644 tests/test_constants.py create mode 100644 tests/test_error_constants.py create mode 100644 tests/test_s3_utils.py diff --git a/tests/audio/test_monlam_tts_service.py b/tests/audio/test_monlam_tts_service.py new file mode 100644 index 0000000..1fa3796 --- /dev/null +++ b/tests/audio/test_monlam_tts_service.py @@ -0,0 +1,240 @@ +import pytest +from unittest.mock import patch, MagicMock +import httpx + +from worker_api.audio.services.monlam_tts_service import ( + generate_monlam_tts_audio, + DEFAULT_MONLAM_VOICE_NAME, +) +from worker_api.audio.enums import MonlamVoiceName + + +class TestGenerateMonlamTtsAudio: + def test_empty_content_raises_error(self): + with pytest.raises(ValueError, match="Content cannot be empty"): + generate_monlam_tts_audio("") + + def test_whitespace_content_raises_error(self): + with pytest.raises(ValueError, match="Content cannot be empty"): + generate_monlam_tts_audio(" ") + + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_missing_api_key_raises_error(self, mock_get): + def get_side_effect(key): + if key == "MONLAM_BASE_URL": + return "https://api.monlam.ai" + if key == "MONLAM_API_KEY": + return None + return "default" + + mock_get.side_effect = get_side_effect + + with pytest.raises(RuntimeError, match="MONLAM_API_KEY is not configured"): + generate_monlam_tts_audio("བོད་སྐད།") + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_successful_generation(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_response = MagicMock() + mock_response.content = b"RIFF" + b"\x00" * 100 + mock_response.status_code = 200 + mock_post.return_value = mock_response + + result = generate_monlam_tts_audio("བོད་སྐད།") + + assert result == b"RIFF" + b"\x00" * 100 + mock_post.assert_called_once() + + call_args = mock_post.call_args + assert call_args[0][0] == "https://api.monlam.ai/api/v1/text-to-speech/stream" + assert call_args[1]["headers"]["X-API-Key"] == "fake_api_key" + assert call_args[1]["json"]["text"] == "བོད་སྐད།" + assert call_args[1]["json"]["provider"] == "test_provider" + assert call_args[1]["json"]["model_name"] == "test_model" + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_custom_voice_name(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_response = MagicMock() + mock_response.content = b"RIFF" + b"\x00" * 100 + mock_response.status_code = 200 + mock_post.return_value = mock_response + + result = generate_monlam_tts_audio("བོད་སྐད།", voice_name="custom_voice") + + assert result == b"RIFF" + b"\x00" * 100 + + call_args = mock_post.call_args + assert call_args[1]["json"]["voice_name"] == "custom_voice" + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_default_voice_name(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_response = MagicMock() + mock_response.content = b"RIFF" + b"\x00" * 100 + mock_response.status_code = 200 + mock_post.return_value = mock_response + + result = generate_monlam_tts_audio("བོད་སྐད།") + + call_args = mock_post.call_args + assert call_args[1]["json"]["voice_name"] == DEFAULT_MONLAM_VOICE_NAME + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_http_status_error(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + + mock_post.side_effect = httpx.HTTPStatusError( + "Error", + request=MagicMock(), + response=mock_response + ) + + with pytest.raises(RuntimeError, match="Monlam TTS request failed with status 500"): + generate_monlam_tts_audio("བོད་སྐད།") + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_request_error(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_post.side_effect = httpx.RequestError("Connection failed") + + with pytest.raises(RuntimeError, match="Monlam TTS request failed"): + generate_monlam_tts_audio("བོད་སྐད།") + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_invalid_audio_data_empty(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_response = MagicMock() + mock_response.content = b"" + mock_response.status_code = 200 + mock_post.return_value = mock_response + + with pytest.raises(RuntimeError, match="Monlam TTS generation returned invalid audio data"): + generate_monlam_tts_audio("བོད་སྐད།") + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_invalid_audio_data_not_riff(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_response = MagicMock() + mock_response.content = b"INVALID_DATA" + mock_response.status_code = 200 + mock_post.return_value = mock_response + + with pytest.raises(RuntimeError, match="Monlam TTS generation returned invalid audio data"): + generate_monlam_tts_audio("བོད་སྐད།") + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_base_url_trailing_slash_removed(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai/", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_response = MagicMock() + mock_response.content = b"RIFF" + b"\x00" * 100 + mock_response.status_code = 200 + mock_post.return_value = mock_response + + generate_monlam_tts_audio("བོད་སྐད།") + + call_args = mock_post.call_args + assert call_args[0][0] == "https://api.monlam.ai/api/v1/text-to-speech/stream" + + +def test_default_monlam_voice_name(): + assert DEFAULT_MONLAM_VOICE_NAME == MonlamVoiceName.DOLKAR_LHASA_FEMALE.value diff --git a/tests/audio/test_tts_service.py b/tests/audio/test_tts_service.py new file mode 100644 index 0000000..94d4e79 --- /dev/null +++ b/tests/audio/test_tts_service.py @@ -0,0 +1,234 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +import struct + +from worker_api.audio.services.tts_service import ( + generate_tts_audio, + _normalize_language, + _generate_gemini_tts_audio, + _convert_to_wav, + _parse_audio_mime_type, + SUPPORTED_TTS_LANGUAGES, +) +from worker_api.audio.enums import PlanAudioType + + +class TestNormalizeLanguage: + def test_normalize_language_lowercase(self): + assert _normalize_language("EN") == "en" + assert _normalize_language("Bo") == "bo" + + def test_normalize_language_strip(self): + assert _normalize_language(" en ") == "en" + assert _normalize_language(" bo ") == "bo" + + def test_normalize_language_none(self): + assert _normalize_language(None) == "en" + + def test_normalize_language_empty(self): + assert _normalize_language("") == "en" + + +class TestGenerateTtsAudio: + def test_empty_content_raises_error(self): + with pytest.raises(ValueError, match="Content cannot be empty"): + generate_tts_audio("", PlanAudioType.RECITATION) + + def test_whitespace_content_raises_error(self): + with pytest.raises(ValueError, match="Content cannot be empty"): + generate_tts_audio(" ", PlanAudioType.RECITATION) + + def test_unsupported_language_raises_error(self): + with pytest.raises(ValueError, match="Unsupported language for TTS"): + generate_tts_audio("Hello", PlanAudioType.RECITATION, language="fr") + + @patch("worker_api.audio.services.tts_service.generate_monlam_tts_audio") + def test_tibetan_language_uses_monlam(self, mock_monlam): + mock_monlam.return_value = b"fake_audio_data" + + result = generate_tts_audio("བོད་སྐད།", PlanAudioType.RECITATION, language="bo") + + mock_monlam.assert_called_once_with("བོད་སྐད།", voice_name=None) + assert result == b"fake_audio_data" + + @patch("worker_api.audio.services.tts_service.generate_monlam_tts_audio") + def test_tibetan_language_with_voice_name(self, mock_monlam): + mock_monlam.return_value = b"fake_audio_data" + + result = generate_tts_audio( + "བོད་སྐད།", + PlanAudioType.RECITATION, + language="bo", + voice_name="custom_voice" + ) + + mock_monlam.assert_called_once_with("བོད་སྐད།", voice_name="custom_voice") + assert result == b"fake_audio_data" + + @patch("worker_api.audio.services.tts_service._generate_gemini_tts_audio") + def test_english_language_uses_gemini(self, mock_gemini): + mock_gemini.return_value = b"fake_wav_data" + + result = generate_tts_audio("Hello world", PlanAudioType.RECITATION, language="en") + + mock_gemini.assert_called_once_with( + content="Hello world", + audio_type=PlanAudioType.RECITATION + ) + assert result == b"fake_wav_data" + + def test_supported_languages(self): + assert "en" in SUPPORTED_TTS_LANGUAGES + assert "bo" in SUPPORTED_TTS_LANGUAGES + + +class TestGenerateGeminiTtsAudio: + @patch("worker_api.audio.services.tts_service.get") + def test_missing_api_key_raises_error(self, mock_get): + mock_get.return_value = None + + with pytest.raises(RuntimeError, match="GEMINI_API_KEY is not configured"): + _generate_gemini_tts_audio("Hello", PlanAudioType.RECITATION) + + @patch("worker_api.audio.services.tts_service._convert_to_wav") + @patch("worker_api.audio.services.tts_service.get") + def test_successful_generation(self, mock_get, mock_convert): + mock_get.return_value = "fake_api_key" + mock_convert.return_value = b"fake_wav_data" + + with patch("google.genai.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_part = MagicMock() + mock_part.inline_data.data = b"raw_audio_data" + mock_part.inline_data.mime_type = "audio/L16;rate=24000" + + mock_candidate = MagicMock() + mock_candidate.content.parts = [mock_part] + + mock_response = MagicMock() + mock_response.candidates = [mock_candidate] + + mock_client.models.generate_content.return_value = mock_response + + result = _generate_gemini_tts_audio("Hello world", PlanAudioType.RECITATION) + + assert result == b"fake_wav_data" + mock_convert.assert_called_once_with(b"raw_audio_data", "audio/L16;rate=24000") + + @patch("worker_api.audio.services.tts_service.get") + def test_no_candidates_raises_error(self, mock_get): + mock_get.return_value = "fake_api_key" + + with patch("google.genai.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_response = MagicMock() + mock_response.candidates = [] + + mock_client.models.generate_content.return_value = mock_response + + with pytest.raises(RuntimeError, match="TTS generation returned no audio data"): + _generate_gemini_tts_audio("Hello", PlanAudioType.RECITATION) + + @patch("worker_api.audio.services.tts_service.get") + def test_no_inline_data_raises_error(self, mock_get): + mock_get.return_value = "fake_api_key" + + with patch("google.genai.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_part = MagicMock() + mock_part.inline_data = None + + mock_candidate = MagicMock() + mock_candidate.content.parts = [mock_part] + + mock_response = MagicMock() + mock_response.candidates = [mock_candidate] + + mock_client.models.generate_content.return_value = mock_response + + with pytest.raises(RuntimeError, match="TTS generation returned no audio data"): + _generate_gemini_tts_audio("Hello", PlanAudioType.RECITATION) + + +class TestParseAudioMimeType: + def test_default_values(self): + result = _parse_audio_mime_type("audio/wav") + assert result["bits_per_sample"] == 16 + assert result["rate"] == 24000 + + def test_parse_rate(self): + result = _parse_audio_mime_type("audio/L16;rate=48000") + assert result["rate"] == 48000 + assert result["bits_per_sample"] == 16 + + def test_parse_bits_per_sample(self): + result = _parse_audio_mime_type("audio/L24;rate=24000") + assert result["bits_per_sample"] == 24 + assert result["rate"] == 24000 + + def test_parse_both_parameters(self): + result = _parse_audio_mime_type("audio/L32;rate=16000") + assert result["bits_per_sample"] == 32 + assert result["rate"] == 16000 + + def test_invalid_rate_uses_default(self): + result = _parse_audio_mime_type("audio/L16;rate=invalid") + assert result["rate"] == 24000 + + def test_invalid_bits_uses_default(self): + result = _parse_audio_mime_type("audio/Linvalid;rate=24000") + assert result["bits_per_sample"] == 16 + + +class TestConvertToWav: + def test_convert_basic_audio(self): + audio_data = b"\x00\x01" * 100 + mime_type = "audio/L16;rate=24000" + + result = _convert_to_wav(audio_data, mime_type) + + assert result[:4] == b"RIFF" + assert result[8:12] == b"WAVE" + assert result[12:16] == b"fmt " + assert result[36:40] == b"data" + assert len(result) > len(audio_data) + + def test_convert_different_sample_rate(self): + audio_data = b"\x00\x01" * 50 + mime_type = "audio/L16;rate=48000" + + result = _convert_to_wav(audio_data, mime_type) + + assert result[:4] == b"RIFF" + sample_rate = struct.unpack(" Date: Tue, 23 Jun 2026 11:33:22 +0530 Subject: [PATCH 7/7] update --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 35bd1b7..ce0e0d8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,5 +28,5 @@ COPY . /app # Expose the port that the app runs on EXPOSE 8001 -# Command to run the application -CMD ["sh", "-c", "poetry run alembic upgrade head && poetry run uvicorn worker_api.app:api --host 0.0.0.0 --port 8001 --log-level debug"] +# Command to run the application (skip alembic since migrations are managed by app-pecha-backend) +CMD ["uvicorn", "worker_api.app:api", "--host", "0.0.0.0", "--port", "8001"]