diff --git a/mkdocs/docs/concepts/backends.md b/mkdocs/docs/concepts/backends.md
index 8f7b3325d8..e0724f840c 100644
--- a/mkdocs/docs/concepts/backends.md
+++ b/mkdocs/docs/concepts/backends.md
@@ -1051,9 +1051,9 @@ Compared to [VM-based](#vm-based) backends, they offer less fine-grained control
### Kubernetes
-Regardless of whether it’s on-prem Kubernetes or managed, `dstack` can orchestrate container-based runs across your clusters.
+Regardless of whether it’s on-prem Kubernetes or managed, `dstack` can orchestrate container-based runs across your clusters. A single `kubernetes` backend can manage one or many clusters — each cluster is selected via a kubeconfig [context](https://kubernetes.io/docs/concepts/configuration/organize-cluster-access-kubeconfig/#context).
-To use the `kubernetes` backend with `dstack`, you need to configure it with the path to the kubeconfig file, the IP address of any node in the cluster, and the port that `dstack` will use for proxying SSH traffic.
+The recommended way is to enable clusters explicitly via the `contexts` property:
@@ -1066,22 +1066,48 @@ projects:
kubeconfig:
filename: ~/.kube/config
- proxy_jump:
- hostname: 204.12.171.137
- port: 32000
+ contexts:
+ - name: gpu-cluster-a
+ - name: gpu-cluster-b
```
!!! info "Proxy jump"
- To allow the `dstack` server and CLI to access runs via SSH, `dstack` requires a node that acts as a jump host to proxy SSH traffic into containers.
+ To allow the `dstack` server and CLI to access runs via SSH, `dstack` uses a node in each cluster as a jump host to proxy SSH traffic into containers. No additional setup is required — `dstack` configures and manages the proxy automatically.
- To configure this node, specify `hostname` and `port` under the `proxy_jump` property:
+ By default, `dstack` autodetects the jump host:
- - `hostname` — the IP address of any cluster node selected as the jump host. Both the `dstack` server and CLI must be able to reach it. This node can be either a GPU node or a CPU-only node — it makes no difference.
- - `port` — any accessible port on that node, which `dstack` uses to forward SSH traffic.
+ - `hostname` — picks the `ExternalIP` of the jump pod's node, or a random node `ExternalIP` from the cluster if the jump pod's node has none. If no node in the cluster has an `ExternalIP`, provisioning fails and you must set `hostname` explicitly.
+ - `port` — Kubernetes allocates a port from the cluster's NodePort range.
- No additional setup is required — `dstack` configures and manages the proxy automatically.
+ Set `proxy_jump.hostname` and `proxy_jump.port` per context to override autodetection — useful when nodes lack `ExternalIP`s, or when you want a stable, firewall-friendly port:
+
+ ```yaml
+ contexts:
+ - name: gpu-cluster-a
+ proxy_jump:
+ hostname: 204.12.171.137
+ port: 32000
+ ```
+
+ Both fields are independent — you can set just one.
+
+ The jump host can be a GPU node or a CPU-only node — it makes no difference. The only requirement is that both the `dstack` server and CLI can reach `hostname:port`.
+
+!!! info "Region and namespace"
+ Each enabled context becomes its own `dstack` region, named after the context. When creating a `dstack` [volume](volumes.md) or [gateway](gateways.md), the `region` field selects which cluster the resource is provisioned in.
+
+ The namespace `dstack` uses for managed resources is taken from each kubeconfig context's `namespace` property, defaulting to `default` if not set:
+
+ ```yaml
+ contexts:
+ - name: gpu-cluster-a
+ context:
+ cluster: gpu-cluster-a
+ user: kubernetes-admin
+ namespace: dstack
+ ```
??? info "User interface"
If you are configuring the `kubernetes` backend on the [project settings page](projects.md#backends),
@@ -1091,17 +1117,16 @@ projects:
```yaml
type: kubernetes
-
+
kubeconfig:
data: |
apiVersion: v1
kind: Config
- current-context: kubernetes-admin@gpu-cluster
clusters:
- - name: gpu-cluster
+ - name: gpu-cluster-a
cluster:
- server: https://gpu-cluster.internal.example.com:6443
+ server: https://gpu-cluster-a.internal.example.com:6443
certificate-authority-data: LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0t...LS0tLQo=
users:
@@ -1111,17 +1136,50 @@ projects:
client-key-data: LS0tLS1CRUdJTiBQUklWQVRFIEtFWS0tLS0t...LS0tLQo=
contexts:
- - name: kubernetes-admin@gpu-cluster
+ - name: gpu-cluster-a
context:
- cluster: gpu-cluster
+ cluster: gpu-cluster-a
user: kubernetes-admin
-
- proxy_jump:
- hostname: 204.12.171.137
- port: 32000
+ namespace: dstack
+
+ contexts:
+ - name: gpu-cluster-a
+ proxy_jump:
+ hostname: 204.12.171.137
+ port: 32000
```
+
+??? warning "Legacy configuration (without `contexts`)"
+ If `contexts` is not set, `dstack` falls back to using the kubeconfig's `current-context` as the only cluster, and the top-level `proxy_jump` and `namespace` properties apply:
+
+
+
+ ```yaml
+ projects:
+ - name: main
+ backends:
+ - type: kubernetes
+
+ kubeconfig:
+ filename: ~/.kube/config
+
+ namespace: dstack
+
+ proxy_jump:
+ hostname: 204.12.171.137
+ port: 32000
+ ```
+
+
+
+ This mode is not recommended and may be deprecated and removed in the future. It also has a namespace-handling quirk: the top-level `namespace` property **overrides** the kubeconfig context's namespace (defaulting to `default` if not set in the config), unlike the `contexts` mode where the kubeconfig is authoritative. A warning is logged when the two disagree. To prepare for a possible future change, set the same value in both your kubeconfig context and the backend config.
+
+ With this configuration, the cluster's region is an empty string. When creating a `dstack` volume or gateway, set `region: ''` explicitly in the configuration.
+
+ !!! warning "Migrating from legacy to `contexts`"
+ Switching an existing backend from the legacy mode to `contexts` is not transparent for already-provisioned resources: their region changes from an empty string to the context name, so `dstack` can no longer terminate them. Terminate all jobs, gateways, and volumes managed by the backend before changing the configuration.
??? info "Required operators"
=== "NVIDIA"
@@ -1149,7 +1207,7 @@ projects:
--8<-- "snippets/kubernetes/dstack-backend-role.yaml"
```
- Ensure you've created a ClusterRoleBinding to grant the role to the user or the service account you're using.
+ Ensure you've created a ClusterRoleBinding and RoleBinding to grant the roles to the user or the service account you're using.
??? info "Resources and offers"
If you use ranges with [`resources`](../concepts/tasks.md#resources) (e.g. `gpu: 1..8` or `memory: 64GB..`) in fleet or run configurations, other backends collect and try all offers that satisfy the range.
diff --git a/mkdocs/docs/reference/dstack.yml/volume.md b/mkdocs/docs/reference/dstack.yml/volume.md
index 2675b684ee..d3f851c8c0 100644
--- a/mkdocs/docs/reference/dstack.yml/volume.md
+++ b/mkdocs/docs/reference/dstack.yml/volume.md
@@ -60,3 +60,5 @@ The `volume` configuration type allows creating, registering, and updating [volu
show_root_heading: false
backend:
required: true
+ region:
+ required: true
diff --git a/mkdocs/docs/reference/server/config.yml.md b/mkdocs/docs/reference/server/config.yml.md
index 80e48b028e..a76edaca58 100644
--- a/mkdocs/docs/reference/server/config.yml.md
+++ b/mkdocs/docs/reference/server/config.yml.md
@@ -278,6 +278,18 @@ to configure [backends](../../concepts/backends.md) and other [server-level sett
yq -o=json ~/.kube/config | jq -c | jq -R
```
+###### `projects[n].backends[type=kubernetes].contexts[n]` { #kubernetes-contexts data-toc-label="contexts" }
+
+#SCHEMA# dstack._internal.core.backends.kubernetes.models.KubernetesContextConfig
+ overrides:
+ show_root_heading: false
+
+###### `projects[n].backends[type=kubernetes].contexts[n].proxy_jump` { #kubernetes-contexts-proxy_jump data-toc-label="proxy_jump" }
+
+#SCHEMA# dstack._internal.core.backends.kubernetes.models.KubernetesProxyJumpConfig
+ overrides:
+ show_root_heading: false
+
###### `projects[n].backends[type=kubernetes].proxy_jump` { #kubernetes-proxy_jump data-toc-label="proxy_jump" }
#SCHEMA# dstack._internal.core.backends.kubernetes.models.KubernetesProxyJumpConfig
diff --git a/scripts/merge_kubeconfigs.sh b/scripts/merge_kubeconfigs.sh
new file mode 100755
index 0000000000..e7087f1aca
--- /dev/null
+++ b/scripts/merge_kubeconfigs.sh
@@ -0,0 +1,12 @@
+#!/bin/sh
+set -eu
+
+if [ ${#} -lt 2 ]; then
+ echo "usage: $(basename "${0}") PATH1 PATH2 [PATH3 ...]" >&2
+ exit 1
+fi
+
+# Windows is not supported; on Windows a path separator is ';', not ':'
+KUBECONFIG=$(IFS=':'; echo "${*}")
+export KUBECONFIG
+kubectl config view --raw --flatten | grep -Ev '^current-context: '
diff --git a/scripts/setup_kubernetes.py b/scripts/setup_kubernetes.py
index b38896d715..22295ffebf 100644
--- a/scripts/setup_kubernetes.py
+++ b/scripts/setup_kubernetes.py
@@ -201,7 +201,9 @@ def generate_kubeconfig(
service_account_token: str,
) -> str:
logging.info("generating kubeconfig")
- kubeconfig_content = kubectl.call("config", "view", "--minify", "--raw", capture_stdout=True)
+ kubeconfig_content = kubectl.call(
+ "config", "view", "--minify", "--raw", "--flatten", capture_stdout=True
+ )
with tempfile.NamedTemporaryFile("w+") as f:
f.write(kubeconfig_content)
f.flush()
diff --git a/src/dstack/_internal/core/backends/kubernetes/api_client.py b/src/dstack/_internal/core/backends/kubernetes/api_client.py
new file mode 100644
index 0000000000..29a915d646
--- /dev/null
+++ b/src/dstack/_internal/core/backends/kubernetes/api_client.py
@@ -0,0 +1,46 @@
+from typing import Optional
+
+from kubernetes.client.api_client import ApiClient as _BaseApiClient
+from kubernetes.client.configuration import Configuration as _ClientConfiguration
+from kubernetes.client.exceptions import ApiException
+from kubernetes.config import load_kube_config_from_dict
+from urllib3.exceptions import HTTPError
+
+# 30 * 2 (original request + 1 retry) = 60 seconds total
+DEFAULT_REQUEST_TIMEOUT = 30
+DEFAULT_RETRIES = 1
+
+
+API_CLIENT_EXCEPTIONS: tuple[type[Exception], ...] = (HTTPError, ApiException)
+
+
+class ApiClient(_BaseApiClient):
+ def __init__(self, *, configuration: _ClientConfiguration, request_timeout: int) -> None:
+ self.__request_timeout = request_timeout
+ super().__init__(configuration=configuration)
+
+ def request(self, *args, **kwargs):
+ if kwargs.get("_request_timeout") is None:
+ kwargs["_request_timeout"] = self.__request_timeout
+ return super().request(*args, **kwargs) # pyright: ignore[reportAttributeAccessIssue]
+
+
+def get_api_client_from_kubeconfig_dict(
+ kubeconfig_dict: dict,
+ *,
+ context: str,
+ request_timeout: Optional[int] = None,
+ retries: Optional[int] = None,
+) -> ApiClient:
+ if request_timeout is None:
+ request_timeout = DEFAULT_REQUEST_TIMEOUT
+ if retries is None:
+ retries = DEFAULT_RETRIES
+ client_configuration = _ClientConfiguration()
+ client_configuration.retries = retries # pyright: ignore[reportAttributeAccessIssue]
+ load_kube_config_from_dict(
+ config_dict=kubeconfig_dict,
+ context=context,
+ client_configuration=client_configuration,
+ )
+ return ApiClient(configuration=client_configuration, request_timeout=request_timeout)
diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py
index 062e458b4b..7a8979ed34 100644
--- a/src/dstack/_internal/core/backends/kubernetes/compute.py
+++ b/src/dstack/_internal/core/backends/kubernetes/compute.py
@@ -1,3 +1,4 @@
+import concurrent.futures
import random
import shlex
import subprocess
@@ -28,10 +29,8 @@
get_dstack_gateway_commands,
merge_tags,
)
-from dstack._internal.core.backends.kubernetes.models import (
- KubernetesConfig,
- KubernetesProxyJumpConfig,
-)
+from dstack._internal.core.backends.kubernetes.api_client import API_CLIENT_EXCEPTIONS
+from dstack._internal.core.backends.kubernetes.models import KubernetesConfig
from dstack._internal.core.backends.kubernetes.resources import (
AMD_GPU_DEVICE_ID_LABEL_PREFIX,
AMD_GPU_NAME_TO_DEVICE_IDS,
@@ -60,15 +59,16 @@
parse_quantity,
)
from dstack._internal.core.backends.kubernetes.utils import (
+ LEGACY_CURRENT_CONTEXT_REGION,
+ Cluster,
+ SkipOfferCache,
call_api_method,
- get_api_from_kubeconfig_dict,
- kubeconfig_data_to_kubeconfig_dict,
- kubeconfig_dict_to_kubeconfig,
+ get_clusters_from_backend_config,
try_delete_object_if_exists,
watch_events,
)
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
-from dstack._internal.core.errors import ComputeError, ProvisioningError
+from dstack._internal.core.errors import ComputeError, ProvisioningError, SkipOffer
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel
from dstack._internal.core.models.gateways import (
@@ -136,39 +136,34 @@ class KubernetesCompute(
):
def __init__(self, config: KubernetesConfig):
super().__init__()
- self.config = config.copy()
- proxy_jump = self.config.proxy_jump
- if proxy_jump is None:
- proxy_jump = KubernetesProxyJumpConfig()
- self.proxy_jump = proxy_jump
- kubeconfig_dict = kubeconfig_data_to_kubeconfig_dict(config.kubeconfig.data)
- self.api = get_api_from_kubeconfig_dict(kubeconfig_dict)
- kubeconfig = kubeconfig_dict_to_kubeconfig(kubeconfig_dict)
- current_context = kubeconfig.get_context()
- if current_context.namespace != config.namespace:
- logger.warning(
- (
- "Namespace mismatch: kubeconfig -> '%s', backend config -> '%s'."
- " The current dstack version ignores kubeconfig"
- " and uses deprecated namespace property from backend config."
- " Future versions will use namespace from kubeconfig."
- " To keep using '%s' namespace in future versions and suppress this warning,"
- " set namespace to '%s' in kubeconfig context '%s'"
- ),
- current_context.namespace,
- config.namespace,
- config.namespace,
- config.namespace,
- kubeconfig.current_context,
- )
- # TODO: switch to current_context.namespace
- self.namespace = config.namespace
- logger.debug("Using namespace '%s'", self.namespace)
+ self.region_cluster_map = {c.region: c for c in get_clusters_from_backend_config(config)}
+ self.skip_offer_cache = SkipOfferCache(ttl=60)
def get_offers_by_requirements(
self, requirements: Requirements
) -> list[InstanceOfferWithAvailability]:
- return get_instance_offers(self.api, requirements)
+ offers: list[InstanceOfferWithAvailability] = []
+ with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
+ future_cluster_map: dict[
+ concurrent.futures.Future[list[InstanceOfferWithAvailability]], Cluster
+ ] = {}
+ for region, cluster in self.region_cluster_map.items():
+ api = client.CoreV1Api(cluster.api_client)
+ future = executor.submit(get_instance_offers, api, region, requirements)
+ future_cluster_map[future] = cluster
+ for future in concurrent.futures.as_completed(future_cluster_map):
+ try:
+ cluster_offers = future.result()
+ except API_CLIENT_EXCEPTIONS as e:
+ logger.warning(
+ "Failed to get offers from cluster %s: %s: %s",
+ future_cluster_map[future],
+ e.__class__.__name__,
+ e,
+ )
+ continue
+ offers.extend(cluster_offers)
+ return offers
def run_job(
self,
@@ -180,8 +175,13 @@ def run_job(
volumes: list[Volume],
placement_group: Optional[PlacementGroup],
) -> JobProvisioningData:
- api = self.api
- namespace = self.namespace
+ cluster = self.region_cluster_map.get(instance_offer.region)
+ if cluster is None:
+ raise ComputeError(f"Unknown region: {instance_offer.region!r}")
+ if self.skip_offer_cache.check(run, job, instance_offer):
+ raise SkipOffer(f"cluster {cluster} has recently failed to schedule a similar job")
+ api = client.CoreV1Api(cluster.api_client)
+ namespace = cluster.namespace
# There is one jump pod per project that is used as an ssh proxy jump to connect
# to all job pods of the same project.
@@ -193,7 +193,7 @@ def run_job(
namespace=namespace,
jump_pod_name=jump_pod_name,
jump_pod_service_name=jump_pod_service_name,
- jump_pod_port=self.proxy_jump.port,
+ jump_pod_port=cluster.proxy_jump.port,
project_ssh_public_key=project_ssh_public_key.strip(),
)
@@ -246,6 +246,7 @@ def run_job(
timeout_seconds=JOB_POD_SCHEDULING_TIMEOUT,
)
if not is_pod_scheduled_or_finished:
+ self.skip_offer_cache.add(run, job, instance_offer)
reason, message = _get_unscheduled_pod_reason_message(
api=api,
namespace=namespace,
@@ -256,6 +257,7 @@ def run_job(
f" {reason or 'unknown reason'}: {message or 'no message'}"
)
if pod_phase is not None and pod_phase.is_finished():
+ # It's not clear if we should add an entry to the SkipOfferCache in this case.
raise ComputeError(f"Pod {pod_name} already finished: {pod_phase}")
pod_service_name = _get_pod_service_name(pod_name)
@@ -316,15 +318,21 @@ def update_provisioning_data(
project_ssh_public_key: str,
project_ssh_private_key: str,
):
+ cluster = self.region_cluster_map.get(provisioning_data.region)
+ if cluster is None:
+ raise ProvisioningError(f"Unknown region: {provisioning_data.region!r}")
+ api = client.CoreV1Api(cluster.api_client)
+ namespace = cluster.namespace
+
if provisioning_data.backend_data is not None:
# Before running a job, ensure the jump pod is running and has user's public SSH key.
backend_data = KubernetesBackendData.load(provisioning_data.backend_data)
ssh_proxy = _check_and_configure_jump_pod_service(
- api=self.api,
- namespace=self.namespace,
+ api=api,
+ namespace=namespace,
jump_pod_name=backend_data.jump_pod_name,
jump_pod_service_name=backend_data.jump_pod_service_name,
- jump_pod_hostname=self.proxy_jump.hostname,
+ jump_pod_hostname=cluster.proxy_jump.hostname,
project_ssh_private_key=project_ssh_private_key,
user_ssh_public_key=backend_data.user_ssh_public_key,
)
@@ -336,9 +344,9 @@ def update_provisioning_data(
# in case update_provisioning_data() is called again.
provisioning_data.backend_data = None
- pod = self.api.read_namespaced_pod(
+ pod = api.read_namespaced_pod(
name=provisioning_data.instance_id,
- namespace=self.namespace,
+ namespace=namespace,
)
if pod.status is None:
return
@@ -346,19 +354,20 @@ def update_provisioning_data(
if not pod_ip:
return
provisioning_data.internal_ip = pod_ip
- service = self.api.read_namespaced_service(
+ service = api.read_namespaced_service(
name=_get_pod_service_name(provisioning_data.instance_id),
- namespace=self.namespace,
+ namespace=namespace,
)
service_spec = get_or_error(service.spec)
provisioning_data.hostname = get_or_error(service_spec.cluster_ip)
pod_spec = get_or_error(pod.spec)
- node = self.api.read_node(name=get_or_error(pod_spec.node_name))
+ node = api.read_node(name=get_or_error(pod_spec.node_name))
# In the original offer, the resources have already been adjusted according to
# the run configuration resource requirements, see get_offers_by_requirements()
original_resources = provisioning_data.instance_type.resources
instance_offer = get_instance_offer_from_node(
node=node,
+ region=cluster.region,
cpu_request=original_resources.cpus,
memory_mib_request=original_resources.memory_mib,
gpu_request=len(original_resources.gpus),
@@ -366,14 +375,30 @@ def update_provisioning_data(
)
if instance_offer is not None:
provisioning_data.instance_type = instance_offer.instance
- provisioning_data.region = instance_offer.region
provisioning_data.price = instance_offer.price
def terminate_instance(
self, instance_id: str, region: str, backend_data: Optional[str] = None
):
- api = self.api
- namespace = self.namespace
+ cluster = self.region_cluster_map.get(region)
+ if cluster is None and region == "-":
+ # legacy DUMMY_REGION
+ cluster = self.region_cluster_map.get(LEGACY_CURRENT_CONTEXT_REGION)
+ if cluster is not None:
+ logger.warning(
+ (
+ "Terminating instance %s in unknown region %s."
+ " Assuming it was created before multi-cluster support was added"
+ " and is located in cluster %s"
+ ),
+ instance_id,
+ repr(region),
+ cluster,
+ )
+ if cluster is None:
+ raise ComputeError(f"Unknown region: {region!r}")
+ api = client.CoreV1Api(cluster.api_client)
+ namespace = cluster.namespace
deleted = [
try_delete_object_if_exists(
api.delete_namespaced_service,
@@ -401,6 +426,12 @@ def create_gateway(
self,
configuration: GatewayComputeConfiguration,
) -> GatewayProvisioningData:
+ cluster = self.region_cluster_map.get(configuration.region)
+ if cluster is None:
+ raise ComputeError(f"Unknown region: {configuration.region!r}")
+ api = client.CoreV1Api(cluster.api_client)
+ namespace = cluster.namespace
+
# Gateway creation is currently limited to Kubernetes with Load Balancer support.
# If the cluster does not support Load Balancer, the service will be provisioned but
# the external IP/hostname will never be allocated.
@@ -448,8 +479,8 @@ def create_gateway(
]
),
)
- self.api.create_namespaced_pod(
- namespace=self.namespace,
+ api.create_namespaced_pod(
+ namespace=namespace,
body=pod,
)
service = client.V1Service(
@@ -478,18 +509,18 @@ def create_gateway(
],
),
)
- self.api.create_namespaced_service(
- namespace=self.namespace,
+ api.create_namespaced_service(
+ namespace=namespace,
body=service,
)
# address is eiher a domain name or an IP address
address = _wait_for_load_balancer_address(
- api=self.api,
- namespace=self.namespace,
+ api=api,
+ namespace=namespace,
service_name=_get_pod_service_name(instance_name),
)
if address is None:
- self.terminate_instance(instance_name, region="")
+ self.terminate_instance(instance_name, region=configuration.region)
raise ComputeError(
"Failed to get gateway hostname. "
"Ensure the Kubernetes cluster supports Load Balancer services."
@@ -497,7 +528,7 @@ def create_gateway(
return GatewayProvisioningData(
instance_id=instance_name,
ip_address=address,
- region="",
+ region=cluster.region,
)
def terminate_gateway(
@@ -506,21 +537,50 @@ def terminate_gateway(
configuration: GatewayComputeConfiguration,
backend_data: Optional[str] = None,
):
+ region = configuration.region
+ cluster = self.region_cluster_map.get(region)
+ if cluster is None:
+ # It may be a legacy configuration with the region set to an arbitrary value
+ cluster = self.region_cluster_map.get(LEGACY_CURRENT_CONTEXT_REGION)
+ if cluster is not None:
+ logger.warning(
+ (
+ "Terminating gateway %s in unknown region %s."
+ " Assuming it was created before multi-cluster support was added"
+ " and is located in cluster %s"
+ ),
+ instance_id,
+ repr(region),
+ cluster,
+ )
+ region = LEGACY_CURRENT_CONTEXT_REGION
+ else:
+ raise ComputeError(f"Unknown region: {region!r}")
self.terminate_instance(
instance_id=instance_id,
- region=configuration.region,
+ region=region,
backend_data=backend_data,
)
def register_volume(self, volume: Volume) -> VolumeProvisioningData:
assert isinstance(volume.configuration, KubernetesVolumeConfiguration)
+
+ region = volume.configuration.region
+ cluster = self.region_cluster_map.get(region)
+ if cluster is None:
+ if region == "":
+ raise ComputeError("region is not set")
+ raise ComputeError(f"Unknown region: {region!r}")
+ api = client.CoreV1Api(cluster.api_client)
+ namespace = cluster.namespace
+
pvc_name = volume.configuration.claim_name
assert pvc_name is not None
pvc = call_api_method(
- self.api.read_namespaced_persistent_volume_claim,
+ api.read_namespaced_persistent_volume_claim,
expected=404,
- namespace=self.namespace,
+ namespace=namespace,
name=pvc_name,
)
if pvc is None:
@@ -548,7 +608,15 @@ def register_volume(self, volume: Volume) -> VolumeProvisioningData:
def create_volume(self, volume: Volume) -> VolumeProvisioningData:
assert isinstance(volume.configuration, KubernetesVolumeConfiguration)
- assert volume.configuration.size is not None
+
+ region = volume.configuration.region
+ cluster = self.region_cluster_map.get(region)
+ if cluster is None:
+ if region == "":
+ raise ComputeError("region is not set")
+ raise ComputeError(f"Unknown region: {region!r}")
+ api = client.CoreV1Api(cluster.api_client)
+ namespace = cluster.namespace
labels = {
format_dstack_label_key("owner"): "dstack",
@@ -562,6 +630,8 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData:
)
labels = filter_invalid_labels(labels)
+ assert volume.configuration.size is not None
+
pvc_name = generate_unique_volume_name(volume, max_length=OBJECT_NAME_MAX_LENGTH)
pvc = client.V1PersistentVolumeClaim(
metadata=client.V1ObjectMeta(
@@ -578,8 +648,8 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData:
),
),
)
- self.api.create_namespaced_persistent_volume_claim(
- namespace=self.namespace,
+ api.create_namespaced_persistent_volume_claim(
+ namespace=namespace,
body=pvc,
)
logger.debug("Created PVC %s for volume %s", pvc_name, volume.name)
@@ -594,13 +664,21 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData:
def delete_volume(self, volume: Volume):
assert isinstance(volume.configuration, KubernetesVolumeConfiguration)
+
+ region = volume.configuration.region
+ cluster = self.region_cluster_map.get(region)
+ if cluster is None:
+ raise ComputeError(f"Unknown region: {region!r}")
+ api = client.CoreV1Api(cluster.api_client)
+ namespace = cluster.namespace
+
pvc_name = volume.volume_id
assert pvc_name is not None
pvc = call_api_method(
- self.api.delete_namespaced_persistent_volume_claim,
+ api.delete_namespaced_persistent_volume_claim,
expected=404,
- namespace=self.namespace,
+ namespace=namespace,
name=pvc_name,
)
if pvc is None:
@@ -1170,11 +1248,15 @@ def _wait_for_pod_scheduled_or_finished(
# the scheduler confirmed capacity and that the assigned node is actually Ready and
# working on the pod.
pod_phase: Optional[PodPhase] = None
+ # Ensure that API's timeoutSeconds fires earlier than the network timeout, which defaults to
+ # our custom ApiClient's constructor parameter, see DEFAULT_REQUEST_TIMEOUT
+ request_timeout = timeout_seconds + 5
with watch_events(
api.list_namespaced_pod,
namespace=namespace,
field_selector=f"metadata.name={pod_name}",
timeout_seconds=timeout_seconds,
+ _request_timeout=request_timeout,
) as event_iter:
for _, pod in event_iter:
pod_status = pod.status
diff --git a/src/dstack/_internal/core/backends/kubernetes/configurator.py b/src/dstack/_internal/core/backends/kubernetes/configurator.py
index 9753294b11..b8872f0211 100644
--- a/src/dstack/_internal/core/backends/kubernetes/configurator.py
+++ b/src/dstack/_internal/core/backends/kubernetes/configurator.py
@@ -3,7 +3,6 @@
Configurator,
raise_invalid_credentials_error,
)
-from dstack._internal.core.backends.kubernetes import utils as kubernetes_utils
from dstack._internal.core.backends.kubernetes.backend import KubernetesBackend
from dstack._internal.core.backends.kubernetes.models import (
KubernetesBackendConfig,
@@ -11,6 +10,11 @@
KubernetesConfig,
KubernetesStoredConfig,
)
+from dstack._internal.core.backends.kubernetes.utils import (
+ check_cluster,
+ get_clusters_from_backend_config,
+)
+from dstack._internal.core.errors import ServerClientError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.utils.logging import get_logger
@@ -29,12 +33,17 @@ class KubernetesConfigurator(
def validate_config(
self, config: KubernetesBackendConfigWithCreds, default_creds_enabled: bool
):
+ self._check_config_contexts(config)
try:
- api = kubernetes_utils.get_api_from_kubeconfig_data(config.kubeconfig.data)
- api.list_node()
+ clusters = get_clusters_from_backend_config(config, request_timeout=10, retries=0)
except Exception as e:
- logger.debug("Invalid kubeconfig: %s", str(e))
- raise_invalid_credentials_error(fields=[["kubeconfig"]])
+ raise ServerClientError(str(e))
+ for cluster in clusters:
+ if not check_cluster(cluster):
+ raise_invalid_credentials_error(
+ fields=[["kubeconfig"]],
+ details=f"Failed to validate cluster {cluster}",
+ )
def create_backend(
self, project_name: str, config: KubernetesBackendConfigWithCreds
@@ -59,3 +68,24 @@ def get_backend(self, record: BackendRecord) -> KubernetesBackend:
def _get_config(self, record: BackendRecord) -> KubernetesConfig:
return KubernetesConfig.__response__.parse_raw(record.config)
+
+ def _check_config_contexts(self, config: KubernetesBackendConfig):
+ if config.contexts is None:
+ return
+ if config.proxy_jump is not None:
+ raise ServerClientError("proxy_jump must not be set if contexts is set")
+ if config.namespace is not None:
+ raise ServerClientError("namespace must not be set if contexts is set")
+ seen: set[str] = set()
+ duplicates: set[str] = set()
+ for context in config.contexts:
+ if isinstance(context, str):
+ name = context
+ else:
+ name = context.name
+ if name in seen:
+ duplicates.add(name)
+ else:
+ seen.add(name)
+ if duplicates:
+ raise ServerClientError(f"duplicate contexts: {', '.join(sorted(duplicates))}")
diff --git a/src/dstack/_internal/core/backends/kubernetes/models.py b/src/dstack/_internal/core/backends/kubernetes/models.py
index bb1609733e..eb92982e45 100644
--- a/src/dstack/_internal/core/backends/kubernetes/models.py
+++ b/src/dstack/_internal/core/backends/kubernetes/models.py
@@ -5,8 +5,6 @@
from dstack._internal.core.backends.base.models import fill_data
from dstack._internal.core.models.common import CoreModel
-DEFAULT_NAMESPACE = "default"
-
class KubernetesProxyJumpConfig(CoreModel):
hostname: Annotated[
@@ -17,6 +15,13 @@ class KubernetesProxyJumpConfig(CoreModel):
] = None
+class KubernetesContextConfig(CoreModel):
+ name: Annotated[str, Field(description="The name of the context")]
+ proxy_jump: Annotated[
+ Optional[KubernetesProxyJumpConfig], Field(description="The SSH proxy jump configuration")
+ ] = None
+
+
class KubeconfigConfig(CoreModel):
filename: Annotated[str, Field(description="The path to the kubeconfig file")] = ""
data: Annotated[str, Field(description="The contents of the kubeconfig file")]
@@ -24,22 +29,45 @@ class KubeconfigConfig(CoreModel):
class KubernetesBackendConfig(CoreModel):
type: Annotated[Literal["kubernetes"], Field(description="The type of backend")] = "kubernetes"
+ contexts: Annotated[
+ Optional[list[Union[KubernetesContextConfig, str]]],
+ Field(
+ description=(
+ "Enabled contexts (clusters). Each context should map to a separate cluster."
+ " The context name becomes the region name."
+ " If `contexts` is set, top-level `proxy_jump` and `namespace` must not be set."
+ " `proxy_jump`, if necessary, should be configured per-context;"
+ " `namespace` is taken from the corresponding kubeconfig context's property."
+ " If `contexts` is not set (not recommended), the kubeconfig's `current-context`"
+ " is used as the only context, with an empty string as the region name"
+ ),
+ ),
+ ] = None
proxy_jump: Annotated[
- Optional[KubernetesProxyJumpConfig], Field(description="The SSH proxy jump configuration")
+ Optional[KubernetesProxyJumpConfig],
+ Field(
+ description=(
+ "Only used if `contexts` is not set; must not be set otherwise."
+ " The SSH proxy jump configuration"
+ ),
+ ),
] = None
namespace: Annotated[
- str,
+ Optional[str],
Field(
description=(
- "The namespace for resources managed by `dstack`."
- " Always overrides the namespace set in the kubeconfig, even if not set. "
- " Deprecated and will be eventually removed in futute versions, but"
- " in the current version must be set unless equals to `default`."
+ "Only used if `contexts` is not set; must not be set otherwise."
+ " The namespace for resources managed by `dstack`."
+ " If `contexts` is not set, overrides the namespace set in the kubeconfig,"
+ " even if not set. Defaults to `default`."
+ " Deprecated; will eventually be removed in future versions,"
+ " but in the current version must be set if `contexts` is not set and the value"
+ " is not equal to `default`."
" Future versions will use the namespace from the kubeconfig instead."
" To prepare for future versions, set the same value in the kubeconfig"
)
),
- ] = DEFAULT_NAMESPACE
+ ] = None
"""`namespace` is formally deprecated since 0.20.20 but still used. Future versions will switch
to namespace from kubeconfig context, which is currently ignored"""
diff --git a/src/dstack/_internal/core/backends/kubernetes/resources.py b/src/dstack/_internal/core/backends/kubernetes/resources.py
index fa29d7b513..4d9d266ae2 100644
--- a/src/dstack/_internal/core/backends/kubernetes/resources.py
+++ b/src/dstack/_internal/core/backends/kubernetes/resources.py
@@ -240,7 +240,7 @@ def is_taint_tolerated(taint: V1Taint) -> bool:
def get_instance_offers(
- api: CoreV1Api, requirements: Requirements
+ api: CoreV1Api, region: str, requirements: Requirements
) -> list[InstanceOfferWithAvailability]:
resources_spec = requirements.resources
assert isinstance(resources_spec.cpu, CPUSpec)
@@ -262,6 +262,7 @@ def get_instance_offers(
node=node,
node_name=node_name,
node_allocated_resources=nodes_allocated_resources.get(node_name),
+ region=region,
cpu_request=cpu_request,
memory_mib_request=memory_mib_request,
gpu_request=gpu_request,
@@ -275,6 +276,7 @@ def get_instance_offers(
def get_instance_offer_from_node(
node: V1Node,
*,
+ region: str,
cpu_request: int,
memory_mib_request: int,
gpu_request: int,
@@ -287,6 +289,7 @@ def get_instance_offer_from_node(
node=node,
node_name=node_name,
node_allocated_resources=None,
+ region=region,
cpu_request=cpu_request,
memory_mib_request=memory_mib_request,
gpu_request=gpu_request,
@@ -342,6 +345,7 @@ def _get_instance_offer_from_node(
node: V1Node,
node_name: str,
node_allocated_resources: Optional[KubernetesResources],
+ region: str,
cpu_request: int,
memory_mib_request: int,
gpu_request: int,
@@ -384,7 +388,7 @@ def _get_instance_offer_from_node(
),
),
price=0,
- region="",
+ region=region,
availability=InstanceAvailability.AVAILABLE,
instance_runtime=InstanceRuntime.RUNNER,
)
diff --git a/src/dstack/_internal/core/backends/kubernetes/utils.py b/src/dstack/_internal/core/backends/kubernetes/utils.py
index 054e48eba3..a90f750f1b 100644
--- a/src/dstack/_internal/core/backends/kubernetes/utils.py
+++ b/src/dstack/_internal/core/backends/kubernetes/utils.py
@@ -1,5 +1,6 @@
from collections.abc import Generator
from contextlib import contextmanager
+from dataclasses import dataclass
from typing import (
Annotated,
Any,
@@ -12,20 +13,28 @@
Union,
cast,
)
+from uuid import UUID
import yaml
-from kubernetes.client import CoreV1Api, V1Status
+from cachetools import TTLCache
+from kubernetes.client import V1Status, VersionApi
from kubernetes.client.exceptions import ApiException
-from kubernetes.config import (
- # XXX: This function is missing in the stubs package
- new_client_from_config_dict, # pyright: ignore[reportAttributeAccessIssue]
-)
from kubernetes.watch import Watch
from pydantic import Field
from typing_extensions import ParamSpec, TypedDict
-from urllib3.exceptions import HTTPError
+from dstack._internal.core.backends.kubernetes.api_client import (
+ API_CLIENT_EXCEPTIONS,
+ ApiClient,
+ get_api_client_from_kubeconfig_dict,
+)
+from dstack._internal.core.backends.kubernetes.models import (
+ KubernetesBackendConfigWithCreds,
+ KubernetesProxyJumpConfig,
+)
from dstack._internal.core.models.common import CoreModel
+from dstack._internal.core.models.instances import InstanceOffer
+from dstack._internal.core.models.runs import Job, Run
from dstack._internal.utils.logging import get_logger
logger = get_logger(__name__)
@@ -34,6 +43,122 @@
P = ParamSpec("P")
+LEGACY_CURRENT_CONTEXT_REGION = ""
+
+
+@dataclass
+class Cluster:
+ context_name: str
+ region: str
+ api_client: ApiClient
+ namespace: str
+ proxy_jump: KubernetesProxyJumpConfig
+
+ def __str__(self) -> str:
+ parts: list[str] = []
+ parts.append(f"context={self.context_name!r}")
+ if self.context_name != self.region:
+ parts.append(f"region={self.region!r}")
+ return f"({' '.join(parts)})"
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}{self}"
+
+
+def check_cluster(cluster: Cluster) -> bool:
+ version_api = VersionApi(cluster.api_client)
+ try:
+ version_info = version_api.get_code()
+ except API_CLIENT_EXCEPTIONS as e:
+ logger.debug("cluster %s check failed: %s: %s", cluster, e.__class__.__name__, e)
+ return False
+ logger.debug("cluster %s gitVersion: %s", cluster, version_info.git_version)
+ return True
+
+
+def get_clusters_from_backend_config(
+ config: KubernetesBackendConfigWithCreds,
+ *,
+ request_timeout: Optional[int] = None,
+ retries: Optional[int] = None,
+) -> list[Cluster]:
+ clusters: list[Cluster] = []
+ kubeconfig_dict = kubeconfig_data_to_kubeconfig_dict(config.kubeconfig.data)
+ kubeconfig = kubeconfig_dict_to_kubeconfig(kubeconfig_dict)
+ if config.contexts is not None:
+ for context in config.contexts:
+ if isinstance(context, str):
+ context_name = context
+ proxy_jump = None
+ else:
+ context_name = context.name
+ proxy_jump = context.proxy_jump
+ kubeconfig_context = kubeconfig.get_context(context_name)
+ api_client = get_api_client_from_kubeconfig_dict(
+ kubeconfig_dict,
+ context=context_name,
+ request_timeout=request_timeout,
+ retries=retries,
+ )
+ namespace = kubeconfig_context.namespace
+ if proxy_jump is None:
+ proxy_jump = KubernetesProxyJumpConfig()
+ clusters.append(
+ Cluster(
+ context_name=context_name,
+ region=context_name,
+ api_client=api_client,
+ namespace=namespace,
+ proxy_jump=proxy_jump,
+ )
+ )
+ else:
+ current_kubeconfig_context = kubeconfig.get_context()
+ context_name = kubeconfig.current_context
+ # Already checked by Kubeconfig.get_context()
+ assert context_name is not None
+ api_client = get_api_client_from_kubeconfig_dict(
+ kubeconfig_dict,
+ context=context_name,
+ request_timeout=request_timeout,
+ retries=retries,
+ )
+ config_namespace = config.namespace
+ if config_namespace is None:
+ config_namespace = "default"
+ context_namespace = current_kubeconfig_context.namespace
+ if context_namespace != config_namespace:
+ logger.warning(
+ (
+ "Namespace mismatch: kubeconfig -> '%s', backend config -> '%s'."
+ " The current dstack version ignores kubeconfig"
+ " and uses deprecated namespace property from backend config."
+ " Future versions will use namespace from kubeconfig."
+ " To keep using '%s' namespace in future versions and suppress this warning,"
+ " set namespace to '%s' in kubeconfig context '%s'"
+ ),
+ context_namespace,
+ config_namespace,
+ config_namespace,
+ config_namespace,
+ context_name,
+ )
+ proxy_jump = config.proxy_jump
+ if proxy_jump is None:
+ proxy_jump = KubernetesProxyJumpConfig()
+ clusters.append(
+ Cluster(
+ context_name=context_name,
+ region=LEGACY_CURRENT_CONTEXT_REGION,
+ api_client=api_client,
+ # TODO: switch to context_namespace
+ namespace=config_namespace,
+ proxy_jump=proxy_jump,
+ )
+ )
+ return clusters
+
+
class KubeconfigContext(CoreModel):
namespace: str = "default"
@@ -74,18 +199,28 @@ def kubeconfig_dict_to_kubeconfig(kubeconfig_dict: dict) -> Kubeconfig:
return Kubeconfig.__response__.parse_obj(kubeconfig_dict)
-def get_api_from_kubeconfig_data(
- kubeconfig_data: str, *, context: Optional[str] = None
-) -> CoreV1Api:
- kubeconfig_dict = kubeconfig_data_to_kubeconfig_dict(kubeconfig_data)
- return get_api_from_kubeconfig_dict(kubeconfig_dict, context=context)
+class SkipOfferCache:
+ """
+ `SkipOfferCache` is used to track (run/job, offer) pairs that failed to provision.
+
+ The current implementation tracks _any_ job of the specific run (identified by `Run.id`)
+ on the specific cluster (identified by `InstanceOffer.region`, that is, a kubeconfig context).
+ """
+
+ def __init__(self, *, ttl: int, maxsize: int = 1000) -> None:
+ self._cache = TTLCache[tuple[UUID, str], Literal[True]](maxsize=maxsize, ttl=ttl)
+
+ def add(self, run: Run, job: Job, offer: InstanceOffer) -> None:
+ self._cache[self._build_key(run, job, offer)] = True
+ def check(self, run: Run, job: Job, offer: InstanceOffer) -> bool:
+ return self._build_key(run, job, offer) in self._cache
-def get_api_from_kubeconfig_dict(
- kubeconfig_dict: dict, *, context: Optional[str] = None
-) -> CoreV1Api:
- api_client = new_client_from_config_dict(config_dict=kubeconfig_dict, context=context)
- return CoreV1Api(api_client=api_client)
+ def _build_key(self, run: Run, job: Job, offer: InstanceOffer) -> tuple[UUID, str]:
+ # The current implementation uses only Run.id ignoring the job/job spec.
+ # A more sophisticated implementation could use some parts of the job spec
+ # (e.g., requirements, volumes) instead.
+ return (run.id, offer.region)
def call_api_method(
@@ -136,7 +271,7 @@ def try_delete_object_if_exists(
namespace=namespace,
name=name,
)
- except (HTTPError, ApiException) as e:
+ except API_CLIENT_EXCEPTIONS as e:
if should_delete_manually_if_failed:
logger.exception(
"Failed to delete %s %s in namespace %s. Please delete it manually",
diff --git a/src/dstack/_internal/core/compatibility/volumes.py b/src/dstack/_internal/core/compatibility/volumes.py
index 44ee882511..a0afabf1c6 100644
--- a/src/dstack/_internal/core/compatibility/volumes.py
+++ b/src/dstack/_internal/core/compatibility/volumes.py
@@ -33,7 +33,10 @@ def _get_volume_configuration_excludes(
) -> IncludeExcludeDictType:
configuration_excludes: IncludeExcludeDictType = {}
- if isinstance(configuration, KubernetesVolumeConfiguration) and not configuration.read_only:
- configuration_excludes["read_only"] = True
+ if isinstance(configuration, KubernetesVolumeConfiguration):
+ if not configuration.read_only:
+ configuration_excludes["read_only"] = True
+ if configuration.region == "":
+ configuration_excludes["region"] = True
return configuration_excludes
diff --git a/src/dstack/_internal/core/errors.py b/src/dstack/_internal/core/errors.py
index 5182d063b2..e2685a64e4 100644
--- a/src/dstack/_internal/core/errors.py
+++ b/src/dstack/_internal/core/errors.py
@@ -128,6 +128,19 @@ def __init__(self, details: str) -> None:
return super().__init__(details)
+class SkipOffer(ComputeError):
+ """
+ Used by Compute.run_job and Compute.create_instance to signal that the offer should be skipped.
+ """
+
+ def __init__(self, details: str) -> None:
+ """
+ Args:
+ details: details about why the offer should be skipped
+ """
+ return super().__init__(details)
+
+
class CLIError(DstackError):
pass
diff --git a/src/dstack/_internal/core/models/volumes.py b/src/dstack/_internal/core/models/volumes.py
index 03d52d2ac9..1b96331903 100644
--- a/src/dstack/_internal/core/models/volumes.py
+++ b/src/dstack/_internal/core/models/volumes.py
@@ -134,10 +134,12 @@ class RunpodVolumeConfiguration(VolumeConfigurationWithRegion, VolumeConfigurati
"""Runpod doesn't have AZs but we accept this field for compatibility with older clients."""
-class KubernetesVolumeConfiguration(BaseVolumeConfiguration):
+class KubernetesVolumeConfiguration(VolumeConfigurationWithRegion):
backend: Annotated[
Literal[BackendType.KUBERNETES], Field(description="The volume backend")
] = BackendType.KUBERNETES
+ region: Annotated[str, Field(description="The volume region (cluster)")] = ""
+ """`region` uses a default value for backward compatibility."""
size: Annotated[
Optional[Memory],
Field(
diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py
index b1ceac47d6..0673b6b3c1 100644
--- a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py
+++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py
@@ -20,6 +20,7 @@
from dstack._internal.core.errors import (
BackendError,
PlacementGroupNotSupportedError,
+ SkipOffer,
)
from dstack._internal.core.models.instances import (
InstanceOfferWithAvailability,
@@ -114,8 +115,15 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult:
include_only_create_instance_supported_backends=True,
)
+ offers_iter = iter(offers)
+ offers_tried = 0
# Limit number of offers tried to prevent long-running processing in case all offers fail.
- for backend, instance_offer in offers[: server_settings.MAX_OFFERS_TRIED]:
+ while offers_tried < server_settings.MAX_OFFERS_TRIED:
+ backend_with_instance_offer = next(offers_iter, None)
+ if backend_with_instance_offer is None:
+ break
+ backend, instance_offer = backend_with_instance_offer
+
if instance_offer.backend not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT:
continue
compute = backend.compute()
@@ -159,6 +167,7 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult:
instance_offer.region,
instance_offer.price,
)
+ offers_tried += 1
try:
job_provisioning_data = await run_async(
compute.create_instance,
@@ -166,6 +175,17 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult:
instance_configuration,
placement_group_model_to_placement_group_optional(placement_group_model),
)
+ except SkipOffer as exc:
+ offers_tried -= 1
+ logger.info(
+ "%s launch in %s/%s skipped: %s",
+ instance_offer.instance.name,
+ instance_offer.backend.value,
+ instance_offer.region,
+ exc,
+ extra={"instance_name": instance_model.name},
+ )
+ continue
except BackendError as exc:
logger.warning(
"%s launch in %s/%s failed: %s",
diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py
index 8ecdc0e030..2161eb8958 100644
--- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py
+++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py
@@ -20,7 +20,7 @@
BACKENDS_WITH_GROUP_PROVISIONING_SUPPORT,
BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT,
)
-from dstack._internal.core.errors import BackendError, ServerClientError
+from dstack._internal.core.errors import BackendError, ServerClientError, SkipOffer
from dstack._internal.core.models.common import NetworkMode
from dstack._internal.core.models.compute_groups import (
ComputeGroupProvisioningData,
@@ -2120,8 +2120,13 @@ async def _provision_new_capacity(
instance_mounts=check_run_spec_requires_instance_mounts(run.run_spec),
placement_group=placement_group_model_to_placement_group_optional(placement_group_model),
)
+ offers_iter = iter(offers)
offers_tried = 0
- for backend, offer in offers[: settings.MAX_OFFERS_TRIED]:
+ while offers_tried < settings.MAX_OFFERS_TRIED:
+ backend_with_offer = next(offers_iter, None)
+ if backend_with_offer is None:
+ break
+ backend, offer = backend_with_offer
logger.debug(
"%s: trying %s in %s/%s for $%0.4f per hour",
fmt(job_model),
@@ -2214,6 +2219,17 @@ async def _provision_new_capacity(
new_placement_group_models=new_placement_group_models,
),
)
+ except SkipOffer as e:
+ offers_tried -= 1
+ logger.info(
+ "%s: %s launch in %s/%s skipped: %s",
+ fmt(job_model),
+ offer.instance.name,
+ offer.backend.value,
+ offer.region,
+ e,
+ )
+ continue
except BackendError as e:
logger.warning(
"%s: %s launch in %s/%s failed: %s",
diff --git a/src/tests/_internal/core/backends/kubernetes/test_configurator.py b/src/tests/_internal/core/backends/kubernetes/test_configurator.py
index 2c5e665d53..36f32a08b9 100644
--- a/src/tests/_internal/core/backends/kubernetes/test_configurator.py
+++ b/src/tests/_internal/core/backends/kubernetes/test_configurator.py
@@ -1,43 +1,95 @@
-from unittest.mock import Mock, patch
+from unittest.mock import Mock
import pytest
-from dstack._internal.core.backends.kubernetes.configurator import (
- KubernetesConfigurator,
-)
+from dstack._internal.core.backends.kubernetes.configurator import KubernetesConfigurator
from dstack._internal.core.backends.kubernetes.models import (
KubeconfigConfig,
KubernetesBackendConfigWithCreds,
+ KubernetesContextConfig,
KubernetesProxyJumpConfig,
)
-from dstack._internal.core.errors import BackendInvalidCredentialsError
+from dstack._internal.core.errors import ServerClientError
+
+
+@pytest.fixture
+def get_clusters_mock(monkeypatch: pytest.MonkeyPatch) -> Mock:
+ mock = Mock(return_value=[])
+ monkeypatch.setattr(
+ "dstack._internal.core.backends.kubernetes.configurator.get_clusters_from_backend_config",
+ mock,
+ )
+ return mock
class TestKubernetesConfigurator:
- def test_validate_config_valid(self):
+ @pytest.mark.usefixtures("get_clusters_mock")
+ def test_validate_config_valid_current_context(self):
+ config = KubernetesBackendConfigWithCreds(
+ kubeconfig=KubeconfigConfig(data="mocked", filename="-"),
+ proxy_jump=KubernetesProxyJumpConfig(hostname=None, port=30022),
+ namespace="ns",
+ )
+ KubernetesConfigurator().validate_config(config, default_creds_enabled=True)
+
+ @pytest.mark.usefixtures("get_clusters_mock")
+ def test_validate_config_valid_explicit_contexts(self):
config = KubernetesBackendConfigWithCreds(
- kubeconfig=KubeconfigConfig(data="valid", filename="-"),
- proxy_jump=KubernetesProxyJumpConfig(hostname=None, port=None),
+ kubeconfig=KubeconfigConfig(data="mocked", filename="-"),
+ contexts=["ctx"],
)
- with patch(
- "dstack._internal.core.backends.kubernetes.utils.get_api_from_kubeconfig_data"
- ) as get_api_mock:
- api_mock = Mock()
- api_mock.list_node.return_value = Mock()
- get_api_mock.return_value = api_mock
+ KubernetesConfigurator().validate_config(config, default_creds_enabled=True)
+
+ @pytest.mark.usefixtures("get_clusters_mock")
+ def test_validate_config_contexts_proxy_jump_mutually_exclusive(self):
+ config = KubernetesBackendConfigWithCreds(
+ kubeconfig=KubeconfigConfig(data="mocked", filename="-"),
+ proxy_jump=KubernetesProxyJumpConfig(hostname=None, port=30022),
+ contexts=["ctx"],
+ )
+ with pytest.raises(ServerClientError, match="proxy_jump must not be set"):
KubernetesConfigurator().validate_config(config, default_creds_enabled=True)
- def test_validate_config_invalid_config(self):
+ @pytest.mark.usefixtures("get_clusters_mock")
+ def test_validate_config_contexts_namespace_mutually_exclusive(self):
config = KubernetesBackendConfigWithCreds(
- kubeconfig=KubeconfigConfig(data="invalid", filename="-"),
- proxy_jump=KubernetesProxyJumpConfig(hostname=None, port=None),
+ kubeconfig=KubeconfigConfig(data="mocked", filename="-"),
+ namespace="ns",
+ contexts=["ctx"],
+ )
+ with pytest.raises(ServerClientError, match="namespace must not be set"):
+ KubernetesConfigurator().validate_config(config, default_creds_enabled=True)
+
+ @pytest.mark.usefixtures("get_clusters_mock")
+ def test_validate_config_duplicate_contexts(self):
+ config = KubernetesBackendConfigWithCreds(
+ kubeconfig=KubeconfigConfig(data="mocked", filename="-"),
+ contexts=[
+ "ctx-3",
+ KubernetesContextConfig(name="ctx-4"),
+ "ctx-1",
+ KubernetesContextConfig(name="ctx-1"),
+ "ctx-2",
+ KubernetesContextConfig(name="ctx-3"),
+ ],
+ )
+ with pytest.raises(ServerClientError, match="duplicate contexts: ctx-1, ctx-3"):
+ KubernetesConfigurator().validate_config(config, default_creds_enabled=True)
+
+ def test_validate_config_cluster_check_failed(
+ self, monkeypatch: pytest.MonkeyPatch, get_clusters_mock: Mock
+ ):
+ config = KubernetesBackendConfigWithCreds(
+ kubeconfig=KubeconfigConfig(data="mocked", filename="-"),
+ contexts=["ctx"],
+ )
+
+ monkeypatch.setattr(
+ "dstack._internal.core.backends.kubernetes.configurator.check_cluster",
+ Mock(return_value=False),
)
- with (
- patch(
- "dstack._internal.core.backends.kubernetes.utils.get_api_from_kubeconfig_data"
- ) as get_api_mock,
- pytest.raises(BackendInvalidCredentialsError) as exc_info,
- ):
- get_api_mock.side_effect = Exception("Invalid config")
+ cluster_mock = Mock()
+ get_clusters_mock.return_value = [cluster_mock]
+ with pytest.raises(ServerClientError, match="Failed to validate cluster") as exc_info:
KubernetesConfigurator().validate_config(config, default_creds_enabled=True)
assert exc_info.value.fields == [["kubeconfig"]]
diff --git a/src/tests/_internal/core/backends/kubernetes/test_utils.py b/src/tests/_internal/core/backends/kubernetes/test_utils.py
new file mode 100644
index 0000000000..f58868217f
--- /dev/null
+++ b/src/tests/_internal/core/backends/kubernetes/test_utils.py
@@ -0,0 +1,318 @@
+import logging
+from textwrap import dedent
+from typing import Optional, Union
+
+import pytest
+
+from dstack._internal.core.backends.kubernetes.models import (
+ KubeconfigConfig,
+ KubernetesBackendConfigWithCreds,
+ KubernetesContextConfig,
+ KubernetesProxyJumpConfig,
+)
+from dstack._internal.core.backends.kubernetes.utils import (
+ Cluster,
+ get_clusters_from_backend_config,
+)
+
+
+class TestGetClustersFromBackendConfig:
+ def make_config(
+ self,
+ kubeconfig_data: str,
+ *,
+ contexts: Optional[list[Union[KubernetesContextConfig, str]]] = None,
+ namespace: Optional[str] = None,
+ proxy_jump: Optional[KubernetesProxyJumpConfig] = None,
+ ) -> KubernetesBackendConfigWithCreds:
+ return KubernetesBackendConfigWithCreds(
+ kubeconfig=KubeconfigConfig(data=kubeconfig_data, filename="-"),
+ contexts=contexts,
+ namespace=namespace,
+ proxy_jump=proxy_jump,
+ )
+
+ def make_kubeconfig(
+ self,
+ *,
+ current_context: str = "ctx-a",
+ # (context name, namespace) pairs
+ contexts: tuple[tuple[str, str], ...] = (("ctx-a", "default"),),
+ ) -> str:
+ clusters_yaml = "\n".join(
+ dedent(f"""
+ - name: cluster-{name}
+ cluster:
+ server: https://{name}.example.com:6443
+ """)
+ for name, _ in contexts
+ )
+ users_yaml = "\n".join(
+ dedent(f"""
+ - name: user-{name}
+ user:
+ token: token-{name}
+ """)
+ for name, _ in contexts
+ )
+ contexts_yaml = "\n".join(
+ dedent(f"""
+ - name: {name}
+ context:
+ cluster: cluster-{name}
+ user: user-{name}
+ namespace: {namespace}
+ """)
+ for name, namespace in contexts
+ )
+ return dedent("""
+ apiVersion: v1
+ kind: Config
+ current-context: {current_context}
+ clusters:
+ {clusters}
+ contexts:
+ {contexts}
+ users:
+ {users}
+ """).format(
+ current_context=current_context,
+ clusters=clusters_yaml,
+ contexts=contexts_yaml,
+ users=users_yaml,
+ )
+
+ def test_returns_single_cluster_using_current_context(self):
+ config = self.make_config(
+ self.make_kubeconfig(
+ current_context="ctx-a",
+ contexts=(
+ ("ctx-b", "team-b"),
+ ("ctx-a", "default"),
+ ),
+ ),
+ )
+
+ clusters = get_clusters_from_backend_config(config)
+
+ assert len(clusters) == 1
+ cluster = clusters[0]
+ assert isinstance(cluster, Cluster)
+ assert cluster.context_name == "ctx-a"
+ assert cluster.region == ""
+ assert cluster.namespace == "default"
+ assert cluster.proxy_jump == KubernetesProxyJumpConfig()
+ assert cluster.api_client.configuration.host == "https://ctx-a.example.com:6443" # pyright: ignore[reportAttributeAccessIssue]
+
+ def test_single_context_uses_namespace_from_backend_config(self):
+ config = self.make_config(
+ self.make_kubeconfig(contexts=(("ctx-a", "team-a"),)),
+ namespace="team-a",
+ )
+
+ clusters = get_clusters_from_backend_config(config)
+
+ assert clusters[0].namespace == "team-a"
+
+ def test_single_context_defaults_namespace_when_not_set(self):
+ config = self.make_config(
+ self.make_kubeconfig(contexts=(("ctx-a", "team-a"),)),
+ namespace=None,
+ )
+
+ clusters = get_clusters_from_backend_config(config)
+
+ assert clusters[0].namespace == "default"
+
+ def test_single_context_uses_proxy_jump_from_backend_config(self):
+ proxy_jump = KubernetesProxyJumpConfig(hostname="1.2.3.4", port=2222)
+ config = self.make_config(
+ self.make_kubeconfig(),
+ proxy_jump=proxy_jump,
+ )
+
+ clusters = get_clusters_from_backend_config(config)
+
+ assert clusters[0].proxy_jump == proxy_jump
+
+ def test_single_context_uses_default_proxy_jump_when_unset(self):
+ config = self.make_config(self.make_kubeconfig(), proxy_jump=None)
+
+ clusters = get_clusters_from_backend_config(config)
+
+ assert clusters[0].proxy_jump == KubernetesProxyJumpConfig()
+
+ def test_single_context_warns_on_namespace_mismatch(self, caplog: pytest.LogCaptureFixture):
+ caplog.set_level(logging.WARNING)
+ config = self.make_config(
+ self.make_kubeconfig(contexts=(("ctx-a", "kube-ns"),)),
+ namespace="config-ns",
+ )
+
+ clusters = get_clusters_from_backend_config(config)
+
+ assert clusters[0].namespace == "config-ns"
+ assert "Namespace mismatch" in caplog.text
+ assert "kube-ns" in caplog.text
+ assert "config-ns" in caplog.text
+
+ def test_single_context_does_not_warn_when_namespace_matches(
+ self, caplog: pytest.LogCaptureFixture
+ ):
+ caplog.set_level(logging.WARNING)
+ config = self.make_config(
+ self.make_kubeconfig(contexts=(("ctx-a", "team-a"),)),
+ namespace="team-a",
+ )
+
+ get_clusters_from_backend_config(config)
+
+ assert "Namespace mismatch" not in caplog.text
+
+ def test_single_context_raises_when_current_context_missing(self):
+ kubeconfig = dedent("""
+ apiVersion: v1
+ kind: Config
+ clusters:
+ - name: cluster-a
+ cluster:
+ server: https://a.example.com:6443
+ contexts:
+ - name: ctx-a
+ context:
+ cluster: cluster-a
+ user: user-a
+ users:
+ - name: user-a
+ user:
+ token: t
+ """)
+ config = self.make_config(kubeconfig)
+
+ with pytest.raises(ValueError, match="current-context is not set"):
+ get_clusters_from_backend_config(config)
+
+ def test_contexts_as_strings(self):
+ config = self.make_config(
+ self.make_kubeconfig(
+ current_context="ctx-a",
+ contexts=(
+ ("ctx-a", "ns-a"),
+ ("ctx-b", "ns-b"),
+ ),
+ ),
+ contexts=["ctx-a", "ctx-b"],
+ )
+
+ clusters = get_clusters_from_backend_config(config)
+
+ assert [c.context_name for c in clusters] == ["ctx-a", "ctx-b"]
+ assert [c.region for c in clusters] == ["ctx-a", "ctx-b"]
+ assert [c.namespace for c in clusters] == ["ns-a", "ns-b"]
+ assert all(c.proxy_jump == KubernetesProxyJumpConfig() for c in clusters)
+ assert clusters[0].api_client.configuration.host == "https://ctx-a.example.com:6443" # pyright: ignore[reportAttributeAccessIssue]
+ assert clusters[1].api_client.configuration.host == "https://ctx-b.example.com:6443" # pyright: ignore[reportAttributeAccessIssue]
+
+ def test_contexts_with_per_context_proxy_jump(self):
+ proxy_jump_a = KubernetesProxyJumpConfig(hostname="a.example.com", port=2201)
+ proxy_jump_b = KubernetesProxyJumpConfig(hostname="b.example.com", port=2202)
+ config = self.make_config(
+ self.make_kubeconfig(
+ contexts=(
+ ("ctx-a", "ns-a"),
+ ("ctx-b", "ns-b"),
+ ),
+ ),
+ contexts=[
+ KubernetesContextConfig(name="ctx-a", proxy_jump=proxy_jump_a),
+ KubernetesContextConfig(name="ctx-b", proxy_jump=proxy_jump_b),
+ ],
+ )
+
+ clusters = get_clusters_from_backend_config(config)
+
+ assert clusters[0].proxy_jump == proxy_jump_a
+ assert clusters[1].proxy_jump == proxy_jump_b
+
+ def test_contexts_mix_string_and_object(self):
+ proxy_jump = KubernetesProxyJumpConfig(hostname="b.example.com", port=2222)
+ config = self.make_config(
+ self.make_kubeconfig(
+ contexts=(
+ ("ctx-a", "ns-a"),
+ ("ctx-b", "ns-b"),
+ ),
+ ),
+ contexts=[
+ "ctx-a",
+ KubernetesContextConfig(name="ctx-b", proxy_jump=proxy_jump),
+ ],
+ )
+
+ clusters = get_clusters_from_backend_config(config)
+
+ assert clusters[0].proxy_jump == KubernetesProxyJumpConfig()
+ assert clusters[1].proxy_jump == proxy_jump
+
+ def test_contexts_object_without_proxy_jump_uses_default(self):
+ config = self.make_config(
+ self.make_kubeconfig(contexts=(("ctx-a", "ns-a"),)),
+ contexts=[KubernetesContextConfig(name="ctx-a", proxy_jump=None)],
+ )
+
+ clusters = get_clusters_from_backend_config(config)
+
+ assert clusters[0].proxy_jump == KubernetesProxyJumpConfig()
+
+ def test_contexts_ignores_backend_namespace_and_proxy_jump(self):
+ config = self.make_config(
+ self.make_kubeconfig(contexts=(("ctx-a", "kube-ns"),)),
+ contexts=["ctx-a"],
+ namespace="config-ns",
+ proxy_jump=KubernetesProxyJumpConfig(hostname="ignored", port=1),
+ )
+
+ clusters = get_clusters_from_backend_config(config)
+
+ assert clusters[0].namespace == "kube-ns"
+ assert clusters[0].proxy_jump == KubernetesProxyJumpConfig()
+
+ def test_contexts_does_not_warn_on_namespace_mismatch(self, caplog: pytest.LogCaptureFixture):
+ caplog.set_level(logging.WARNING)
+ config = self.make_config(
+ self.make_kubeconfig(contexts=(("ctx-a", "kube-ns"),)),
+ contexts=["ctx-a"],
+ namespace="config-ns",
+ )
+
+ get_clusters_from_backend_config(config)
+
+ assert "Namespace mismatch" not in caplog.text
+
+ def test_contexts_raises_for_unknown_context(self):
+ config = self.make_config(
+ self.make_kubeconfig(contexts=(("ctx-a", "ns-a"),)),
+ contexts=["ctx-missing"],
+ )
+
+ with pytest.raises(ValueError, match="context ctx-missing not found"):
+ get_clusters_from_backend_config(config)
+
+ def test_empty_contexts_returns_no_clusters(self):
+ config = self.make_config(
+ self.make_kubeconfig(contexts=(("ctx-a", "ns-a"),)),
+ contexts=[],
+ )
+
+ clusters = get_clusters_from_backend_config(config)
+
+ assert clusters == []
+
+ def test_request_timeout_and_retries_propagate_to_client(self):
+ config = self.make_config(self.make_kubeconfig())
+
+ clusters = get_clusters_from_backend_config(config, request_timeout=7, retries=5)
+
+ api_client = clusters[0].api_client
+ assert api_client.configuration.retries == 5 # pyright: ignore[reportAttributeAccessIssue]
+ assert getattr(api_client, "_ApiClient__request_timeout", None) == 7
diff --git a/src/tests/_internal/server/services/test_backend_configs.py b/src/tests/_internal/server/services/test_backend_configs.py
index 833dda2758..d99bdcb985 100644
--- a/src/tests/_internal/server/services/test_backend_configs.py
+++ b/src/tests/_internal/server/services/test_backend_configs.py
@@ -183,11 +183,12 @@ def test_ui_config_embedded_kubeconfig_initializes_backend(self):
backend_config = config_yaml_to_backend_config(config_yaml)
backend = KubernetesBackend(backend_config)
- assert backend.compute().api.api_client.configuration.host == (
+ cluster = backend.compute().region_cluster_map[""]
+ assert cluster.api_client.configuration.host == (
"https://gpu-cluster.internal.example.com:6443"
)
- assert backend.compute().proxy_jump.hostname == "204.12.171.137"
- assert backend.compute().proxy_jump.port == 32000
+ assert cluster.proxy_jump.hostname == "204.12.171.137"
+ assert cluster.proxy_jump.port == 32000
def test_kubeconfig_context_namespace_does_not_set_backend_namespace(self):
config_yaml = dedent(
@@ -226,4 +227,5 @@ def test_kubeconfig_context_namespace_does_not_set_backend_namespace(self):
backend_config = config_yaml_to_backend_config(config_yaml)
backend = KubernetesBackend(backend_config)
- assert backend.compute().config.namespace == "default"
+ cluster = backend.compute().region_cluster_map[""]
+ assert cluster.namespace == "default"