Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ write_to = "src/vllm_router/_version.py"
[tool.isort]
profile = "black"

[tool.ruff]
target-version = "py312"

[tool.pytest.ini_options]
asyncio_mode = "auto"

Expand Down
60 changes: 60 additions & 0 deletions src/vllm_router/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ The router can be configured using command-line arguments. Below are the availab
- `--static-models`: The models running in the static serving engines, separated by commas (e.g., `model1,model2`).
- `--static-aliases`: The aliases of the models running in the static serving engines, separated by commas and associated using colons (e.g., `model_alias1:model,mode_alias2:model`).
- `--static-backend-health-checks`: Enable this flag to make vllm-router check periodically if the models work by sending dummy requests to their endpoints.
- `--static-fallback-models`: Fallback model mappings, separated by commas (e.g., `model1:fallback1,model2:fallback2`). When all backends for a model are unavailable, requests are retried on the fallback model.
- `--k8s-port`: The port of vLLM processes when using K8s service discovery. Default is `8000`.
- `--k8s-namespace`: The namespace of vLLM pods when using K8s service discovery. Default is `default`.
- `--k8s-label-selector`: The label selector to filter vLLM pods when using K8s service discovery.
Expand Down Expand Up @@ -108,6 +109,64 @@ different endpoints for each model type.
> Enabling this flag will put some load on your backend every minute as real requests are send to the nodes
> to test their functionality.

## Fallback models

When all backends for a model become unavailable (e.g. during node reboots), the
router can automatically retry the request on a different **fallback model**. The
model name in the request body is rewritten to the fallback model name before
forwarding, so the fallback backend receives the correct model identifier.

Fallback triggers in two situations:

1. **No healthy endpoints** -- all backends have been marked unhealthy by the
periodic health check. The router switches to the fallback model immediately
without attempting the primary backends.
2. **All instance-level failover attempts failed** -- the primary backends were
still considered healthy but every attempt returned a connection error (e.g.
the node went down between health checks). After exhausting
`--max-instance-failover-reroute-attempts`, the router retries once on the
fallback model.

### Configuration

**In a YAML config file**, add `fallback_model` to any model entry. The value
must be the name of another model defined in `static_models`:

```yaml
static_models:
glm-5:
static_backends:
- https://gpu-node-1/glm-5
- https://gpu-node-2/glm-5
static_model_type: chat
fallback_model: glm-5-cloud # fall back to the cloud-hosted variant
glm-5-cloud:
static_backends:
- http://cloud-gateway:1975
static_model_type: chat
healthcheck_disabled: true
```

**Via CLI**, use `--static-fallback-models` with comma-separated
`model:fallback` pairs:

```bash
vllm-router --port 8000 \
--service-discovery static \
--static-backends "https://gpu-node-1/glm-5,https://gpu-node-2/glm-5,http://cloud-gateway:1975" \
--static-models "glm-5,glm-5,glm-5-cloud" \
--static-model-types "chat,chat,chat" \
--static-fallback-models "glm-5:glm-5-cloud" \
--static-backend-health-checks \
--max-instance-failover-reroute-attempts 2 \
--routing-logic roundrobin
```

Combining `fallback_model` with `--max-instance-failover-reroute-attempts` and a
short `--static-backend-health-check-interval` gives the best resilience: failed
requests are retried on other instances first, then on the fallback model, while
the health check quickly removes dead backends from future routing decisions.

## Dynamic Router Config

The router can be configured dynamically using a config file when passing the `--dynamic-config-yaml` or
Expand All @@ -128,6 +187,7 @@ Currently, the dynamic config supports the following fields:
- (When using `static` service discovery) `static_models`: The models running in the static serving engines, separated by commas (e.g., `model1,model2`).
- (When using `static` service discovery) `static_aliases`: The aliases of the models running in the static serving engines, separated by commas and associated using colons (e.g., `model_alias1:model,mode_alias2:model`).
- (When using `static` service discovery and if you enable the `--static-backend-health-checks` flag) `static_model_types`: The model types running in the static serving engines, separated by commas (e.g., `chat,chat`).
- (When using `static` service discovery) `fallback_model`: A per-model string in the YAML config (under each model entry) specifying another model to fall back to when all backends are unavailable.
- (When using `k8s` service discovery) `k8s_port`: The port of vLLM processes when using K8s service discovery. Default is `8000`.
- (When using `k8s` service discovery) `k8s_namespace`: The namespace of vLLM pods when using K8s service discovery. Default is `default`.
- (When using `k8s` service discovery) `k8s_label_selector`: The label selector to filter vLLM pods when using K8s service discovery.
Expand Down
5 changes: 5 additions & 0 deletions src/vllm_router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ def initialize_all(app: FastAPI, args):
static_backend_health_check_timeout_seconds=args.static_backend_health_check_timeout_seconds,
prefill_model_labels=args.prefill_model_labels,
decode_model_labels=args.decode_model_labels,
fallback_models=(
parse_static_aliases(args.static_fallback_models)
if args.static_fallback_models
else None
),
)
elif args.service_discovery == "k8s":
initialize_service_discovery(
Expand Down
7 changes: 7 additions & 0 deletions src/vllm_router/dynamic_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class DynamicRouterConfig:
static_aliases: Optional[str] = None
static_model_labels: Optional[str] = None
static_model_types: Optional[str] = None
static_fallback_models: Optional[str] = None
static_backend_health_checks: Optional[bool] = False
static_backend_health_check_interval: Optional[int] = 60
static_backend_health_check_timeout_seconds: Optional[int] = 10
Expand Down Expand Up @@ -97,6 +98,7 @@ def from_args(args) -> "DynamicRouterConfig":
static_backend_health_checks=args.static_backend_health_checks,
static_backend_health_check_interval=args.static_backend_health_check_interval,
static_backend_health_check_timeout_seconds=args.static_backend_health_check_timeout_seconds,
static_fallback_models=getattr(args, "static_fallback_models", None),
k8s_port=args.k8s_port,
k8s_namespace=args.k8s_namespace,
k8s_label_selector=args.k8s_label_selector,
Expand Down Expand Up @@ -185,6 +187,11 @@ def reconfigure_service_discovery(self, config: DynamicRouterConfig):
decode_model_labels=parse_comma_separated_args(
config.decode_model_labels
),
fallback_models=(
parse_static_aliases(config.static_fallback_models)
if config.static_fallback_models
else None
),
)
elif config.service_discovery == "k8s":
reconfigure_service_discovery(
Expand Down
7 changes: 7 additions & 0 deletions src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@ def parse_args():
default=None,
help="The model labels of static backends, separated by commas. E.g., model1,model2",
)
parser.add_argument(
"--static-fallback-models",
type=str,
default=None,
help="Fallback model mappings, separated by commas. E.g., model1:fallback1,model2:fallback2. "
"When all backends for a model are unavailable, requests are retried on the fallback model.",
)
parser.add_argument(
"--static-backend-health-checks",
action="store_true",
Expand Down
15 changes: 15 additions & 0 deletions src/vllm_router/parsers/yaml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ def generate_static_model_types(models: dict[str, Any]) -> str:
return ",".join(static_model_types)


def generate_static_fallback_models(models: dict[str, Any]) -> str | None:
"""Generate comma-separated fallback model mappings.

Format: model1:fallback1,model2:fallback2
"""
fallback_models = []
for name, details in models.items():
if "fallback_model" in details:
fallback_models.append(f"{name}:{details['fallback_model']}")
return ",".join(fallback_models) if fallback_models else None


def read_and_process_yaml_config_file(config_path: str) -> dict[str, Any]:
with open(config_path, encoding="utf-8") as f:
try:
Expand All @@ -49,6 +61,9 @@ def read_and_process_yaml_config_file(config_path: str) -> dict[str, Any]:
yaml_config["static_backends"] = generate_static_backends(models)
yaml_config["static_models"] = generate_static_models(models)
yaml_config["static_model_types"] = generate_static_model_types(models)
fallback_models = generate_static_fallback_models(models)
if fallback_models:
yaml_config["static_fallback_models"] = fallback_models
if aliases:
yaml_config["static_aliases"] = generate_static_aliases(aliases)
return yaml_config
Expand Down
2 changes: 2 additions & 0 deletions src/vllm_router/service_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def __init__(
static_backend_health_check_timeout_seconds: int = 10,
prefill_model_labels: List[str] | None = None,
decode_model_labels: List[str] | None = None,
fallback_models: Dict[str, str] | None = None,
):
self.app = app
assert len(urls) == len(models), "URLs and models should have the same length"
Expand All @@ -225,6 +226,7 @@ def __init__(
self.aliases = aliases
self.model_labels = model_labels
self.model_types = model_types
self.fallback_models = fallback_models or {}
self.engines_id = [str(uuid.uuid4()) for i in range(0, len(urls))]
self.added_timestamp = int(time.time())
self.unhealthy_endpoint_hashes = []
Expand Down
Loading
Loading