From 5ba3062685d6ee9dd1064d7c81a570f6c2943969 Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Tue, 28 Apr 2026 14:22:39 +0000 Subject: [PATCH] Interpolate JobSpec secrets for Compute.run_job() Fixes: https://github.com/dstackai/dstack/issues/3833 --- .../background/pipeline_tasks/jobs_running.py | 15 +-- .../pipeline_tasks/jobs_submitted.py | 22 +++- .../server/services/jobs/__init__.py | 13 ++ src/dstack/_internal/utils/interpolator.py | 5 +- .../pipeline_tasks/test_submitted_jobs.py | 123 ++++++++++++++++++ 5 files changed, 163 insertions(+), 15 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index e63e6dd5ae..d441a9e2d3 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -86,6 +86,7 @@ get_job_attached_volumes, get_job_runtime_data, get_job_spec, + interpolate_job_spec_secrets, is_master_job, job_model_to_job_submission, ) @@ -105,7 +106,7 @@ from dstack._internal.server.services.storage import get_default_storage from dstack._internal.server.utils import sentry_utils from dstack._internal.utils.common import get_current_datetime, get_or_error, run_async -from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator +from dstack._internal.utils.interpolator import InterpolatorError from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -504,7 +505,7 @@ async def _prepare_startup_context( ).repo_creds try: - _interpolate_secrets(secrets, context.job.job_spec) + interpolate_job_spec_secrets(context.job.job_spec, secrets) except InterpolatorError as e: _terminate_job( job_model=context.job_model, @@ -1667,16 +1668,6 @@ async def _get_job_file_archive(archive_id: uuid.UUID, user: UserModel) -> bytes return blob -def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec) -> None: - interpolate = VariablesInterpolator({"secrets": secrets}).interpolate_or_error - job_spec.env = {k: interpolate(v) for k, v in job_spec.env.items()} - if job_spec.registry_auth is not None: - job_spec.registry_auth = RegistryAuth( - username=interpolate(job_spec.registry_auth.username), - password=interpolate(job_spec.registry_auth.password), - ) - - def _emit_reachability_change_event( session: AsyncSession, job_model: JobModel, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py index 9bf59b8fd2..76738a8cf1 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py @@ -107,6 +107,7 @@ get_job_configured_volumes, get_job_runtime_data, get_job_spec, + interpolate_job_spec_secrets, is_master_job, is_multinode_job, switch_job_status, @@ -135,9 +136,11 @@ from dstack._internal.server.services.runs.spec import ( check_run_spec_requires_instance_mounts, ) +from dstack._internal.server.services.secrets import get_project_secrets_mapping from dstack._internal.server.services.volumes import volume_model_to_volume from dstack._internal.server.utils import sentry_utils from dstack._internal.utils.common import get_current_datetime, get_or_error, run_async +from dstack._internal.utils.interpolator import InterpolatorError from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -1335,6 +1338,8 @@ async def _process_new_capacity_provisioning( master_job_provisioning_data=master_provisioning_data, volumes=preconditions.prepared_job_volumes.volumes, ) + if isinstance(provision_new_capacity_result, _TerminateSubmittedJobResult): + return provision_new_capacity_result if isinstance(provision_new_capacity_result, _FailedNewCapacityProvisioning): logger.debug("%s: provisioning failed", fmt(context.job_model)) return _TerminateSubmittedJobResult( @@ -2060,12 +2065,22 @@ async def _provision_new_capacity( project_ssh_private_key: str, master_job_provisioning_data: Optional[JobProvisioningData] = None, volumes: Optional[list[list[Volume]]] = None, -) -> Union[_FailedNewCapacityProvisioning, _ProvisionNewCapacityResult]: +) -> Union[ + _TerminateSubmittedJobResult, _FailedNewCapacityProvisioning, _ProvisionNewCapacityResult +]: + secrets = await _load_project_secrets(project=project) jobs = copy.deepcopy(jobs) for job in jobs: job.job_spec.image_name, job.job_spec.registry_auth = apply_server_docker_defaults( job.job_spec.image_name, job.job_spec.registry_auth ) + try: + interpolate_job_spec_secrets(job.job_spec, secrets) + except InterpolatorError as e: + return _TerminateSubmittedJobResult( + reason=JobTerminationReason.TERMINATED_BY_SERVER, + message=f"Secrets interpolation error: {e.args[0]}", + ) job = jobs[0] if volumes is None: volumes = [] @@ -2228,6 +2243,11 @@ async def _provision_new_capacity( ) +async def _load_project_secrets(project: ProjectModel) -> dict[str, str]: + async with get_session_ctx() as session: + return await get_project_secrets_mapping(session=session, project=project) + + async def _load_fleet_placement_group_models(fleet_id: uuid.UUID) -> list["PlacementGroupModel"]: async with get_session_ctx() as session: res = await session.execute( diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index da39e661ea..5dc0699113 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -1,5 +1,6 @@ import itertools import json +from collections.abc import Mapping from datetime import timedelta from typing import Dict, Iterable, List, Optional, Tuple from uuid import UUID @@ -27,6 +28,7 @@ JobStatus, JobSubmission, JobTerminationReason, + RegistryAuth, RunSpec, ) from dstack._internal.core.models.volumes import Volume, VolumeMountPoint, VolumeStatus @@ -62,6 +64,7 @@ ) from dstack._internal.utils import common from dstack._internal.utils.common import run_async +from dstack._internal.utils.interpolator import VariablesInterpolator from dstack._internal.utils.logging import get_logger from dstack._internal.utils.ssh import build_ssh_command, build_ssh_url_authority @@ -166,6 +169,16 @@ async def get_job_specs_from_run_spec( return job_specs +def interpolate_job_spec_secrets(job_spec: JobSpec, secrets: Mapping[str, str]) -> None: + interpolate = VariablesInterpolator({"secrets": secrets}).interpolate_or_error + job_spec.env = {k: interpolate(v) for k, v in job_spec.env.items()} + if job_spec.registry_auth is not None: + job_spec.registry_auth = RegistryAuth( + username=interpolate(job_spec.registry_auth.username), + password=interpolate(job_spec.registry_auth.password), + ) + + def find_job(jobs: List[Job], replica_num: int, job_num: int) -> Job: for job in jobs: if job.job_spec.replica_num == replica_num and job.job_spec.job_num == job_num: diff --git a/src/dstack/_internal/utils/interpolator.py b/src/dstack/_internal/utils/interpolator.py index 9d2fd8997a..9a4e44659b 100644 --- a/src/dstack/_internal/utils/interpolator.py +++ b/src/dstack/_internal/utils/interpolator.py @@ -1,5 +1,6 @@ import string -from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union, overload +from collections.abc import Mapping +from typing import Iterable, List, Literal, Optional, Tuple, Union, overload class Pattern: @@ -25,7 +26,7 @@ class InterpolatorError(ValueError): class VariablesInterpolator: def __init__( - self, namespaces: Dict[str, Dict[str, str]], *, skip: Optional[Iterable[str]] = None + self, namespaces: Mapping[str, Mapping[str, str]], *, skip: Optional[Iterable[str]] = None ): self.skip = set(skip) if skip is not None else set() self.variables = {f"{ns}.{k}": v for ns in namespaces for k, v in namespaces[ns].items()} diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_submitted_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_submitted_jobs.py index cb705da5b5..c4baf6f897 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_submitted_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_submitted_jobs.py @@ -13,6 +13,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import RegistryAuth from dstack._internal.core.models.configurations import TaskConfiguration +from dstack._internal.core.models.envs import Env from dstack._internal.core.models.fleets import FleetNodesSpec, InstanceGroupPlacement from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.placement import PlacementGroup @@ -40,6 +41,8 @@ PlacementGroupModel, VolumeAttachmentModel, ) +from dstack._internal.server.services.docker import ImageConfig +from dstack._internal.server.services.jobs.configurators.base import JobConfigurator from dstack._internal.server.testing.common import ( ComputeMockSpec, create_export, @@ -50,6 +53,7 @@ create_project, create_repo, create_run, + create_secret, create_user, create_volume, get_compute_group_provisioning_data, @@ -1873,6 +1877,125 @@ async def test_run_jobs_uses_server_default_registry( username="server-user", password="server-pass" ) + async def test_interpolates_secrets_when_provisioning_new_capacity( + self, + test_db, + session: AsyncSession, + image_config_mock: ImageConfig, + worker: JobSubmittedWorker, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + fleet = await create_fleet(session=session, project=project) + await create_secret(session=session, project=project, name="token", value="s3cret") + await create_secret( + session=session, project=project, name="registry_user", value="docker-user" + ) + await create_secret( + session=session, project=project, name="registry_pass", value="docker-pass" + ) + run_spec = get_run_spec( + run_name="test-run", + repo_id=repo.name, + configuration=TaskConfiguration( + image="ubuntu", + env=Env.parse_obj({"TOKEN": "${{ secrets.token }}"}), + registry_auth=RegistryAuth( + username="${{ secrets.registry_user }}", + password="${{ secrets.registry_pass }}", + ), + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + fleet=fleet, + run_spec=run_spec, + ) + with patch.object(JobConfigurator, "_get_image_config") as m: + m.return_value = image_config_mock + job = await create_job(session=session, run=run, instance_assigned=True) + + offer = get_instance_offer_with_availability(backend=BackendType.RUNPOD) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + backend_mock = Mock() + m.return_value = [backend_mock] + backend_mock.TYPE = BackendType.RUNPOD + backend_mock.compute.return_value.get_offers.return_value = [offer] + backend_mock.compute.return_value.run_job.return_value = get_job_provisioning_data( + dockerized=False, backend=BackendType.RUNPOD + ) + + await _process_job(session=session, worker=worker, job_model=job) + + backend_mock.compute.return_value.run_job.assert_called_once() + submitted_job = backend_mock.compute.return_value.run_job.call_args[0][1] + assert submitted_job.job_spec.env == {"TOKEN": "s3cret"} + assert submitted_job.job_spec.registry_auth == RegistryAuth( + username="docker-user", password="docker-pass" + ) + # The persisted JobModel keeps the unresolved literals so secrets aren't leaked. + await session.refresh(job) + assert "${{ secrets.token }}" in job.job_spec_data + assert "${{ secrets.registry_user }}" in job.job_spec_data + assert "${{ secrets.registry_pass }}" in job.job_spec_data + + async def test_terminates_job_when_secret_is_missing( + self, + test_db, + session: AsyncSession, + image_config_mock: ImageConfig, + worker: JobSubmittedWorker, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + fleet = await create_fleet(session=session, project=project) + run_spec = get_run_spec( + run_name="test-run", + repo_id=repo.name, + configuration=TaskConfiguration( + image="ubuntu", + registry_auth=RegistryAuth( + username="registry_user", + password="${{ secrets.registry_pass }}", + ), + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + fleet=fleet, + run_spec=run_spec, + ) + with patch.object(JobConfigurator, "_get_image_config") as m: + m.return_value = image_config_mock + job = await create_job(session=session, run=run, instance_assigned=True) + + offer = get_instance_offer_with_availability(backend=BackendType.RUNPOD) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + backend_mock = Mock() + m.return_value = [backend_mock] + backend_mock.TYPE = BackendType.RUNPOD + backend_mock.compute.return_value.get_offers.return_value = [offer] + backend_mock.compute.return_value.run_job.return_value = get_job_provisioning_data( + dockerized=False, backend=BackendType.RUNPOD + ) + + await _process_job(session=session, worker=worker, job_model=job) + + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + assert job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER + assert job.termination_reason_message is not None + assert "Secrets interpolation error" in job.termination_reason_message + backend_mock.compute.return_value.run_job.assert_not_called() + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)