diff --git a/mkdocs/docs/concepts/backends.md b/mkdocs/docs/concepts/backends.md index 8f7b3325d8..e0724f840c 100644 --- a/mkdocs/docs/concepts/backends.md +++ b/mkdocs/docs/concepts/backends.md @@ -1051,9 +1051,9 @@ Compared to [VM-based](#vm-based) backends, they offer less fine-grained control ### Kubernetes -Regardless of whether it’s on-prem Kubernetes or managed, `dstack` can orchestrate container-based runs across your clusters. +Regardless of whether it’s on-prem Kubernetes or managed, `dstack` can orchestrate container-based runs across your clusters. A single `kubernetes` backend can manage one or many clusters — each cluster is selected via a kubeconfig [context](https://kubernetes.io/docs/concepts/configuration/organize-cluster-access-kubeconfig/#context). -To use the `kubernetes` backend with `dstack`, you need to configure it with the path to the kubeconfig file, the IP address of any node in the cluster, and the port that `dstack` will use for proxying SSH traffic. +The recommended way is to enable clusters explicitly via the `contexts` property:
@@ -1066,22 +1066,48 @@ projects: kubeconfig: filename: ~/.kube/config - proxy_jump: - hostname: 204.12.171.137 - port: 32000 + contexts: + - name: gpu-cluster-a + - name: gpu-cluster-b ```
!!! info "Proxy jump" - To allow the `dstack` server and CLI to access runs via SSH, `dstack` requires a node that acts as a jump host to proxy SSH traffic into containers. + To allow the `dstack` server and CLI to access runs via SSH, `dstack` uses a node in each cluster as a jump host to proxy SSH traffic into containers. No additional setup is required — `dstack` configures and manages the proxy automatically. - To configure this node, specify `hostname` and `port` under the `proxy_jump` property: + By default, `dstack` autodetects the jump host: - - `hostname` — the IP address of any cluster node selected as the jump host. Both the `dstack` server and CLI must be able to reach it. This node can be either a GPU node or a CPU-only node — it makes no difference. - - `port` — any accessible port on that node, which `dstack` uses to forward SSH traffic. + - `hostname` — picks the `ExternalIP` of the jump pod's node, or a random node `ExternalIP` from the cluster if the jump pod's node has none. If no node in the cluster has an `ExternalIP`, provisioning fails and you must set `hostname` explicitly. + - `port` — Kubernetes allocates a port from the cluster's NodePort range. - No additional setup is required — `dstack` configures and manages the proxy automatically. + Set `proxy_jump.hostname` and `proxy_jump.port` per context to override autodetection — useful when nodes lack `ExternalIP`s, or when you want a stable, firewall-friendly port: + + ```yaml + contexts: + - name: gpu-cluster-a + proxy_jump: + hostname: 204.12.171.137 + port: 32000 + ``` + + Both fields are independent — you can set just one. + + The jump host can be a GPU node or a CPU-only node — it makes no difference. The only requirement is that both the `dstack` server and CLI can reach `hostname:port`. + +!!! info "Region and namespace" + Each enabled context becomes its own `dstack` region, named after the context. When creating a `dstack` [volume](volumes.md) or [gateway](gateways.md), the `region` field selects which cluster the resource is provisioned in. + + The namespace `dstack` uses for managed resources is taken from each kubeconfig context's `namespace` property, defaulting to `default` if not set: + + ```yaml + contexts: + - name: gpu-cluster-a + context: + cluster: gpu-cluster-a + user: kubernetes-admin + namespace: dstack + ``` ??? info "User interface" If you are configuring the `kubernetes` backend on the [project settings page](projects.md#backends), @@ -1091,17 +1117,16 @@ projects: ```yaml type: kubernetes - + kubeconfig: data: | apiVersion: v1 kind: Config - current-context: kubernetes-admin@gpu-cluster clusters: - - name: gpu-cluster + - name: gpu-cluster-a cluster: - server: https://gpu-cluster.internal.example.com:6443 + server: https://gpu-cluster-a.internal.example.com:6443 certificate-authority-data: LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0t...LS0tLQo= users: @@ -1111,17 +1136,50 @@ projects: client-key-data: LS0tLS1CRUdJTiBQUklWQVRFIEtFWS0tLS0t...LS0tLQo= contexts: - - name: kubernetes-admin@gpu-cluster + - name: gpu-cluster-a context: - cluster: gpu-cluster + cluster: gpu-cluster-a user: kubernetes-admin - - proxy_jump: - hostname: 204.12.171.137 - port: 32000 + namespace: dstack + + contexts: + - name: gpu-cluster-a + proxy_jump: + hostname: 204.12.171.137 + port: 32000 ``` + +??? warning "Legacy configuration (without `contexts`)" + If `contexts` is not set, `dstack` falls back to using the kubeconfig's `current-context` as the only cluster, and the top-level `proxy_jump` and `namespace` properties apply: + +
+ + ```yaml + projects: + - name: main + backends: + - type: kubernetes + + kubeconfig: + filename: ~/.kube/config + + namespace: dstack + + proxy_jump: + hostname: 204.12.171.137 + port: 32000 + ``` + +
+ + This mode is not recommended and may be deprecated and removed in the future. It also has a namespace-handling quirk: the top-level `namespace` property **overrides** the kubeconfig context's namespace (defaulting to `default` if not set in the config), unlike the `contexts` mode where the kubeconfig is authoritative. A warning is logged when the two disagree. To prepare for a possible future change, set the same value in both your kubeconfig context and the backend config. + + With this configuration, the cluster's region is an empty string. When creating a `dstack` volume or gateway, set `region: ''` explicitly in the configuration. + + !!! warning "Migrating from legacy to `contexts`" + Switching an existing backend from the legacy mode to `contexts` is not transparent for already-provisioned resources: their region changes from an empty string to the context name, so `dstack` can no longer terminate them. Terminate all jobs, gateways, and volumes managed by the backend before changing the configuration. ??? info "Required operators" === "NVIDIA" @@ -1149,7 +1207,7 @@ projects: --8<-- "snippets/kubernetes/dstack-backend-role.yaml" ``` - Ensure you've created a ClusterRoleBinding to grant the role to the user or the service account you're using. + Ensure you've created a ClusterRoleBinding and RoleBinding to grant the roles to the user or the service account you're using. ??? info "Resources and offers" If you use ranges with [`resources`](../concepts/tasks.md#resources) (e.g. `gpu: 1..8` or `memory: 64GB..`) in fleet or run configurations, other backends collect and try all offers that satisfy the range. diff --git a/mkdocs/docs/reference/dstack.yml/volume.md b/mkdocs/docs/reference/dstack.yml/volume.md index 2675b684ee..d3f851c8c0 100644 --- a/mkdocs/docs/reference/dstack.yml/volume.md +++ b/mkdocs/docs/reference/dstack.yml/volume.md @@ -60,3 +60,5 @@ The `volume` configuration type allows creating, registering, and updating [volu show_root_heading: false backend: required: true + region: + required: true diff --git a/mkdocs/docs/reference/server/config.yml.md b/mkdocs/docs/reference/server/config.yml.md index 80e48b028e..a76edaca58 100644 --- a/mkdocs/docs/reference/server/config.yml.md +++ b/mkdocs/docs/reference/server/config.yml.md @@ -278,6 +278,18 @@ to configure [backends](../../concepts/backends.md) and other [server-level sett yq -o=json ~/.kube/config | jq -c | jq -R ``` +###### `projects[n].backends[type=kubernetes].contexts[n]` { #kubernetes-contexts data-toc-label="contexts" } + +#SCHEMA# dstack._internal.core.backends.kubernetes.models.KubernetesContextConfig + overrides: + show_root_heading: false + +###### `projects[n].backends[type=kubernetes].contexts[n].proxy_jump` { #kubernetes-contexts-proxy_jump data-toc-label="proxy_jump" } + +#SCHEMA# dstack._internal.core.backends.kubernetes.models.KubernetesProxyJumpConfig + overrides: + show_root_heading: false + ###### `projects[n].backends[type=kubernetes].proxy_jump` { #kubernetes-proxy_jump data-toc-label="proxy_jump" } #SCHEMA# dstack._internal.core.backends.kubernetes.models.KubernetesProxyJumpConfig diff --git a/scripts/merge_kubeconfigs.sh b/scripts/merge_kubeconfigs.sh new file mode 100755 index 0000000000..e7087f1aca --- /dev/null +++ b/scripts/merge_kubeconfigs.sh @@ -0,0 +1,12 @@ +#!/bin/sh +set -eu + +if [ ${#} -lt 2 ]; then + echo "usage: $(basename "${0}") PATH1 PATH2 [PATH3 ...]" >&2 + exit 1 +fi + +# Windows is not supported; on Windows a path separator is ';', not ':' +KUBECONFIG=$(IFS=':'; echo "${*}") +export KUBECONFIG +kubectl config view --raw --flatten | grep -Ev '^current-context: ' diff --git a/scripts/setup_kubernetes.py b/scripts/setup_kubernetes.py index b38896d715..22295ffebf 100644 --- a/scripts/setup_kubernetes.py +++ b/scripts/setup_kubernetes.py @@ -201,7 +201,9 @@ def generate_kubeconfig( service_account_token: str, ) -> str: logging.info("generating kubeconfig") - kubeconfig_content = kubectl.call("config", "view", "--minify", "--raw", capture_stdout=True) + kubeconfig_content = kubectl.call( + "config", "view", "--minify", "--raw", "--flatten", capture_stdout=True + ) with tempfile.NamedTemporaryFile("w+") as f: f.write(kubeconfig_content) f.flush() diff --git a/src/dstack/_internal/core/backends/kubernetes/api_client.py b/src/dstack/_internal/core/backends/kubernetes/api_client.py new file mode 100644 index 0000000000..29a915d646 --- /dev/null +++ b/src/dstack/_internal/core/backends/kubernetes/api_client.py @@ -0,0 +1,46 @@ +from typing import Optional + +from kubernetes.client.api_client import ApiClient as _BaseApiClient +from kubernetes.client.configuration import Configuration as _ClientConfiguration +from kubernetes.client.exceptions import ApiException +from kubernetes.config import load_kube_config_from_dict +from urllib3.exceptions import HTTPError + +# 30 * 2 (original request + 1 retry) = 60 seconds total +DEFAULT_REQUEST_TIMEOUT = 30 +DEFAULT_RETRIES = 1 + + +API_CLIENT_EXCEPTIONS: tuple[type[Exception], ...] = (HTTPError, ApiException) + + +class ApiClient(_BaseApiClient): + def __init__(self, *, configuration: _ClientConfiguration, request_timeout: int) -> None: + self.__request_timeout = request_timeout + super().__init__(configuration=configuration) + + def request(self, *args, **kwargs): + if kwargs.get("_request_timeout") is None: + kwargs["_request_timeout"] = self.__request_timeout + return super().request(*args, **kwargs) # pyright: ignore[reportAttributeAccessIssue] + + +def get_api_client_from_kubeconfig_dict( + kubeconfig_dict: dict, + *, + context: str, + request_timeout: Optional[int] = None, + retries: Optional[int] = None, +) -> ApiClient: + if request_timeout is None: + request_timeout = DEFAULT_REQUEST_TIMEOUT + if retries is None: + retries = DEFAULT_RETRIES + client_configuration = _ClientConfiguration() + client_configuration.retries = retries # pyright: ignore[reportAttributeAccessIssue] + load_kube_config_from_dict( + config_dict=kubeconfig_dict, + context=context, + client_configuration=client_configuration, + ) + return ApiClient(configuration=client_configuration, request_timeout=request_timeout) diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 062e458b4b..7a8979ed34 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -1,3 +1,4 @@ +import concurrent.futures import random import shlex import subprocess @@ -28,10 +29,8 @@ get_dstack_gateway_commands, merge_tags, ) -from dstack._internal.core.backends.kubernetes.models import ( - KubernetesConfig, - KubernetesProxyJumpConfig, -) +from dstack._internal.core.backends.kubernetes.api_client import API_CLIENT_EXCEPTIONS +from dstack._internal.core.backends.kubernetes.models import KubernetesConfig from dstack._internal.core.backends.kubernetes.resources import ( AMD_GPU_DEVICE_ID_LABEL_PREFIX, AMD_GPU_NAME_TO_DEVICE_IDS, @@ -60,15 +59,16 @@ parse_quantity, ) from dstack._internal.core.backends.kubernetes.utils import ( + LEGACY_CURRENT_CONTEXT_REGION, + Cluster, + SkipOfferCache, call_api_method, - get_api_from_kubeconfig_dict, - kubeconfig_data_to_kubeconfig_dict, - kubeconfig_dict_to_kubeconfig, + get_clusters_from_backend_config, try_delete_object_if_exists, watch_events, ) from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT -from dstack._internal.core.errors import ComputeError, ProvisioningError +from dstack._internal.core.errors import ComputeError, ProvisioningError, SkipOffer from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel from dstack._internal.core.models.gateways import ( @@ -136,39 +136,34 @@ class KubernetesCompute( ): def __init__(self, config: KubernetesConfig): super().__init__() - self.config = config.copy() - proxy_jump = self.config.proxy_jump - if proxy_jump is None: - proxy_jump = KubernetesProxyJumpConfig() - self.proxy_jump = proxy_jump - kubeconfig_dict = kubeconfig_data_to_kubeconfig_dict(config.kubeconfig.data) - self.api = get_api_from_kubeconfig_dict(kubeconfig_dict) - kubeconfig = kubeconfig_dict_to_kubeconfig(kubeconfig_dict) - current_context = kubeconfig.get_context() - if current_context.namespace != config.namespace: - logger.warning( - ( - "Namespace mismatch: kubeconfig -> '%s', backend config -> '%s'." - " The current dstack version ignores kubeconfig" - " and uses deprecated namespace property from backend config." - " Future versions will use namespace from kubeconfig." - " To keep using '%s' namespace in future versions and suppress this warning," - " set namespace to '%s' in kubeconfig context '%s'" - ), - current_context.namespace, - config.namespace, - config.namespace, - config.namespace, - kubeconfig.current_context, - ) - # TODO: switch to current_context.namespace - self.namespace = config.namespace - logger.debug("Using namespace '%s'", self.namespace) + self.region_cluster_map = {c.region: c for c in get_clusters_from_backend_config(config)} + self.skip_offer_cache = SkipOfferCache(ttl=60) def get_offers_by_requirements( self, requirements: Requirements ) -> list[InstanceOfferWithAvailability]: - return get_instance_offers(self.api, requirements) + offers: list[InstanceOfferWithAvailability] = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: + future_cluster_map: dict[ + concurrent.futures.Future[list[InstanceOfferWithAvailability]], Cluster + ] = {} + for region, cluster in self.region_cluster_map.items(): + api = client.CoreV1Api(cluster.api_client) + future = executor.submit(get_instance_offers, api, region, requirements) + future_cluster_map[future] = cluster + for future in concurrent.futures.as_completed(future_cluster_map): + try: + cluster_offers = future.result() + except API_CLIENT_EXCEPTIONS as e: + logger.warning( + "Failed to get offers from cluster %s: %s: %s", + future_cluster_map[future], + e.__class__.__name__, + e, + ) + continue + offers.extend(cluster_offers) + return offers def run_job( self, @@ -180,8 +175,13 @@ def run_job( volumes: list[Volume], placement_group: Optional[PlacementGroup], ) -> JobProvisioningData: - api = self.api - namespace = self.namespace + cluster = self.region_cluster_map.get(instance_offer.region) + if cluster is None: + raise ComputeError(f"Unknown region: {instance_offer.region!r}") + if self.skip_offer_cache.check(run, job, instance_offer): + raise SkipOffer(f"cluster {cluster} has recently failed to schedule a similar job") + api = client.CoreV1Api(cluster.api_client) + namespace = cluster.namespace # There is one jump pod per project that is used as an ssh proxy jump to connect # to all job pods of the same project. @@ -193,7 +193,7 @@ def run_job( namespace=namespace, jump_pod_name=jump_pod_name, jump_pod_service_name=jump_pod_service_name, - jump_pod_port=self.proxy_jump.port, + jump_pod_port=cluster.proxy_jump.port, project_ssh_public_key=project_ssh_public_key.strip(), ) @@ -246,6 +246,7 @@ def run_job( timeout_seconds=JOB_POD_SCHEDULING_TIMEOUT, ) if not is_pod_scheduled_or_finished: + self.skip_offer_cache.add(run, job, instance_offer) reason, message = _get_unscheduled_pod_reason_message( api=api, namespace=namespace, @@ -256,6 +257,7 @@ def run_job( f" {reason or 'unknown reason'}: {message or 'no message'}" ) if pod_phase is not None and pod_phase.is_finished(): + # It's not clear if we should add an entry to the SkipOfferCache in this case. raise ComputeError(f"Pod {pod_name} already finished: {pod_phase}") pod_service_name = _get_pod_service_name(pod_name) @@ -316,15 +318,21 @@ def update_provisioning_data( project_ssh_public_key: str, project_ssh_private_key: str, ): + cluster = self.region_cluster_map.get(provisioning_data.region) + if cluster is None: + raise ProvisioningError(f"Unknown region: {provisioning_data.region!r}") + api = client.CoreV1Api(cluster.api_client) + namespace = cluster.namespace + if provisioning_data.backend_data is not None: # Before running a job, ensure the jump pod is running and has user's public SSH key. backend_data = KubernetesBackendData.load(provisioning_data.backend_data) ssh_proxy = _check_and_configure_jump_pod_service( - api=self.api, - namespace=self.namespace, + api=api, + namespace=namespace, jump_pod_name=backend_data.jump_pod_name, jump_pod_service_name=backend_data.jump_pod_service_name, - jump_pod_hostname=self.proxy_jump.hostname, + jump_pod_hostname=cluster.proxy_jump.hostname, project_ssh_private_key=project_ssh_private_key, user_ssh_public_key=backend_data.user_ssh_public_key, ) @@ -336,9 +344,9 @@ def update_provisioning_data( # in case update_provisioning_data() is called again. provisioning_data.backend_data = None - pod = self.api.read_namespaced_pod( + pod = api.read_namespaced_pod( name=provisioning_data.instance_id, - namespace=self.namespace, + namespace=namespace, ) if pod.status is None: return @@ -346,19 +354,20 @@ def update_provisioning_data( if not pod_ip: return provisioning_data.internal_ip = pod_ip - service = self.api.read_namespaced_service( + service = api.read_namespaced_service( name=_get_pod_service_name(provisioning_data.instance_id), - namespace=self.namespace, + namespace=namespace, ) service_spec = get_or_error(service.spec) provisioning_data.hostname = get_or_error(service_spec.cluster_ip) pod_spec = get_or_error(pod.spec) - node = self.api.read_node(name=get_or_error(pod_spec.node_name)) + node = api.read_node(name=get_or_error(pod_spec.node_name)) # In the original offer, the resources have already been adjusted according to # the run configuration resource requirements, see get_offers_by_requirements() original_resources = provisioning_data.instance_type.resources instance_offer = get_instance_offer_from_node( node=node, + region=cluster.region, cpu_request=original_resources.cpus, memory_mib_request=original_resources.memory_mib, gpu_request=len(original_resources.gpus), @@ -366,14 +375,30 @@ def update_provisioning_data( ) if instance_offer is not None: provisioning_data.instance_type = instance_offer.instance - provisioning_data.region = instance_offer.region provisioning_data.price = instance_offer.price def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None ): - api = self.api - namespace = self.namespace + cluster = self.region_cluster_map.get(region) + if cluster is None and region == "-": + # legacy DUMMY_REGION + cluster = self.region_cluster_map.get(LEGACY_CURRENT_CONTEXT_REGION) + if cluster is not None: + logger.warning( + ( + "Terminating instance %s in unknown region %s." + " Assuming it was created before multi-cluster support was added" + " and is located in cluster %s" + ), + instance_id, + repr(region), + cluster, + ) + if cluster is None: + raise ComputeError(f"Unknown region: {region!r}") + api = client.CoreV1Api(cluster.api_client) + namespace = cluster.namespace deleted = [ try_delete_object_if_exists( api.delete_namespaced_service, @@ -401,6 +426,12 @@ def create_gateway( self, configuration: GatewayComputeConfiguration, ) -> GatewayProvisioningData: + cluster = self.region_cluster_map.get(configuration.region) + if cluster is None: + raise ComputeError(f"Unknown region: {configuration.region!r}") + api = client.CoreV1Api(cluster.api_client) + namespace = cluster.namespace + # Gateway creation is currently limited to Kubernetes with Load Balancer support. # If the cluster does not support Load Balancer, the service will be provisioned but # the external IP/hostname will never be allocated. @@ -448,8 +479,8 @@ def create_gateway( ] ), ) - self.api.create_namespaced_pod( - namespace=self.namespace, + api.create_namespaced_pod( + namespace=namespace, body=pod, ) service = client.V1Service( @@ -478,18 +509,18 @@ def create_gateway( ], ), ) - self.api.create_namespaced_service( - namespace=self.namespace, + api.create_namespaced_service( + namespace=namespace, body=service, ) # address is eiher a domain name or an IP address address = _wait_for_load_balancer_address( - api=self.api, - namespace=self.namespace, + api=api, + namespace=namespace, service_name=_get_pod_service_name(instance_name), ) if address is None: - self.terminate_instance(instance_name, region="") + self.terminate_instance(instance_name, region=configuration.region) raise ComputeError( "Failed to get gateway hostname. " "Ensure the Kubernetes cluster supports Load Balancer services." @@ -497,7 +528,7 @@ def create_gateway( return GatewayProvisioningData( instance_id=instance_name, ip_address=address, - region="", + region=cluster.region, ) def terminate_gateway( @@ -506,21 +537,50 @@ def terminate_gateway( configuration: GatewayComputeConfiguration, backend_data: Optional[str] = None, ): + region = configuration.region + cluster = self.region_cluster_map.get(region) + if cluster is None: + # It may be a legacy configuration with the region set to an arbitrary value + cluster = self.region_cluster_map.get(LEGACY_CURRENT_CONTEXT_REGION) + if cluster is not None: + logger.warning( + ( + "Terminating gateway %s in unknown region %s." + " Assuming it was created before multi-cluster support was added" + " and is located in cluster %s" + ), + instance_id, + repr(region), + cluster, + ) + region = LEGACY_CURRENT_CONTEXT_REGION + else: + raise ComputeError(f"Unknown region: {region!r}") self.terminate_instance( instance_id=instance_id, - region=configuration.region, + region=region, backend_data=backend_data, ) def register_volume(self, volume: Volume) -> VolumeProvisioningData: assert isinstance(volume.configuration, KubernetesVolumeConfiguration) + + region = volume.configuration.region + cluster = self.region_cluster_map.get(region) + if cluster is None: + if region == "": + raise ComputeError("region is not set") + raise ComputeError(f"Unknown region: {region!r}") + api = client.CoreV1Api(cluster.api_client) + namespace = cluster.namespace + pvc_name = volume.configuration.claim_name assert pvc_name is not None pvc = call_api_method( - self.api.read_namespaced_persistent_volume_claim, + api.read_namespaced_persistent_volume_claim, expected=404, - namespace=self.namespace, + namespace=namespace, name=pvc_name, ) if pvc is None: @@ -548,7 +608,15 @@ def register_volume(self, volume: Volume) -> VolumeProvisioningData: def create_volume(self, volume: Volume) -> VolumeProvisioningData: assert isinstance(volume.configuration, KubernetesVolumeConfiguration) - assert volume.configuration.size is not None + + region = volume.configuration.region + cluster = self.region_cluster_map.get(region) + if cluster is None: + if region == "": + raise ComputeError("region is not set") + raise ComputeError(f"Unknown region: {region!r}") + api = client.CoreV1Api(cluster.api_client) + namespace = cluster.namespace labels = { format_dstack_label_key("owner"): "dstack", @@ -562,6 +630,8 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData: ) labels = filter_invalid_labels(labels) + assert volume.configuration.size is not None + pvc_name = generate_unique_volume_name(volume, max_length=OBJECT_NAME_MAX_LENGTH) pvc = client.V1PersistentVolumeClaim( metadata=client.V1ObjectMeta( @@ -578,8 +648,8 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData: ), ), ) - self.api.create_namespaced_persistent_volume_claim( - namespace=self.namespace, + api.create_namespaced_persistent_volume_claim( + namespace=namespace, body=pvc, ) logger.debug("Created PVC %s for volume %s", pvc_name, volume.name) @@ -594,13 +664,21 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData: def delete_volume(self, volume: Volume): assert isinstance(volume.configuration, KubernetesVolumeConfiguration) + + region = volume.configuration.region + cluster = self.region_cluster_map.get(region) + if cluster is None: + raise ComputeError(f"Unknown region: {region!r}") + api = client.CoreV1Api(cluster.api_client) + namespace = cluster.namespace + pvc_name = volume.volume_id assert pvc_name is not None pvc = call_api_method( - self.api.delete_namespaced_persistent_volume_claim, + api.delete_namespaced_persistent_volume_claim, expected=404, - namespace=self.namespace, + namespace=namespace, name=pvc_name, ) if pvc is None: @@ -1170,11 +1248,15 @@ def _wait_for_pod_scheduled_or_finished( # the scheduler confirmed capacity and that the assigned node is actually Ready and # working on the pod. pod_phase: Optional[PodPhase] = None + # Ensure that API's timeoutSeconds fires earlier than the network timeout, which defaults to + # our custom ApiClient's constructor parameter, see DEFAULT_REQUEST_TIMEOUT + request_timeout = timeout_seconds + 5 with watch_events( api.list_namespaced_pod, namespace=namespace, field_selector=f"metadata.name={pod_name}", timeout_seconds=timeout_seconds, + _request_timeout=request_timeout, ) as event_iter: for _, pod in event_iter: pod_status = pod.status diff --git a/src/dstack/_internal/core/backends/kubernetes/configurator.py b/src/dstack/_internal/core/backends/kubernetes/configurator.py index 9753294b11..b8872f0211 100644 --- a/src/dstack/_internal/core/backends/kubernetes/configurator.py +++ b/src/dstack/_internal/core/backends/kubernetes/configurator.py @@ -3,7 +3,6 @@ Configurator, raise_invalid_credentials_error, ) -from dstack._internal.core.backends.kubernetes import utils as kubernetes_utils from dstack._internal.core.backends.kubernetes.backend import KubernetesBackend from dstack._internal.core.backends.kubernetes.models import ( KubernetesBackendConfig, @@ -11,6 +10,11 @@ KubernetesConfig, KubernetesStoredConfig, ) +from dstack._internal.core.backends.kubernetes.utils import ( + check_cluster, + get_clusters_from_backend_config, +) +from dstack._internal.core.errors import ServerClientError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.utils.logging import get_logger @@ -29,12 +33,17 @@ class KubernetesConfigurator( def validate_config( self, config: KubernetesBackendConfigWithCreds, default_creds_enabled: bool ): + self._check_config_contexts(config) try: - api = kubernetes_utils.get_api_from_kubeconfig_data(config.kubeconfig.data) - api.list_node() + clusters = get_clusters_from_backend_config(config, request_timeout=10, retries=0) except Exception as e: - logger.debug("Invalid kubeconfig: %s", str(e)) - raise_invalid_credentials_error(fields=[["kubeconfig"]]) + raise ServerClientError(str(e)) + for cluster in clusters: + if not check_cluster(cluster): + raise_invalid_credentials_error( + fields=[["kubeconfig"]], + details=f"Failed to validate cluster {cluster}", + ) def create_backend( self, project_name: str, config: KubernetesBackendConfigWithCreds @@ -59,3 +68,24 @@ def get_backend(self, record: BackendRecord) -> KubernetesBackend: def _get_config(self, record: BackendRecord) -> KubernetesConfig: return KubernetesConfig.__response__.parse_raw(record.config) + + def _check_config_contexts(self, config: KubernetesBackendConfig): + if config.contexts is None: + return + if config.proxy_jump is not None: + raise ServerClientError("proxy_jump must not be set if contexts is set") + if config.namespace is not None: + raise ServerClientError("namespace must not be set if contexts is set") + seen: set[str] = set() + duplicates: set[str] = set() + for context in config.contexts: + if isinstance(context, str): + name = context + else: + name = context.name + if name in seen: + duplicates.add(name) + else: + seen.add(name) + if duplicates: + raise ServerClientError(f"duplicate contexts: {', '.join(sorted(duplicates))}") diff --git a/src/dstack/_internal/core/backends/kubernetes/models.py b/src/dstack/_internal/core/backends/kubernetes/models.py index bb1609733e..eb92982e45 100644 --- a/src/dstack/_internal/core/backends/kubernetes/models.py +++ b/src/dstack/_internal/core/backends/kubernetes/models.py @@ -5,8 +5,6 @@ from dstack._internal.core.backends.base.models import fill_data from dstack._internal.core.models.common import CoreModel -DEFAULT_NAMESPACE = "default" - class KubernetesProxyJumpConfig(CoreModel): hostname: Annotated[ @@ -17,6 +15,13 @@ class KubernetesProxyJumpConfig(CoreModel): ] = None +class KubernetesContextConfig(CoreModel): + name: Annotated[str, Field(description="The name of the context")] + proxy_jump: Annotated[ + Optional[KubernetesProxyJumpConfig], Field(description="The SSH proxy jump configuration") + ] = None + + class KubeconfigConfig(CoreModel): filename: Annotated[str, Field(description="The path to the kubeconfig file")] = "" data: Annotated[str, Field(description="The contents of the kubeconfig file")] @@ -24,22 +29,45 @@ class KubeconfigConfig(CoreModel): class KubernetesBackendConfig(CoreModel): type: Annotated[Literal["kubernetes"], Field(description="The type of backend")] = "kubernetes" + contexts: Annotated[ + Optional[list[Union[KubernetesContextConfig, str]]], + Field( + description=( + "Enabled contexts (clusters). Each context should map to a separate cluster." + " The context name becomes the region name." + " If `contexts` is set, top-level `proxy_jump` and `namespace` must not be set." + " `proxy_jump`, if necessary, should be configured per-context;" + " `namespace` is taken from the corresponding kubeconfig context's property." + " If `contexts` is not set (not recommended), the kubeconfig's `current-context`" + " is used as the only context, with an empty string as the region name" + ), + ), + ] = None proxy_jump: Annotated[ - Optional[KubernetesProxyJumpConfig], Field(description="The SSH proxy jump configuration") + Optional[KubernetesProxyJumpConfig], + Field( + description=( + "Only used if `contexts` is not set; must not be set otherwise." + " The SSH proxy jump configuration" + ), + ), ] = None namespace: Annotated[ - str, + Optional[str], Field( description=( - "The namespace for resources managed by `dstack`." - " Always overrides the namespace set in the kubeconfig, even if not set. " - " Deprecated and will be eventually removed in futute versions, but" - " in the current version must be set unless equals to `default`." + "Only used if `contexts` is not set; must not be set otherwise." + " The namespace for resources managed by `dstack`." + " If `contexts` is not set, overrides the namespace set in the kubeconfig," + " even if not set. Defaults to `default`." + " Deprecated; will eventually be removed in future versions," + " but in the current version must be set if `contexts` is not set and the value" + " is not equal to `default`." " Future versions will use the namespace from the kubeconfig instead." " To prepare for future versions, set the same value in the kubeconfig" ) ), - ] = DEFAULT_NAMESPACE + ] = None """`namespace` is formally deprecated since 0.20.20 but still used. Future versions will switch to namespace from kubeconfig context, which is currently ignored""" diff --git a/src/dstack/_internal/core/backends/kubernetes/resources.py b/src/dstack/_internal/core/backends/kubernetes/resources.py index fa29d7b513..4d9d266ae2 100644 --- a/src/dstack/_internal/core/backends/kubernetes/resources.py +++ b/src/dstack/_internal/core/backends/kubernetes/resources.py @@ -240,7 +240,7 @@ def is_taint_tolerated(taint: V1Taint) -> bool: def get_instance_offers( - api: CoreV1Api, requirements: Requirements + api: CoreV1Api, region: str, requirements: Requirements ) -> list[InstanceOfferWithAvailability]: resources_spec = requirements.resources assert isinstance(resources_spec.cpu, CPUSpec) @@ -262,6 +262,7 @@ def get_instance_offers( node=node, node_name=node_name, node_allocated_resources=nodes_allocated_resources.get(node_name), + region=region, cpu_request=cpu_request, memory_mib_request=memory_mib_request, gpu_request=gpu_request, @@ -275,6 +276,7 @@ def get_instance_offers( def get_instance_offer_from_node( node: V1Node, *, + region: str, cpu_request: int, memory_mib_request: int, gpu_request: int, @@ -287,6 +289,7 @@ def get_instance_offer_from_node( node=node, node_name=node_name, node_allocated_resources=None, + region=region, cpu_request=cpu_request, memory_mib_request=memory_mib_request, gpu_request=gpu_request, @@ -342,6 +345,7 @@ def _get_instance_offer_from_node( node: V1Node, node_name: str, node_allocated_resources: Optional[KubernetesResources], + region: str, cpu_request: int, memory_mib_request: int, gpu_request: int, @@ -384,7 +388,7 @@ def _get_instance_offer_from_node( ), ), price=0, - region="", + region=region, availability=InstanceAvailability.AVAILABLE, instance_runtime=InstanceRuntime.RUNNER, ) diff --git a/src/dstack/_internal/core/backends/kubernetes/utils.py b/src/dstack/_internal/core/backends/kubernetes/utils.py index 054e48eba3..a90f750f1b 100644 --- a/src/dstack/_internal/core/backends/kubernetes/utils.py +++ b/src/dstack/_internal/core/backends/kubernetes/utils.py @@ -1,5 +1,6 @@ from collections.abc import Generator from contextlib import contextmanager +from dataclasses import dataclass from typing import ( Annotated, Any, @@ -12,20 +13,28 @@ Union, cast, ) +from uuid import UUID import yaml -from kubernetes.client import CoreV1Api, V1Status +from cachetools import TTLCache +from kubernetes.client import V1Status, VersionApi from kubernetes.client.exceptions import ApiException -from kubernetes.config import ( - # XXX: This function is missing in the stubs package - new_client_from_config_dict, # pyright: ignore[reportAttributeAccessIssue] -) from kubernetes.watch import Watch from pydantic import Field from typing_extensions import ParamSpec, TypedDict -from urllib3.exceptions import HTTPError +from dstack._internal.core.backends.kubernetes.api_client import ( + API_CLIENT_EXCEPTIONS, + ApiClient, + get_api_client_from_kubeconfig_dict, +) +from dstack._internal.core.backends.kubernetes.models import ( + KubernetesBackendConfigWithCreds, + KubernetesProxyJumpConfig, +) from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.instances import InstanceOffer +from dstack._internal.core.models.runs import Job, Run from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -34,6 +43,122 @@ P = ParamSpec("P") +LEGACY_CURRENT_CONTEXT_REGION = "" + + +@dataclass +class Cluster: + context_name: str + region: str + api_client: ApiClient + namespace: str + proxy_jump: KubernetesProxyJumpConfig + + def __str__(self) -> str: + parts: list[str] = [] + parts.append(f"context={self.context_name!r}") + if self.context_name != self.region: + parts.append(f"region={self.region!r}") + return f"({' '.join(parts)})" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}{self}" + + +def check_cluster(cluster: Cluster) -> bool: + version_api = VersionApi(cluster.api_client) + try: + version_info = version_api.get_code() + except API_CLIENT_EXCEPTIONS as e: + logger.debug("cluster %s check failed: %s: %s", cluster, e.__class__.__name__, e) + return False + logger.debug("cluster %s gitVersion: %s", cluster, version_info.git_version) + return True + + +def get_clusters_from_backend_config( + config: KubernetesBackendConfigWithCreds, + *, + request_timeout: Optional[int] = None, + retries: Optional[int] = None, +) -> list[Cluster]: + clusters: list[Cluster] = [] + kubeconfig_dict = kubeconfig_data_to_kubeconfig_dict(config.kubeconfig.data) + kubeconfig = kubeconfig_dict_to_kubeconfig(kubeconfig_dict) + if config.contexts is not None: + for context in config.contexts: + if isinstance(context, str): + context_name = context + proxy_jump = None + else: + context_name = context.name + proxy_jump = context.proxy_jump + kubeconfig_context = kubeconfig.get_context(context_name) + api_client = get_api_client_from_kubeconfig_dict( + kubeconfig_dict, + context=context_name, + request_timeout=request_timeout, + retries=retries, + ) + namespace = kubeconfig_context.namespace + if proxy_jump is None: + proxy_jump = KubernetesProxyJumpConfig() + clusters.append( + Cluster( + context_name=context_name, + region=context_name, + api_client=api_client, + namespace=namespace, + proxy_jump=proxy_jump, + ) + ) + else: + current_kubeconfig_context = kubeconfig.get_context() + context_name = kubeconfig.current_context + # Already checked by Kubeconfig.get_context() + assert context_name is not None + api_client = get_api_client_from_kubeconfig_dict( + kubeconfig_dict, + context=context_name, + request_timeout=request_timeout, + retries=retries, + ) + config_namespace = config.namespace + if config_namespace is None: + config_namespace = "default" + context_namespace = current_kubeconfig_context.namespace + if context_namespace != config_namespace: + logger.warning( + ( + "Namespace mismatch: kubeconfig -> '%s', backend config -> '%s'." + " The current dstack version ignores kubeconfig" + " and uses deprecated namespace property from backend config." + " Future versions will use namespace from kubeconfig." + " To keep using '%s' namespace in future versions and suppress this warning," + " set namespace to '%s' in kubeconfig context '%s'" + ), + context_namespace, + config_namespace, + config_namespace, + config_namespace, + context_name, + ) + proxy_jump = config.proxy_jump + if proxy_jump is None: + proxy_jump = KubernetesProxyJumpConfig() + clusters.append( + Cluster( + context_name=context_name, + region=LEGACY_CURRENT_CONTEXT_REGION, + api_client=api_client, + # TODO: switch to context_namespace + namespace=config_namespace, + proxy_jump=proxy_jump, + ) + ) + return clusters + + class KubeconfigContext(CoreModel): namespace: str = "default" @@ -74,18 +199,28 @@ def kubeconfig_dict_to_kubeconfig(kubeconfig_dict: dict) -> Kubeconfig: return Kubeconfig.__response__.parse_obj(kubeconfig_dict) -def get_api_from_kubeconfig_data( - kubeconfig_data: str, *, context: Optional[str] = None -) -> CoreV1Api: - kubeconfig_dict = kubeconfig_data_to_kubeconfig_dict(kubeconfig_data) - return get_api_from_kubeconfig_dict(kubeconfig_dict, context=context) +class SkipOfferCache: + """ + `SkipOfferCache` is used to track (run/job, offer) pairs that failed to provision. + + The current implementation tracks _any_ job of the specific run (identified by `Run.id`) + on the specific cluster (identified by `InstanceOffer.region`, that is, a kubeconfig context). + """ + + def __init__(self, *, ttl: int, maxsize: int = 1000) -> None: + self._cache = TTLCache[tuple[UUID, str], Literal[True]](maxsize=maxsize, ttl=ttl) + + def add(self, run: Run, job: Job, offer: InstanceOffer) -> None: + self._cache[self._build_key(run, job, offer)] = True + def check(self, run: Run, job: Job, offer: InstanceOffer) -> bool: + return self._build_key(run, job, offer) in self._cache -def get_api_from_kubeconfig_dict( - kubeconfig_dict: dict, *, context: Optional[str] = None -) -> CoreV1Api: - api_client = new_client_from_config_dict(config_dict=kubeconfig_dict, context=context) - return CoreV1Api(api_client=api_client) + def _build_key(self, run: Run, job: Job, offer: InstanceOffer) -> tuple[UUID, str]: + # The current implementation uses only Run.id ignoring the job/job spec. + # A more sophisticated implementation could use some parts of the job spec + # (e.g., requirements, volumes) instead. + return (run.id, offer.region) def call_api_method( @@ -136,7 +271,7 @@ def try_delete_object_if_exists( namespace=namespace, name=name, ) - except (HTTPError, ApiException) as e: + except API_CLIENT_EXCEPTIONS as e: if should_delete_manually_if_failed: logger.exception( "Failed to delete %s %s in namespace %s. Please delete it manually", diff --git a/src/dstack/_internal/core/compatibility/volumes.py b/src/dstack/_internal/core/compatibility/volumes.py index 44ee882511..a0afabf1c6 100644 --- a/src/dstack/_internal/core/compatibility/volumes.py +++ b/src/dstack/_internal/core/compatibility/volumes.py @@ -33,7 +33,10 @@ def _get_volume_configuration_excludes( ) -> IncludeExcludeDictType: configuration_excludes: IncludeExcludeDictType = {} - if isinstance(configuration, KubernetesVolumeConfiguration) and not configuration.read_only: - configuration_excludes["read_only"] = True + if isinstance(configuration, KubernetesVolumeConfiguration): + if not configuration.read_only: + configuration_excludes["read_only"] = True + if configuration.region == "": + configuration_excludes["region"] = True return configuration_excludes diff --git a/src/dstack/_internal/core/errors.py b/src/dstack/_internal/core/errors.py index 5182d063b2..e2685a64e4 100644 --- a/src/dstack/_internal/core/errors.py +++ b/src/dstack/_internal/core/errors.py @@ -128,6 +128,19 @@ def __init__(self, details: str) -> None: return super().__init__(details) +class SkipOffer(ComputeError): + """ + Used by Compute.run_job and Compute.create_instance to signal that the offer should be skipped. + """ + + def __init__(self, details: str) -> None: + """ + Args: + details: details about why the offer should be skipped + """ + return super().__init__(details) + + class CLIError(DstackError): pass diff --git a/src/dstack/_internal/core/models/volumes.py b/src/dstack/_internal/core/models/volumes.py index 03d52d2ac9..1b96331903 100644 --- a/src/dstack/_internal/core/models/volumes.py +++ b/src/dstack/_internal/core/models/volumes.py @@ -134,10 +134,12 @@ class RunpodVolumeConfiguration(VolumeConfigurationWithRegion, VolumeConfigurati """Runpod doesn't have AZs but we accept this field for compatibility with older clients.""" -class KubernetesVolumeConfiguration(BaseVolumeConfiguration): +class KubernetesVolumeConfiguration(VolumeConfigurationWithRegion): backend: Annotated[ Literal[BackendType.KUBERNETES], Field(description="The volume backend") ] = BackendType.KUBERNETES + region: Annotated[str, Field(description="The volume region (cluster)")] = "" + """`region` uses a default value for backward compatibility.""" size: Annotated[ Optional[Memory], Field( diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py index b1ceac47d6..0673b6b3c1 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py @@ -20,6 +20,7 @@ from dstack._internal.core.errors import ( BackendError, PlacementGroupNotSupportedError, + SkipOffer, ) from dstack._internal.core.models.instances import ( InstanceOfferWithAvailability, @@ -114,8 +115,15 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: include_only_create_instance_supported_backends=True, ) + offers_iter = iter(offers) + offers_tried = 0 # Limit number of offers tried to prevent long-running processing in case all offers fail. - for backend, instance_offer in offers[: server_settings.MAX_OFFERS_TRIED]: + while offers_tried < server_settings.MAX_OFFERS_TRIED: + backend_with_instance_offer = next(offers_iter, None) + if backend_with_instance_offer is None: + break + backend, instance_offer = backend_with_instance_offer + if instance_offer.backend not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT: continue compute = backend.compute() @@ -159,6 +167,7 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: instance_offer.region, instance_offer.price, ) + offers_tried += 1 try: job_provisioning_data = await run_async( compute.create_instance, @@ -166,6 +175,17 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: instance_configuration, placement_group_model_to_placement_group_optional(placement_group_model), ) + except SkipOffer as exc: + offers_tried -= 1 + logger.info( + "%s launch in %s/%s skipped: %s", + instance_offer.instance.name, + instance_offer.backend.value, + instance_offer.region, + exc, + extra={"instance_name": instance_model.name}, + ) + continue except BackendError as exc: logger.warning( "%s launch in %s/%s failed: %s", diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py index 8ecdc0e030..2161eb8958 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py @@ -20,7 +20,7 @@ BACKENDS_WITH_GROUP_PROVISIONING_SUPPORT, BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT, ) -from dstack._internal.core.errors import BackendError, ServerClientError +from dstack._internal.core.errors import BackendError, ServerClientError, SkipOffer from dstack._internal.core.models.common import NetworkMode from dstack._internal.core.models.compute_groups import ( ComputeGroupProvisioningData, @@ -2120,8 +2120,13 @@ async def _provision_new_capacity( instance_mounts=check_run_spec_requires_instance_mounts(run.run_spec), placement_group=placement_group_model_to_placement_group_optional(placement_group_model), ) + offers_iter = iter(offers) offers_tried = 0 - for backend, offer in offers[: settings.MAX_OFFERS_TRIED]: + while offers_tried < settings.MAX_OFFERS_TRIED: + backend_with_offer = next(offers_iter, None) + if backend_with_offer is None: + break + backend, offer = backend_with_offer logger.debug( "%s: trying %s in %s/%s for $%0.4f per hour", fmt(job_model), @@ -2214,6 +2219,17 @@ async def _provision_new_capacity( new_placement_group_models=new_placement_group_models, ), ) + except SkipOffer as e: + offers_tried -= 1 + logger.info( + "%s: %s launch in %s/%s skipped: %s", + fmt(job_model), + offer.instance.name, + offer.backend.value, + offer.region, + e, + ) + continue except BackendError as e: logger.warning( "%s: %s launch in %s/%s failed: %s", diff --git a/src/tests/_internal/core/backends/kubernetes/test_configurator.py b/src/tests/_internal/core/backends/kubernetes/test_configurator.py index 2c5e665d53..36f32a08b9 100644 --- a/src/tests/_internal/core/backends/kubernetes/test_configurator.py +++ b/src/tests/_internal/core/backends/kubernetes/test_configurator.py @@ -1,43 +1,95 @@ -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest -from dstack._internal.core.backends.kubernetes.configurator import ( - KubernetesConfigurator, -) +from dstack._internal.core.backends.kubernetes.configurator import KubernetesConfigurator from dstack._internal.core.backends.kubernetes.models import ( KubeconfigConfig, KubernetesBackendConfigWithCreds, + KubernetesContextConfig, KubernetesProxyJumpConfig, ) -from dstack._internal.core.errors import BackendInvalidCredentialsError +from dstack._internal.core.errors import ServerClientError + + +@pytest.fixture +def get_clusters_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=[]) + monkeypatch.setattr( + "dstack._internal.core.backends.kubernetes.configurator.get_clusters_from_backend_config", + mock, + ) + return mock class TestKubernetesConfigurator: - def test_validate_config_valid(self): + @pytest.mark.usefixtures("get_clusters_mock") + def test_validate_config_valid_current_context(self): + config = KubernetesBackendConfigWithCreds( + kubeconfig=KubeconfigConfig(data="mocked", filename="-"), + proxy_jump=KubernetesProxyJumpConfig(hostname=None, port=30022), + namespace="ns", + ) + KubernetesConfigurator().validate_config(config, default_creds_enabled=True) + + @pytest.mark.usefixtures("get_clusters_mock") + def test_validate_config_valid_explicit_contexts(self): config = KubernetesBackendConfigWithCreds( - kubeconfig=KubeconfigConfig(data="valid", filename="-"), - proxy_jump=KubernetesProxyJumpConfig(hostname=None, port=None), + kubeconfig=KubeconfigConfig(data="mocked", filename="-"), + contexts=["ctx"], ) - with patch( - "dstack._internal.core.backends.kubernetes.utils.get_api_from_kubeconfig_data" - ) as get_api_mock: - api_mock = Mock() - api_mock.list_node.return_value = Mock() - get_api_mock.return_value = api_mock + KubernetesConfigurator().validate_config(config, default_creds_enabled=True) + + @pytest.mark.usefixtures("get_clusters_mock") + def test_validate_config_contexts_proxy_jump_mutually_exclusive(self): + config = KubernetesBackendConfigWithCreds( + kubeconfig=KubeconfigConfig(data="mocked", filename="-"), + proxy_jump=KubernetesProxyJumpConfig(hostname=None, port=30022), + contexts=["ctx"], + ) + with pytest.raises(ServerClientError, match="proxy_jump must not be set"): KubernetesConfigurator().validate_config(config, default_creds_enabled=True) - def test_validate_config_invalid_config(self): + @pytest.mark.usefixtures("get_clusters_mock") + def test_validate_config_contexts_namespace_mutually_exclusive(self): config = KubernetesBackendConfigWithCreds( - kubeconfig=KubeconfigConfig(data="invalid", filename="-"), - proxy_jump=KubernetesProxyJumpConfig(hostname=None, port=None), + kubeconfig=KubeconfigConfig(data="mocked", filename="-"), + namespace="ns", + contexts=["ctx"], + ) + with pytest.raises(ServerClientError, match="namespace must not be set"): + KubernetesConfigurator().validate_config(config, default_creds_enabled=True) + + @pytest.mark.usefixtures("get_clusters_mock") + def test_validate_config_duplicate_contexts(self): + config = KubernetesBackendConfigWithCreds( + kubeconfig=KubeconfigConfig(data="mocked", filename="-"), + contexts=[ + "ctx-3", + KubernetesContextConfig(name="ctx-4"), + "ctx-1", + KubernetesContextConfig(name="ctx-1"), + "ctx-2", + KubernetesContextConfig(name="ctx-3"), + ], + ) + with pytest.raises(ServerClientError, match="duplicate contexts: ctx-1, ctx-3"): + KubernetesConfigurator().validate_config(config, default_creds_enabled=True) + + def test_validate_config_cluster_check_failed( + self, monkeypatch: pytest.MonkeyPatch, get_clusters_mock: Mock + ): + config = KubernetesBackendConfigWithCreds( + kubeconfig=KubeconfigConfig(data="mocked", filename="-"), + contexts=["ctx"], + ) + + monkeypatch.setattr( + "dstack._internal.core.backends.kubernetes.configurator.check_cluster", + Mock(return_value=False), ) - with ( - patch( - "dstack._internal.core.backends.kubernetes.utils.get_api_from_kubeconfig_data" - ) as get_api_mock, - pytest.raises(BackendInvalidCredentialsError) as exc_info, - ): - get_api_mock.side_effect = Exception("Invalid config") + cluster_mock = Mock() + get_clusters_mock.return_value = [cluster_mock] + with pytest.raises(ServerClientError, match="Failed to validate cluster") as exc_info: KubernetesConfigurator().validate_config(config, default_creds_enabled=True) assert exc_info.value.fields == [["kubeconfig"]] diff --git a/src/tests/_internal/core/backends/kubernetes/test_utils.py b/src/tests/_internal/core/backends/kubernetes/test_utils.py new file mode 100644 index 0000000000..f58868217f --- /dev/null +++ b/src/tests/_internal/core/backends/kubernetes/test_utils.py @@ -0,0 +1,318 @@ +import logging +from textwrap import dedent +from typing import Optional, Union + +import pytest + +from dstack._internal.core.backends.kubernetes.models import ( + KubeconfigConfig, + KubernetesBackendConfigWithCreds, + KubernetesContextConfig, + KubernetesProxyJumpConfig, +) +from dstack._internal.core.backends.kubernetes.utils import ( + Cluster, + get_clusters_from_backend_config, +) + + +class TestGetClustersFromBackendConfig: + def make_config( + self, + kubeconfig_data: str, + *, + contexts: Optional[list[Union[KubernetesContextConfig, str]]] = None, + namespace: Optional[str] = None, + proxy_jump: Optional[KubernetesProxyJumpConfig] = None, + ) -> KubernetesBackendConfigWithCreds: + return KubernetesBackendConfigWithCreds( + kubeconfig=KubeconfigConfig(data=kubeconfig_data, filename="-"), + contexts=contexts, + namespace=namespace, + proxy_jump=proxy_jump, + ) + + def make_kubeconfig( + self, + *, + current_context: str = "ctx-a", + # (context name, namespace) pairs + contexts: tuple[tuple[str, str], ...] = (("ctx-a", "default"),), + ) -> str: + clusters_yaml = "\n".join( + dedent(f""" + - name: cluster-{name} + cluster: + server: https://{name}.example.com:6443 + """) + for name, _ in contexts + ) + users_yaml = "\n".join( + dedent(f""" + - name: user-{name} + user: + token: token-{name} + """) + for name, _ in contexts + ) + contexts_yaml = "\n".join( + dedent(f""" + - name: {name} + context: + cluster: cluster-{name} + user: user-{name} + namespace: {namespace} + """) + for name, namespace in contexts + ) + return dedent(""" + apiVersion: v1 + kind: Config + current-context: {current_context} + clusters: + {clusters} + contexts: + {contexts} + users: + {users} + """).format( + current_context=current_context, + clusters=clusters_yaml, + contexts=contexts_yaml, + users=users_yaml, + ) + + def test_returns_single_cluster_using_current_context(self): + config = self.make_config( + self.make_kubeconfig( + current_context="ctx-a", + contexts=( + ("ctx-b", "team-b"), + ("ctx-a", "default"), + ), + ), + ) + + clusters = get_clusters_from_backend_config(config) + + assert len(clusters) == 1 + cluster = clusters[0] + assert isinstance(cluster, Cluster) + assert cluster.context_name == "ctx-a" + assert cluster.region == "" + assert cluster.namespace == "default" + assert cluster.proxy_jump == KubernetesProxyJumpConfig() + assert cluster.api_client.configuration.host == "https://ctx-a.example.com:6443" # pyright: ignore[reportAttributeAccessIssue] + + def test_single_context_uses_namespace_from_backend_config(self): + config = self.make_config( + self.make_kubeconfig(contexts=(("ctx-a", "team-a"),)), + namespace="team-a", + ) + + clusters = get_clusters_from_backend_config(config) + + assert clusters[0].namespace == "team-a" + + def test_single_context_defaults_namespace_when_not_set(self): + config = self.make_config( + self.make_kubeconfig(contexts=(("ctx-a", "team-a"),)), + namespace=None, + ) + + clusters = get_clusters_from_backend_config(config) + + assert clusters[0].namespace == "default" + + def test_single_context_uses_proxy_jump_from_backend_config(self): + proxy_jump = KubernetesProxyJumpConfig(hostname="1.2.3.4", port=2222) + config = self.make_config( + self.make_kubeconfig(), + proxy_jump=proxy_jump, + ) + + clusters = get_clusters_from_backend_config(config) + + assert clusters[0].proxy_jump == proxy_jump + + def test_single_context_uses_default_proxy_jump_when_unset(self): + config = self.make_config(self.make_kubeconfig(), proxy_jump=None) + + clusters = get_clusters_from_backend_config(config) + + assert clusters[0].proxy_jump == KubernetesProxyJumpConfig() + + def test_single_context_warns_on_namespace_mismatch(self, caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.WARNING) + config = self.make_config( + self.make_kubeconfig(contexts=(("ctx-a", "kube-ns"),)), + namespace="config-ns", + ) + + clusters = get_clusters_from_backend_config(config) + + assert clusters[0].namespace == "config-ns" + assert "Namespace mismatch" in caplog.text + assert "kube-ns" in caplog.text + assert "config-ns" in caplog.text + + def test_single_context_does_not_warn_when_namespace_matches( + self, caplog: pytest.LogCaptureFixture + ): + caplog.set_level(logging.WARNING) + config = self.make_config( + self.make_kubeconfig(contexts=(("ctx-a", "team-a"),)), + namespace="team-a", + ) + + get_clusters_from_backend_config(config) + + assert "Namespace mismatch" not in caplog.text + + def test_single_context_raises_when_current_context_missing(self): + kubeconfig = dedent(""" + apiVersion: v1 + kind: Config + clusters: + - name: cluster-a + cluster: + server: https://a.example.com:6443 + contexts: + - name: ctx-a + context: + cluster: cluster-a + user: user-a + users: + - name: user-a + user: + token: t + """) + config = self.make_config(kubeconfig) + + with pytest.raises(ValueError, match="current-context is not set"): + get_clusters_from_backend_config(config) + + def test_contexts_as_strings(self): + config = self.make_config( + self.make_kubeconfig( + current_context="ctx-a", + contexts=( + ("ctx-a", "ns-a"), + ("ctx-b", "ns-b"), + ), + ), + contexts=["ctx-a", "ctx-b"], + ) + + clusters = get_clusters_from_backend_config(config) + + assert [c.context_name for c in clusters] == ["ctx-a", "ctx-b"] + assert [c.region for c in clusters] == ["ctx-a", "ctx-b"] + assert [c.namespace for c in clusters] == ["ns-a", "ns-b"] + assert all(c.proxy_jump == KubernetesProxyJumpConfig() for c in clusters) + assert clusters[0].api_client.configuration.host == "https://ctx-a.example.com:6443" # pyright: ignore[reportAttributeAccessIssue] + assert clusters[1].api_client.configuration.host == "https://ctx-b.example.com:6443" # pyright: ignore[reportAttributeAccessIssue] + + def test_contexts_with_per_context_proxy_jump(self): + proxy_jump_a = KubernetesProxyJumpConfig(hostname="a.example.com", port=2201) + proxy_jump_b = KubernetesProxyJumpConfig(hostname="b.example.com", port=2202) + config = self.make_config( + self.make_kubeconfig( + contexts=( + ("ctx-a", "ns-a"), + ("ctx-b", "ns-b"), + ), + ), + contexts=[ + KubernetesContextConfig(name="ctx-a", proxy_jump=proxy_jump_a), + KubernetesContextConfig(name="ctx-b", proxy_jump=proxy_jump_b), + ], + ) + + clusters = get_clusters_from_backend_config(config) + + assert clusters[0].proxy_jump == proxy_jump_a + assert clusters[1].proxy_jump == proxy_jump_b + + def test_contexts_mix_string_and_object(self): + proxy_jump = KubernetesProxyJumpConfig(hostname="b.example.com", port=2222) + config = self.make_config( + self.make_kubeconfig( + contexts=( + ("ctx-a", "ns-a"), + ("ctx-b", "ns-b"), + ), + ), + contexts=[ + "ctx-a", + KubernetesContextConfig(name="ctx-b", proxy_jump=proxy_jump), + ], + ) + + clusters = get_clusters_from_backend_config(config) + + assert clusters[0].proxy_jump == KubernetesProxyJumpConfig() + assert clusters[1].proxy_jump == proxy_jump + + def test_contexts_object_without_proxy_jump_uses_default(self): + config = self.make_config( + self.make_kubeconfig(contexts=(("ctx-a", "ns-a"),)), + contexts=[KubernetesContextConfig(name="ctx-a", proxy_jump=None)], + ) + + clusters = get_clusters_from_backend_config(config) + + assert clusters[0].proxy_jump == KubernetesProxyJumpConfig() + + def test_contexts_ignores_backend_namespace_and_proxy_jump(self): + config = self.make_config( + self.make_kubeconfig(contexts=(("ctx-a", "kube-ns"),)), + contexts=["ctx-a"], + namespace="config-ns", + proxy_jump=KubernetesProxyJumpConfig(hostname="ignored", port=1), + ) + + clusters = get_clusters_from_backend_config(config) + + assert clusters[0].namespace == "kube-ns" + assert clusters[0].proxy_jump == KubernetesProxyJumpConfig() + + def test_contexts_does_not_warn_on_namespace_mismatch(self, caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.WARNING) + config = self.make_config( + self.make_kubeconfig(contexts=(("ctx-a", "kube-ns"),)), + contexts=["ctx-a"], + namespace="config-ns", + ) + + get_clusters_from_backend_config(config) + + assert "Namespace mismatch" not in caplog.text + + def test_contexts_raises_for_unknown_context(self): + config = self.make_config( + self.make_kubeconfig(contexts=(("ctx-a", "ns-a"),)), + contexts=["ctx-missing"], + ) + + with pytest.raises(ValueError, match="context ctx-missing not found"): + get_clusters_from_backend_config(config) + + def test_empty_contexts_returns_no_clusters(self): + config = self.make_config( + self.make_kubeconfig(contexts=(("ctx-a", "ns-a"),)), + contexts=[], + ) + + clusters = get_clusters_from_backend_config(config) + + assert clusters == [] + + def test_request_timeout_and_retries_propagate_to_client(self): + config = self.make_config(self.make_kubeconfig()) + + clusters = get_clusters_from_backend_config(config, request_timeout=7, retries=5) + + api_client = clusters[0].api_client + assert api_client.configuration.retries == 5 # pyright: ignore[reportAttributeAccessIssue] + assert getattr(api_client, "_ApiClient__request_timeout", None) == 7 diff --git a/src/tests/_internal/server/services/test_backend_configs.py b/src/tests/_internal/server/services/test_backend_configs.py index 833dda2758..d99bdcb985 100644 --- a/src/tests/_internal/server/services/test_backend_configs.py +++ b/src/tests/_internal/server/services/test_backend_configs.py @@ -183,11 +183,12 @@ def test_ui_config_embedded_kubeconfig_initializes_backend(self): backend_config = config_yaml_to_backend_config(config_yaml) backend = KubernetesBackend(backend_config) - assert backend.compute().api.api_client.configuration.host == ( + cluster = backend.compute().region_cluster_map[""] + assert cluster.api_client.configuration.host == ( "https://gpu-cluster.internal.example.com:6443" ) - assert backend.compute().proxy_jump.hostname == "204.12.171.137" - assert backend.compute().proxy_jump.port == 32000 + assert cluster.proxy_jump.hostname == "204.12.171.137" + assert cluster.proxy_jump.port == 32000 def test_kubeconfig_context_namespace_does_not_set_backend_namespace(self): config_yaml = dedent( @@ -226,4 +227,5 @@ def test_kubeconfig_context_namespace_does_not_set_backend_namespace(self): backend_config = config_yaml_to_backend_config(config_yaml) backend = KubernetesBackend(backend_config) - assert backend.compute().config.namespace == "default" + cluster = backend.compute().region_cluster_map[""] + assert cluster.namespace == "default"