Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
get_job_attached_volumes,
get_job_runtime_data,
get_job_spec,
interpolate_job_spec_secrets,
is_master_job,
job_model_to_job_submission,
)
Expand All @@ -105,7 +106,7 @@
from dstack._internal.server.services.storage import get_default_storage
from dstack._internal.server.utils import sentry_utils
from dstack._internal.utils.common import get_current_datetime, get_or_error, run_async
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
from dstack._internal.utils.interpolator import InterpolatorError
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -504,7 +505,7 @@ async def _prepare_startup_context(
).repo_creds

try:
_interpolate_secrets(secrets, context.job.job_spec)
interpolate_job_spec_secrets(context.job.job_spec, secrets)
except InterpolatorError as e:
_terminate_job(
job_model=context.job_model,
Expand Down Expand Up @@ -1667,16 +1668,6 @@ async def _get_job_file_archive(archive_id: uuid.UUID, user: UserModel) -> bytes
return blob


def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec) -> None:
interpolate = VariablesInterpolator({"secrets": secrets}).interpolate_or_error
job_spec.env = {k: interpolate(v) for k, v in job_spec.env.items()}
if job_spec.registry_auth is not None:
job_spec.registry_auth = RegistryAuth(
username=interpolate(job_spec.registry_auth.username),
password=interpolate(job_spec.registry_auth.password),
)


def _emit_reachability_change_event(
session: AsyncSession,
job_model: JobModel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
get_job_configured_volumes,
get_job_runtime_data,
get_job_spec,
interpolate_job_spec_secrets,
is_master_job,
is_multinode_job,
switch_job_status,
Expand Down Expand Up @@ -135,9 +136,11 @@
from dstack._internal.server.services.runs.spec import (
check_run_spec_requires_instance_mounts,
)
from dstack._internal.server.services.secrets import get_project_secrets_mapping
from dstack._internal.server.services.volumes import volume_model_to_volume
from dstack._internal.server.utils import sentry_utils
from dstack._internal.utils.common import get_current_datetime, get_or_error, run_async
from dstack._internal.utils.interpolator import InterpolatorError
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -1335,6 +1338,8 @@ async def _process_new_capacity_provisioning(
master_job_provisioning_data=master_provisioning_data,
volumes=preconditions.prepared_job_volumes.volumes,
)
if isinstance(provision_new_capacity_result, _TerminateSubmittedJobResult):
return provision_new_capacity_result
if isinstance(provision_new_capacity_result, _FailedNewCapacityProvisioning):
logger.debug("%s: provisioning failed", fmt(context.job_model))
return _TerminateSubmittedJobResult(
Expand Down Expand Up @@ -2060,12 +2065,22 @@ async def _provision_new_capacity(
project_ssh_private_key: str,
master_job_provisioning_data: Optional[JobProvisioningData] = None,
volumes: Optional[list[list[Volume]]] = None,
) -> Union[_FailedNewCapacityProvisioning, _ProvisionNewCapacityResult]:
) -> Union[
_TerminateSubmittedJobResult, _FailedNewCapacityProvisioning, _ProvisionNewCapacityResult
]:
secrets = await _load_project_secrets(project=project)
jobs = copy.deepcopy(jobs)
for job in jobs:
job.job_spec.image_name, job.job_spec.registry_auth = apply_server_docker_defaults(
job.job_spec.image_name, job.job_spec.registry_auth
)
try:
interpolate_job_spec_secrets(job.job_spec, secrets)
except InterpolatorError as e:
return _TerminateSubmittedJobResult(
reason=JobTerminationReason.TERMINATED_BY_SERVER,
message=f"Secrets interpolation error: {e.args[0]}",
)
job = jobs[0]
if volumes is None:
volumes = []
Expand Down Expand Up @@ -2228,6 +2243,11 @@ async def _provision_new_capacity(
)


async def _load_project_secrets(project: ProjectModel) -> dict[str, str]:
async with get_session_ctx() as session:
return await get_project_secrets_mapping(session=session, project=project)


async def _load_fleet_placement_group_models(fleet_id: uuid.UUID) -> list["PlacementGroupModel"]:
async with get_session_ctx() as session:
res = await session.execute(
Expand Down
13 changes: 13 additions & 0 deletions src/dstack/_internal/server/services/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import json
from collections.abc import Mapping
from datetime import timedelta
from typing import Dict, Iterable, List, Optional, Tuple
from uuid import UUID
Expand Down Expand Up @@ -27,6 +28,7 @@
JobStatus,
JobSubmission,
JobTerminationReason,
RegistryAuth,
RunSpec,
)
from dstack._internal.core.models.volumes import Volume, VolumeMountPoint, VolumeStatus
Expand Down Expand Up @@ -62,6 +64,7 @@
)
from dstack._internal.utils import common
from dstack._internal.utils.common import run_async
from dstack._internal.utils.interpolator import VariablesInterpolator
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.ssh import build_ssh_command, build_ssh_url_authority

Expand Down Expand Up @@ -166,6 +169,16 @@ async def get_job_specs_from_run_spec(
return job_specs


def interpolate_job_spec_secrets(job_spec: JobSpec, secrets: Mapping[str, str]) -> None:
interpolate = VariablesInterpolator({"secrets": secrets}).interpolate_or_error
job_spec.env = {k: interpolate(v) for k, v in job_spec.env.items()}
if job_spec.registry_auth is not None:
job_spec.registry_auth = RegistryAuth(
username=interpolate(job_spec.registry_auth.username),
password=interpolate(job_spec.registry_auth.password),
)


def find_job(jobs: List[Job], replica_num: int, job_num: int) -> Job:
for job in jobs:
if job.job_spec.replica_num == replica_num and job.job_spec.job_num == job_num:
Expand Down
5 changes: 3 additions & 2 deletions src/dstack/_internal/utils/interpolator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import string
from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union, overload
from collections.abc import Mapping
from typing import Iterable, List, Literal, Optional, Tuple, Union, overload


class Pattern:
Expand All @@ -25,7 +26,7 @@ class InterpolatorError(ValueError):

class VariablesInterpolator:
def __init__(
self, namespaces: Dict[str, Dict[str, str]], *, skip: Optional[Iterable[str]] = None
self, namespaces: Mapping[str, Mapping[str, str]], *, skip: Optional[Iterable[str]] = None
):
self.skip = set(skip) if skip is not None else set()
self.variables = {f"{ns}.{k}": v for ns in namespaces for k, v in namespaces[ns].items()}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import RegistryAuth
from dstack._internal.core.models.configurations import TaskConfiguration
from dstack._internal.core.models.envs import Env
from dstack._internal.core.models.fleets import FleetNodesSpec, InstanceGroupPlacement
from dstack._internal.core.models.instances import InstanceStatus
from dstack._internal.core.models.placement import PlacementGroup
Expand Down Expand Up @@ -40,6 +41,8 @@
PlacementGroupModel,
VolumeAttachmentModel,
)
from dstack._internal.server.services.docker import ImageConfig
from dstack._internal.server.services.jobs.configurators.base import JobConfigurator
from dstack._internal.server.testing.common import (
ComputeMockSpec,
create_export,
Expand All @@ -50,6 +53,7 @@
create_project,
create_repo,
create_run,
create_secret,
create_user,
create_volume,
get_compute_group_provisioning_data,
Expand Down Expand Up @@ -1873,6 +1877,125 @@ async def test_run_jobs_uses_server_default_registry(
username="server-user", password="server-pass"
)

async def test_interpolates_secrets_when_provisioning_new_capacity(
self,
test_db,
session: AsyncSession,
image_config_mock: ImageConfig,
worker: JobSubmittedWorker,
):
project = await create_project(session=session)
user = await create_user(session=session)
repo = await create_repo(session=session, project_id=project.id)
fleet = await create_fleet(session=session, project=project)
await create_secret(session=session, project=project, name="token", value="s3cret")
await create_secret(
session=session, project=project, name="registry_user", value="docker-user"
)
await create_secret(
session=session, project=project, name="registry_pass", value="docker-pass"
)
run_spec = get_run_spec(
run_name="test-run",
repo_id=repo.name,
configuration=TaskConfiguration(
image="ubuntu",
env=Env.parse_obj({"TOKEN": "${{ secrets.token }}"}),
registry_auth=RegistryAuth(
username="${{ secrets.registry_user }}",
password="${{ secrets.registry_pass }}",
),
),
)
run = await create_run(
session=session,
project=project,
repo=repo,
user=user,
fleet=fleet,
run_spec=run_spec,
)
with patch.object(JobConfigurator, "_get_image_config") as m:
m.return_value = image_config_mock
job = await create_job(session=session, run=run, instance_assigned=True)

offer = get_instance_offer_with_availability(backend=BackendType.RUNPOD)
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.RUNPOD
backend_mock.compute.return_value.get_offers.return_value = [offer]
backend_mock.compute.return_value.run_job.return_value = get_job_provisioning_data(
dockerized=False, backend=BackendType.RUNPOD
)

await _process_job(session=session, worker=worker, job_model=job)

backend_mock.compute.return_value.run_job.assert_called_once()
submitted_job = backend_mock.compute.return_value.run_job.call_args[0][1]
assert submitted_job.job_spec.env == {"TOKEN": "s3cret"}
assert submitted_job.job_spec.registry_auth == RegistryAuth(
username="docker-user", password="docker-pass"
)
# The persisted JobModel keeps the unresolved literals so secrets aren't leaked.
await session.refresh(job)
assert "${{ secrets.token }}" in job.job_spec_data
assert "${{ secrets.registry_user }}" in job.job_spec_data
assert "${{ secrets.registry_pass }}" in job.job_spec_data

async def test_terminates_job_when_secret_is_missing(
self,
test_db,
session: AsyncSession,
image_config_mock: ImageConfig,
worker: JobSubmittedWorker,
):
project = await create_project(session=session)
user = await create_user(session=session)
repo = await create_repo(session=session, project_id=project.id)
fleet = await create_fleet(session=session, project=project)
run_spec = get_run_spec(
run_name="test-run",
repo_id=repo.name,
configuration=TaskConfiguration(
image="ubuntu",
registry_auth=RegistryAuth(
username="registry_user",
password="${{ secrets.registry_pass }}",
),
),
)
run = await create_run(
session=session,
project=project,
repo=repo,
user=user,
fleet=fleet,
run_spec=run_spec,
)
with patch.object(JobConfigurator, "_get_image_config") as m:
m.return_value = image_config_mock
job = await create_job(session=session, run=run, instance_assigned=True)

offer = get_instance_offer_with_availability(backend=BackendType.RUNPOD)
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.RUNPOD
backend_mock.compute.return_value.get_offers.return_value = [offer]
backend_mock.compute.return_value.run_job.return_value = get_job_provisioning_data(
dockerized=False, backend=BackendType.RUNPOD
)

await _process_job(session=session, worker=worker, job_model=job)

await session.refresh(job)
assert job.status == JobStatus.TERMINATING
assert job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER
assert job.termination_reason_message is not None
assert "Secrets interpolation error" in job.termination_reason_message
backend_mock.compute.return_value.run_job.assert_not_called()


@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
Expand Down
Loading