diff --git a/.github/workflows/github-ci.yml b/.github/workflows/github-ci.yml index d6da52b..51a1729 100644 --- a/.github/workflows/github-ci.yml +++ b/.github/workflows/github-ci.yml @@ -97,3 +97,40 @@ jobs: - name: Run tests with coverage run: | poetry run pytest --cov=worker_api --cov-report=xml + + sonarqube: + name: SonarQube Scan + runs-on: ubuntu-latest + needs: [ buildDockerImage ] + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + version: 1.7.1 + virtualenvs-create: true + virtualenvs-in-project: true + + - name: Install dependencies + run: | + poetry config virtualenvs.create false --local + poetry install --no-interaction --with dev + + - name: Run tests with coverage + run: | + poetry run pytest --cov=worker_api --cov-report=xml + + - name: SonarQube Scan + uses: SonarSource/sonarqube-scan-action@v5 + env: + SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} + diff --git a/Dockerfile b/Dockerfile index 35bd1b7..ce0e0d8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,5 +28,5 @@ COPY . /app # Expose the port that the app runs on EXPOSE 8001 -# Command to run the application -CMD ["sh", "-c", "poetry run alembic upgrade head && poetry run uvicorn worker_api.app:api --host 0.0.0.0 --port 8001 --log-level debug"] +# Command to run the application (skip alembic since migrations are managed by app-pecha-backend) +CMD ["uvicorn", "worker_api.app:api", "--host", "0.0.0.0", "--port", "8001"] diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 0000000..140df82 --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,12 @@ +sonar.projectKey=webuddhist_worker +sonar.organization=webuddhist +sonar.sources=worker_api +sonar.tests=tests +sonar.python.version=3.12 +sonar.python.coverage.reportPaths=coverage.xml + +# Exclude model files, repository files, and generated code from code coverage +sonar.coverage.exclusions=**/*_models.py,**/*_response_models.py,**/*_repository.py,**/external_clients/** + +# Exclude generated code from analysis entirely +sonar.exclusions=**/external_clients/** diff --git a/tests/audio/__init__.py b/tests/audio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/audio/test_audio_generate_service.py b/tests/audio/test_audio_generate_service.py new file mode 100644 index 0000000..796ca52 --- /dev/null +++ b/tests/audio/test_audio_generate_service.py @@ -0,0 +1,348 @@ +""" +Tests for audio generation service. +""" +import pytest +from uuid import uuid4 +from unittest.mock import patch, MagicMock, AsyncMock +from io import BytesIO + +from worker_api.audio.enums import ContentType, PlanAudioType, MonlamVoiceName +from worker_api.audio.services.audio_generate_service import ( + generate_plan_audio_service, + _generate_audio_segments, + _update_subtask_timestamps, + _build_combined_wav, + _upload_and_persist_audio, +) + + +class TestGeneratePlanAudioService: + """Tests for generate_plan_audio_service function.""" + + @pytest.mark.asyncio + @patch("worker_api.audio.services.audio_generate_service.SessionLocal") + @patch("worker_api.audio.services.audio_generate_service.get_plan_day_by_id_any_plan") + @patch("worker_api.audio.services.audio_generate_service._generate_audio_segments") + @patch("worker_api.audio.services.audio_generate_service._update_subtask_timestamps") + @patch("worker_api.audio.services.audio_generate_service._build_combined_wav") + @patch("worker_api.audio.services.audio_generate_service._upload_and_persist_audio") + @patch("worker_api.audio.services.audio_generate_service.generate_presigned_access_url") + async def test_generate_day_audio_success( + self, + mock_presigned_url, + mock_upload, + mock_build_wav, + mock_update_timestamps, + mock_generate_segments, + mock_get_day, + mock_session_local, + ): + """Test successful audio generation for a plan day.""" + day_id = uuid4() + plan_id = uuid4() + + # Mock database session + mock_db = MagicMock() + mock_session_local.return_value.__enter__.return_value = mock_db + + # Mock plan item + mock_plan_item = MagicMock() + mock_plan_item.id = day_id + mock_plan_item.plan_id = plan_id + mock_plan_item.tasks = [] + mock_get_day.return_value = mock_plan_item + + # Mock audio segments + mock_generate_segments.return_value = ([b"audio_data"], [MagicMock()]) + mock_update_timestamps.return_value = 45000 + mock_build_wav.return_value = (b"wav_data", 1000) + + # Mock audio row + mock_audio_row = MagicMock() + mock_audio_row.audio_key = "audio/test.wav" + mock_audio_row.duration_ms = 45000 + mock_upload.return_value = mock_audio_row + + mock_presigned_url.return_value = "https://s3.example.com/audio.wav" + + # Execute + result = await generate_plan_audio_service( + language="en", + day_id=day_id, + audio_type=PlanAudioType.TEXT_READING, + ) + + # Assert + assert result["audio_url"] == "https://s3.example.com/audio.wav" + assert result["audio_duration_ms"] == 45000 + assert result["s3_key"] == "audio/test.wav" + + mock_get_day.assert_called_once_with(db=mock_db, day_id=day_id) + mock_generate_segments.assert_called_once() + mock_upload.assert_called_once() + + @pytest.mark.asyncio + @patch("worker_api.audio.services.audio_generate_service.SessionLocal") + @patch("worker_api.audio.services.audio_generate_service.get_plan_day_by_id_any_plan") + @patch("worker_api.audio.services.audio_generate_service._generate_audio_segments") + async def test_generate_day_audio_no_segments( + self, + mock_generate_segments, + mock_get_day, + mock_session_local, + ): + """Test audio generation returns empty when no segments are generated.""" + day_id = uuid4() + + mock_db = MagicMock() + mock_session_local.return_value.__enter__.return_value = mock_db + + mock_plan_item = MagicMock() + mock_plan_item.tasks = [] + mock_get_day.return_value = mock_plan_item + + # No audio segments + mock_generate_segments.return_value = ([], []) + + result = await generate_plan_audio_service( + language="en", + day_id=day_id, + ) + + assert result == [] + + @pytest.mark.asyncio + @patch("worker_api.audio.services.audio_generate_service.SessionLocal") + @patch("worker_api.audio.services.audio_generate_service.get_sub_task_by_subtask_id") + @patch("worker_api.audio.services.audio_generate_service.generate_tts_audio") + @patch("worker_api.audio.services.audio_generate_service.upload_bytes") + @patch("worker_api.audio.services.audio_generate_service.upsert_sub_task_timestamp") + @patch("worker_api.audio.services.audio_generate_service.generate_presigned_access_url") + async def test_generate_subtask_audio_success( + self, + mock_presigned_url, + mock_upsert_timestamp, + mock_upload, + mock_tts, + mock_get_subtask, + mock_session_local, + ): + """Test successful audio generation for a single subtask.""" + sub_task_id = uuid4() + task_id = uuid4() + + mock_db = MagicMock() + mock_session_local.return_value.__enter__.return_value = mock_db + + # Mock subtask + mock_subtask = MagicMock() + mock_subtask.id = sub_task_id + mock_subtask.task_id = task_id + mock_subtask.content = "Test content" + mock_subtask.content_type = ContentType.TEXT + mock_get_subtask.return_value = mock_subtask + + # Mock TTS audio (44 byte header + audio data) + wav_header = b"RIFF" + b"\x00" * 40 + audio_data = b"\x00" * 1000 + mock_tts.return_value = wav_header + audio_data + + mock_presigned_url.return_value = "https://s3.example.com/audio.wav" + + # Execute + result = await generate_plan_audio_service( + language="en", + sub_task_id=sub_task_id, + audio_type=PlanAudioType.TEXT_READING, + ) + + # Assert + assert "audio_url" in result + assert "audio_duration_ms" in result + assert "s3_key" in result + + mock_get_subtask.assert_called_once_with(db=mock_db, id=sub_task_id) + mock_tts.assert_called_once() + mock_upload.assert_called_once() + mock_upsert_timestamp.assert_called_once() + + @pytest.mark.asyncio + @patch("worker_api.audio.services.audio_generate_service.SessionLocal") + @patch("worker_api.audio.services.audio_generate_service.get_sub_task_by_subtask_id") + async def test_generate_subtask_audio_not_found( + self, + mock_get_subtask, + mock_session_local, + ): + """Test error when subtask is not found.""" + sub_task_id = uuid4() + + mock_db = MagicMock() + mock_session_local.return_value.__enter__.return_value = mock_db + + mock_get_subtask.return_value = None + + with pytest.raises(Exception) as exc_info: + await generate_plan_audio_service( + language="en", + sub_task_id=sub_task_id, + ) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + @patch("worker_api.audio.services.audio_generate_service.SessionLocal") + @patch("worker_api.audio.services.audio_generate_service.get_sub_task_by_subtask_id") + async def test_generate_subtask_audio_invalid_content_type( + self, + mock_get_subtask, + mock_session_local, + ): + """Test error when subtask has invalid content type.""" + sub_task_id = uuid4() + + mock_db = MagicMock() + mock_session_local.return_value.__enter__.return_value = mock_db + + mock_subtask = MagicMock() + mock_subtask.content_type = ContentType.VIDEO # Invalid for audio + mock_get_subtask.return_value = mock_subtask + + with pytest.raises(Exception) as exc_info: + await generate_plan_audio_service( + language="en", + sub_task_id=sub_task_id, + ) + + assert exc_info.value.status_code == 400 + + +class TestGenerateAudioSegments: + """Tests for _generate_audio_segments helper function.""" + + @patch("worker_api.audio.services.audio_generate_service.generate_tts_audio") + def test_generate_segments_with_text_content(self, mock_tts): + """Test generating audio segments from text content.""" + wav_data = b"RIFF" + b"\x00" * 40 + b"audio_data" + mock_tts.return_value = wav_data + + mock_subtask = MagicMock() + mock_subtask.content = "Test content" + mock_subtask.content_type = ContentType.TEXT + mock_subtask.audio_url = None + + mock_task = MagicMock() + mock_task.sub_tasks = [mock_subtask] + + segments, refs = _generate_audio_segments( + [mock_task], + PlanAudioType.TEXT_READING, + "en", + ) + + assert len(segments) == 1 + assert len(refs) == 1 + assert refs[0] == mock_subtask + mock_tts.assert_called_once() + + @patch("worker_api.audio.services.audio_generate_service.download_bytes") + def test_generate_segments_with_existing_audio(self, mock_download): + """Test reusing existing audio from subtask.""" + wav_data = b"RIFF" + b"\x00" * 40 + b"existing_audio" + mock_download.return_value = wav_data + + mock_subtask = MagicMock() + mock_subtask.content_type = ContentType.TEXT + mock_subtask.audio_url = "audio/existing.wav" + + mock_task = MagicMock() + mock_task.sub_tasks = [mock_subtask] + + segments, refs = _generate_audio_segments( + [mock_task], + PlanAudioType.TEXT_READING, + "en", + ) + + assert len(segments) == 1 + mock_download.assert_called_once_with(key="audio/existing.wav") + + def test_generate_segments_skips_invalid_content_types(self): + """Test that non-text/source_reference subtasks are skipped.""" + mock_subtask = MagicMock() + mock_subtask.content_type = ContentType.VIDEO + + mock_task = MagicMock() + mock_task.sub_tasks = [mock_subtask] + + segments, refs = _generate_audio_segments( + [mock_task], + PlanAudioType.TEXT_READING, + "en", + ) + + assert len(segments) == 0 + assert len(refs) == 0 + + +class TestBuildCombinedWav: + """Tests for _build_combined_wav helper function.""" + + def test_build_wav_single_segment(self): + """Test building WAV file from single audio segment.""" + audio_data = b"\x00" * 1000 + + wav, size = _build_combined_wav([audio_data]) + + assert len(wav) > len(audio_data) # Header + data + assert wav[:4] == b"RIFF" + assert b"WAVE" in wav + assert size == len(audio_data) + + def test_build_wav_multiple_segments(self): + """Test building WAV file from multiple audio segments.""" + segment1 = b"\x00" * 500 + segment2 = b"\x01" * 500 + + wav, size = _build_combined_wav([segment1, segment2]) + + assert wav[:4] == b"RIFF" + assert size == len(segment1) + len(segment2) + + def test_build_wav_empty_segments(self): + """Test building WAV file with no segments.""" + wav, size = _build_combined_wav([]) + + assert wav[:4] == b"RIFF" + assert size == 0 + + +class TestUpdateSubtaskTimestamps: + """Tests for _update_subtask_timestamps helper function.""" + + @patch("worker_api.audio.services.audio_generate_service.upsert_sub_task_timestamp") + def test_update_timestamps(self, mock_upsert): + """Test updating subtask timestamps.""" + mock_db = MagicMock() + + # 1000 bytes at 2 bytes per sample = 500 samples + # 500 samples / 24000 Hz = 0.020833 seconds = 20.833 ms + audio_segment = b"\x00" * 1000 + + mock_subtask = MagicMock() + mock_subtask.id = uuid4() + + duration = _update_subtask_timestamps( + mock_db, + [audio_segment], + [mock_subtask], + 24000, # sample_rate + 2, # bytes_per_sample + ) + + assert duration > 0 + mock_upsert.assert_called_once() + call_kwargs = mock_upsert.call_args.kwargs + assert call_kwargs["sub_task_id"] == mock_subtask.id + assert call_kwargs["start_ms"] == 0 + assert call_kwargs["end_ms"] > 0 diff --git a/tests/audio/test_audio_views.py b/tests/audio/test_audio_views.py new file mode 100644 index 0000000..57eb737 --- /dev/null +++ b/tests/audio/test_audio_views.py @@ -0,0 +1,195 @@ +""" +Tests for audio generation API endpoints. +""" +import pytest +from uuid import uuid4 +from unittest.mock import patch, AsyncMock + +from worker_api.audio.enums import PlanAudioType, MonlamVoiceName + + +class TestGeneratePlanAudio: + """Tests for POST /audio/generate endpoint.""" + + @pytest.mark.asyncio + @patch("worker_api.audio.audio_views.generate_plan_audio_service") + async def test_generate_audio_with_day_id(self, mock_service, client): + """Test generating audio for a plan day.""" + day_id = uuid4() + mock_service.return_value = { + "audio_url": "https://s3.example.com/audio.wav", + "audio_duration_ms": 45000, + "s3_key": "audio/plan_days/test.wav" + } + + response = client.post( + "/api/v1/audio/generate", + json={ + "day_id": str(day_id), + "language": "en", + "type": "TEXT_READING" + } + ) + + assert response.status_code == 200 + data = response.json() + assert "audio_url" in data + assert "audio_duration_ms" in data + assert "s3_key" in data + assert data["audio_duration_ms"] == 45000 + + mock_service.assert_called_once() + call_kwargs = mock_service.call_args.kwargs + assert call_kwargs["day_id"] == day_id + assert call_kwargs["language"] == "en" + assert call_kwargs["audio_type"] == PlanAudioType.TEXT_READING + + @pytest.mark.asyncio + @patch("worker_api.audio.audio_views.generate_plan_audio_service") + async def test_generate_audio_with_sub_task_id(self, mock_service, client): + """Test generating audio for a single subtask.""" + sub_task_id = uuid4() + mock_service.return_value = { + "audio_url": "https://s3.example.com/audio.wav", + "audio_duration_ms": 12000, + "s3_key": "audio/plan_subtasks/test.wav" + } + + response = client.post( + "/api/v1/audio/generate", + json={ + "sub_task_id": str(sub_task_id), + "language": "bo", + "type": "RECITATION", + "voice_name": "dolkar_lhasa_female" + } + ) + + assert response.status_code == 200 + data = response.json() + assert "audio_url" in data + assert data["audio_duration_ms"] == 12000 + + mock_service.assert_called_once() + call_kwargs = mock_service.call_args.kwargs + assert call_kwargs["sub_task_id"] == sub_task_id + assert call_kwargs["language"] == "bo" + assert call_kwargs["audio_type"] == PlanAudioType.RECITATION + assert call_kwargs["voice_name"] == MonlamVoiceName.DOLKAR_LHASA_FEMALE + + def test_generate_audio_missing_both_ids(self, client): + """Test validation error when both day_id and sub_task_id are missing.""" + response = client.post( + "/api/v1/audio/generate", + json={ + "language": "en", + "type": "TEXT_READING" + } + ) + + assert response.status_code == 422 + data = response.json() + assert "detail" in data + + def test_generate_audio_with_both_ids(self, client): + """Test validation error when both day_id and sub_task_id are provided.""" + response = client.post( + "/api/v1/audio/generate", + json={ + "day_id": str(uuid4()), + "sub_task_id": str(uuid4()), + "language": "en", + "type": "TEXT_READING" + } + ) + + assert response.status_code == 422 + data = response.json() + assert "detail" in data + + def test_generate_audio_missing_language(self, client): + """Test validation error when language is missing.""" + response = client.post( + "/api/v1/audio/generate", + json={ + "day_id": str(uuid4()), + "type": "TEXT_READING" + } + ) + + assert response.status_code == 422 + + @pytest.mark.asyncio + @patch("worker_api.audio.audio_views.generate_plan_audio_service") + async def test_generate_audio_with_instruction_type(self, mock_service, client): + """Test generating instruction audio.""" + day_id = uuid4() + mock_service.return_value = { + "audio_url": "https://s3.example.com/audio.wav", + "audio_duration_ms": 30000, + "s3_key": "audio/plan_days/test.wav" + } + + response = client.post( + "/api/v1/audio/generate", + json={ + "day_id": str(day_id), + "language": "en", + "type": "INSTRUCTION" + } + ) + + assert response.status_code == 200 + mock_service.assert_called_once() + call_kwargs = mock_service.call_args.kwargs + assert call_kwargs["audio_type"] == PlanAudioType.INSTRUCTION + + @pytest.mark.asyncio + @patch("worker_api.audio.audio_views.generate_plan_audio_service") + async def test_generate_audio_default_type(self, mock_service, client): + """Test that TEXT_READING is the default audio type.""" + day_id = uuid4() + mock_service.return_value = { + "audio_url": "https://s3.example.com/audio.wav", + "audio_duration_ms": 30000, + "s3_key": "audio/plan_days/test.wav" + } + + response = client.post( + "/api/v1/audio/generate", + json={ + "day_id": str(day_id), + "language": "en" + } + ) + + assert response.status_code == 200 + mock_service.assert_called_once() + call_kwargs = mock_service.call_args.kwargs + assert call_kwargs["audio_type"] == PlanAudioType.TEXT_READING + + @pytest.mark.asyncio + @patch("worker_api.audio.audio_views.generate_plan_audio_service") + async def test_generate_audio_tibetan_with_voice(self, mock_service, client): + """Test generating Tibetan audio with specific voice.""" + sub_task_id = uuid4() + mock_service.return_value = { + "audio_url": "https://s3.example.com/audio.wav", + "audio_duration_ms": 15000, + "s3_key": "audio/plan_subtasks/test.wav" + } + + response = client.post( + "/api/v1/audio/generate", + json={ + "sub_task_id": str(sub_task_id), + "language": "bo", + "type": "TEXT_READING", + "voice_name": "sonamtsering_lhasa_male" + } + ) + + assert response.status_code == 200 + mock_service.assert_called_once() + call_kwargs = mock_service.call_args.kwargs + assert call_kwargs["voice_name"] == MonlamVoiceName.SONAMTSERING_LHASA_MALE diff --git a/tests/audio/test_monlam_tts_service.py b/tests/audio/test_monlam_tts_service.py new file mode 100644 index 0000000..1fa3796 --- /dev/null +++ b/tests/audio/test_monlam_tts_service.py @@ -0,0 +1,240 @@ +import pytest +from unittest.mock import patch, MagicMock +import httpx + +from worker_api.audio.services.monlam_tts_service import ( + generate_monlam_tts_audio, + DEFAULT_MONLAM_VOICE_NAME, +) +from worker_api.audio.enums import MonlamVoiceName + + +class TestGenerateMonlamTtsAudio: + def test_empty_content_raises_error(self): + with pytest.raises(ValueError, match="Content cannot be empty"): + generate_monlam_tts_audio("") + + def test_whitespace_content_raises_error(self): + with pytest.raises(ValueError, match="Content cannot be empty"): + generate_monlam_tts_audio(" ") + + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_missing_api_key_raises_error(self, mock_get): + def get_side_effect(key): + if key == "MONLAM_BASE_URL": + return "https://api.monlam.ai" + if key == "MONLAM_API_KEY": + return None + return "default" + + mock_get.side_effect = get_side_effect + + with pytest.raises(RuntimeError, match="MONLAM_API_KEY is not configured"): + generate_monlam_tts_audio("བོད་སྐད།") + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_successful_generation(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_response = MagicMock() + mock_response.content = b"RIFF" + b"\x00" * 100 + mock_response.status_code = 200 + mock_post.return_value = mock_response + + result = generate_monlam_tts_audio("བོད་སྐད།") + + assert result == b"RIFF" + b"\x00" * 100 + mock_post.assert_called_once() + + call_args = mock_post.call_args + assert call_args[0][0] == "https://api.monlam.ai/api/v1/text-to-speech/stream" + assert call_args[1]["headers"]["X-API-Key"] == "fake_api_key" + assert call_args[1]["json"]["text"] == "བོད་སྐད།" + assert call_args[1]["json"]["provider"] == "test_provider" + assert call_args[1]["json"]["model_name"] == "test_model" + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_custom_voice_name(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_response = MagicMock() + mock_response.content = b"RIFF" + b"\x00" * 100 + mock_response.status_code = 200 + mock_post.return_value = mock_response + + result = generate_monlam_tts_audio("བོད་སྐད།", voice_name="custom_voice") + + assert result == b"RIFF" + b"\x00" * 100 + + call_args = mock_post.call_args + assert call_args[1]["json"]["voice_name"] == "custom_voice" + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_default_voice_name(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_response = MagicMock() + mock_response.content = b"RIFF" + b"\x00" * 100 + mock_response.status_code = 200 + mock_post.return_value = mock_response + + result = generate_monlam_tts_audio("བོད་སྐད།") + + call_args = mock_post.call_args + assert call_args[1]["json"]["voice_name"] == DEFAULT_MONLAM_VOICE_NAME + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_http_status_error(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + + mock_post.side_effect = httpx.HTTPStatusError( + "Error", + request=MagicMock(), + response=mock_response + ) + + with pytest.raises(RuntimeError, match="Monlam TTS request failed with status 500"): + generate_monlam_tts_audio("བོད་སྐད།") + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_request_error(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_post.side_effect = httpx.RequestError("Connection failed") + + with pytest.raises(RuntimeError, match="Monlam TTS request failed"): + generate_monlam_tts_audio("བོད་སྐད།") + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_invalid_audio_data_empty(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_response = MagicMock() + mock_response.content = b"" + mock_response.status_code = 200 + mock_post.return_value = mock_response + + with pytest.raises(RuntimeError, match="Monlam TTS generation returned invalid audio data"): + generate_monlam_tts_audio("བོད་སྐད།") + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_invalid_audio_data_not_riff(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_response = MagicMock() + mock_response.content = b"INVALID_DATA" + mock_response.status_code = 200 + mock_post.return_value = mock_response + + with pytest.raises(RuntimeError, match="Monlam TTS generation returned invalid audio data"): + generate_monlam_tts_audio("བོད་སྐད།") + + @patch("worker_api.audio.services.monlam_tts_service.httpx.post") + @patch("worker_api.audio.services.monlam_tts_service.get") + def test_base_url_trailing_slash_removed(self, mock_get, mock_post): + def get_side_effect(key): + config = { + "MONLAM_BASE_URL": "https://api.monlam.ai/", + "MONLAM_API_KEY": "fake_api_key", + "MONLAM_TTS_PROVIDER": "test_provider", + "MONLAM_TTS_MODEL_NAME": "test_model", + "MONLAM_TTS_VOICE_NAME": None, + } + return config.get(key) + + mock_get.side_effect = get_side_effect + + mock_response = MagicMock() + mock_response.content = b"RIFF" + b"\x00" * 100 + mock_response.status_code = 200 + mock_post.return_value = mock_response + + generate_monlam_tts_audio("བོད་སྐད།") + + call_args = mock_post.call_args + assert call_args[0][0] == "https://api.monlam.ai/api/v1/text-to-speech/stream" + + +def test_default_monlam_voice_name(): + assert DEFAULT_MONLAM_VOICE_NAME == MonlamVoiceName.DOLKAR_LHASA_FEMALE.value diff --git a/tests/audio/test_tts_service.py b/tests/audio/test_tts_service.py new file mode 100644 index 0000000..94d4e79 --- /dev/null +++ b/tests/audio/test_tts_service.py @@ -0,0 +1,234 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +import struct + +from worker_api.audio.services.tts_service import ( + generate_tts_audio, + _normalize_language, + _generate_gemini_tts_audio, + _convert_to_wav, + _parse_audio_mime_type, + SUPPORTED_TTS_LANGUAGES, +) +from worker_api.audio.enums import PlanAudioType + + +class TestNormalizeLanguage: + def test_normalize_language_lowercase(self): + assert _normalize_language("EN") == "en" + assert _normalize_language("Bo") == "bo" + + def test_normalize_language_strip(self): + assert _normalize_language(" en ") == "en" + assert _normalize_language(" bo ") == "bo" + + def test_normalize_language_none(self): + assert _normalize_language(None) == "en" + + def test_normalize_language_empty(self): + assert _normalize_language("") == "en" + + +class TestGenerateTtsAudio: + def test_empty_content_raises_error(self): + with pytest.raises(ValueError, match="Content cannot be empty"): + generate_tts_audio("", PlanAudioType.RECITATION) + + def test_whitespace_content_raises_error(self): + with pytest.raises(ValueError, match="Content cannot be empty"): + generate_tts_audio(" ", PlanAudioType.RECITATION) + + def test_unsupported_language_raises_error(self): + with pytest.raises(ValueError, match="Unsupported language for TTS"): + generate_tts_audio("Hello", PlanAudioType.RECITATION, language="fr") + + @patch("worker_api.audio.services.tts_service.generate_monlam_tts_audio") + def test_tibetan_language_uses_monlam(self, mock_monlam): + mock_monlam.return_value = b"fake_audio_data" + + result = generate_tts_audio("བོད་སྐད།", PlanAudioType.RECITATION, language="bo") + + mock_monlam.assert_called_once_with("བོད་སྐད།", voice_name=None) + assert result == b"fake_audio_data" + + @patch("worker_api.audio.services.tts_service.generate_monlam_tts_audio") + def test_tibetan_language_with_voice_name(self, mock_monlam): + mock_monlam.return_value = b"fake_audio_data" + + result = generate_tts_audio( + "བོད་སྐད།", + PlanAudioType.RECITATION, + language="bo", + voice_name="custom_voice" + ) + + mock_monlam.assert_called_once_with("བོད་སྐད།", voice_name="custom_voice") + assert result == b"fake_audio_data" + + @patch("worker_api.audio.services.tts_service._generate_gemini_tts_audio") + def test_english_language_uses_gemini(self, mock_gemini): + mock_gemini.return_value = b"fake_wav_data" + + result = generate_tts_audio("Hello world", PlanAudioType.RECITATION, language="en") + + mock_gemini.assert_called_once_with( + content="Hello world", + audio_type=PlanAudioType.RECITATION + ) + assert result == b"fake_wav_data" + + def test_supported_languages(self): + assert "en" in SUPPORTED_TTS_LANGUAGES + assert "bo" in SUPPORTED_TTS_LANGUAGES + + +class TestGenerateGeminiTtsAudio: + @patch("worker_api.audio.services.tts_service.get") + def test_missing_api_key_raises_error(self, mock_get): + mock_get.return_value = None + + with pytest.raises(RuntimeError, match="GEMINI_API_KEY is not configured"): + _generate_gemini_tts_audio("Hello", PlanAudioType.RECITATION) + + @patch("worker_api.audio.services.tts_service._convert_to_wav") + @patch("worker_api.audio.services.tts_service.get") + def test_successful_generation(self, mock_get, mock_convert): + mock_get.return_value = "fake_api_key" + mock_convert.return_value = b"fake_wav_data" + + with patch("google.genai.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_part = MagicMock() + mock_part.inline_data.data = b"raw_audio_data" + mock_part.inline_data.mime_type = "audio/L16;rate=24000" + + mock_candidate = MagicMock() + mock_candidate.content.parts = [mock_part] + + mock_response = MagicMock() + mock_response.candidates = [mock_candidate] + + mock_client.models.generate_content.return_value = mock_response + + result = _generate_gemini_tts_audio("Hello world", PlanAudioType.RECITATION) + + assert result == b"fake_wav_data" + mock_convert.assert_called_once_with(b"raw_audio_data", "audio/L16;rate=24000") + + @patch("worker_api.audio.services.tts_service.get") + def test_no_candidates_raises_error(self, mock_get): + mock_get.return_value = "fake_api_key" + + with patch("google.genai.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_response = MagicMock() + mock_response.candidates = [] + + mock_client.models.generate_content.return_value = mock_response + + with pytest.raises(RuntimeError, match="TTS generation returned no audio data"): + _generate_gemini_tts_audio("Hello", PlanAudioType.RECITATION) + + @patch("worker_api.audio.services.tts_service.get") + def test_no_inline_data_raises_error(self, mock_get): + mock_get.return_value = "fake_api_key" + + with patch("google.genai.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_part = MagicMock() + mock_part.inline_data = None + + mock_candidate = MagicMock() + mock_candidate.content.parts = [mock_part] + + mock_response = MagicMock() + mock_response.candidates = [mock_candidate] + + mock_client.models.generate_content.return_value = mock_response + + with pytest.raises(RuntimeError, match="TTS generation returned no audio data"): + _generate_gemini_tts_audio("Hello", PlanAudioType.RECITATION) + + +class TestParseAudioMimeType: + def test_default_values(self): + result = _parse_audio_mime_type("audio/wav") + assert result["bits_per_sample"] == 16 + assert result["rate"] == 24000 + + def test_parse_rate(self): + result = _parse_audio_mime_type("audio/L16;rate=48000") + assert result["rate"] == 48000 + assert result["bits_per_sample"] == 16 + + def test_parse_bits_per_sample(self): + result = _parse_audio_mime_type("audio/L24;rate=24000") + assert result["bits_per_sample"] == 24 + assert result["rate"] == 24000 + + def test_parse_both_parameters(self): + result = _parse_audio_mime_type("audio/L32;rate=16000") + assert result["bits_per_sample"] == 32 + assert result["rate"] == 16000 + + def test_invalid_rate_uses_default(self): + result = _parse_audio_mime_type("audio/L16;rate=invalid") + assert result["rate"] == 24000 + + def test_invalid_bits_uses_default(self): + result = _parse_audio_mime_type("audio/Linvalid;rate=24000") + assert result["bits_per_sample"] == 16 + + +class TestConvertToWav: + def test_convert_basic_audio(self): + audio_data = b"\x00\x01" * 100 + mime_type = "audio/L16;rate=24000" + + result = _convert_to_wav(audio_data, mime_type) + + assert result[:4] == b"RIFF" + assert result[8:12] == b"WAVE" + assert result[12:16] == b"fmt " + assert result[36:40] == b"data" + assert len(result) > len(audio_data) + + def test_convert_different_sample_rate(self): + audio_data = b"\x00\x01" * 50 + mime_type = "audio/L16;rate=48000" + + result = _convert_to_wav(audio_data, mime_type) + + assert result[:4] == b"RIFF" + sample_rate = struct.unpack(" Optional[PlanItemAudio]: + return ( + db.query(PlanItemAudio) + .filter(PlanItemAudio.plan_item_id == plan_item_id) + .first() + ) + + +def upsert_plan_item_audio(db: Session, plan_item_audio: PlanItemAudio) -> PlanItemAudio: + existing = get_plan_item_audio_by_plan_item_id(db=db, plan_item_id=plan_item_audio.plan_item_id) + if existing: + existing.audio_key = plan_item_audio.audio_key + existing.duration_ms = plan_item_audio.duration_ms + existing.mime_type = plan_item_audio.mime_type + existing.file_size_bytes = plan_item_audio.file_size_bytes + existing.updated_by = plan_item_audio.updated_by + db.commit() + db.refresh(existing) + return existing + db.add(plan_item_audio) + db.commit() + db.refresh(plan_item_audio) + return plan_item_audio diff --git a/worker_api/audio/repositories/plan_items_repository.py b/worker_api/audio/repositories/plan_items_repository.py new file mode 100644 index 0000000..6561227 --- /dev/null +++ b/worker_api/audio/repositories/plan_items_repository.py @@ -0,0 +1,25 @@ +from uuid import UUID +from fastapi import HTTPException +from sqlalchemy.orm import Session, joinedload +from starlette import status + +from worker_api.audio.models.plan_items_models import PlanItem +from worker_api.audio.models.plan_tasks_models import PlanTask +from worker_api.audio.models.plan_sub_tasks_models import PlanSubTask + + +def get_plan_day_by_id_any_plan(db: Session, day_id: UUID) -> PlanItem: + plan_item = ( + db.query(PlanItem) + .options( + joinedload(PlanItem.tasks).joinedload(PlanTask.sub_tasks).joinedload(PlanSubTask.timestamp), + ) + .filter(PlanItem.id == day_id) + .first() + ) + if not plan_item: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": "BAD_REQUEST", "message": "Plan day not found"}, + ) + return plan_item diff --git a/worker_api/audio/repositories/plan_sub_tasks_repository.py b/worker_api/audio/repositories/plan_sub_tasks_repository.py new file mode 100644 index 0000000..9bb4faa --- /dev/null +++ b/worker_api/audio/repositories/plan_sub_tasks_repository.py @@ -0,0 +1,8 @@ +from uuid import UUID +from sqlalchemy.orm import Session + +from worker_api.audio.models.plan_sub_tasks_models import PlanSubTask + + +def get_sub_task_by_subtask_id(db: Session, id: UUID) -> PlanSubTask: + return db.query(PlanSubTask).filter(PlanSubTask.id == id).first() diff --git a/worker_api/audio/repositories/sub_task_timestamps_repository.py b/worker_api/audio/repositories/sub_task_timestamps_repository.py new file mode 100644 index 0000000..6295d89 --- /dev/null +++ b/worker_api/audio/repositories/sub_task_timestamps_repository.py @@ -0,0 +1,42 @@ +from typing import Optional +from uuid import UUID +from sqlalchemy.orm import Session + +from worker_api.audio.models.sub_task_timestamps_models import SubTaskTimestamp + + +def get_sub_task_timestamp_by_sub_task_id( + db: Session, sub_task_id: UUID +) -> Optional[SubTaskTimestamp]: + return ( + db.query(SubTaskTimestamp) + .filter(SubTaskTimestamp.sub_task_id == sub_task_id) + .first() + ) + + +def upsert_sub_task_timestamp( + db: Session, + sub_task_id: UUID, + start_ms: int, + end_ms: int, + created_by: str, +) -> SubTaskTimestamp: + existing = get_sub_task_timestamp_by_sub_task_id(db=db, sub_task_id=sub_task_id) + if existing: + existing.start_ms = start_ms + existing.end_ms = end_ms + existing.updated_by = created_by + db.commit() + db.refresh(existing) + return existing + row = SubTaskTimestamp( + sub_task_id=sub_task_id, + start_ms=start_ms, + end_ms=end_ms, + created_by=created_by, + ) + db.add(row) + db.commit() + db.refresh(row) + return row diff --git a/worker_api/audio/schemas.py b/worker_api/audio/schemas.py new file mode 100644 index 0000000..68c125b --- /dev/null +++ b/worker_api/audio/schemas.py @@ -0,0 +1,21 @@ +from typing import Optional +from uuid import UUID +from pydantic import BaseModel, model_validator + +from worker_api.audio.enums import PlanAudioType, MonlamVoiceName + + +class GeneratePlanAudioRequest(BaseModel): + day_id: Optional[UUID] = None + sub_task_id: Optional[UUID] = None + language: str + type: Optional[PlanAudioType] = PlanAudioType.TEXT_READING + voice_name: MonlamVoiceName = MonlamVoiceName.DOLKAR_LHASA_FEMALE + + @model_validator(mode="after") + def validate_either_day_or_subtask(self): + if not self.day_id and not self.sub_task_id: + raise ValueError("Either day_id or sub_task_id must be provided") + if self.day_id and self.sub_task_id: + raise ValueError("Provide either day_id or sub_task_id, not both") + return self diff --git a/worker_api/audio/services/__init__.py b/worker_api/audio/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/worker_api/audio/services/audio_generate_service.py b/worker_api/audio/services/audio_generate_service.py new file mode 100644 index 0000000..0edfd90 --- /dev/null +++ b/worker_api/audio/services/audio_generate_service.py @@ -0,0 +1,250 @@ +import struct +from io import BytesIO +from typing import Optional, List +from uuid import UUID, uuid4 + +from fastapi import HTTPException +from sqlalchemy.orm import Session +from starlette import status + +from worker_api.audio.enums import ContentType, PlanAudioType, MonlamVoiceName +from worker_api.audio.models.plan_item_audio_models import PlanItemAudio +from worker_api.audio.models.plan_items_models import PlanItem +from worker_api.audio.models.plan_sub_tasks_models import PlanSubTask +from worker_api.audio.repositories.plan_items_repository import get_plan_day_by_id_any_plan +from worker_api.audio.repositories.plan_sub_tasks_repository import get_sub_task_by_subtask_id +from worker_api.audio.repositories.plan_item_audio_repository import upsert_plan_item_audio +from worker_api.audio.repositories.sub_task_timestamps_repository import upsert_sub_task_timestamp +from worker_api.audio.services.tts_service import generate_tts_audio +from worker_api.config import get +from worker_api.db.database import SessionLocal +from worker_api.uploads.S3_utils import upload_bytes, download_bytes, generate_presigned_access_url + +WAV_CONTENT_TYPE = "audio/wav" + + +def _generate_audio_segments( + tasks, + audio_type: PlanAudioType, + language: str, + voice_name: Optional[str] = None, +) -> tuple[List[bytes], list]: + wav_header_size = 44 + audio_segments: List[bytes] = [] + subtask_refs = [] + allowed_types = {ContentType.TEXT, ContentType.SOURCE_REFERENCE} + for task in tasks: + subtask = task.sub_tasks[0] if task.sub_tasks else None + if not subtask: + continue + if subtask.content_type not in allowed_types: + continue + + if subtask.audio_url: + existing_wav = download_bytes( + key=subtask.audio_url, + ) + raw_pcm = existing_wav[wav_header_size:] + else: + wav_bytes = generate_tts_audio( + subtask.content, audio_type, language, voice_name=voice_name + ) + raw_pcm = wav_bytes[wav_header_size:] + + audio_segments.append(raw_pcm) + subtask_refs.append(subtask) + return audio_segments, subtask_refs + + +def _update_subtask_timestamps( + db: Session, + audio_segments: List[bytes], + subtask_refs: list, + sample_rate: int, + bytes_per_sample: int, +) -> int: + current_offset_ms = 0 + for i, raw_pcm in enumerate(audio_segments): + segment_samples = len(raw_pcm) // bytes_per_sample + segment_duration_ms = int((segment_samples / sample_rate) * 1000) + upsert_sub_task_timestamp( + db=db, + sub_task_id=subtask_refs[i].id, + start_ms=current_offset_ms, + end_ms=current_offset_ms + segment_duration_ms, + created_by="system", + ) + current_offset_ms += segment_duration_ms + return current_offset_ms + + +def _build_combined_wav(audio_segments: List[bytes]) -> tuple[bytes, int]: + sample_rate = 24000 + bits_per_sample = 16 + num_channels = 1 + bytes_per_sample = bits_per_sample // 8 + + combined_pcm = b"".join(audio_segments) + block_align = num_channels * bytes_per_sample + byte_rate = sample_rate * block_align + data_size = len(combined_pcm) + chunk_size = 36 + data_size + + wav_header = struct.pack( + "<4sI4s4sIHHIIHH4sI", + b"RIFF", chunk_size, b"WAVE", + b"fmt ", 16, 1, num_channels, + sample_rate, byte_rate, block_align, bits_per_sample, + b"data", data_size, + ) + return wav_header + combined_pcm, data_size + + +def _upload_and_persist_audio( + db: Session, + combined_wav: bytes, + duration_ms: int, + plan_id: UUID, + plan_item_id: UUID, +) -> PlanItemAudio: + s3_key = f"audio/plan_days/{plan_id}/{plan_item_id}/{uuid4()}.wav" + upload_bytes( + file_bytes=BytesIO(combined_wav), + key=s3_key, + content_type=WAV_CONTENT_TYPE, + ) + return upsert_plan_item_audio( + db=db, + plan_item_audio=PlanItemAudio( + plan_item_id=plan_item_id, + audio_key=s3_key, + duration_ms=duration_ms, + mime_type=WAV_CONTENT_TYPE, + file_size_bytes=len(combined_wav), + created_by="system", + ), + ) + + +async def generate_plan_audio_service( + language: str, + day_id: Optional[UUID] = None, + sub_task_id: Optional[UUID] = None, + audio_type: PlanAudioType = PlanAudioType.TEXT_READING, + voice_name: MonlamVoiceName = MonlamVoiceName.DOLKAR_LHASA_FEMALE, +): + if sub_task_id: + return await _generate_subtask_audio( + sub_task_id=sub_task_id, + audio_type=audio_type, + language=language, + voice_name=voice_name, + ) + + SAMPLE_RATE = 24000 + BYTES_PER_SAMPLE = 2 + + with SessionLocal() as db: + plan_item: PlanItem = get_plan_day_by_id_any_plan(db=db, day_id=day_id) + + audio_segments, subtask_refs = _generate_audio_segments( + plan_item.tasks, audio_type, language, voice_name + ) + if not audio_segments: + return [] + + duration_ms = _update_subtask_timestamps( + db=db, + audio_segments=audio_segments, + subtask_refs=subtask_refs, + sample_rate=SAMPLE_RATE, + bytes_per_sample=BYTES_PER_SAMPLE, + ) + + combined_wav, _ = _build_combined_wav(audio_segments) + + audio_row = _upload_and_persist_audio( + db=db, + combined_wav=combined_wav, + duration_ms=duration_ms, + plan_id=plan_item.plan_id, + plan_item_id=plan_item.id, + ) + + audio_url = generate_presigned_access_url( + key=audio_row.audio_key, + ) + + return { + "audio_url": audio_url, + "audio_duration_ms": audio_row.duration_ms, + "s3_key": audio_row.audio_key, + } + + +async def _generate_subtask_audio( + sub_task_id: UUID, + audio_type: PlanAudioType, + language: str, + voice_name: Optional[str] = None, +): + SAMPLE_RATE = 24000 + BYTES_PER_SAMPLE = 2 + WAV_HEADER_SIZE = 44 + + with SessionLocal() as db: + subtask: PlanSubTask = get_sub_task_by_subtask_id(db=db, id=sub_task_id) + if not subtask: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": "BAD_REQUEST", "message": "Sub task not found"}, + ) + + allowed_types = {ContentType.TEXT, ContentType.SOURCE_REFERENCE} + if subtask.content_type not in allowed_types: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": "BAD_REQUEST", + "message": "Sub task content type must be TEXT or SOURCE_REFERENCE for audio generation", + }, + ) + + wav_bytes = generate_tts_audio( + subtask.content, audio_type, language, voice_name=voice_name + ) + raw_pcm = wav_bytes[WAV_HEADER_SIZE:] + + segment_samples = len(raw_pcm) // BYTES_PER_SAMPLE + duration_ms = int((segment_samples / SAMPLE_RATE) * 1000) + + combined_wav, _ = _build_combined_wav([raw_pcm]) + + s3_key = f"audio/plan_subtasks/{subtask.task_id}/{sub_task_id}/{uuid4()}.wav" + upload_bytes( + file_bytes=BytesIO(combined_wav), + key=s3_key, + content_type=WAV_CONTENT_TYPE, + ) + + subtask.audio_url = s3_key + subtask.duration = str(duration_ms) + db.commit() + + upsert_sub_task_timestamp( + db=db, + sub_task_id=sub_task_id, + start_ms=0, + end_ms=duration_ms, + created_by="system", + ) + + audio_url = generate_presigned_access_url( + key=s3_key, + ) + + return { + "audio_url": audio_url, + "audio_duration_ms": duration_ms, + "s3_key": s3_key, + } diff --git a/worker_api/audio/services/audio_prompt.py b/worker_api/audio/services/audio_prompt.py new file mode 100644 index 0000000..4c2f3bd --- /dev/null +++ b/worker_api/audio/services/audio_prompt.py @@ -0,0 +1,75 @@ +from worker_api.audio.enums import PlanAudioType + +DEFAULT_VOICE_NAME = "Algenib" +DEFAULT_ACCENT = "Neutral" + +RECITATION_SCENE = """## Scene: +""" + +RECITATION_SAMPLE_CONTEXT = """## Sample Context: +""" + +INSTRUCTION_SCENE = """## Scene: + The Corporate Studio. +""" + +INSTRUCTION_SAMPLE_CONTEXT = """## Sample Context: +InstructionAal E-learning. Measured pacAing with clear pauses for clarity. Tone is authoritative, accessible, and articulate. +.""" + +TEXT_READING_SCENE = """## Scene: +It is dawn inside a vast, silent meditation hall nestled deep within a quiet forest monastery. A senior monk sits perfectly upright, speaking to a small circle of practitioners. The room possesses a subtle, warm, and spacious acoustic resonance. The atmosphere demands an organic, deeply grounded, and entirely unhurried delivery, where every word is heavy with presence. +""" + +TEXT_READING_SAMPLE_CONTEXT = """## Sample Context: +This voice profile is the definitive standard for serene scriptural readings, timeless philosophical translations, and monastic audiobooks where the listener requires immense breathing room to contemplate profound truths +""" + +AUDIO_TYPE_CONFIGS = { + PlanAudioType.RECITATION: { + "style": "Warm, understanding, soft tone with gentle inflections.", + "pace": "Natural conversational pace.", + "scene": RECITATION_SCENE, + "sample_context": RECITATION_SAMPLE_CONTEXT, + }, + PlanAudioType.INSTRUCTION: { + "style": "Authoritative, accessible, and articulate", + "pace": "Natural conversational pace with clear pauses for clarity", + "scene": INSTRUCTION_SCENE, + "sample_context": INSTRUCTION_SAMPLE_CONTEXT, + }, + PlanAudioType.TEXT_READING: { + "style": "Warm, understanding, soft tone with gentle inflections. Pace: Natural conversational pace.", + "pace": "Natural comfortable reading pace", + "scene": TEXT_READING_SCENE, + "sample_context": TEXT_READING_SAMPLE_CONTEXT, + }, +} + + +def build_tts_prompt( + transcript: str, + audio_type: PlanAudioType, +) -> str: + type_config = AUDIO_TYPE_CONFIGS[audio_type] + + director_note = ( + f"# Director's note\n" + f"Style: {type_config['style']}. " + f"Pace: {type_config['pace']}. " + f"Accent: {DEFAULT_ACCENT}." + ) + + parts = [ + "Read the following transcript based on the director's note.", + "", + director_note, + "", + type_config["scene"], + "", + type_config["sample_context"], + "", + f"## Transcript:\n{transcript}", + ] + + return "\n".join(parts) diff --git a/worker_api/audio/services/monlam_tts_service.py b/worker_api/audio/services/monlam_tts_service.py new file mode 100644 index 0000000..fe8929f --- /dev/null +++ b/worker_api/audio/services/monlam_tts_service.py @@ -0,0 +1,49 @@ +import httpx + +from worker_api.config import get +from worker_api.audio.enums import MonlamVoiceName + +DEFAULT_MONLAM_VOICE_NAME = MonlamVoiceName.DOLKAR_LHASA_FEMALE.value + + +def generate_monlam_tts_audio(content: str, voice_name: str | None = None) -> bytes: + if not content.strip(): + raise ValueError("Content cannot be empty") + + base_url = get("MONLAM_BASE_URL").rstrip("/") + api_key = get("MONLAM_API_KEY") + if not api_key: + raise RuntimeError("MONLAM_API_KEY is not configured") + + payload = { + "text": content, + "provider": get("MONLAM_TTS_PROVIDER"), + "model_name": get("MONLAM_TTS_MODEL_NAME"), + } + resolved_voice_name = voice_name or get("MONLAM_TTS_VOICE_NAME") or DEFAULT_MONLAM_VOICE_NAME + payload["voice_name"] = resolved_voice_name + + try: + response = httpx.post( + f"{base_url}/api/v1/text-to-speech/stream", + headers={ + "X-API-Key": api_key, + "Content-Type": "application/json", + }, + json=payload, + timeout=300.0, + ) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + detail = exc.response.text + raise RuntimeError( + f"Monlam TTS request failed with status {exc.response.status_code}: {detail}" + ) from exc + except httpx.RequestError as exc: + raise RuntimeError(f"Monlam TTS request failed: {exc}") from exc + + wav_bytes = response.content + if not wav_bytes or wav_bytes[:4] != b"RIFF": + raise RuntimeError("Monlam TTS generation returned invalid audio data") + + return wav_bytes diff --git a/worker_api/audio/services/tts_service.py b/worker_api/audio/services/tts_service.py new file mode 100644 index 0000000..ad310ce --- /dev/null +++ b/worker_api/audio/services/tts_service.py @@ -0,0 +1,133 @@ +import struct + +from worker_api.config import get +from worker_api.audio.services.audio_prompt import build_tts_prompt, DEFAULT_VOICE_NAME +from worker_api.audio.services.monlam_tts_service import generate_monlam_tts_audio +from worker_api.audio.enums import PlanAudioType + +SUPPORTED_TTS_LANGUAGES = {"en", "bo"} + + +def _normalize_language(language: str) -> str: + return (language or "en").strip().lower() + + +def generate_tts_audio( + content: str, + audio_type: PlanAudioType, + language: str = "en", + voice_name: str | None = None, +) -> bytes: + if not content.strip(): + raise ValueError("Content cannot be empty") + + normalized_language = _normalize_language(language) + if normalized_language not in SUPPORTED_TTS_LANGUAGES: + raise ValueError( + f"Unsupported language for TTS: {language}. Supported: {', '.join(sorted(SUPPORTED_TTS_LANGUAGES))}" + ) + + if normalized_language == "bo": + return generate_monlam_tts_audio(content, voice_name=voice_name) + + return _generate_gemini_tts_audio(content=content, audio_type=audio_type) + + +def _generate_gemini_tts_audio( + content: str, + audio_type: PlanAudioType, +) -> bytes: + prompt = build_tts_prompt(transcript=content, audio_type=audio_type) + + from google import genai + from google.genai import types + + api_key = get("GEMINI_API_KEY") + if not api_key: + raise RuntimeError("GEMINI_API_KEY is not configured") + + client = genai.Client(api_key=api_key) + + response = client.models.generate_content( + model="gemini-2.5-flash-preview-tts", + contents=[ + types.Content( + role="user", + parts=[types.Part.from_text(text=prompt)], + ), + ], + config=types.GenerateContentConfig( + temperature=1, + response_modalities=["audio"], + speech_config=types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name=DEFAULT_VOICE_NAME + ) + ) + ), + ), + ) + + if not response.candidates or not response.candidates[0].content.parts: + raise RuntimeError("TTS generation returned no audio data") + + part = response.candidates[0].content.parts[0] + if not part.inline_data or not part.inline_data.data: + raise RuntimeError("TTS generation returned no audio data") + + audio_data = part.inline_data.data + mime_type = part.inline_data.mime_type or "audio/L16;rate=24000" + + return _convert_to_wav(audio_data, mime_type) + + +def _convert_to_wav(audio_data: bytes, mime_type: str) -> bytes: + params = _parse_audio_mime_type(mime_type) + bits_per_sample = params["bits_per_sample"] + sample_rate = params["rate"] + num_channels = 1 + data_size = len(audio_data) + bytes_per_sample = bits_per_sample // 8 + block_align = num_channels * bytes_per_sample + byte_rate = sample_rate * block_align + chunk_size = 36 + data_size + + header = struct.pack( + "<4sI4s4sIHHIIHH4sI", + b"RIFF", + chunk_size, + b"WAVE", + b"fmt ", + 16, + 1, + num_channels, + sample_rate, + byte_rate, + block_align, + bits_per_sample, + b"data", + data_size, + ) + return header + audio_data + + +def _parse_audio_mime_type(mime_type: str) -> dict: + bits_per_sample = 16 + rate = 24000 + + parts = mime_type.split(";") + for param in parts: + param = param.strip() + if param.lower().startswith("rate="): + try: + rate = int(param.split("=", 1)[1]) + except (ValueError, IndexError): + pass + elif param.startswith("audio/L"): + try: + bits_per_sample = int(param.split("L", 1)[1]) + except (ValueError, IndexError): + pass + + return {"bits_per_sample": bits_per_sample, "rate": rate} diff --git a/worker_api/config.py b/worker_api/config.py index 835a6f2..08eb3db 100644 --- a/worker_api/config.py +++ b/worker_api/config.py @@ -85,13 +85,21 @@ SQS_TIMEOUT=1800, GROUP_INVITE_EXPIRY_MINUTES=30, - WEBUDDHIST_EMAIL_LOGO_URL="https://studio.webuddhist.com/assets/pecha_icon-DkKJLXuA.png", + WEBUDDHIST_EMAIL_LOGO_URL="", # Request observability (per-endpoint memory and latency logging) REQUEST_OBSERVABILITY_ENABLED="true", REQUEST_OBSERVABILITY_MEMORY_WARN_MB=50, REQUEST_OBSERVABILITY_SKIP_PATHS="/health", + # TTS Configuration + GEMINI_API_KEY="", + MONLAM_BASE_URL="", + MONLAM_API_KEY="", + MONLAM_TTS_PROVIDER="", + MONLAM_TTS_MODEL_NAME="", + MONLAM_TTS_VOICE_NAME="", + ) TIME_FORMAT_PATTERN = re.compile(r"^([01]\d|2[0-3]):[0-5]\d$") diff --git a/worker_api/llm/__init__.py b/worker_api/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/worker_api/llm/llm_response_models.py b/worker_api/llm/llm_response_models.py new file mode 100644 index 0000000..a081376 --- /dev/null +++ b/worker_api/llm/llm_response_models.py @@ -0,0 +1,13 @@ +from typing import Optional +from pydantic import BaseModel, Field + + +class LLMChatRequest(BaseModel): + prompt: str = Field(..., description="User prompt to send to Gemini") + system_prompt: Optional[str] = Field(None, description="Optional system instruction for Gemini") + model: Optional[str] = Field(None, description="Optional Gemini model name (defaults to gemini-2.5-flash)") + + +class LLMChatResponse(BaseModel): + response: str = Field(..., description="Gemini's response text") + model: str = Field(..., description="Model used for generation") diff --git a/worker_api/llm/llm_service.py b/worker_api/llm/llm_service.py new file mode 100644 index 0000000..ccda716 --- /dev/null +++ b/worker_api/llm/llm_service.py @@ -0,0 +1,78 @@ +import asyncio +from typing import Optional +from google import genai +from google.genai import types +from fastapi import HTTPException, status + +from worker_api.config import get + +DEFAULT_MODEL = "gemini-2.5-flash" + + +def _chat_with_gemini_sync( + prompt: str, + system_prompt: Optional[str] = None, + model: Optional[str] = None +) -> dict: + api_key = get("GEMINI_API_KEY") + if not api_key: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="GEMINI_API_KEY is not configured" + ) + + client = genai.Client(api_key=api_key) + + model_name = model or DEFAULT_MODEL + + config = types.GenerateContentConfig( + temperature=1.0, + ) + + if system_prompt: + config.system_instruction = system_prompt + + contents = [ + types.Content( + role="user", + parts=[types.Part.from_text(text=prompt)], + ), + ] + + try: + response = client.models.generate_content( + model=model_name, + contents=contents, + config=config, + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Gemini API error: {str(e)}" + ) + + if not response.candidates or not response.candidates[0].content.parts: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Gemini returned no response" + ) + + response_text = response.candidates[0].content.parts[0].text + + return { + "response": response_text, + "model": model_name + } + + +async def chat_with_gemini( + prompt: str, + system_prompt: Optional[str] = None, + model: Optional[str] = None +) -> dict: + return await asyncio.to_thread( + _chat_with_gemini_sync, + prompt, + system_prompt, + model + ) diff --git a/worker_api/llm/llm_views.py b/worker_api/llm/llm_views.py new file mode 100644 index 0000000..6858e16 --- /dev/null +++ b/worker_api/llm/llm_views.py @@ -0,0 +1,17 @@ +from fastapi import APIRouter +from starlette import status + +from worker_api.llm.llm_response_models import LLMChatRequest, LLMChatResponse +from worker_api.llm.llm_service import chat_with_gemini + +llm_router = APIRouter(prefix="/llm", tags=["LLM"]) + + +@llm_router.post("/chat", status_code=status.HTTP_200_OK) +async def chat(request: LLMChatRequest) -> LLMChatResponse: + result = await chat_with_gemini( + prompt=request.prompt, + system_prompt=request.system_prompt, + model=request.model + ) + return LLMChatResponse(**result) diff --git a/worker_api/uploads/S3_utils.py b/worker_api/uploads/S3_utils.py index 4b1e152..30fbd75 100644 --- a/worker_api/uploads/S3_utils.py +++ b/worker_api/uploads/S3_utils.py @@ -127,6 +127,37 @@ def generate_presigned_access_url(key: str, expiration: int = None) -> str: ) +def download_bytes(key: str) -> bytes: + """ + Download file bytes from S3 bucket. + + Args: + key: S3 object key (path in bucket) + + Returns: + File content as bytes + + Raises: + HTTPException: If download fails + """ + try: + bucket_name = get("AWS_BUCKET_NAME") + response = s3_client.get_object(Bucket=bucket_name, Key=key) + return response["Body"].read() + except ClientError as e: + logger.error(f"Failed to download file from S3: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Failed to download file: {str(e)}", + ) + except Exception as e: + logger.error(f"Unexpected error downloading file from S3: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Unexpected error: {str(e)}", + ) + + def delete_file(key: str): """ Delete a file from S3 bucket.