diff --git a/docs/src/content/docs/configuration/invokeai-yaml.mdx b/docs/src/content/docs/configuration/invokeai-yaml.mdx index 987c8eb98a2..6ac56053928 100644 --- a/docs/src/content/docs/configuration/invokeai-yaml.mdx +++ b/docs/src/content/docs/configuration/invokeai-yaml.mdx @@ -114,6 +114,39 @@ Most common algorithms are supported, like `md5`, `sha256`, and `sha512`. These These options set the paths of various directories and files used by InvokeAI. Any user-defined paths should be absolute paths. +#### Multi-GPU Generation + +On a machine with more than one GPU, InvokeAI can run several generation sessions at the same time — one per GPU — instead of processing the queue one job at a time. Jobs are distributed fairly across users, so a single user's large batch cannot monopolize every GPU while others wait. + +This is controlled by the `generation_devices` setting: + +```yaml +generation_devices: auto # default value +``` + +| Value | Behavior | +| -------------------------- | ----------------------------------------------------------------------------------------------------------------------- | +| `auto` | Use every available CUDA GPU, running one generation session per GPU concurrently. This is the default. | +| `[cuda:0,cuda:1]` | Use the specific devices listed, one session per device. Useful for reserving a GPU for other work. | +| `[cuda:0]` | Use a single specific device. Generation runs serially, as it did before multi-GPU support. | +| `[]` | Use the first detected device. Generation runs serially, as it did before multi-GPU support. | + +Each entry in the list must be one of `cpu`, `cuda`, `mps`, or `cuda:N`, where `N` is a zero-based device number (`cuda:0` is the first GPU, `cuda:1` the second, and so on). + +```yaml +# Use the first and third GPUs, leaving the second free for other tasks +generation_devices: [cuda:0, cuda:2] +``` + +Notes: + +- On a system without a CUDA GPU, `auto` resolves to the single best available device (`mps` on Apple Silicon, otherwise `cpu`), so generation runs serially. +- Each active GPU gets its own model cache, and model weights are duplicated in system RAM for every device. Running many GPUs in parallel therefore increases RAM usage — ensure you have ample system memory before enabling a large device list. +- Duplicate entries are ignored; `[cuda:0, cuda:0]` is treated as `[cuda:0]`. +- You can restrict which physical GPUs InvokeAI sees with the `CUDA_VISIBLE_DEVICES` environment variable. When set, `auto` only enumerates the visible subset, and `cuda:N` indices refer to positions within that subset. + +During parallel generation, the progress display shows one progress bar per active session, stacked vertically, each disappearing as its session completes. + #### Image Subfolder Strategy By default, generated images are stored in a single flat directory under `outputs/images/`. The `image_subfolder_strategy` setting lets you organize newly-created images into subfolders automatically. You can edit this setting in `invokeai.yaml` or, as an admin user, in the Settings panel. diff --git a/docs/src/generated/settings.json b/docs/src/generated/settings.json index fcb47dbfb23..1987a90abce 100644 --- a/docs/src/generated/settings.json +++ b/docs/src/generated/settings.json @@ -490,6 +490,17 @@ "type": "", "validation": {} }, + { + "category": "DEVICE", + "default": "auto", + "description": "Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number)", + "env_var": "INVOKEAI_GENERATION_DEVICES", + "literal_values": [], + "name": "generation_devices", + "required": false, + "type": "typing.Union[typing.Literal['auto'], list[str]]", + "validation": {} + }, { "category": "DEVICE", "default": "auto", diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index 832e58f5e24..a8e0c68d781 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -1,15 +1,16 @@ import locale +import re from enum import Enum from importlib.metadata import distributions from pathlib import Path as FilePath from threading import Lock -from typing import Any +from typing import Any, Literal, Union import torch import yaml from fastapi import Body, HTTPException, Path from fastapi.routing import APIRouter -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.api.auth_dependencies import AdminUserOrDefault from invokeai.app.api.dependencies import ApiDependencies @@ -118,6 +119,16 @@ def _remove_nullable_default_from_schema(schema: dict[str, Any]) -> None: schema.update(non_null_schemas[0]) +_GENERATION_DEVICE_PATTERN = re.compile(r"^(cpu|mps|cuda(:\d+)?)$") + + +class GenerationDeviceOption(BaseModel): + """A device that may be selected for generation.""" + + device: str = Field(description="The device identifier, e.g. 'cuda:0', 'mps', or 'cpu'") + name: str = Field(description="Human-readable device name") + + class UpdateAppGenerationSettingsRequest(BaseModel): """Writable generation-related app settings.""" @@ -131,14 +142,59 @@ class UpdateAppGenerationSettingsRequest(BaseModel): ge=0, description="Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items.", ) + generation_devices: Union[Literal["auto"], list[str]] | None = Field( + default=None, + description="Devices to use for parallel generation. `auto` uses every available GPU; provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices. Takes effect after restarting InvokeAI.", + json_schema_extra=_remove_nullable_default_from_schema, + ) + + @field_validator("generation_devices") + @classmethod + def validate_generation_devices( + cls, v: Union[Literal["auto"], list[str], None] + ) -> Union[Literal["auto"], list[str], None]: + if v is None or v == "auto": + return v + for device in v: + if not _GENERATION_DEVICE_PATTERN.match(device): + raise ValueError( + f"Invalid generation device '{device}'. Valid values are 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'." + ) + return v @model_validator(mode="after") def validate_explicit_nulls(self) -> "UpdateAppGenerationSettingsRequest": if "image_subfolder_strategy" in self.model_fields_set and self.image_subfolder_strategy is None: raise ValueError("image_subfolder_strategy may not be null") + if "generation_devices" in self.model_fields_set and self.generation_devices is None: + raise ValueError("generation_devices may not be null") return self +@app_router.get( + "/generation_device_options", + operation_id="get_generation_device_options", + status_code=200, + response_model=list[GenerationDeviceOption], +) +async def get_generation_device_options() -> list[GenerationDeviceOption]: + """List the devices available for generation, for use with the `generation_devices` setting.""" + options: list[GenerationDeviceOption] = [] + if torch.cuda.is_available(): + for index in range(torch.cuda.device_count()): + device = f"cuda:{index}" + try: + name = torch.cuda.get_device_name(index) + except Exception: + name = device + options.append(GenerationDeviceOption(device=device, name=name)) + elif torch.backends.mps.is_available(): + options.append(GenerationDeviceOption(device="mps", name="Apple MPS")) + else: + options.append(GenerationDeviceOption(device="cpu", name="CPU")) + return options + + @app_router.get( "/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields ) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index bdd2e406444..53c4c68981f 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -443,7 +443,11 @@ async def update_model_record( # nn.Module at load time, so toggling them on a cached model is otherwise silently a no-op until # the entry is evicted. Drop any unlocked cached entries for this model so the next load rebuilds. if _load_settings_changed(previous_config, config): - dropped = ApiDependencies.invoker.services.model_manager.load.ram_cache.drop_model(key) + # Drop the model from every per-device cache so the next load on any GPU rebuilds it. + dropped = sum( + cache.drop_model(key) + for cache in ApiDependencies.invoker.services.model_manager.load.ram_caches.values() + ) if dropped: logger.info( f"Dropped {dropped} cached entr{'y' if dropped == 1 else 'ies'} for model {key} after settings change." @@ -1304,9 +1308,10 @@ async def get_stats() -> Optional[CacheStats]: ) async def empty_model_cache(current_admin: AdminUserOrDefault) -> None: """Drop all models from the model cache to free RAM/VRAM. 'Locked' models that are in active use will not be dropped.""" - # Request 1000GB of room in order to force the cache to drop all models. + # Request 1000GB of room in order to force each per-device cache to drop all models. ApiDependencies.invoker.services.logger.info("Emptying model cache.") - ApiDependencies.invoker.services.model_manager.load.ram_cache.make_room(1000 * 2**30) + for cache in ApiDependencies.invoker.services.model_manager.load.ram_caches.values(): + cache.make_room(1000 * 2**30) class HFTokenStatus(str, Enum): diff --git a/invokeai/app/invocations/anima_denoise.py b/invokeai/app/invocations/anima_denoise.py index 9fa4b3fb07a..b301e817f9c 100644 --- a/invokeai/app/invocations/anima_denoise.py +++ b/invokeai/app/invocations/anima_denoise.py @@ -608,7 +608,7 @@ def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> tor if driver is not None: user_step = 0 - pbar = tqdm(total=total_steps, desc="Denoising (Anima)") + pbar = tqdm(total=total_steps, desc=f"Denoising (Anima){TorchDevice.get_session_device_label()}") for it in driver.iterations(): timestep = torch.tensor( [it.sigma_curr * ANIMA_MULTIPLIER], device=device, dtype=inference_dtype @@ -655,7 +655,9 @@ def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> tor pbar.close() else: # Built-in Euler implementation (default for Anima) - for step_idx in tqdm(range(total_steps), desc="Denoising (Anima)"): + for step_idx in tqdm( + range(total_steps), desc=f"Denoising (Anima){TorchDevice.get_session_device_label()}" + ): sigma_curr = sigmas[step_idx] sigma_prev = sigmas[step_idx + 1] diff --git a/invokeai/app/invocations/cogview4_denoise.py b/invokeai/app/invocations/cogview4_denoise.py index c04210401be..cb06d2b3ff6 100644 --- a/invokeai/app/invocations/cogview4_denoise.py +++ b/invokeai/app/invocations/cogview4_denoise.py @@ -294,7 +294,7 @@ def _run_diffusion( assert isinstance(transformer, CogView4Transformer2DModel) # Denoising loop - for step_idx in tqdm(range(total_steps)): + for step_idx in tqdm(range(total_steps), desc=f"Denoising{TorchDevice.get_session_device_label()}"): t_curr = timesteps[step_idx] sigma_curr = sigmas[step_idx] sigma_prev = sigmas[step_idx + 1] diff --git a/invokeai/app/invocations/qwen_image_image_to_latents.py b/invokeai/app/invocations/qwen_image_image_to_latents.py index ef88e03082b..cac536f00c9 100644 --- a/invokeai/app/invocations/qwen_image_image_to_latents.py +++ b/invokeai/app/invocations/qwen_image_image_to_latents.py @@ -18,6 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_qwen_image @invocation( @@ -44,7 +45,10 @@ class QwenImageImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard) @staticmethod def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: - with vae_info.model_on_device() as (_, vae): + # Reserve working memory for the encode so the cache offloads any large resident model first; + # otherwise the encode's activations OOM (the VAE weights themselves are tiny). + estimated_working_memory = estimate_vae_working_memory_qwen_image("encode", image_tensor, vae_info.model) + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, AutoencoderKLQwenImage) vae.disable_tiling() diff --git a/invokeai/app/invocations/qwen_image_latents_to_image.py b/invokeai/app/invocations/qwen_image_latents_to_image.py index b3ea39c4bbf..c418fe43cbe 100644 --- a/invokeai/app/invocations/qwen_image_latents_to_image.py +++ b/invokeai/app/invocations/qwen_image_latents_to_image.py @@ -19,6 +19,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_qwen_image @invocation( @@ -41,15 +42,26 @@ def invoke(self, context: InvocationContext) -> ImageOutput: vae_info = context.models.load(self.vae.vae) assert isinstance(vae_info.model, AutoencoderKLQwenImage) + # Reserve working memory for the decode so the cache offloads any large resident model (e.g. + # the transformer) first; otherwise the decode's activations OOM. See estimator for details. + estimated_working_memory = estimate_vae_working_memory_qwen_image("decode", latents, vae_info.model) with ( SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), - vae_info.model_on_device() as (_, vae), + vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae), ): context.util.signal_progress("Running VAE") assert isinstance(vae, AutoencoderKLQwenImage) latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype) - vae.disable_tiling() + # Honor the global force_tiled_decode setting, like the SD/SDXL l2i node. Tiling bounds the + # VAE's per-tile memory, which is the scalable way to decode very large outputs that would + # exceed VRAM even after offloading the transformer/text encoder. For normal sizes, leave + # it off (faster, no tile blending) — the reserved working memory offloads other models so + # the full-frame decode fits. + if context.config.get().force_tiled_decode: + vae.enable_tiling() + else: + vae.disable_tiling() tiling_context = nullcontext() diff --git a/invokeai/app/invocations/sd3_denoise.py b/invokeai/app/invocations/sd3_denoise.py index f6c90b9690c..10c9080ac5e 100644 --- a/invokeai/app/invocations/sd3_denoise.py +++ b/invokeai/app/invocations/sd3_denoise.py @@ -284,7 +284,10 @@ def _run_diffusion( assert isinstance(transformer, SD3Transformer2DModel) # 6. Denoising loop - for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): + for step_idx, (t_curr, t_prev) in tqdm( + list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True))), + desc=f"Denoising{TorchDevice.get_session_device_label()}", + ): # Expand the latents if we are doing CFG. latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # Expand the timestep to match the latent model input. diff --git a/invokeai/app/invocations/z_image_denoise.py b/invokeai/app/invocations/z_image_denoise.py index 576d10ac9a1..50b41f6121e 100644 --- a/invokeai/app/invocations/z_image_denoise.py +++ b/invokeai/app/invocations/z_image_denoise.py @@ -570,7 +570,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor: # Use diffusers scheduler for stepping # Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps) # This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps - pbar = tqdm(total=total_steps, desc="Denoising") + pbar = tqdm(total=total_steps, desc=f"Denoising{TorchDevice.get_session_device_label()}") for step_index in range(num_scheduler_steps): sched_timestep = scheduler.timesteps[step_index] # Convert scheduler timestep (0-1000) to normalized sigma (0-1) @@ -687,7 +687,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor: pbar.close() else: # Original Euler implementation (default, optimized for Z-Image) - for step_idx in tqdm(range(total_steps)): + for step_idx in tqdm(range(total_steps), desc=f"Denoising{TorchDevice.get_session_device_label()}"): sigma_curr = sigmas[step_idx] sigma_prev = sigmas[step_idx + 1] diff --git a/invokeai/app/run_app.py b/invokeai/app/run_app.py index febd4f4d4b1..389b61e7347 100644 --- a/invokeai/app/run_app.py +++ b/invokeai/app/run_app.py @@ -41,7 +41,7 @@ def run_app() -> None: from invokeai.app.invocations.load_custom_nodes import load_custom_nodes from invokeai.backend.util.devices import TorchDevice - torch_device_name = TorchDevice.get_torch_device_name() + torch_device_name = TorchDevice.get_generation_devices_summary(app_config.generation_devices) logger.info(f"Using torch device: {torch_device_name}") # Import from startup_utils here to avoid importing torch before configure_torch_cuda_allocator() is called. diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index e6cc7c2798c..8c07c2139f4 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -11,7 +11,7 @@ import shutil from functools import lru_cache from pathlib import Path -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Union import yaml from pydantic import BaseModel, Field, PrivateAttr, field_validator @@ -205,6 +205,7 @@ class InvokeAIAppConfig(BaseSettings): # DEVICE device: str = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", pattern=r"^(auto|cpu|mps|cuda(:\d+)?)$") + generation_devices: Union[Literal["auto"], list[str]] = Field(default="auto", description="Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number)") precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.") # GENERATION @@ -257,6 +258,27 @@ class InvokeAIAppConfig(BaseSettings): model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True) + @field_validator("generation_devices") + @classmethod + def validate_generation_devices(cls, v: Union[str, list[str]]) -> Union[str, list[str]]: + if v == "auto": + return v + # A non-"auto" string would otherwise be iterated character-by-character below (rejecting + # 'c' from "cuda:0"), producing a confusing error. Require an explicit list instead. + if isinstance(v, str): + raise ValueError( + f"Invalid generation_devices value '{v}'. Use 'auto' or a list of devices, e.g. ['cuda:0', 'cuda:1']." + ) + if len(v) == 0: + raise ValueError("generation_devices cannot be an empty list. Use 'auto' or a list of devices.") + pattern = re.compile(r"^(cpu|mps|cuda(:\d+)?)$") + for device in v: + if not pattern.match(device): + raise ValueError( + f"Invalid generation device '{device}'. Valid values are 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'." + ) + return v + def update_config(self, config: dict[str, Any] | InvokeAIAppConfig, clobber: bool = True) -> None: """Updates the config, overwriting existing values. diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 0c530f9a2f7..c30fa31b75c 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -138,6 +138,10 @@ class InvocationProgressEvent(InvocationEventBase): image: ProgressImage | None = Field( default=None, description="An image representing the current state of the progress" ) + device: str | None = Field( + default=None, + description="The device processing this session, e.g. 'cuda:1' (set only when running on a CUDA GPU)", + ) @classmethod def build( @@ -148,6 +152,13 @@ def build( percentage: float | None = None, image: ProgressImage | None = None, ) -> "InvocationProgressEvent": + # This is emitted from the session-processor worker thread, which pins its CUDA device via + # TorchDevice.set_session_device(). Resolve that here so the UI can label progress by GPU. + from invokeai.backend.util.devices import TorchDevice + + session_device = TorchDevice.get_session_device() + device = str(session_device) if session_device is not None and session_device.type == "cuda" else None + return cls( queue_id=queue_item.queue_id, item_id=queue_item.item_id, @@ -161,6 +172,7 @@ def build( percentage=percentage, image=image, message=message, + device=device, ) diff --git a/invokeai/app/services/image_files/image_files_disk.py b/invokeai/app/services/image_files/image_files_disk.py index 12b737a7cf1..ec84439547a 100644 --- a/invokeai/app/services/image_files/image_files_disk.py +++ b/invokeai/app/services/image_files/image_files_disk.py @@ -1,4 +1,5 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team +import threading from pathlib import Path from queue import Queue from typing import Optional, Union @@ -23,6 +24,9 @@ def __init__(self, output_folder: Union[str, Path]): self.__cache: dict[Path, PILImageType] = {} self.__cache_ids = Queue[Path]() self.__max_cache_size = 10 # TODO: get this from config + # Guards the cache structures (__cache / __cache_ids), which are read and mutated from + # multiple session-processor worker threads in multi-GPU parallel mode. + self.__cache_lock = threading.Lock() self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) self.__thumbnails_folder = self.__output_folder / "thumbnails" @@ -41,6 +45,13 @@ def get(self, image_name: str, image_subfolder: str = "") -> PILImageType: return cache_item image = Image.open(image_path) + # Image.open() is lazy: it reads the header but defers pixel decoding (and holds the + # file handle open) until the first .load()/.copy()/.convert(). The opened object is + # cached and the SAME object is handed to every caller, so in multi-GPU parallel mode + # two worker threads can call .copy() on it concurrently and race on the shared file + # handle and decoder state, producing "broken data stream" / "self.png is not None" + # errors. Forcing the decode here makes the cached object safe for concurrent reads. + image.load() self.__set_cache(image_path, image) return image except FileNotFoundError as e: @@ -105,16 +116,18 @@ def delete(self, image_name: str, image_subfolder: str = "") -> None: if image_path.exists(): image_path.unlink() - if image_path in self.__cache: - del self.__cache[image_path] thumbnail_name = get_thumbnail_name(image_name) thumbnail_path = self.get_path(thumbnail_name, True, image_subfolder=image_subfolder) if thumbnail_path.exists(): thumbnail_path.unlink() - if thumbnail_path in self.__cache: - del self.__cache[thumbnail_path] + + with self.__cache_lock: + if image_path in self.__cache: + del self.__cache[image_path] + if thumbnail_path in self.__cache: + del self.__cache[thumbnail_path] except Exception as e: raise ImageFileDeleteException from e @@ -185,13 +198,15 @@ def __validate_storage_folders(self) -> None: folder.mkdir(parents=True, exist_ok=True) def __get_cache(self, image_name: Path) -> Optional[PILImageType]: - return None if image_name not in self.__cache else self.__cache[image_name] + with self.__cache_lock: + return None if image_name not in self.__cache else self.__cache[image_name] def __set_cache(self, image_name: Path, image: PILImageType): - if image_name not in self.__cache: - self.__cache[image_name] = image - self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache - if len(self.__cache) > self.__max_cache_size: - cache_id = self.__cache_ids.get() - if cache_id in self.__cache: - del self.__cache[cache_id] + with self.__cache_lock: + if image_name not in self.__cache: + self.__cache[image_name] = image + self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache + if len(self.__cache) > self.__max_cache_size: + cache_id = self.__cache_ids.get() + if cache_id in self.__cache: + del self.__cache[cache_id] diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 3baf11029ff..7c9fdeee11b 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -116,6 +116,11 @@ def __init__( self._restore_completed_event.set() self._download_queue = download_queue self._download_cache: Dict[int, ModelInstallJob] = {} + # Per-source locks serializing download_and_cache_model() so parallel (multi-GPU) sessions + # that need the same remote model (e.g. the LaMa infill model) don't race to download into + # the same cache directory. _download_cache_locks_guard protects the dict itself. + self._download_cache_locks: Dict[str, threading.Lock] = {} + self._download_cache_locks_guard = threading.Lock() self._running = False self._session = session self._install_thread: Optional[threading.Thread] = None @@ -724,27 +729,47 @@ def download_and_cache_model( if len(contents) > 0: return contents[0] - model_path.mkdir(parents=True, exist_ok=True) - model_source = self._guess_source(str(source)) - remote_files, _ = self._remote_files_from_source(model_source) - # Handle multiple subfolders for HFModelSource - subfolders = model_source.subfolders if isinstance(model_source, HFModelSource) else [] - job = self._multifile_download( - dest=model_path, - remote_files=remote_files, - subfolder=model_source.subfolder - if isinstance(model_source, HFModelSource) and len(subfolders) <= 1 - else None, - subfolders=subfolders if len(subfolders) > 1 else None, - ) - files_string = "file" if len(remote_files) == 1 else "files" - self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})") - self._download_queue.wait_for_job(job) - if job.complete: - assert job.download_path is not None - return job.download_path - else: - raise Exception(job.error) + # Serialize concurrent downloads of the same source. Parallel multi-GPU sessions can each + # request the same remote model (e.g. the LaMa infill model) at once; without this lock they + # both download into the same cache directory and collide on the final rename, which fails on + # Windows with "WinError 32: the file is being used by another process". The other waiters + # find the completed download on the post-lock re-check below and skip downloading. + with self._download_cache_lock(str(source)): + if model_path.exists(): + contents = list(model_path.iterdir()) + if len(contents) > 0: + return contents[0] + + model_path.mkdir(parents=True, exist_ok=True) + model_source = self._guess_source(str(source)) + remote_files, _ = self._remote_files_from_source(model_source) + # Handle multiple subfolders for HFModelSource + subfolders = model_source.subfolders if isinstance(model_source, HFModelSource) else [] + job = self._multifile_download( + dest=model_path, + remote_files=remote_files, + subfolder=model_source.subfolder + if isinstance(model_source, HFModelSource) and len(subfolders) <= 1 + else None, + subfolders=subfolders if len(subfolders) > 1 else None, + ) + files_string = "file" if len(remote_files) == 1 else "files" + self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})") + self._download_queue.wait_for_job(job) + if job.complete: + assert job.download_path is not None + return job.download_path + else: + raise Exception(job.error) + + def _download_cache_lock(self, source: str) -> threading.Lock: + """Return the lock that serializes downloads for a given source, creating it on first use.""" + with self._download_cache_locks_guard: + lock = self._download_cache_locks.get(source) + if lock is None: + lock = threading.Lock() + self._download_cache_locks[source] = lock + return lock def _remote_files_from_source( self, source: ModelSource diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 87a405b4ea4..8fc9823328d 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -26,7 +26,21 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo @property @abstractmethod def ram_cache(self) -> ModelCache: - """Return the RAM cache used by this loader.""" + """Return the RAM cache for the current thread's execution device. + + In multi-GPU mode, each session-processor worker is pinned to a device and gets its own + cache; this resolves to the calling thread's cache. Outside a worker (e.g. API threads), + it resolves to the default device's cache. + """ + + @property + @abstractmethod + def ram_caches(self) -> dict[str, ModelCache]: + """Return all per-device RAM caches, keyed by normalized device string. + + Use this for maintenance operations that must apply to every device (clear cache, drop a + model from all devices, shutdown). + """ @abstractmethod def load_model_from_path( diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 2e2d2ae219d..33c7ef6108c 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -18,7 +18,7 @@ ModelLoaderRegistry, ModelLoaderRegistryBase, ) -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.model_cache import MODEL_LOAD_LOCK, ModelCache from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType from invokeai.backend.util.devices import TorchDevice @@ -33,13 +33,25 @@ def __init__( app_config: InvokeAIAppConfig, ram_cache: ModelCache, registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry, + ram_caches: Optional[dict[str, ModelCache]] = None, ): - """Initialize the model load service.""" + """Initialize the model load service. + + Args: + ram_cache: The default RAM cache, used when no per-device cache matches the calling + thread (e.g. single-device installs, or API threads). + ram_caches: Optional map of normalized device string -> ModelCache for multi-GPU mode. + One cache per generation device. The default `ram_cache` is always included. + """ logger = InvokeAILogger.get_logger(self.__class__.__name__) logger.setLevel(app_config.log_level.upper()) self._logger = logger self._app_config = app_config - self._ram_cache = ram_cache + self._default_ram_cache = ram_cache + # Map normalized device string -> cache. Always includes the default cache so that callers + # without a pinned device (API threads) resolve to a valid cache. + self._ram_caches: dict[str, ModelCache] = dict(ram_caches) if ram_caches else {} + self._ram_caches.setdefault(str(TorchDevice.normalize(ram_cache.execution_device)), ram_cache) self._registry = registry def start(self, invoker: Invoker) -> None: @@ -47,8 +59,18 @@ def start(self, invoker: Invoker) -> None: @property def ram_cache(self) -> ModelCache: - """Return the RAM cache used by this loader.""" - return self._ram_cache + """Return the RAM cache for the calling thread's execution device. + + `choose_torch_device()` is thread-local-aware: a session-processor worker pinned to a GPU + gets that GPU's cache; everything else falls back to the default cache. + """ + key = str(TorchDevice.choose_torch_device()) + return self._ram_caches.get(key, self._default_ram_cache) + + @property + def ram_caches(self) -> dict[str, ModelCache]: + """Return all per-device RAM caches, keyed by normalized device string.""" + return dict(self._ram_caches) def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ @@ -67,7 +89,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo loaded_model: LoadedModel = implementation( app_config=self._app_config, logger=self._logger, - ram_cache=self._ram_cache, + ram_cache=self.ram_cache, ).load_model(model_config, submodel_type) if hasattr(self, "_invoker"): @@ -78,9 +100,11 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo def load_model_from_path( self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None ) -> LoadedModelWithoutConfig: + # Resolve the calling thread's cache once so the whole load uses a single device's cache. + ram_cache = self.ram_cache cache_key = str(model_path) try: - return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache) + return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) except IndexError: pass @@ -110,7 +134,7 @@ def diffusers_load_directory(directory: Path) -> AnyModel: load_class = GenericDiffusersLoader( app_config=self._app_config, logger=self._logger, - ram_cache=self._ram_cache, + ram_cache=ram_cache, convert_cache=self.convert_cache, ).get_hf_load_class(directory) return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) @@ -123,6 +147,15 @@ def diffusers_load_directory(directory: Path) -> AnyModel: else lambda path: safetensors_load_file(path, device="cpu") ) assert loader is not None - raw_model = loader(model_path) - self._ram_cache.put(key=cache_key, model=raw_model) - return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache) + # Serialize construction (see MODEL_LOAD_LOCK): the diffusers loader path uses the same + # process-global, non-thread-safe monkey-patches as the main loader, so it takes the write + # lock to exclude concurrent VRAM moves. Re-check the cache after acquiring the lock in case + # a worker sharing this cache built it while we waited. + with MODEL_LOAD_LOCK.write_lock(): + try: + return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) + except IndexError: + pass + raw_model = loader(model_path) + ram_cache.put(key=cache_key, model=raw_model) + return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 6141a635f4d..176b61ddcab 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -17,6 +17,8 @@ from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase from invokeai.app.services.model_records.model_records_base import ModelRecordServiceBase from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.ram_budget import RamBudget +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger @@ -60,9 +62,10 @@ def start(self, invoker: Invoker) -> None: service.start(invoker) def stop(self, invoker: Invoker) -> None: - # Shutdown the model cache to cancel any pending timers - if hasattr(self._load, "ram_cache"): - self._load.ram_cache.shutdown() + # Shutdown every per-device model cache to cancel any pending keep-alive timers. + if hasattr(self._load, "ram_caches"): + for cache in self._load.ram_caches.values(): + cache.shutdown() for service in [self._store, self._install, self._load]: if hasattr(service, "stop"): @@ -85,22 +88,76 @@ def build_model_manager( logger = InvokeAILogger.get_logger(cls.__name__) logger.setLevel(app_config.log_level.upper()) - ram_cache = ModelCache( - execution_device_working_mem_gb=app_config.device_working_mem_gb, - enable_partial_loading=app_config.enable_partial_loading, - keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights, - max_ram_cache_size_gb=app_config.max_cache_ram_gb, - max_vram_cache_size_gb=app_config.max_cache_vram_gb, - execution_device=execution_device or TorchDevice.choose_torch_device(), - storage_device="cpu", - log_memory_usage=app_config.log_memory_usage, - logger=logger, - keep_alive_minutes=app_config.model_cache_keep_alive_min, + # One store + budget shared by every per-device cache. The store deduplicates each model's CPU + # weights to a single copy across GPUs (see SharedCpuWeightsStore); the budget is the single + # system-wide RAM authority so per-device caches stop double-counting shared weights when they + # decide what to evict (see RamBudget). + shared_store = SharedCpuWeightsStore() + + def build_cache(device: torch.device) -> ModelCache: + return ModelCache( + execution_device_working_mem_gb=app_config.device_working_mem_gb, + enable_partial_loading=app_config.enable_partial_loading, + keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights, + max_ram_cache_size_gb=app_config.max_cache_ram_gb, + max_vram_cache_size_gb=app_config.max_cache_vram_gb, + execution_device=device, + storage_device="cpu", + log_memory_usage=app_config.log_memory_usage, + logger=logger, + keep_alive_minutes=app_config.model_cache_keep_alive_min, + shared_cpu_weights=shared_store, + ) + + # The default cache for callers without a pinned device (API threads, single-device installs). + default_device = execution_device or TorchDevice.choose_torch_device() + ram_cache = build_cache(default_device) + + # In multi-GPU mode, build one independent cache per generation device. Each session-processor + # worker is pinned to a device (see TorchDevice.set_session_device) and resolves to its own + # cache. The default cache is always included by ModelLoadService. + ram_caches: dict[str, ModelCache] = {str(TorchDevice.normalize(default_device)): ram_cache} + for device in TorchDevice.get_generation_devices(app_config.generation_devices): + key = str(device) + if key not in ram_caches: + ram_caches[key] = build_cache(device) + + # Attach the single global RAM budget. The cap is the user's max_cache_ram_gb interpreted as a + # true system-wide limit; when unset, it is the sum of the caches' individually-calculated + # sizes, so each device keeps its effective capacity and weight deduplication becomes headroom. + # That sum is then clamped to a safe fraction of system RAM: each per-device heuristic already + # allows up to ~half of system RAM, so summing across N GPUs would otherwise claim ~N× that and + # leave nothing for the OS, causing swap thrashing. The clamp leaves real headroom; shared-weight + # dedup means the true footprint usually stays well under the cap regardless. + gb = 2**30 + distinct_caches = list(dict.fromkeys(ram_caches.values())) + # Cross-device weight adoption (and its per-model meta-shell capture) only pays off with more + # than one device cache; disable the capture cost otherwise. + shared_store.enable_shell_capture = len(distinct_caches) > 1 + if app_config.max_cache_ram_gb is not None: + global_ram_budget_bytes = int(app_config.max_cache_ram_gb * gb) + else: + summed_cache_bytes = sum(c.local_ram_cache_size_bytes for c in distinct_caches) + system_ram_headroom_bytes = ModelCache.calc_system_ram_headroom_bytes() + global_ram_budget_bytes = min(summed_cache_bytes, system_ram_headroom_bytes) + if global_ram_budget_bytes < summed_cache_bytes: + logger.info( + f"Capping model cache RAM budget at {global_ram_budget_bytes / gb:.2f} GB to leave system " + f"headroom (sum of per-device caches was {summed_cache_bytes / gb:.2f} GB)." + ) + ram_budget = RamBudget(max_bytes=global_ram_budget_bytes, shared_store=shared_store) + for cache in distinct_caches: + cache.set_ram_budget(ram_budget) + logger.info( + f"Model cache global RAM budget: {global_ram_budget_bytes / gb:.2f} GB " + f"across {len(distinct_caches)} device cache(s)." ) + loader = ModelLoadService( app_config=app_config, ram_cache=ram_cache, registry=ModelLoaderRegistry, + ram_caches=ram_caches, ) installer = ModelInstallService( app_config=app_config, diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index b361259a4b1..ae00173e422 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -1,4 +1,5 @@ from queue import Queue +from threading import Lock from typing import TYPE_CHECKING, Optional, TypeVar from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase @@ -21,6 +22,9 @@ def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: self._cache: dict[str, T] = {} self._cache_ids = Queue[str]() self._max_cache_size = max_cache_size + # Guards the in-memory cache so concurrent session-processor workers (multi-GPU) can't race + # the check-then-evict in `_set_cache` (which could otherwise raise KeyError on eviction). + self._cache_lock = Lock() def start(self, invoker: "Invoker") -> None: self._invoker = invoker @@ -50,16 +54,19 @@ def save(self, obj: T) -> str: def delete(self, name: str) -> None: self._underlying_storage.delete(name) - if name in self._cache: - del self._cache[name] + with self._cache_lock: + if name in self._cache: + del self._cache[name] self._on_deleted(name) def _get_cache(self, name: str) -> Optional[T]: - return None if name not in self._cache else self._cache[name] + with self._cache_lock: + return None if name not in self._cache else self._cache[name] def _set_cache(self, name: str, data: T): - if name not in self._cache: - self._cache[name] = data - self._cache_ids.put(name) - if self._cache_ids.qsize() > self._max_cache_size: - self._cache.pop(self._cache_ids.get()) + with self._cache_lock: + if name not in self._cache: + self._cache[name] = data + self._cache_ids.put(name) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 7159c19e746..93c4554b1fe 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -5,6 +5,8 @@ from threading import Event as ThreadEvent from typing import Optional +import torch + from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.events.events_common import ( BatchEnqueuedEvent, @@ -31,6 +33,7 @@ from invokeai.app.services.shared.graph import NodeInputError from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler +from invokeai.backend.util.devices import TorchDevice class DefaultSessionRunner(SessionRunnerBase): @@ -305,6 +308,26 @@ def _on_node_error( ) +class _SessionWorker: + """A single generation worker: one thread, optionally pinned to one device. + + In single-device (legacy) mode there is exactly one worker with `device=None`. In multi-GPU + mode there is one worker per configured device, each with its own session runner and cancel + event so concurrent sessions can be canceled independently. + """ + + def __init__(self, device: Optional[torch.device], runner: SessionRunnerBase) -> None: + self.device = device + self.runner = runner + self.cancel_event = ThreadEvent() + self.queue_item: Optional[SessionQueueItem] = None + self.thread: Optional[Thread] = None + + @property + def label(self) -> str: + return str(self.device) if self.device is not None else "default device" + + class DefaultSessionProcessor(SessionProcessorBase): def __init__( self, @@ -319,57 +342,113 @@ def __init__( self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or [] self._thread_limit = thread_limit self._polling_interval = polling_interval + self._workers: list[_SessionWorker] = [] + + def _resolve_devices(self) -> list[Optional[torch.device]]: + """Determine the per-worker devices from config. + + Resolves `generation_devices` (which defaults to `"auto"` — every available GPU) into one + normalized device per worker. Returns a single `None` (legacy single-worker, device chosen by + the global config) only if the resolution is empty (e.g. `generation_devices` set to an empty + list). + """ + generation_devices = self._invoker.services.configuration.generation_devices + devices = TorchDevice.get_generation_devices(generation_devices) + if not devices: + return [None] + return list(devices) + + def _clone_session_runner(self, template: SessionRunnerBase) -> SessionRunnerBase: + """Create an independent runner for an additional worker. + + Each worker needs its own runner because the runner stores its session's cancel event. + We carry over the template's callbacks so all workers behave identically. + """ + if isinstance(template, DefaultSessionRunner): + return DefaultSessionRunner( + on_before_run_session_callbacks=list(template._on_before_run_session_callbacks), + on_before_run_node_callbacks=list(template._on_before_run_node_callbacks), + on_after_run_node_callbacks=list(template._on_after_run_node_callbacks), + on_node_error_callbacks=list(template._on_node_error_callbacks), + on_after_run_session_callbacks=list(template._on_after_run_session_callbacks), + ) + # Unknown runner implementation — only safe to reuse in single-worker mode. + return template def start(self, invoker: Invoker) -> None: self._invoker: Invoker = invoker - self._queue_item: Optional[SessionQueueItem] = None - self._invocation: Optional[BaseInvocation] = None self._resume_event = ThreadEvent() self._stop_event = ThreadEvent() self._poll_now_event = ThreadEvent() - self._cancel_event = ThreadEvent() register_events(QueueClearedEvent, self._on_queue_cleared) register_events(BatchEnqueuedEvent, self._on_batch_enqueued) register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed) - self._thread_semaphore = BoundedSemaphore(self._thread_limit) + devices = self._resolve_devices() # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, - # the profiler will create a new profile for each session. + # the profiler will create a new profile for each session. Profiling uses a process-global cProfile, which + # cannot cleanly attribute work when multiple sessions run concurrently, so it is disabled in multi-GPU mode. + profiler_enabled = self._invoker.services.configuration.profile_graphs + if profiler_enabled and len(devices) > 1: + self._invoker.services.logger.warning( + "Graph profiling is disabled because multiple generation devices are configured." + ) + profiler_enabled = False self._profiler = ( Profiler( logger=self._invoker.services.logger, output_dir=self._invoker.services.configuration.profiles_path, prefix=self._invoker.services.configuration.profile_prefix, ) - if self._invoker.services.configuration.profile_graphs + if profiler_enabled else None ) - self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler) - self._thread = Thread( - name="session_processor", - target=self._process, - daemon=True, - kwargs={ - "stop_event": self._stop_event, - "poll_now_event": self._poll_now_event, - "resume_event": self._resume_event, - "cancel_event": self._cancel_event, - }, - ) - self._thread.start() + self._thread_semaphore = BoundedSemaphore(len(devices)) + + # Start in the running (resumed) state. + self._stop_event.clear() + self._resume_event.set() + + self._workers = [] + for index, device in enumerate(devices): + runner = self.session_runner if index == 0 else self._clone_session_runner(self.session_runner) + worker = _SessionWorker(device=device, runner=runner) + runner.start(services=invoker.services, cancel_event=worker.cancel_event, profiler=self._profiler) + self._workers.append(worker) + + if len(self._workers) > 1: + self._invoker.services.logger.info( + f"Starting session processor with {len(self._workers)} parallel workers on devices: " + f"{', '.join(w.label for w in self._workers)}" + ) + + for index, worker in enumerate(self._workers): + worker.thread = Thread( + name=f"session_processor_{index}", + target=self._process, + daemon=True, + kwargs={ + "worker": worker, + "stop_event": self._stop_event, + "poll_now_event": self._poll_now_event, + "resume_event": self._resume_event, + }, + ) + worker.thread.start() def stop(self, *args, **kwargs) -> None: self._stop_event.set() # Cancel any in-progress generation so that long-running nodes (e.g. denoising) stop at - # the next step boundary instead of running to completion. Without this, the generation + # the next step boundary instead of running to completion. Without this, a generation # thread may still be executing CUDA operations when Python teardown begins, which can # cause a C++ std::terminate() crash ("terminate called without an active exception"). - self._cancel_event.set() - # Wake the thread if it is sleeping in poll_now_event.wait() or blocked in resume_event.wait() (paused). + for worker in self._workers: + worker.cancel_event.set() + # Wake any worker sleeping in poll_now_event.wait() or blocked in resume_event.wait() (paused). self._poll_now_event.set() self._resume_event.set() @@ -377,28 +456,31 @@ def _poll_now(self) -> None: self._poll_now_event.set() async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None: - if self._queue_item and self._queue_item.queue_id == event[1].queue_id: - self._cancel_event.set() + # Cancel every worker currently running an item from the cleared queue. + canceled = False + for worker in self._workers: + if worker.queue_item and worker.queue_item.queue_id == event[1].queue_id: + worker.cancel_event.set() + canceled = True + if canceled: self._poll_now() async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None: self._poll_now() async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None: - # Make sure the cancel event is for the currently processing queue item - if self._queue_item and self._queue_item.item_id != event[1].item_id: - return - if self._queue_item and event[1].status in ["completed", "failed", "canceled"]: - # When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is - # emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel - # event, which the session runner checks between invocations. If set, the session runner loop is broken. - # - # Long-running nodes that cannot be interrupted easily present a challenge. `denoise_latents` is one such - # node, but it gets a step callback, called on each step of denoising. This callback checks if the queue item - # is canceled, and if it is, raises a `CanceledException` to stop execution immediately. - if event[1].status == "canceled": - self._cancel_event.set() - self._poll_now() + # Find the worker (if any) currently running the item whose status changed. + for worker in self._workers: + if worker.queue_item and worker.queue_item.item_id == event[1].item_id: + if event[1].status in ["completed", "failed", "canceled"]: + # When the queue item is canceled via HTTP, the status is set to "canceled" and this event is + # emitted. We respond by setting that worker's cancel event, which its session runner checks + # between invocations (and which denoise_latents' step callback checks mid-node, raising + # CanceledException to stop immediately). + if event[1].status == "canceled": + worker.cancel_event.set() + self._poll_now() + return def resume(self) -> SessionProcessorStatus: if not self._resume_event.is_set(): @@ -413,22 +495,41 @@ def pause(self) -> SessionProcessorStatus: def get_status(self) -> SessionProcessorStatus: return SessionProcessorStatus( is_started=self._resume_event.is_set(), - is_processing=self._queue_item is not None, + is_processing=any(worker.queue_item is not None for worker in self._workers), ) + def _is_queue_item_terminal(self, item_id: int) -> bool: + """Return True if the queue item is already finished (canceled/failed/completed) or gone. + + Checked right after a worker claims an item to catch a cancellation that raced the claim and + so never reached this worker's cancel_event — e.g. the status-changed handler ran before the + worker recorded `queue_item` and so couldn't match a worker to signal. + """ + try: + status = self._invoker.services.session_queue.get_queue_item(item_id).status + except SessionQueueItemNotFoundError: + return True + return status in ("canceled", "failed", "completed") + def _process( self, + worker: _SessionWorker, stop_event: ThreadEvent, poll_now_event: ThreadEvent, resume_event: ThreadEvent, - cancel_event: ThreadEvent, ): try: - # Any unhandled exception in this block is a fatal processor error and will stop the processor. + # Any unhandled exception in this block is a fatal processor error and will stop this worker. self._thread_semaphore.acquire() - stop_event.clear() - resume_event.set() - cancel_event.clear() + + # Pin this worker thread to its device so all device-selecting code (TorchDevice.choose_torch_device, + # which nodes and the model loader consult) resolves to this GPU. CUDA's current device is per-thread. + if worker.device is not None: + TorchDevice.set_session_device(worker.device) + if worker.device.type == "cuda": + torch.cuda.set_device(worker.device) + + worker.cancel_event.clear() while not stop_event.is_set(): poll_now_event.clear() @@ -437,15 +538,39 @@ def _process( # If we are paused, wait for resume event resume_event.wait() - # Get the next session to process - self._queue_item = self._invoker.services.session_queue.dequeue() + if stop_event.is_set(): + break + + # Clear any stale cancel signal from the previous item BEFORE claiming the next + # one. Clearing it after dequeue (as before) could wipe a cancel that arrived for + # the item we just claimed — e.g. during the gc.collect() below — silently losing + # the cancellation. Any cancel that arrives after this point for the claimed item + # stays set and is caught by the runner's _is_canceled() check. + worker.cancel_event.clear() + + # Get the next session to process. dequeue() atomically claims the item, so concurrent + # workers never receive the same item. Pass this worker's device so the item is + # tagged with the GPU that ran it (None in single-device/legacy mode). + worker.queue_item = self._invoker.services.session_queue.dequeue( + device=str(worker.device) if worker.device is not None else None + ) - if self._queue_item is None: + if worker.queue_item is None: # The queue was empty, wait for next polling interval or event to try again self._invoker.services.logger.debug("Waiting for next polling interval or event") poll_now_event.wait(self._polling_interval) continue + # A cancellation can race the claim: it may have marked the row terminal before + # this worker recorded `queue_item`, so _on_queue_item_status_changed couldn't set + # our cancel_event. Re-check (cancel_event + a fresh DB status read) and skip + # running if the item is already finished, so the cancel is never lost. + if worker.cancel_event.is_set() or self._is_queue_item_terminal(worker.queue_item.item_id): + self._invoker.services.logger.debug( + f"Queue item {worker.queue_item.item_id} was canceled before it started; skipping." + ) + continue + # GC-ing here can reduce peak memory usage of the invoke process by freeing allocated memory blocks. # Most queue items take seconds to execute, so the relative cost of a GC is very small. # Python will never cede allocated memory back to the OS, so anything we can do to reduce the peak @@ -453,19 +578,19 @@ def _process( gc.collect() self._invoker.services.logger.info( - f"Executing queue item {self._queue_item.item_id}, session {self._queue_item.session_id}" + f"Executing queue item {worker.queue_item.item_id}, session {worker.queue_item.session_id} " + f"on {worker.label}" ) - cancel_event.clear() # Run the graph - self.session_runner.run(queue_item=self._queue_item) + worker.runner.run(queue_item=worker.queue_item) except Exception as e: error_type = e.__class__.__name__ error_message = str(e) error_traceback = traceback.format_exc() self._on_non_fatal_processor_error( - queue_item=self._queue_item, + queue_item=worker.queue_item, error_type=error_type, error_message=error_message, error_traceback=error_traceback, @@ -474,7 +599,7 @@ def _process( poll_now_event.wait(self._polling_interval) continue except Exception as e: - # Fatal error in processor, log and pass - we're done here + # Fatal error in this worker, log and pass - we're done here error_type = e.__class__.__name__ error_message = str(e) error_traceback = traceback.format_exc() @@ -482,9 +607,9 @@ def _process( self._invoker.services.logger.error(error_traceback) pass finally: - stop_event.clear() - poll_now_event.clear() - self._queue_item = None + worker.queue_item = None + if worker.device is not None: + TorchDevice.clear_session_device() self._thread_semaphore.release() def _on_non_fatal_processor_error( diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index bb7971cf90d..4ed4f1a62c1 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -31,8 +31,8 @@ class SessionQueueBase(ABC): """Base class for session queue""" @abstractmethod - def dequeue(self) -> Optional[SessionQueueItem]: - """Dequeues the next session queue item.""" + def dequeue(self, device: Optional[str] = None) -> Optional[SessionQueueItem]: + """Dequeues the next session queue item, recording the processing device (e.g. 'cuda:1') if given.""" pass @abstractmethod diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 54bf9158afc..63324337aa1 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -262,6 +262,10 @@ class SessionQueueItem(BaseModel): retried_from_item_id: Optional[int] = Field( default=None, description="The item_id of the queue item that this item was retried from" ) + device: Optional[str] = Field( + default=None, + description="The device that processed this queue item, e.g. 'cuda:1' (set only when running on a CUDA GPU)", + ) session: GraphExecutionState = Field(description="The fully-populated session to be executed") workflow: Optional[WorkflowWithoutID] = Field( default=None, description="The workflow associated with this queue item" diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index c29ed9b0038..e6d8229860d 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -1,7 +1,8 @@ import asyncio import json import sqlite3 -from typing import Optional, Union, cast +import threading +from typing import Any, Optional, Union, cast from pydantic_core import to_jsonable_python @@ -42,6 +43,12 @@ class SqliteSessionQueue(SessionQueueBase): __invoker: Invoker + # Serializes the select-candidate-then-claim sequence in `dequeue()`. The DB connection's + # RLock serializes individual statements, but the gap between selecting the next pending item + # and marking it 'in_progress' is a race: with multiple session-processor workers (multi-GPU), + # two workers could select the same item. Holding this lock across the whole claim prevents it. + _dequeue_lock = threading.Lock() + def start(self, invoker: Invoker) -> None: self.__invoker = invoker self._set_in_progress_to_canceled() @@ -209,28 +216,34 @@ async def enqueue_batch( self.__invoker.services.events.emit_batch_enqueued(enqueue_result, user_id=user_id) return enqueue_result - def dequeue(self) -> Optional[SessionQueueItem]: - with self._db.transaction() as cursor: - cursor.execute( - """--sql - SELECT - sq.*, - u.display_name as user_display_name, - u.email as user_email - FROM session_queue sq - LEFT JOIN users u ON sq.user_id = u.user_id - WHERE sq.status = 'pending' - ORDER BY - sq.priority DESC, - sq.item_id ASC - LIMIT 1 - """ - ) - result = cast(Union[sqlite3.Row, None], cursor.fetchone()) - if result is None: - return None - queue_item = SessionQueueItem.queue_item_from_dict(dict(result)) - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress") + def dequeue(self, device: Optional[str] = None) -> Optional[SessionQueueItem]: + # Hold the dequeue lock across the select-then-claim so concurrent workers (multi-GPU) + # cannot select and claim the same pending item. `_set_queue_item_status` already no-ops + # if the item was concurrently moved to a terminal state (e.g. canceled), so we only need + # to guard against two dequeues racing for the same pending row. + with self._dequeue_lock: + with self._db.transaction() as cursor: + cursor.execute( + """--sql + SELECT + sq.*, + u.display_name as user_display_name, + u.email as user_email + FROM session_queue sq + LEFT JOIN users u ON sq.user_id = u.user_id + WHERE sq.status = 'pending' + ORDER BY + sq.priority DESC, + sq.item_id ASC + LIMIT 1 + """ + ) + result = cast(Union[sqlite3.Row, None], cursor.fetchone()) + if result is None: + return None + queue_item = SessionQueueItem.queue_item_from_dict(dict(result)) + # Record the claiming worker's device so the UI can label the item by GPU. + queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress", device=device) return queue_item def get_next(self, queue_id: str) -> Optional[SessionQueueItem]: @@ -287,6 +300,7 @@ def _set_queue_item_status( error_type: Optional[str] = None, error_message: Optional[str] = None, error_traceback: Optional[str] = None, + device: Optional[str] = None, ) -> SessionQueueItem: with self._db.transaction() as cursor: cursor.execute( @@ -308,10 +322,10 @@ def _set_queue_item_status( cursor.execute( """--sql UPDATE session_queue - SET status = ?, status_sequence = COALESCE(status_sequence, 0) + 1, error_type = ?, error_message = ?, error_traceback = ? + SET status = ?, status_sequence = COALESCE(status_sequence, 0) + 1, error_type = ?, error_message = ?, error_traceback = ?, device = COALESCE(?, device) WHERE item_id = ? """, - (status, error_type, error_message, error_traceback, item_id), + (status, error_type, error_message, error_traceback, device, item_id), ) queue_item = self.get_queue_item(item_id) @@ -461,30 +475,61 @@ def fail_queue_item( ) return queue_item + def _cancel_in_progress_matching(self, match_filter: str, params: list[Any]) -> int: + """Cancel every in-progress item matching `match_filter`, emitting a cancel event for each. + + The bulk-cancel methods exclude in-progress items from their single UPDATE statement, because + a running item must be canceled via `_set_queue_item_status()` so that its + `QueueItemStatusChangedEvent` is emitted — the session processor responds to that event by + setting the cancel event of the worker running that exact item_id. With multiple workers + (multi-GPU) more than one item can be in_progress at once, so each matching item is canceled + individually here rather than relying on a single `get_current()` (which returns only one). + + `match_filter` is a WHERE fragment without the leading WHERE (e.g. + "queue_id == ? AND batch_id IN (?, ?)"); `params` are its bound values. + + Returns the number of in-progress items actually canceled. + """ + with self._db.transaction() as cursor: + cursor.execute( + f"""--sql + SELECT item_id + FROM session_queue + WHERE status == 'in_progress' AND {match_filter}; + """, + tuple(params), + ) + item_ids = [row[0] for row in cursor.fetchall()] + + canceled = 0 + for item_id in item_ids: + # _set_queue_item_status no-ops (and returns the existing item) if the item finished + # between the SELECT and now, so count only the ones we actually moved to 'canceled'. + if self._set_queue_item_status(item_id, "canceled").status == "canceled": + canceled += 1 + return canceled + def cancel_by_batch_ids( self, queue_id: str, batch_ids: list[str], user_id: Optional[str] = None ) -> CancelByBatchIDsResult: - with self._db.transaction() as cursor: - current_queue_item = self.get_current(queue_id) - placeholders = ", ".join(["?" for _ in batch_ids]) + placeholders = ", ".join(["?" for _ in batch_ids]) + # Build the match filter (with optional user_id filter) shared by the bulk update and the + # in-progress cancellation below. + user_filter = "AND user_id = ?" if user_id is not None else "" + match_filter = f"queue_id == ? AND batch_id IN ({placeholders}) {user_filter}" + params: list[Any] = [queue_id] + batch_ids + if user_id is not None: + params.append(user_id) - # Build WHERE clause with optional user_id filter - user_filter = "AND user_id = ?" if user_id is not None else "" + with self._db.transaction() as cursor: where = f"""--sql - WHERE - queue_id == ? - AND batch_id IN ({placeholders}) + WHERE {match_filter} AND status != 'canceled' AND status != 'completed' AND status != 'failed' - -- We will cancel the current item separately below - skip it here + -- In-progress items are canceled individually below so each worker is signaled. AND status != 'in_progress' - {user_filter} """ - params = [queue_id] + batch_ids - if user_id is not None: - params.append(user_id) - cursor.execute( f"""--sql SELECT COUNT(*) @@ -504,36 +549,28 @@ def cancel_by_batch_ids( tuple(params), ) - # Handle current item separately - check ownership if user_id is provided - if current_queue_item is not None and current_queue_item.batch_id in batch_ids: - if user_id is None or current_queue_item.user_id == user_id: - self._set_queue_item_status(current_queue_item.item_id, "canceled") - + # Cancel every in-progress item matching the same filter (multi-GPU: possibly several at once). + count += self._cancel_in_progress_matching(match_filter, params) return CancelByBatchIDsResult(canceled=count) def cancel_by_destination( self, queue_id: str, destination: str, user_id: Optional[str] = None ) -> CancelByDestinationResult: - with self._db.transaction() as cursor: - current_queue_item = self.get_current(queue_id) + user_filter = "AND user_id = ?" if user_id is not None else "" + match_filter = f"queue_id == ? AND destination == ? {user_filter}" + params: list[Any] = [queue_id, destination] + if user_id is not None: + params.append(user_id) - # Build WHERE clause with optional user_id filter - user_filter = "AND user_id = ?" if user_id is not None else "" + with self._db.transaction() as cursor: where = f"""--sql - WHERE - queue_id == ? - AND destination == ? + WHERE {match_filter} AND status != 'canceled' AND status != 'completed' AND status != 'failed' - -- We will cancel the current item separately below - skip it here + -- In-progress items are canceled individually below so each worker is signaled. AND status != 'in_progress' - {user_filter} """ - params = [queue_id, destination] - if user_id is not None: - params.append(user_id) - cursor.execute( f"""--sql SELECT COUNT(*) @@ -553,30 +590,26 @@ def cancel_by_destination( tuple(params), ) - # Handle current item separately - check ownership if user_id is provided - if current_queue_item is not None and current_queue_item.destination == destination: - if user_id is None or current_queue_item.user_id == user_id: - self._set_queue_item_status(current_queue_item.item_id, "canceled") - + # Cancel every in-progress item matching the same filter (multi-GPU: possibly several at once). + count += self._cancel_in_progress_matching(match_filter, params) return CancelByDestinationResult(canceled=count) def delete_by_destination( self, queue_id: str, destination: str, user_id: Optional[str] = None ) -> DeleteByDestinationResult: - with self._db.transaction() as cursor: - current_queue_item = self.get_current(queue_id) + user_filter = "AND user_id = ?" if user_id is not None else "" + match_filter = f"queue_id == ? AND destination == ? {user_filter}" + params: list[Any] = [queue_id, destination] + if user_id is not None: + params.append(user_id) - # Handle current item separately - check ownership if user_id is provided - if current_queue_item is not None and current_queue_item.destination == destination: - if user_id is None or current_queue_item.user_id == user_id: - self.cancel_queue_item(current_queue_item.item_id) - - # Build WHERE clause with optional user_id filter - user_filter = "AND user_id = ?" if user_id is not None else "" - params = [queue_id, destination] - if user_id is not None: - params.append(user_id) + # Cancel every in-progress item first so each running worker is signaled to stop before we + # delete its row. With multiple workers (multi-GPU) more than one item can be in_progress; + # canceling only get_current() would leave the others running (and then failing to update a + # deleted row). See _cancel_in_progress_matching. + self._cancel_in_progress_matching(match_filter, params) + with self._db.transaction() as cursor: cursor.execute( f"""--sql SELECT COUNT(*) @@ -635,18 +668,18 @@ def delete_all_except_current(self, queue_id: str, user_id: Optional[str] = None return DeleteAllExceptCurrentResult(deleted=count) def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult: + match_filter = "queue_id == ?" + params: list[Any] = [queue_id] + with self._db.transaction() as cursor: - current_queue_item = self.get_current(queue_id) - where = """--sql - WHERE - queue_id is ? + where = f"""--sql + WHERE {match_filter} AND status != 'canceled' AND status != 'completed' AND status != 'failed' - -- We will cancel the current item separately below - skip it here + -- In-progress items are canceled individually below so each worker is signaled. AND status != 'in_progress' """ - params = [queue_id] cursor.execute( f"""--sql SELECT COUNT(*) @@ -666,8 +699,8 @@ def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult: tuple(params), ) - if current_queue_item is not None and current_queue_item.queue_id == queue_id: - self._set_queue_item_status(current_queue_item.item_id, "canceled") + # Cancel every in-progress item in the queue (multi-GPU: possibly several at once). + count += self._cancel_in_progress_matching(match_filter, params) return CancelByQueueIDResult(canceled=count) def cancel_all_except_current(self, queue_id: str, user_id: Optional[str] = None) -> CancelAllExceptCurrentResult: diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 3e1d5c53f3e..14b6c61a85a 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -35,6 +35,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_30 import build_migration_30 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_31 import build_migration_31 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_32 import build_migration_32 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_33 import build_migration_33 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -87,6 +88,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_30()) migrator.register_migration(build_migration_31()) migrator.register_migration(build_migration_32()) + migrator.register_migration(build_migration_33()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_33.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_33.py new file mode 100644 index 00000000000..d99e1897137 --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_33.py @@ -0,0 +1,36 @@ +"""Migration 33: Add device column to session_queue table. + +This records which device (e.g. 'cuda:1') processed a queue item, so the UI can show a per-item +GPU number in the Session Queue. Existing rows get NULL (unknown device). +""" + +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration33Callback: + """Migration to add a device column to the session_queue table.""" + + def __call__(self, cursor: sqlite3.Cursor) -> None: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='session_queue';") + if cursor.fetchone() is None: + return + + cursor.execute("PRAGMA table_info(session_queue);") + columns = [row[1] for row in cursor.fetchall()] + + if "device" not in columns: + cursor.execute("ALTER TABLE session_queue ADD COLUMN device TEXT;") + + +def build_migration_33() -> Migration: + """Builds the migration object for migrating from version 32 to version 33. + + This migration adds a device column to the session_queue table. + """ + return Migration( + from_version=32, + to_version=33, + callback=Migration33Callback(), + ) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index 0f4cf07ee5b..7b29a58d44f 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -15,6 +15,7 @@ from invokeai.backend.flux.model import Flux from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.util.devices import TorchDevice def denoise( @@ -95,7 +96,7 @@ def denoise( # Use diffusers scheduler for stepping # Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps) # This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps - pbar = tqdm(total=total_steps, desc="Denoising") + pbar = tqdm(total=total_steps, desc=f"Denoising{TorchDevice.get_session_device_label()}") for step_index in range(num_scheduler_steps): timestep = scheduler.timesteps[step_index] # Convert scheduler timestep (0-1000) to normalized (0-1) for the model @@ -266,7 +267,10 @@ def denoise( return img # Original Euler implementation (when scheduler is None) - for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): + for step_index, (t_curr, t_prev) in tqdm( + list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True))), + desc=f"Denoising{TorchDevice.get_session_device_label()}", + ): # DyPE: Update step state for timestep-dependent scaling if dype_extension is not None and dype_embedder is not None: dype_extension.update_step_state( diff --git a/invokeai/backend/flux2/denoise.py b/invokeai/backend/flux2/denoise.py index 30c287840b3..561a844bfb8 100644 --- a/invokeai/backend/flux2/denoise.py +++ b/invokeai/backend/flux2/denoise.py @@ -14,6 +14,7 @@ from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.util.devices import TorchDevice def denoise( @@ -118,7 +119,7 @@ def denoise( is_heun = hasattr(scheduler, "state_in_first_order") user_step = 0 - pbar = tqdm(total=total_steps, desc="Denoising") + pbar = tqdm(total=total_steps, desc=f"Denoising{TorchDevice.get_session_device_label()}") for step_index in range(num_scheduler_steps): timestep = scheduler.timesteps[step_index] # Convert scheduler timestep (0-1000) to normalized (0-1) for the model @@ -230,7 +231,10 @@ def denoise( pbar.close() else: # Manual Euler stepping (original behavior) - for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): + for step_index, (t_curr, t_prev) in tqdm( + list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True))), + desc=f"Denoising{TorchDevice.get_session_device_label()}", + ): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) # Run the transformer model (matching diffusers: guidance=guidance, return_dict=False) diff --git a/invokeai/backend/flux2/ref_image_extension.py b/invokeai/backend/flux2/ref_image_extension.py index 9cc6240db66..368f3c4452f 100644 --- a/invokeai/backend/flux2/ref_image_extension.py +++ b/invokeai/backend/flux2/ref_image_extension.py @@ -208,16 +208,32 @@ def _prepare_ref_images(self) -> tuple[torch.Tensor, torch.Tensor]: vae_dtype = next(iter(vae.parameters())).dtype image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype) - # FLUX.2 VAE uses diffusers API - latent_dist = vae.encode(image_tensor, return_dict=False)[0] - - # Use mode() for deterministic encoding (no sampling) - if hasattr(latent_dist, "mode"): - ref_image_latents_unpacked = latent_dist.mode() - elif hasattr(latent_dist, "sample"): - ref_image_latents_unpacked = latent_dist.sample() - else: - ref_image_latents_unpacked = latent_dist + # The FLUX.2 VAE encoder's mid-block self-attention scales quadratically with the + # input's spatial size (and on ROCm, SDPA falls back to a *materialized* attention + # matrix), so encoding a reference image at full size OOMs VRAM — ~15GB at 1024px, + # hundreds of GB at the 2024px reference cap. Tile the encode to bound peak memory + # regardless of reference resolution. The VAE's default tile size equals its + # sample_size (1024), which still OOMs per tile, so force a smaller 512px tile. + # Save/restore the tiling config because this VAE is a shared, cached instance (e.g. + # the final image decode must not inherit these settings). + downsample = 2 ** (len(vae.config.block_out_channels) - 1) + prev_tiling = (vae.use_tiling, vae.tile_sample_min_size, vae.tile_latent_min_size) + vae.use_tiling = True + vae.tile_sample_min_size = 512 + vae.tile_latent_min_size = 512 // downsample + try: + # FLUX.2 VAE uses diffusers API + latent_dist = vae.encode(image_tensor, return_dict=False)[0] + + # Use mode() for deterministic encoding (no sampling) + if hasattr(latent_dist, "mode"): + ref_image_latents_unpacked = latent_dist.mode() + elif hasattr(latent_dist, "sample"): + ref_image_latents_unpacked = latent_dist.sample() + else: + ref_image_latents_unpacked = latent_dist + finally: + vae.use_tiling, vae.tile_sample_min_size, vae.tile_latent_min_size = prev_tiling TorchDevice.empty_cache() diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 4609a2e92ab..984362f185d 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -17,7 +17,7 @@ from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( CachedModelWithPartialLoad, ) -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.model_cache import MODEL_LOAD_LOCK, ModelCache from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType @@ -57,7 +57,12 @@ def __init__(self, cache_record: CacheRecord, cache: ModelCache): self._cache = cache def __enter__(self) -> AnyModel: - self._cache.lock(self._cache_record, None) + # Hold the MODEL_LOAD_LOCK read lock across the VRAM load (lock() runs + # load_state_dict(assign=True), which calls register_parameter) so it can't overlap a + # concurrent model construction that has the global register_parameter -> meta patch active. + # Acquired before the cache's own lock to keep a consistent lock order (see MODEL_LOAD_LOCK). + with MODEL_LOAD_LOCK.read_lock(): + self._cache.lock(self._cache_record, None) try: self.repair_required_tensors_on_device() return self.model @@ -77,7 +82,9 @@ def model_on_device( :param working_mem_bytes: The amount of working memory to keep available on the compute device when loading the model. """ - self._cache.lock(self._cache_record, working_mem_bytes) + # See __enter__ for why the VRAM load is wrapped in the read lock. + with MODEL_LOAD_LOCK.read_lock(): + self._cache.lock(self._cache_record, working_mem_bytes) try: self.repair_required_tensors_on_device() yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model) @@ -94,7 +101,12 @@ def repair_required_tensors_on_device(self) -> int: cached_model = self._cache_record.cached_model if not isinstance(cached_model, CachedModelWithPartialLoad): return 0 - return cached_model.repair_required_tensors_on_compute_device() + # Repair runs load_state_dict(assign=True) -> register_parameter, so it must hold the read + # lock to avoid being hijacked onto the `meta` device by a concurrent construction. This is + # also called directly (outside __enter__/model_on_device) by some text-encoder invocations, + # so the guard lives here rather than only at the call sites. + with MODEL_LOAD_LOCK.read_lock(): + return cached_model.repair_required_tensors_on_compute_device() class LoadedModel(LoadedModelWithoutConfig): diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 040b55cb6ec..de87c797e8e 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -1,6 +1,8 @@ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team """Default implementation of model loading in InvokeAI.""" +import copy +import itertools import re from logging import Logger from pathlib import Path @@ -13,7 +15,11 @@ from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key +from invokeai.backend.model_manager.load.model_cache.model_cache import ( + MODEL_LOAD_LOCK, + ModelCache, + get_model_cache_key, +) from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.model_manager.taxonomy import ( @@ -52,7 +58,9 @@ ) -# TO DO: The loader is not thread safe! +# The construction path is not thread-safe on its own; it monkey-patches process-global torch state +# (see MODEL_LOAD_LOCK). Concurrent callers must hold the MODEL_LOAD_LOCK write lock (see +# _load_and_cache). class ModelLoader(ModelLoaderBase): """Default implementation of ModelLoaderBase.""" @@ -85,8 +93,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo if not model_path.exists(): raise FileNotFoundError(f"Files for model '{model_config.name}' not found at {model_path}") - with skip_torch_weight_init(): - cache_record = self._load_and_cache(model_config, submodel_type) + cache_record = self._load_and_cache(model_config, submodel_type) return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache) @property @@ -124,25 +131,70 @@ def _get_execution_device( def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord: stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")]) + cache_key = get_model_cache_key(config.key, submodel_type) try: - return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name) + return self._ram_cache.get(key=cache_key, stats_name=stats_name) except IndexError: pass - config.path = str(self._get_model_path(config)) - self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type)) - loaded_model = self._load_model(config, submodel_type) - - # Determine execution device from model config, considering submodel type - execution_device = self._get_execution_device(config, submodel_type) - - self._ram_cache.put( - get_model_cache_key(config.key, submodel_type), - model=loaded_model, - execution_device=execution_device, - ) - - return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name) + # Cache miss: construct the model from disk. This path holds the MODEL_LOAD_LOCK *write* + # lock because it relies on process-global, non-thread-safe monkey-patches + # (skip_torch_weight_init and, inside the loaders, accelerate.init_empty_weights / diffusers + # low_cpu_mem_usage). The write lock excludes both other constructions AND concurrent VRAM + # load/unload on other workers (which take the read lock); without that, a concurrent move's + # load_state_dict(assign=True) -> register_parameter gets hijacked onto the `meta` device. + # See MODEL_LOAD_LOCK for the full explanation. + # + # Lock-ordering: the write lock is acquired before any ModelCache._lock taken below + # (get/make_room/put), matching the readers' order, so there is no AB-BA deadlock. + with MODEL_LOAD_LOCK.write_lock(): + # Double-checked locking: another worker sharing this cache may have loaded the same + # entry while we waited for the mutex. (Workers on other devices use a different cache, + # so they will still miss here and construct their own copy — which is intended.) + try: + return self._ram_cache.get(key=cache_key, stats_name=stats_name) + except IndexError: + pass + + config.path = str(self._get_model_path(config)) + + # Fast path (multi-GPU): if another device already loaded this exact model, its canonical + # CPU weights are still resident in the shared store along with an empty (meta-weight) + # clone of the built module. Adopt those weights instead of re-reading the model from + # disk — this avoids both the redundant disk read and the large transient second copy + # that would otherwise spike RAM (and, on a RAM-constrained box, drive the system into + # swap). Any failure falls back to a normal load, so it can never change the result. + loaded_model = self._try_adopt_shared_weights(cache_key) + + shell_to_register: Optional[torch.nn.Module] = None + if loaded_model is None: + self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type)) + with skip_torch_weight_init(): + loaded_model = self._load_model(config, submodel_type) + # Snapshot a meta-weight clone now — before put() applies custom layers or any VRAM + # move — so the next device to load this model can adopt these weights (see above). + # Skipped in single-device setups, where no other cache will ever adopt it. + shared_store = self._ram_cache.shared_cpu_weights + if shared_store is not None and shared_store.enable_shell_capture: + shell_to_register = self._build_meta_shell(loaded_model) + + # Determine execution device from model config, considering submodel type + execution_device = self._get_execution_device(config, submodel_type) + + self._ram_cache.put( + cache_key, + model=loaded_model, + execution_device=execution_device, + ) + + # Register the shell only after put() has created the shared entry (via the wrapper's + # acquire); it is dropped automatically when that entry's last reference is released. + if shell_to_register is not None: + shared_store = self._ram_cache.shared_cpu_weights + if shared_store is not None: + shared_store.set_shell(cache_key, shell_to_register) + + return self._ram_cache.get(key=cache_key, stats_name=stats_name) def get_size_fs( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None @@ -303,6 +355,86 @@ def post_hook(mod: torch.nn.Module, _args: object, _output: object) -> None: module.register_forward_pre_hook(pre_hook) module.register_forward_hook(post_hook, always_call=True) + def _try_adopt_shared_weights(self, cache_key: str) -> Optional[AnyModel]: + """Build this model by adopting another device's already-resident CPU weights, skipping the + disk read entirely. + + Returns the constructed model, or None if adoption is unavailable or fails for any reason (in + which case the caller loads the model from disk normally). Loader-agnostic: it deep-copies the + meta-weight shell that the first device registered (`_build_meta_shell`) and assigns the + shared canonical weights into the copy — no per-loader architecture knowledge required, and + fp8 cast hooks carried by the shell are preserved automatically. + + Must be called while holding the MODEL_LOAD_LOCK write lock (as `_load_and_cache` does), so + the peeked canonical weights and shell cannot be evicted between the peek and the adopt. + """ + shared_store = self._ram_cache.shared_cpu_weights + if shared_store is None: + return None + canonical = shared_store.peek(cache_key) + shell = shared_store.get_shell(cache_key) + if canonical is None or shell is None: + return None + + try: + # Independent module per device (its params will be moved to its own GPU); deep-copying an + # all-meta shell is cheap (no weight data). assign=True then re-points the copy's + # parameters at the shared canonical tensors with no allocation. + model = copy.deepcopy(shell) + model.load_state_dict(canonical, assign=True) + # Safety net: if anything is left on the meta device (e.g. a persistent buffer somehow + # missing from the canonical state dict) the model would silently produce wrong results. + for tensor in itertools.chain(model.parameters(), model.buffers()): + if tensor.is_meta: + raise RuntimeError("adopted model has tensors left on the meta device") + except Exception as e: + # Adoption is best-effort; never let it break a load. Fall back to a normal disk load. + self._logger.warning( + f"Could not adopt shared CPU weights for '{cache_key}' ({e!r}); loading from disk instead." + ) + return None + + self._logger.info( + f"Adopted shared CPU weights for '{cache_key}' from another device's cache (skipped disk load)." + ) + return model + + @staticmethod + def _build_meta_shell(model: AnyModel) -> Optional[torch.nn.Module]: + """Return an empty, meta-weight structural clone of `model`, or None if it can't be cloned. + + The clone has the identical module structure, registered hooks (e.g. the fp8 layerwise-cast + hooks), and non-persistent buffers as `model`, but every parameter and persistent buffer is + replaced by a 0-byte tensor on the `meta` device. A second device adopts it by deep-copying + and assigning the shared canonical weights — so this works for every model family (diffusers, + single-file checkpoint, GGUF, transformers) without any per-loader code. + + Best-effort: returns None on any failure (the model then simply isn't adoptable, and the next + device loads it from disk as before). + """ + if not isinstance(model, torch.nn.Module): + return None + try: + # Persistent buffers come from the canonical state dict on adoption, so they (like params) + # are replaced by meta placeholders. Non-persistent buffers are NOT in the state dict, so + # they must be carried over with real data (deepcopy copies them); they are typically + # small (e.g. rotary-embedding tables, attention masks). + persistent_names = set(model.state_dict().keys()) + persistent_buffer_ids = {id(b) for n, b in model.named_buffers() if n in persistent_names} + + memo: dict[int, object] = {} + for param in model.parameters(recurse=True): + memo[id(param)] = torch.nn.Parameter( + torch.empty_like(param, device="meta"), requires_grad=param.requires_grad + ) + for buffer in model.buffers(recurse=True): + if id(buffer) in persistent_buffer_ids: + memo[id(buffer)] = torch.empty_like(buffer, device="meta") + + return copy.deepcopy(model, memo) + except Exception: + return None + # This needs to be implemented in the subclass def _load_model( self, diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py index bb04edef9b5..243a00015d6 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py @@ -2,6 +2,7 @@ import torch +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor @@ -12,7 +13,13 @@ class CachedModelOnlyFullLoad: """ def __init__( - self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int, keep_ram_copy: bool = False + self, + model: torch.nn.Module | Any, + compute_device: torch.device, + total_bytes: int, + keep_ram_copy: bool = False, + shared_store: SharedCpuWeightsStore | None = None, + cache_key: str | None = None, ): """Initialize a CachedModelOnlyFullLoad. Args: @@ -22,16 +29,40 @@ def __init__( keep_ram_copy (bool): Whether to keep a read-only copy of the model's state dict in RAM. Keeping a RAM copy increases RAM usage, but speeds up model offload from VRAM and LoRA patching (assuming there is sufficient RAM). + shared_store (SharedCpuWeightsStore | None): If provided (along with cache_key), share a single canonical + CPU copy of the weights across per-device caches instead of one copy per device. + cache_key (str | None): The model cache key used to identify shared weights in `shared_store`. """ # model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases. self._model = model self._compute_device = compute_device self._offload_device = torch.device("cpu") + # When set, this model's CPU weights are a shared canonical copy owned by `shared_store` + # under `cache_key`; `release_shared_weights()` must be called exactly once on eviction. + self._shared_store: SharedCpuWeightsStore | None = None + self._shared_key: str | None = None # A CPU read-only copy of the model's state dict. self._cpu_state_dict: dict[str, torch.Tensor] | None = None if isinstance(model, torch.nn.Module) and keep_ram_copy: - self._cpu_state_dict = model.state_dict() + cpu_state_dict = model.state_dict() + # In multi-GPU mode, share one canonical CPU copy across the per-device caches (see + # SharedCpuWeightsStore). If another device already registered this key, re-point our + # module at the shared tensors and drop our duplicate so the weights live once in RAM. + if shared_store is not None and cache_key is not None: + canonical = shared_store.acquire(cache_key, cpu_state_dict) + self._shared_store = shared_store + self._shared_key = cache_key + try: + if canonical is not cpu_state_dict: + model.load_state_dict(canonical, assign=True) + cpu_state_dict = canonical + except Exception: + # The re-point failed after acquiring a reference; release it so the shared + # entry's refcount isn't leaked (this wrapper will never enter the cache). + self.release_shared_weights() + raise + self._cpu_state_dict = cpu_state_dict self._total_bytes = total_bytes self._is_in_vram = False @@ -45,6 +76,27 @@ def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None: # TODO(ryand): Document this better. return self._cpu_state_dict + @property + def uses_shared_weights(self) -> bool: + """True if this model's CPU weights are deduplicated in a SharedCpuWeightsStore. + + When True, its RAM is accounted by the store (counted once across devices); when False, its + RAM is per-instance and must be counted by the RamBudget's non-shared total. + """ + return self._shared_store is not None + + def release_shared_weights(self) -> None: + """Release this model's reference to its shared canonical CPU weights, if any. + + Must be called exactly once when the cache entry is evicted. Idempotent: a second call is a + no-op. After release, the shared store frees the canonical tensors once the last device that + held this key releases it. + """ + if self._shared_store is not None and self._shared_key is not None: + self._shared_store.release(self._shared_key) + self._shared_store = None + self._shared_key = None + def total_bytes(self) -> int: """Get the total size (in bytes) of all the weights in the model.""" return self._total_bytes diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py index 328978b45b1..2a1d83cb011 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py @@ -1,5 +1,6 @@ import torch +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) @@ -14,31 +15,65 @@ class CachedModelWithPartialLoad: MPS memory, etc. """ - def __init__(self, model: torch.nn.Module, compute_device: torch.device, keep_ram_copy: bool = False): + def __init__( + self, + model: torch.nn.Module, + compute_device: torch.device, + keep_ram_copy: bool = False, + shared_store: SharedCpuWeightsStore | None = None, + cache_key: str | None = None, + ): self._model = model self._compute_device = compute_device + # When set, this model's CPU weights are a shared canonical copy owned by `shared_store` + # under `cache_key`; `release_shared_weights()` must be called exactly once on eviction. + self._shared_store: SharedCpuWeightsStore | None = None + self._shared_key: str | None = None model_state_dict = model.state_dict() # A CPU read-only copy of the model's state dict. Used for faster model unloads from VRAM, and to speed up LoRA # patching. Set to `None` if keep_ram_copy is False. - self._cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None + cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None # A dictionary of the size of each tensor in the state dict. # HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for # consistency in case the application code has modified the model's size (e.g. by casting to a different # precision). Of course, this means that we are making model cache load/unload decisions based on model size # data that may not be fully accurate. + # + # Note: these are computed from the model's own state dict *before* the shared-weights re-point + # below. The re-point only swaps tensor storage; keys, shapes and dtypes are unchanged, so the + # metadata is identical either way. Computing it first keeps the acquire the last (and only + # failure-prone) step, so a failure there can release the reference cleanly without leaking it. self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in model_state_dict.items()} - self._total_bytes = sum(self._state_dict_bytes.values()) self._cur_vram_bytes: int | None = None - self._modules_that_support_autocast = self._find_modules_that_support_autocast() self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast( model_state_dict ) self._state_dict_keys_by_module_prefix = self._group_state_dict_keys_by_module_prefix(model_state_dict) + # In multi-GPU mode, share a single canonical CPU copy of the weights across the per-device + # caches instead of keeping one copy per device (see SharedCpuWeightsStore). If another + # device already registered this key, re-point our module's params at the shared tensors and + # drop our freshly-built duplicate so the weights live once in RAM. + if cpu_state_dict is not None and shared_store is not None and cache_key is not None: + canonical = shared_store.acquire(cache_key, cpu_state_dict) + self._shared_store = shared_store + self._shared_key = cache_key + try: + if canonical is not cpu_state_dict: + self._model.load_state_dict(canonical, assign=True) + cpu_state_dict = canonical + except Exception: + # The re-point failed after acquiring a reference; release it so the shared entry's + # refcount isn't leaked (this wrapper will never be inserted into the cache). + self.release_shared_weights() + raise + + self._cpu_state_dict: dict[str, torch.Tensor] | None = cpu_state_dict + def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]: """Find all modules that support autocasting.""" return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore @@ -121,6 +156,27 @@ def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None: # TODO(ryand): Document this better. return self._cpu_state_dict + @property + def uses_shared_weights(self) -> bool: + """True if this model's CPU weights are deduplicated in a SharedCpuWeightsStore. + + When True, its RAM is accounted by the store (counted once across devices); when False, its + RAM is per-instance and must be counted by the RamBudget's non-shared total. + """ + return self._shared_store is not None + + def release_shared_weights(self) -> None: + """Release this model's reference to its shared canonical CPU weights, if any. + + Must be called exactly once when the cache entry is evicted. Idempotent: a second call is a + no-op. After release, the shared store frees the canonical tensors once the last device that + held this key releases it. + """ + if self._shared_store is not None and self._shared_key is not None: + self._shared_store.release(self._shared_key) + self._shared_store = None + self._shared_key = None + def total_bytes(self) -> int: """Get the total size (in bytes) of all the weights in the model.""" return self._total_bytes diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index e3a0928e52b..7bc6931e5a3 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -2,10 +2,11 @@ import logging import threading import time +from contextlib import contextmanager from dataclasses import dataclass from functools import wraps from logging import Logger -from typing import Any, Callable, Dict, List, Optional, Protocol +from typing import Any, Callable, Dict, Generator, List, Optional, Protocol import psutil import torch @@ -19,6 +20,11 @@ from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( CachedModelWithPartialLoad, ) +from invokeai.backend.model_manager.load.model_cache.ram_budget import RamBudget +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import ( + SHARED_CPU_WEIGHTS, + SharedCpuWeightsStore, +) from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( apply_custom_layers_to_model, ) @@ -34,6 +40,84 @@ # Size of a MB in bytes. MB = 2**20 +# Default RAM-cache sizing constants. These are used both by the per-device heuristic +# (_calc_ram_available_to_model_cache) and by the multi-GPU global budget cap +# (ModelManagerService.build_model_manager), so the two stay consistent. +# +# - RAM_CACHE_SYSTEM_FRACTION: fraction of total system RAM the model cache may use by default. +# - RAM_CACHE_BASELINE_BYTES: assumed non-model RAM used by InvokeAI itself, reserved before sizing. +# - MIN_RAM_CACHE_BYTES: absolute floor so the cache is never sized uselessly small. +RAM_CACHE_SYSTEM_FRACTION = 0.5 +RAM_CACHE_BASELINE_BYTES = 2 * GB +MIN_RAM_CACHE_BYTES = 4 * GB + + +class _ModelLoadReadWriteLock: + """A write-preferring readers-writer lock that serializes model construction against VRAM moves. + + The model load machinery depends on PROCESS-GLOBAL monkey-patches that are not thread-safe: + model CONSTRUCTION (diffusers `from_pretrained` / `accelerate.init_empty_weights`) temporarily + replaces `torch.nn.Module.register_parameter` so that every newly-registered parameter is routed + to the `meta` device. While that patch is installed, ANY `register_parameter` call in ANY thread + is hijacked onto `meta`. VRAM load/unload uses `nn.Module.load_state_dict(assign=True)`, which + assigns `Parameter`s via `__setattr__` -> `register_parameter` — so if it runs concurrently with + a construction on another worker thread, its real weights get stranded on `meta`. That surfaces + later as "Cannot copy out of meta tensor; no data!" or "unrecognized device meta". + + - Construction takes the WRITE lock (exclusive — no reader and no other writer may run). + - VRAM load/unload takes the READ lock (shared, so concurrent moves on different GPUs still + overlap each other; they only block while a construction holds the write lock). + + Write-preferring: once a construction is waiting, new readers queue behind it, so a steady stream + of VRAM moves from busy workers can't starve a pending load. + + Lock-ordering contract: callers MUST acquire this lock *before* any `ModelCache._lock`, never + after. Readers do so by taking the read lock around the outer `ModelCache.lock()` call (see + `LoadedModelWithoutConfig`), and writers around the whole construction (see + `ModelLoader._load_and_cache`). Acquiring it in the other order — cache lock first, then this + lock — would risk an AB-BA deadlock with a writer that takes a cache lock during `put()`. + """ + + def __init__(self) -> None: + self._cond = threading.Condition(threading.Lock()) + self._readers = 0 + self._writers_waiting = 0 + self._writer_active = False + + @contextmanager + def read_lock(self) -> Generator[None, None, None]: + with self._cond: + # Defer to any active or waiting writer (write-preferring). + while self._writer_active or self._writers_waiting > 0: + self._cond.wait() + self._readers += 1 + try: + yield + finally: + with self._cond: + self._readers -= 1 + if self._readers == 0: + self._cond.notify_all() + + @contextmanager + def write_lock(self) -> Generator[None, None, None]: + with self._cond: + self._writers_waiting += 1 + while self._writer_active or self._readers > 0: + self._cond.wait() + self._writers_waiting -= 1 + self._writer_active = True + try: + yield + finally: + with self._cond: + self._writer_active = False + self._cond.notify_all() + + +# Process-global lock guarding the non-thread-safe model load machinery. See _ModelLoadReadWriteLock. +MODEL_LOAD_LOCK = _ModelLoadReadWriteLock() + # TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels. def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str: @@ -148,6 +232,8 @@ def __init__( log_memory_usage: bool = False, logger: Optional[Logger] = None, keep_alive_minutes: float = 0, + shared_cpu_weights: SharedCpuWeightsStore | None = SHARED_CPU_WEIGHTS, + ram_budget: RamBudget | None = None, ): """Initialize the model RAM cache. @@ -168,7 +254,15 @@ def __init__( behaviour. :param logger: InvokeAILogger to use (otherwise creates one) :param keep_alive_minutes: How long to keep models in cache after last use (in minutes). 0 means keep indefinitely. + :param shared_cpu_weights: Process-global store that lets per-device caches share a single CPU copy of each + model's weights (see SharedCpuWeightsStore). Defaults to the global store so that, in multi-GPU mode, a + model loaded on multiple GPUs occupies RAM only once. Pass None to disable sharing for this cache. + :param ram_budget: Optional shared RamBudget used as the single global RAM authority across all per-device + caches. When provided, eviction decisions are made against the deduplicated, system-wide RAM total rather + than this cache's local (double-counted) sum. When None, the cache uses its own local RAM accounting. """ + self._shared_cpu_weights = shared_cpu_weights + self._ram_budget = ram_budget self._enable_partial_loading = enable_partial_loading self._keep_ram_copy_of_weights = keep_ram_copy_of_weights self._execution_device_working_mem_gb = execution_device_working_mem_gb @@ -229,6 +323,36 @@ def unsubscribe() -> None: return unsubscribe + @property + def execution_device(self) -> torch.device: + """Return the default execution device this cache loads models onto.""" + return self._execution_device + + @property + def shared_cpu_weights(self) -> SharedCpuWeightsStore | None: + """The process-global store this cache deduplicates CPU weights into, or None if disabled. + + Exposed so the loader can check (via `peek`) whether another device already holds a model's + canonical CPU weights and adopt them at construction time instead of re-reading from disk. + """ + return self._shared_cpu_weights + + def set_ram_budget(self, ram_budget: RamBudget) -> None: + """Attach the shared global RamBudget after construction. + + Used by the model manager once all per-device caches exist and the global cap has been + computed from their individual sizes (see ModelManagerService.build_model_manager). + """ + self._ram_budget = ram_budget + + @property + def local_ram_cache_size_bytes(self) -> int: + """The RAM cache size this cache computed for itself (from max_cache_ram_gb or the heuristic). + + Used by the model manager to seed the global RamBudget cap when no explicit limit is set. + """ + return self._ram_cache_size_bytes + @property @synchronized def stats(self) -> Optional[CacheStats]: @@ -240,9 +364,12 @@ def stats(self) -> Optional[CacheStats]: def stats(self, stats: CacheStats) -> None: """Set the CacheStats object for collecting cache statistics.""" self._stats = stats - # Populate the cache size in the stats object when it's set + # Populate the cache size in the stats object when it's set. Prefer the global budget cap + # (the real system-wide limit) when one is attached. if self._stats is not None: - self._stats.cache_size = self._ram_cache_size_bytes + self._stats.cache_size = ( + self._ram_budget.max_bytes if self._ram_budget is not None else self._ram_cache_size_bytes + ) def _record_activity(self) -> None: """Record model activity and reset the timeout timer if configured. @@ -350,16 +477,30 @@ def put(self, key: str, model: AnyModel, execution_device: Optional[torch.device # Wrap model. if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading: wrapped_model = CachedModelWithPartialLoad( - model, effective_execution_device, keep_ram_copy=self._keep_ram_copy_of_weights + model, + effective_execution_device, + keep_ram_copy=self._keep_ram_copy_of_weights, + shared_store=self._shared_cpu_weights, + cache_key=key, ) else: wrapped_model = CachedModelOnlyFullLoad( - model, effective_execution_device, size, keep_ram_copy=self._keep_ram_copy_of_weights + model, + effective_execution_device, + size, + keep_ram_copy=self._keep_ram_copy_of_weights, + shared_store=self._shared_cpu_weights, + cache_key=key, ) cache_record = CacheRecord(key=key, cached_model=wrapped_model) self._cached_models[key] = cache_record self._cache_stack.append(key) + # Account this model's RAM in the global budget. Shared weights are tracked once by the + # SharedCpuWeightsStore; only non-deduplicated models are added to the budget's non-shared + # total (a non-shared model resident on N devices correctly counts N times). + if self._ram_budget is not None and not wrapped_model.uses_shared_weights: + self._ram_budget.add_non_shared(wrapped_model.total_bytes()) self._logger.debug( f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size / MB:.2f}MB)" ) @@ -546,9 +687,13 @@ def _load_locked_model(self, cache_entry: CacheRecord, working_mem_bytes: Option loaded_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0 # Use the model's actual compute_device for logging, not the cache's default model_device = cache_entry.cached_model.compute_device + if model_device.type == "cuda": + device_label = f"cuda device #{model_device.index}" if model_device.index is not None else "cuda device" + else: + device_label = f"{model_device.type} device" self._logger.info( f"Loaded model '{cache_entry.key}' ({cache_entry.cached_model.model.__class__.__name__}) onto " - f"{model_device.type} device in {(time.time() - start_time):.2f}s. " + f"{device_label} in {(time.time() - start_time):.2f}s. " f"Total model size: {model_total_bytes / MB:.2f}MB, " f"VRAM: {model_cur_vram_bytes / MB:.2f}MB ({loaded_percent:.1%})" ) @@ -625,7 +770,13 @@ def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int: def _get_vram_in_use(self) -> int: """Get the amount of VRAM currently in use by the cache.""" if self._execution_device.type == "cuda": - return torch.cuda.memory_allocated() + # Must be queried for THIS cache's execution device, not the process-current device. In + # multi-GPU mode each worker calls torch.cuda.set_device for its own GPU, so the current + # device flips between workers; querying without the device argument can read a different + # (e.g. idle) GPU's allocation. That breaks the cancellation in _get_vram_available + # (which adds vram_allocated(execution_device)), inflating "available" toward total VRAM + # so the cache never offloads — causing VRAM OOMs that ignore device_working_mem_gb. + return torch.cuda.memory_allocated(self._execution_device) elif self._execution_device.type == "mps": return torch.mps.current_allocated_memory() else: @@ -670,8 +821,10 @@ def _calc_ram_available_to_model_cache(self) -> int: heuristics_applied = [1] total_system_ram_bytes = psutil.virtual_memory().total # Assumed baseline RAM used by InvokeAI for non-model stuff. - baseline_ram_used_by_invokeai = 2 * GB - ram_available_to_model_cache = int(total_system_ram_bytes * 0.5 - baseline_ram_used_by_invokeai) + baseline_ram_used_by_invokeai = RAM_CACHE_BASELINE_BYTES + ram_available_to_model_cache = int( + total_system_ram_bytes * RAM_CACHE_SYSTEM_FRACTION - baseline_ram_used_by_invokeai + ) # Apply heuristic 2. # ------------------ @@ -687,21 +840,49 @@ def _calc_ram_available_to_model_cache(self) -> int: # Apply heuristic 3. # ------------------ - if ram_available_to_model_cache < 4 * GB: + if ram_available_to_model_cache < MIN_RAM_CACHE_BYTES: heuristics_applied.append(3) - ram_available_to_model_cache = 4 * GB + ram_available_to_model_cache = MIN_RAM_CACHE_BYTES self._logger.info( f"Calculated model RAM cache size: {ram_available_to_model_cache / MB:.2f} MB. Heuristics applied: {heuristics_applied}." ) return ram_available_to_model_cache + @staticmethod + def calc_system_ram_headroom_bytes() -> int: + """The default system-wide cap on TOTAL model-cache RAM, leaving headroom for the OS. + + This is the maximum RAM the model caches should collectively use when the user has not set an + explicit `max_cache_ram_gb`. It mirrors heuristic 1 of `_calc_ram_available_to_model_cache` + (a fraction of system RAM, less InvokeAI's baseline) with the same minimum floor. + + In multi-GPU mode there is one cache per device, and each device's heuristic independently + allows up to this fraction of system RAM; summed across N devices that would claim ~N× as + much RAM and cause the system to swap. The model manager uses this value to cap that sum so a + safe amount of RAM is always left for the OS and other processes. + """ + total_system_ram_bytes = psutil.virtual_memory().total + return max( + int(total_system_ram_bytes * RAM_CACHE_SYSTEM_FRACTION) - RAM_CACHE_BASELINE_BYTES, + MIN_RAM_CACHE_BYTES, + ) + def _get_ram_in_use(self) -> int: - """Get the amount of RAM currently in use.""" + """Get the amount of RAM currently in use. + + With a shared RamBudget attached, this returns the deduplicated, system-wide total across all + per-device caches (shared model weights counted once). Without one, it returns this cache's + local sum. + """ + if self._ram_budget is not None: + return self._ram_budget.total_in_use() return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values()) def _get_ram_available(self) -> int: """Get the amount of RAM available for the cache to use.""" + if self._ram_budget is not None: + return self._ram_budget.available() return self._ram_cache_size_bytes - self._get_ram_in_use() def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: @@ -792,7 +973,12 @@ def _log_cache_state(self, title: str = "Model cache state:", include_entry_deta ) if torch.cuda.is_available(): - log += " {:<30} {:.1f} MB\n".format("CUDA Memory Allocated:", torch.cuda.memory_allocated() / MB) + # Query this cache's execution device (not the process-current one) for correct + # per-device numbers in multi-GPU mode. See _get_vram_in_use. + allocated = ( + torch.cuda.memory_allocated(self._execution_device) if self._execution_device.type == "cuda" else 0 + ) + log += " {:<30} {:.1f} MB\n".format("CUDA Memory Allocated:", allocated / MB) log += " {:<30} {}\n".format("Total models:", len(self._cached_models)) if include_entry_details and len(self._cached_models) > 0: @@ -840,7 +1026,18 @@ def _make_room_internal(self, bytes_needed: int) -> None: ram_bytes_freed = 0 pos = 0 models_cleared = 0 - while ram_bytes_freed < ram_bytes_to_free and pos < len(self._cache_stack): + while pos < len(self._cache_stack): + # Stop once there is enough room. With a shared RamBudget, re-check the global, + # deduplicated availability each iteration: evicting a model that other devices still + # hold frees no RAM (its shared weights stay live until the last reference is released), + # so a fixed "bytes freed" tally would be wrong. Without a budget, the local tally is + # exact, so the original cheaper check is kept. + if self._ram_budget is not None: + if bytes_needed <= self._get_ram_available(): + break + elif ram_bytes_freed >= ram_bytes_to_free: + break + model_key = self._cache_stack[pos] cache_entry = self._cached_models[model_key] @@ -884,8 +1081,21 @@ def _make_room_internal(self, bytes_needed: int) -> None: def _delete_cache_entry(self, cache_entry: CacheRecord) -> None: """Delete cache_entry from the cache if it exists. No exception is thrown if it doesn't exist.""" + was_present = cache_entry.key in self._cached_models self._cache_stack = [key for key in self._cache_stack if key != cache_entry.key] self._cached_models.pop(cache_entry.key, None) + # Drop this device's reference to the shared canonical CPU weights so they can be freed once + # the last device releases them. Guard on was_present so a double-delete doesn't + # double-release (release_shared_weights is itself idempotent, but a re-added entry under the + # same key must not be released by a stale delete). + if was_present: + uses_shared = cache_entry.cached_model.uses_shared_weights + total_bytes = cache_entry.cached_model.total_bytes() + cache_entry.cached_model.release_shared_weights() + # Drop the matching non-shared contribution from the global budget (shared weights are + # released via the store above). Captured before release_shared_weights() flips the flag. + if self._ram_budget is not None and not uses_shared: + self._ram_budget.remove_non_shared(total_bytes) @synchronized def drop_model(self, model_key: str) -> int: diff --git a/invokeai/backend/model_manager/load/model_cache/ram_budget.py b/invokeai/backend/model_manager/load/model_cache/ram_budget.py new file mode 100644 index 00000000000..6428c646753 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/ram_budget.py @@ -0,0 +1,64 @@ +import threading +from typing import Optional + +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore + + +class RamBudget: + """The single global authority for how much RAM the model caches are actually using. + + In multi-GPU mode there is one `ModelCache` per device. Each cache independently sums the + `total_bytes()` of the models it holds, so a model resident on N devices is counted N times — + even though Phase 1/2 made its CPU weights live only ONCE in RAM (see SharedCpuWeightsStore). + That per-cache double-count makes the caches believe RAM is fuller than it is, causing premature + eviction and reload churn, and makes `max_cache_ram_gb` meaningless as a system-wide cap. + + RamBudget fixes the accounting by separating RAM into two non-overlapping parts: + + - Shared weights: model weights that are deduplicated in the SharedCpuWeightsStore. Counted + exactly once via `store.total_bytes_in_use()`, regardless of how many devices hold them. + - Non-shared RAM: models that are NOT deduplicated (keep_ram_copy disabled, or non-Module + models whose single in-RAM copy is per-device). These are tracked here as an explicit running + total; a model resident on N devices contributes N times, which is correct because it really + does occupy N copies of RAM. + + `total_in_use()` is the sum of the two and reflects the true RAM footprint. All per-device caches + share one RamBudget and make their eviction decisions against it. + + Thread-safety / lock ordering: RamBudget guards its own counter with an internal lock and NEVER + acquires a ModelCache lock (it only reads the store, which has its own lock). Callers update it + while holding their cache lock, so the only lock order is cache-lock -> (store-lock | budget-lock), + never the reverse — so it cannot deadlock against the per-device caches. + """ + + def __init__(self, max_bytes: int, shared_store: Optional[SharedCpuWeightsStore]): + self._max_bytes = max_bytes + self._store = shared_store + self._non_shared_bytes = 0 + self._lock = threading.Lock() + + @property + def max_bytes(self) -> int: + """The global cap on actual model-cache RAM, in bytes.""" + return self._max_bytes + + def add_non_shared(self, nbytes: int) -> None: + """Record `nbytes` of newly-resident non-deduplicated model RAM.""" + with self._lock: + self._non_shared_bytes += nbytes + + def remove_non_shared(self, nbytes: int) -> None: + """Record the release of `nbytes` of non-deduplicated model RAM.""" + with self._lock: + self._non_shared_bytes = max(0, self._non_shared_bytes - nbytes) + + def total_in_use(self) -> int: + """The true total RAM used by the model caches: shared weights (counted once) + non-shared.""" + shared = self._store.total_bytes_in_use() if self._store is not None else 0 + with self._lock: + non_shared = self._non_shared_bytes + return shared + non_shared + + def available(self) -> int: + """Bytes remaining under the global cap (may be negative if over budget).""" + return self._max_bytes - self.total_in_use() diff --git a/invokeai/backend/model_manager/load/model_cache/shared_cpu_weights.py b/invokeai/backend/model_manager/load/model_cache/shared_cpu_weights.py new file mode 100644 index 00000000000..4ce456e45b6 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/shared_cpu_weights.py @@ -0,0 +1,152 @@ +import threading +from dataclasses import dataclass, field + +import torch + +from invokeai.backend.util.calc_tensor_size import calc_tensor_size + + +@dataclass +class _SharedWeightsEntry: + """A single canonical CPU state dict shared across per-device caches.""" + + state_dict: dict[str, torch.Tensor] + total_bytes: int + # Number of per-device cached models currently aliasing this entry. The entry is freed + # (its RAM released) when this drops to zero. + refcount: int = 0 + # An empty (meta-weight) structural clone of the first-built module, used so a second device can + # adopt the canonical weights without re-reading the model from disk. None until registered (and + # for entries whose model isn't an nn.Module). Holds ~no real RAM: its weights are on `meta`. + shell: object | None = None + _key_bytes: dict[str, int] = field(default_factory=dict) + + +class SharedCpuWeightsStore: + """Process-global store of canonical CPU weight tensors, shared across per-device model caches. + + In multi-GPU mode there is one `ModelCache` per generation device. Without coordination each + cache keeps its own CPU copy of every model's weights, so a model loaded on N GPUs occupies N + copies in RAM. The cached-model wrappers cannot simply share a single `torch.nn.Module`, because + loading to VRAM mutates a module's parameters in place (`load_state_dict(assign=True)` / `.to`), + and two GPUs running the same model concurrently need their params on two different devices at + once. The CPU weight tensors, however, are read-only and device-agnostic, so they can be shared. + + This store keeps a single canonical CPU `state_dict` per cache key. The first device to load a + key registers its freshly-built state dict as canonical; subsequent devices `acquire()` the + canonical and re-point their own module's CPU parameters at the shared tensors (via + `load_state_dict(assign=True)`), discarding their private duplicate. The result: model weights + live once in RAM regardless of how many GPUs hold the model. + + Lifetime is reference-counted. Each per-device cached model that adopts an entry must call + `release()` exactly once when it is evicted; the canonical tensors are dropped only when the + last device releases them. + + Thread-safety: `acquire()`/`release()` are guarded by an internal lock. Note that model + construction (where `acquire()` is normally called) is already serialized process-globally by + `MODEL_LOAD_LOCK.write_lock()`; the internal lock here additionally protects `release()`, which + runs under a per-cache lock off the global construction lock. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._entries: dict[str, _SharedWeightsEntry] = {} + # Whether to capture per-model meta-weight shells for cross-device adoption. Only useful with + # more than one device cache, so the model manager disables it in single-device setups to + # avoid the (small) per-first-load clone cost. See ModelLoader._build_meta_shell. + self.enable_shell_capture: bool = True + + def acquire(self, key: str, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Adopt the canonical CPU state dict for `key`, registering `state_dict` as canonical if + this is the first acquire. + + Increments the entry's refcount. The caller MUST pair every `acquire()` with exactly one + `release()`. + + Returns: + The canonical state dict. If this call registered the entry, the returned object is the + same `state_dict` that was passed in (the caller keeps using its own tensors). Otherwise + it is the previously-registered canonical dict, and the caller is responsible for + re-pointing its module at these tensors and dropping the `state_dict` it passed in. + """ + with self._lock: + entry = self._entries.get(key) + if entry is None: + entry = _SharedWeightsEntry( + state_dict=state_dict, + total_bytes=sum(calc_tensor_size(v) for v in state_dict.values()), + ) + self._entries[key] = entry + entry.refcount += 1 + return entry.state_dict + + def peek(self, key: str) -> dict[str, torch.Tensor] | None: + """Return the canonical state dict for `key` WITHOUT changing its refcount, or None if absent. + + Used by the loader to adopt already-resident weights at construction time (skipping the disk + read) when another device has already loaded this model. The reference is taken later, in the + cached-model wrapper's `acquire()`, exactly as for a normal load — so this peek must not + itself increment the count. + """ + with self._lock: + entry = self._entries.get(key) + return entry.state_dict if entry is not None else None + + def set_shell(self, key: str, shell: object) -> None: + """Register the empty (meta-weight) structural clone for `key`, if an entry exists and none + is set yet. A no-op when the key has no canonical entry (e.g. keep_ram_copy disabled).""" + with self._lock: + entry = self._entries.get(key) + if entry is not None and entry.shell is None: + entry.shell = shell + + def get_shell(self, key: str) -> object | None: + """Return the registered meta-weight shell for `key`, or None if absent.""" + with self._lock: + entry = self._entries.get(key) + return entry.shell if entry is not None else None + + def release(self, key: str) -> None: + """Release one reference to `key`'s canonical state dict, freeing it when the count hits 0. + + A `release()` for a key that is not present is a no-op (e.g. a cached model that never + acquired shared weights, or a double eviction guard). + """ + with self._lock: + entry = self._entries.get(key) + if entry is None: + return + entry.refcount -= 1 + if entry.refcount <= 0: + del self._entries[key] + + # -- Introspection / accounting (also used by tests) ---------------------------------------- + + def __contains__(self, key: str) -> bool: + with self._lock: + return key in self._entries + + def refcount(self, key: str) -> int: + """Return the current refcount for `key`, or 0 if not present.""" + with self._lock: + entry = self._entries.get(key) + return entry.refcount if entry is not None else 0 + + def total_bytes_in_use(self) -> int: + """Return the total size (in bytes) of all canonical state dicts currently held. + + This counts each shared model's weights exactly once, regardless of how many devices alias + it — i.e. the true RAM footprint of cached weights, not the per-device double-count. + """ + with self._lock: + return sum(entry.total_bytes for entry in self._entries.values()) + + def keys(self) -> list[str]: + with self._lock: + return list(self._entries.keys()) + + +# Process-global default store. Per-device caches share this instance so that the same model loaded +# on multiple GPUs keeps a single CPU copy. Tests may construct isolated `SharedCpuWeightsStore` +# instances instead. +SHARED_CPU_WEIGHTS = SharedCpuWeightsStore() diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index b3c46d04db3..739ba458888 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -1080,7 +1080,13 @@ def _dequantize_fp8_weights(self, sd: dict) -> dict: if block_size > 1: scale = scale.repeat_interleave(block_size, dim=dim) - sd[weight_key] = weight_float * scale + # Do the multiply in float32 for precision, but store bf16 (FLUX.2's compute dtype) + # immediately so the *whole* model is never materialized in float32. Holding every + # dequantized weight as float32 here doubled RAM transiently (~36GB vs ~17GB for a 9B + # model) and was the dominant cold-load spike, especially with two GPUs. The result is + # identical to the previous code, which cast the same values to bf16 a few steps later. + sd[weight_key] = (weight_float * scale).to(torch.bfloat16) + del weight_float # Filter out scale metadata keys and other FP8 metadata keys_to_remove = [ @@ -1110,8 +1116,9 @@ def _dequantize_fp8_weights(self, sd: dict) -> dict: del sd[k] for key in keys_to_convert: - # Convert FP8 tensor to float32 - sd[key] = sd[key].float() + # Convert native FP8 tensors straight to bf16 (FLUX.2's compute dtype) rather than float32, + # so a cold load never transiently holds the whole model in float32 (see the scaled path). + sd[key] = sd[key].to(torch.bfloat16) return sd diff --git a/invokeai/backend/patches/layer_patcher.py b/invokeai/backend/patches/layer_patcher.py index fbfcd04de20..e14e35baa08 100644 --- a/invokeai/backend/patches/layer_patcher.py +++ b/invokeai/backend/patches/layer_patcher.py @@ -216,7 +216,10 @@ def _apply_model_layer_patch( param_name, torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad), ) - module_param = expanded_weight + # Point at the module's live (expanded) parameter so the out-of-place weight + # update below lands on the module. `expanded_weight` is a detached raw tensor; + # reassigning its `.data` would not propagate to the newly-set Parameter. + module_param = module_to_patch.get_parameter(param_name) else: # For other LoRAs, shape mismatch indicates architecture incompatibility - skip the layer logger = InvokeAILogger.get_logger(LayerPatcher.__name__) @@ -227,9 +230,17 @@ def _apply_model_layer_patch( ) continue - # Convert param_weight to the correct device and dtype, then apply to model weights + # Convert param_weight to the correct device and dtype, then apply to model weights. param_weight_converted = param_weight.to(device=device, dtype=dtype) - module_param.data.copy_(module_param.data + param_weight_converted) + # Apply out-of-place (assign a new tensor) rather than an in-place `copy_`. The weight we + # are patching may be the model's canonical CPU copy, which is shared across the + # per-device model caches in multi-GPU mode (see SharedCpuWeightsStore) and is also the + # cache's keep_ram_copy used to restore the model after unpatching. An in-place mutation + # here would corrupt that shared/cached tensor — and every other device's view of it. + # Reassigning `.data` leaves the original tensor untouched while giving this module the + # patched weights, and is memory-equivalent (the in-place form already allocated the + # `module_param.data + param_weight_converted` temporary). + module_param.data = module_param.data + param_weight_converted patch.to(device=TorchDevice.CPU_DEVICE) diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 4191db734f9..be3800411ad 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -10,6 +10,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager +from invokeai.backend.util.devices import TorchDevice class StableDiffusionBackend: @@ -44,7 +45,9 @@ def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsMa # ext: preview[pre_denoise_loop, priority=low] ext_manager.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, ctx) - for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020 + for ctx.step_index, ctx.timestep in enumerate( # noqa: B020 + tqdm(ctx.inputs.timesteps, desc=f"Denoising{TorchDevice.get_session_device_label()}") + ): # ext: inpaint (apply mask to latents on non-inpaint models) ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx) diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 359ce45dc4f..7a5e8f3e8b9 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,3 +1,5 @@ +import threading +from collections import Counter, defaultdict from typing import Dict, Literal, Optional, Union import torch @@ -46,9 +48,52 @@ class TorchDevice: CUDA_DEVICE = torch.device("cuda") MPS_DEVICE = torch.device("mps") + # Per-thread execution device. When set (by a session-processor worker thread bound to a + # specific GPU), `choose_torch_device()` returns it instead of consulting the global config. + # This is the lynchpin that makes the ~79 `choose_torch_device()` call sites (nodes, model + # patcher, etc.) resolve to the calling worker's GPU without per-call-site changes. + _session_device = threading.local() + + @classmethod + def set_session_device(cls, device: Union[str, torch.device]) -> None: + """Pin the calling thread's execution device. Used by multi-GPU session workers.""" + cls._session_device.device = cls.normalize(device) + + @classmethod + def get_session_device(cls) -> Optional[torch.device]: + """Return the calling thread's pinned execution device, or None if unset.""" + return getattr(cls._session_device, "device", None) + + @classmethod + def clear_session_device(cls) -> None: + """Remove the calling thread's pinned execution device, reverting to global config.""" + if hasattr(cls._session_device, "device"): + del cls._session_device.device + + @classmethod + def get_session_device_index(cls) -> Optional[int]: + """Return the CUDA index of the calling thread's effective device, or None if not on CUDA. + + Resolves the thread-local session device when a worker has pinned one (multi-GPU), otherwise + falls back to the globally-configured device. Used to annotate logs/progress with the GPU + number so concurrent sessions can be told apart. + """ + device = cls.get_session_device() or cls.choose_torch_device() + return device.index if device.type == "cuda" else None + + @classmethod + def get_session_device_label(cls) -> str: + """Return a ``" (#N)"`` suffix for the calling thread's CUDA device, or ``""`` when not on CUDA.""" + index = cls.get_session_device_index() + return f" (#{index})" if index is not None else "" + @classmethod def choose_torch_device(cls) -> torch.device: """Return the torch.device to use for accelerated inference.""" + # A worker thread pinned to a specific GPU takes precedence over the global config. + session_device = cls.get_session_device() + if session_device is not None: + return session_device app_config = get_config() if app_config.device != "auto": device = torch.device(app_config.device) @@ -87,11 +132,83 @@ def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtyp # CPU / safe fallback return cls._to_dtype("float32") + @classmethod + def get_device_name(cls, device: torch.device) -> str: + """Return the human-readable name for a torch device (e.g. 'AMD Radeon PRO W7900', 'CPU').""" + return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper() + @classmethod def get_torch_device_name(cls) -> str: """Return the device name for the current torch device.""" - device = cls.choose_torch_device() - return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper() + return cls.get_device_name(cls.choose_torch_device()) + + @classmethod + def get_generation_devices_summary(cls, generation_devices: Union[str, list[str], None]) -> str: + """Build a human-readable summary of the devices that will be used for generation. + + For a single device, returns just its name (e.g. ``'AMD Radeon PRO W7900'`` or ``'CPU'``). For + multiple devices, returns a bracketed list annotating each with its GPU number and device id, + e.g. ``'[AMD Radeon PRO W7900 #1 (cuda:0), AMD Radeon PRO W7900 #2 (cuda:1)]'``. Identically + named GPUs get a 1-based ``#N`` suffix so they can be told apart; a uniquely named device gets + no suffix. + """ + devices = cls.get_generation_devices(generation_devices) + if not devices: + # Empty resolution (e.g. `generation_devices` set to an empty list) falls back to the + # single globally-configured device. + devices = [cls.choose_torch_device()] + + names = [cls.get_device_name(device) for device in devices] + if len(devices) == 1: + return names[0] + + name_counts = Counter(names) + ordinals: dict[str, int] = defaultdict(int) + parts: list[str] = [] + for device, name in zip(devices, names, strict=True): + ordinals[name] += 1 + label = f"{name} #{ordinals[name]}" if name_counts[name] > 1 else name + parts.append(f"{label} ({device})") + return "[" + ", ".join(parts) + "]" + + @classmethod + def get_generation_devices(cls, generation_devices: Union[str, list[str], None]) -> list[torch.device]: + """Resolve the configured `generation_devices` into a concrete, deduplicated device list. + + - ``"auto"`` (the default) expands to every visible CUDA device, or the single best available + device (mps/cpu) when CUDA is unavailable. + - An explicit list is normalized and deduplicated, with order preserved. + - ``None`` or an empty list yields an empty list; the caller decides the single-device fallback. + """ + if generation_devices == "auto": + if torch.cuda.is_available(): + device_strs: list[str] = [f"cuda:{index}" for index in range(torch.cuda.device_count())] + else: + device_strs = [str(cls.choose_torch_device())] + elif not generation_devices: + return [] + else: + device_strs = list(generation_devices) + + devices: list[torch.device] = [] + seen: set[str] = set() + for device_str in device_strs: + device = cls.normalize(device_str) + # Fail fast on a CUDA device that doesn't exist, rather than starting a worker pinned to + # it that only errors cryptically at the first tensor allocation. ("auto" only generates + # valid indices, so this just validates explicitly-configured devices.) + if device.type == "cuda": + if not torch.cuda.is_available(): + raise ValueError(f"generation_devices requested '{device_str}', but no CUDA device is available.") + if device.index is not None and device.index >= torch.cuda.device_count(): + raise ValueError( + f"generation_devices requested '{device_str}', but only {torch.cuda.device_count()} " + f"CUDA device(s) are available (valid indices 0-{torch.cuda.device_count() - 1})." + ) + if str(device) not in seen: + seen.add(str(device)) + devices.append(device) + return devices @classmethod def normalize(cls, device: Union[str, torch.device]) -> torch.device: diff --git a/invokeai/backend/util/vae_working_memory.py b/invokeai/backend/util/vae_working_memory.py index f9228ced652..8b91dc54161 100644 --- a/invokeai/backend/util/vae_working_memory.py +++ b/invokeai/backend/util/vae_working_memory.py @@ -2,6 +2,7 @@ import torch from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR @@ -92,6 +93,39 @@ def estimate_vae_working_memory_flux( return int(working_memory) +def estimate_vae_working_memory_qwen_image( + operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKLQwenImage +) -> int: + """Estimate the working memory required by the invocation in bytes. + + Without this, the Qwen Image VAE encode/decode passes no working-memory estimate to the model + cache, so the cache reserves only its small default and never offloads a large resident + transformer (the VAE weights themselves are tiny). The decode then OOMs on its activations. This + mirrors the other VAE estimators: peak working memory scales ~linearly with the number of output + pixels and the element size. The Qwen Image latents are 5D (B, C, frames, H, W); the trailing two + dims are spatial, same as the 2D VAEs. See #8414. + """ + latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1 + + h = latent_scale_factor_for_operation * image_tensor.shape[-2] + w = latent_scale_factor_for_operation * image_tensor.shape[-1] + element_size = next(vae.parameters()).element_size() + + # Calibrated for the Qwen Image VAE, a 3D-conv (video) VAE whose decode allocates large conv3d + # feature maps — a ~1MP decode was measured to peak at ~17 GiB of VRAM, far above the 2D SD/FLUX + # VAEs the generic 2200/1100 constants were tuned for. The reservation must cover that peak AND be + # large enough to make the cache offload an otherwise-resident transformer + text encoder (which + # the decode doesn't need): the offload only frees ~(working_mem - free) bytes, so under-reserving + # leaves the big models resident and the decode OOMs. Over-reserving is safe here (it just offloads + # models the decode doesn't use). Encoding uses ~half the working memory of decoding. + # NOTE: this is linear in output pixels; a sufficiently large output (>~1.5MP) can still exceed + # the card even after offloading everything — that case needs tiled decode, handled separately. + scaling_constant = 13000 if operation == "decode" else 6500 + working_memory = h * w * element_size * scaling_constant + + return int(working_memory) + + def estimate_vae_working_memory_sd3( operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL ) -> int: diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index 7033408b197..522cd1ce4aa 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -6449,6 +6449,30 @@ } } }, + "/api/v1/app/generation_device_options": { + "get": { + "tags": ["app"], + "summary": "Get Generation Device Options", + "description": "List the devices available for generation, for use with the `generation_devices` setting.", + "operationId": "get_generation_device_options", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/GenerationDeviceOption" + }, + "type": "array", + "title": "Response Get Generation Device Options" + } + } + } + } + } + } + }, "/api/v1/app/runtime_config": { "get": { "tags": ["app"], @@ -28463,6 +28487,24 @@ "title": "GeneratePasswordResponse", "description": "Response containing a generated password." }, + "GenerationDeviceOption": { + "properties": { + "device": { + "type": "string", + "title": "Device", + "description": "The device identifier, e.g. 'cuda:0', 'mps', or 'cpu'" + }, + "name": { + "type": "string", + "title": "Name", + "description": "Human-readable device name" + } + }, + "type": "object", + "required": ["device", "name"], + "title": "GenerationDeviceOption", + "description": "A device that may be selected for generation." + }, "GetMaskBoundingBoxInvocation": { "category": "mask", "class": "invocation", @@ -39926,6 +39968,19 @@ ], "default": null, "description": "An image representing the current state of the progress" + }, + "device": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "The device processing this session, e.g. 'cuda:1' (set only when running on a CUDA GPU)", + "title": "Device" } }, "required": [ @@ -39941,7 +39996,8 @@ "invocation_source_id", "message", "percentage", - "image" + "image", + "device" ], "title": "InvocationProgressEvent", "type": "object" @@ -41153,6 +41209,23 @@ "description": "Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", "default": "auto" }, + "generation_devices": { + "anyOf": [ + { + "type": "string", + "const": "auto" + }, + { + "items": { + "type": "string" + }, + "type": "array" + } + ], + "title": "Generation Devices", + "description": "Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number)", + "default": "auto" + }, "precision": { "type": "string", "enum": ["auto", "float16", "bfloat16", "float32"], @@ -65354,6 +65427,18 @@ "title": "Retried From Item Id", "description": "The item_id of the queue item that this item was retried from" }, + "device": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Device", + "description": "The device that processed this queue item, e.g. 'cuda:1' (set only when running on a CUDA GPU)" + }, "session": { "$ref": "#/components/schemas/GraphExecutionState", "description": "The fully-populated session to be executed" @@ -70355,6 +70440,10 @@ ], "title": "Max Queue History", "description": "Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items." + }, + "generation_devices": { + "title": "Generation Devices", + "description": "Devices to use for parallel generation. `auto` uses every available GPU; provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices. Takes effect after restarting InvokeAI." } }, "type": "object", diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index cb0226c2c44..2ba5a4828d7 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -444,6 +444,7 @@ "next": "Next", "status": "Status", "total": "Total", + "gpu": "GPU #", "time": "Time", "credits": "Credits", "pending": "Pending", @@ -1837,6 +1838,11 @@ "enableNSFWChecker": "Enable NSFW Checker", "general": "General", "generation": "Generation", + "generationDevices": "Generation Devices", + "generationDevicesAuto": "Auto (all GPUs)", + "generationDevicesHelp": "Select which devices to use for parallel generation, one session per device. \"Auto\" uses every available GPU.", + "generationDevicesRestart": "Changes take effect after restarting InvokeAI.", + "generationDevicesSaveFailed": "Failed to save Generation Devices", "imageSubfolderStrategy": "Image Subfolder Strategy", "imageSubfolderStrategyDate": "Date", "imageSubfolderStrategyFlat": "Flat", diff --git a/invokeai/frontend/web/src/common/hooks/useProgressDeviceLabel.ts b/invokeai/frontend/web/src/common/hooks/useProgressDeviceLabel.ts new file mode 100644 index 00000000000..701f7ce1ae9 --- /dev/null +++ b/invokeai/frontend/web/src/common/hooks/useProgressDeviceLabel.ts @@ -0,0 +1,39 @@ +import { getCudaDeviceIndex } from 'common/util/getCudaDeviceIndex'; +import { getDeviceNameLabels } from 'common/util/getDeviceNameLabels'; +import { useMemo } from 'react'; +import { useGetGenerationDeviceOptionsQuery } from 'services/api/endpoints/appInfo'; + +type ProgressDeviceLabel = { + /** The CUDA device index, shown in the center of the progress circle (e.g. `0`). */ + index: number; + /** The human-readable device name and number, shown on hover (e.g. `"AMD Radeon PRO W7900 #1"`). */ + name: string; +}; + +/** + * Resolve a device string (e.g. `"cuda:0"`) to the GPU index + name used to annotate progress + * previews. + * + * Returns `null` when there is nothing to show: the device is not a CUDA GPU, or only a single GPU + * is available (single-GPU setups show neither the index nor the hover tooltip). + */ +export const useProgressDeviceLabel = (device: string | null | undefined): ProgressDeviceLabel | null => { + const { data: deviceOptions } = useGetGenerationDeviceOptionsQuery(); + + return useMemo(() => { + const index = getCudaDeviceIndex(device); + if (index === null) { + return null; + } + const options = deviceOptions ?? []; + // With a single GPU there is no ambiguity to resolve, so we show nothing. + if (options.length <= 1) { + return null; + } + const name = device ? getDeviceNameLabels(options)[device] : undefined; + if (!name) { + return null; + } + return { index, name }; + }, [device, deviceOptions]); +}; diff --git a/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.test.ts b/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.test.ts new file mode 100644 index 00000000000..3348ae14a2f --- /dev/null +++ b/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.test.ts @@ -0,0 +1,29 @@ +import { describe, expect, it } from 'vitest'; + +import { getCudaDeviceIndex } from './getCudaDeviceIndex'; + +describe('getCudaDeviceIndex', () => { + it('parses the index from a cuda device string', () => { + expect(getCudaDeviceIndex('cuda:0')).toBe(0); + expect(getCudaDeviceIndex('cuda:1')).toBe(1); + expect(getCudaDeviceIndex('cuda:11')).toBe(11); + }); + + it('returns null for non-cuda devices', () => { + expect(getCudaDeviceIndex('cpu')).toBeNull(); + expect(getCudaDeviceIndex('mps')).toBeNull(); + }); + + it('returns null for null/undefined/empty', () => { + expect(getCudaDeviceIndex(null)).toBeNull(); + expect(getCudaDeviceIndex(undefined)).toBeNull(); + expect(getCudaDeviceIndex('')).toBeNull(); + }); + + it('returns null for malformed cuda strings', () => { + expect(getCudaDeviceIndex('cuda')).toBeNull(); + expect(getCudaDeviceIndex('cuda:')).toBeNull(); + expect(getCudaDeviceIndex('cuda:x')).toBeNull(); + expect(getCudaDeviceIndex('cuda:0:0')).toBeNull(); + }); +}); diff --git a/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.ts b/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.ts new file mode 100644 index 00000000000..d4a394b48fc --- /dev/null +++ b/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.ts @@ -0,0 +1,13 @@ +/** + * Parse the CUDA device index from a device string (e.g. `"cuda:1"` → `1`). + * + * Returns `null` when the device is null/undefined or is not a CUDA device (e.g. `"cpu"`, `"mps"`). + * Used to label progress previews and queue items with their GPU number in multi-GPU setups. + */ +export const getCudaDeviceIndex = (device: string | null | undefined): number | null => { + if (!device) { + return null; + } + const match = /^cuda:(\d+)$/.exec(device); + return match ? Number(match[1]) : null; +}; diff --git a/invokeai/frontend/web/src/common/util/getDeviceNameLabels.test.ts b/invokeai/frontend/web/src/common/util/getDeviceNameLabels.test.ts new file mode 100644 index 00000000000..7a35d57d4f8 --- /dev/null +++ b/invokeai/frontend/web/src/common/util/getDeviceNameLabels.test.ts @@ -0,0 +1,38 @@ +import type { S } from 'services/api/types'; +import { describe, expect, it } from 'vitest'; + +import { getDeviceNameLabels } from './getDeviceNameLabels'; + +const opt = (device: string, name: string): S['GenerationDeviceOption'] => ({ device, name }); + +describe('getDeviceNameLabels', () => { + it('adds a 1-based #N suffix to identically-named devices', () => { + const labels = getDeviceNameLabels([opt('cuda:0', 'AMD Radeon PRO W7900'), opt('cuda:1', 'AMD Radeon PRO W7900')]); + expect(labels).toEqual({ + 'cuda:0': 'AMD Radeon PRO W7900 #1', + 'cuda:1': 'AMD Radeon PRO W7900 #2', + }); + }); + + it('does not add a suffix to a uniquely-named device', () => { + const labels = getDeviceNameLabels([opt('cuda:0', 'AMD Radeon PRO W7900')]); + expect(labels).toEqual({ 'cuda:0': 'AMD Radeon PRO W7900' }); + }); + + it('only suffixes the names that are duplicated', () => { + const labels = getDeviceNameLabels([ + opt('cuda:0', 'RTX 4090'), + opt('cuda:1', 'RTX 3090'), + opt('cuda:2', 'RTX 3090'), + ]); + expect(labels).toEqual({ + 'cuda:0': 'RTX 4090', + 'cuda:1': 'RTX 3090 #1', + 'cuda:2': 'RTX 3090 #2', + }); + }); + + it('returns an empty map for no options', () => { + expect(getDeviceNameLabels([])).toEqual({}); + }); +}); diff --git a/invokeai/frontend/web/src/common/util/getDeviceNameLabels.ts b/invokeai/frontend/web/src/common/util/getDeviceNameLabels.ts new file mode 100644 index 00000000000..210e7b88c67 --- /dev/null +++ b/invokeai/frontend/web/src/common/util/getDeviceNameLabels.ts @@ -0,0 +1,25 @@ +import type { S } from 'services/api/types'; + +/** + * Build a map of device id (e.g. `"cuda:0"`) → human-readable label (e.g. `"AMD Radeon PRO W7900 #1"`). + * + * Devices that share a name get a 1-based `#N` suffix so identical GPUs can be told apart; a + * uniquely-named device gets no suffix. The ordinal is assigned in the order the options are + * provided (which the backend returns in CUDA-index order). Used to label progress previews with + * the GPU they are rendering on in multi-GPU setups. + */ +export const getDeviceNameLabels = (options: S['GenerationDeviceOption'][]): Record => { + const nameCounts = new Map(); + for (const option of options) { + nameCounts.set(option.name, (nameCounts.get(option.name) ?? 0) + 1); + } + + const ordinals = new Map(); + const labels: Record = {}; + for (const option of options) { + const ordinal = (ordinals.get(option.name) ?? 0) + 1; + ordinals.set(option.name, ordinal); + labels[option.device] = (nameCounts.get(option.name) ?? 0) > 1 ? `${option.name} #${ordinal}` : option.name; + } + return labels; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/QueueItemCircularProgress.tsx b/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/QueueItemCircularProgress.tsx index 0a92106a6e6..7095be54aae 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/QueueItemCircularProgress.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/QueueItemCircularProgress.tsx @@ -1,5 +1,6 @@ import type { CircularProgressProps, SystemStyleObject } from '@invoke-ai/ui-library'; -import { CircularProgress, Tooltip } from '@invoke-ai/ui-library'; +import { CircularProgress, Text, Tooltip } from '@invoke-ai/ui-library'; +import { useProgressDeviceLabel } from 'common/hooks/useProgressDeviceLabel'; import { getProgressMessage } from 'features/controlLayers/components/StagingArea/shared'; import { memo } from 'react'; import type { S } from 'services/api/types'; @@ -16,17 +17,34 @@ const circleStyles: SystemStyleObject = { right: 2, }; +// Centered GPU-number label drawn inside the ring (CircularProgressLabel isn't exported by the ui-library). +const labelStyles: SystemStyleObject = { + position: 'absolute', + top: '50%', + left: '50%', + transform: 'translate(-50%, -50%)', + fontSize: '0.6rem', + lineHeight: 1, + fontWeight: 'bold', + color: 'invokeBlue.300', + textShadow: '0 0 3px var(--invoke-colors-base-900)', + pointerEvents: 'none', +}; + type Props = { itemId: number; status: S['SessionQueueItem']['status'] } & CircularProgressProps; export const QueueItemCircularProgress = memo(({ itemId, status, ...rest }: Props) => { const { progressEvent } = useProgressDatum(itemId); + const deviceLabel = useProgressDeviceLabel(progressEvent?.device); if (status !== 'in_progress') { return null; } + const message = getProgressMessage(progressEvent); + return ( - + + > + {deviceLabel && {deviceLabel.index}} + ); }); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx index ceaf6c5f435..1978a7fc1ab 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx @@ -22,6 +22,7 @@ import type { ImageDTO } from 'services/api/types'; import { useImageViewerContext } from './context'; import { NoContentForViewer } from './NoContentForViewer'; import { ProgressImage } from './ProgressImage2'; +import { ProgressImageTiles } from './ProgressImageTiles'; import { ProgressIndicator } from './ProgressIndicator2'; export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | null }) => { @@ -30,10 +31,17 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu const shouldShowItemDetails = useAppSelector(selectShouldShowItemDetails); const shouldShowProgressInViewer = useAppSelector(selectShouldShowProgressInViewer); const { goToPreviousImage, goToNextImage, isFetching } = useNextPrevItemNavigation(); - const { onLoadImage, $progressEvent, $progressImage, $isProgressImageResolving, $isTemporarilyShowingSelectedImage } = - useImageViewerContext(); + const { + onLoadImage, + $progressEvent, + $progressImage, + $activeProgressData, + $isProgressImageResolving, + $isTemporarilyShowingSelectedImage, + } = useImageViewerContext(); const progressEvent = useStore($progressEvent); const progressImage = useStore($progressImage); + const activeProgressData = useStore($activeProgressData); const isProgressImageResolving = useStore($isProgressImageResolving); const isTemporarilyShowingSelectedImage = useStore($isTemporarilyShowingSelectedImage); const [imageToRender, setImageToRender] = useState(null); @@ -186,6 +194,9 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu }); const withProgress = shouldShowProgressInViewer && hasProgressImage && !isTemporarilyShowingSelectedImage; + // When more than one session is generating concurrently (multi-GPU), tile their previews instead of + // showing only the most recent one. + const withTiledProgress = withProgress && activeProgressData.length > 1; return ( } {withProgress && ( - - {progressEvent && ( - + {withTiledProgress ? ( + + ) : ( + <> + + {progressEvent && ( + + )} + )} )} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImageTiles.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImageTiles.tsx new file mode 100644 index 00000000000..6f66c02e929 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImageTiles.tsx @@ -0,0 +1,39 @@ +import { Flex, Grid, GridItem } from '@invoke-ai/ui-library'; +import { memo, useMemo } from 'react'; + +import type { ViewerProgressDatum } from './context'; +import { ProgressImage } from './ProgressImage2'; +import { ProgressIndicator } from './ProgressIndicator2'; + +/** + * Renders one tile per concurrently-running session (multi-GPU). Each tile shows that session's live + * preview image plus a small progress indicator. Used by the viewer when more than one session is + * active; a single active session uses the full-size preview instead. + */ +export const ProgressImageTiles = memo(({ data }: { data: ViewerProgressDatum[] }) => { + // Lay the tiles out in a roughly-square grid that grows with the number of active sessions. + const columns = useMemo(() => Math.ceil(Math.sqrt(data.length)), [data.length]); + + return ( + + {data.map((datum) => ( + + + + + + + ))} + + ); +}); +ProgressImageTiles.displayName = 'ProgressImageTiles'; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx index b635c37d804..dde90ead5fd 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx @@ -1,27 +1,64 @@ import type { CircularProgressProps, SystemStyleObject } from '@invoke-ai/ui-library'; -import { CircularProgress, Tooltip } from '@invoke-ai/ui-library'; -import { memo } from 'react'; +import { CircularProgress, Text, Tooltip } from '@invoke-ai/ui-library'; +import { useProgressDeviceLabel } from 'common/hooks/useProgressDeviceLabel'; +import type { ComponentRef } from 'react'; +import { forwardRef, memo } from 'react'; import type { S } from 'services/api/types'; import { formatProgressMessage } from 'services/events/stores'; const circleStyles: SystemStyleObject = { + // The callers position this circle with `position="absolute"`, which makes it the containing + // block for the absolutely-centered GPU label below. Do NOT set `position` here — an `sx` value + // would override the caller's prop and break the circle's corner anchoring. circle: { transitionProperty: 'none', transitionDuration: '0s', }, }; +// Centered GPU-number label drawn inside the ring (CircularProgressLabel isn't exported by the ui-library). +const labelStyles: SystemStyleObject = { + position: 'absolute', + top: '50%', + left: '50%', + transform: 'translate(-50%, -50%)', + fontSize: '0.6rem', + lineHeight: 1, + fontWeight: 'bold', + color: 'invokeBlue.300', + textShadow: '0 0 3px var(--invoke-colors-base-900)', + pointerEvents: 'none', +}; + +type ProgressDeviceLabel = ReturnType; + +// The circle is split out and memoized so it does NOT re-render when only the tooltip message +// changes. Every progress event re-renders the parent, and during the indeterminate phases +// (everything except denoising) those events keep the same `isIndeterminate`/`value` — but +// re-rendering the CircularProgress restarts its CSS spin animation, which reads as the disk +// "flashing". Memoizing on the visual props keeps the animation continuous. forwardRef so the +// wrapping Tooltip can still anchor to it. +const ProgressCircle = memo( + forwardRef, { deviceLabel: ProgressDeviceLabel } & CircularProgressProps>( + ({ deviceLabel, ...rest }, ref) => ( + + {deviceLabel && {deviceLabel.index}} + + ) + ) +); +ProgressCircle.displayName = 'ProgressCircle'; + export const ProgressIndicator = memo( ({ progressEvent, ...rest }: { progressEvent: S['InvocationProgressEvent'] } & CircularProgressProps) => { + const deviceLabel = useProgressDeviceLabel(progressEvent?.device); + const message = formatProgressMessage(progressEvent); return ( - - + diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx index d502deb4498..6f6a95d4f29 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { selectAutoSwitch } from 'features/gallery/store/gallerySelectors'; import type { ProgressImage as ProgressImageType } from 'features/nodes/types/common'; import { LRUCache } from 'lru-cache'; -import { type Atom, atom, computed, type WritableAtom } from 'nanostores'; +import { type Atom, atom, computed, map, type MapStore, type WritableAtom } from 'nanostores'; import type { PropsWithChildren } from 'react'; import { createContext, memo, useCallback, useContext, useEffect, useMemo, useRef, useState } from 'react'; import type { S } from 'services/api/types'; @@ -12,10 +12,24 @@ import { $socket } from 'services/events/stores'; import { assert } from 'tsafe'; import type { JsonObject } from 'type-fest'; +/** Live progress for a single in-flight session (queue item). Used to tile the viewer when several + * sessions run concurrently (multi-GPU). Only items that have produced a preview image are tracked. */ +export type ViewerProgressDatum = { + itemId: number; + progressEvent: S['InvocationProgressEvent']; + progressImage: ProgressImageType; +}; + +type ViewerProgressDataMap = Record; + type ImageViewerContextValue = { $progressEvent: Atom; $progressImage: Atom; $hasProgressImage: Atom; + /** Per-session progress, keyed by queue item id. Drives the tiled multi-session preview. */ + $progressData: MapStore; + /** Active sessions (those with a preview image), sorted by item id for a stable tile order. */ + $activeProgressData: Atom; $isProgressImageResolving: Atom; $isTemporarilyShowingSelectedImage: WritableAtom; onLoadImage: () => void; @@ -31,6 +45,15 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { const $progressEvent = useState(() => atom(null))[0]; const $progressImage = useState(() => atom(null))[0]; const $hasProgressImage = useState(() => computed($progressImage, (progressImage) => progressImage !== null))[0]; + // Per-session progress, keyed by queue item id, for the tiled multi-session preview (multi-GPU). + const $progressData = useState(() => map({}))[0]; + const $activeProgressData = useState(() => + computed($progressData, (progressData) => + Object.values(progressData) + .filter((datum): datum is ViewerProgressDatum => datum !== undefined) + .sort((a, b) => a.itemId - b.itemId) + ) + )[0]; const $isProgressImageResolving = useState(() => atom(false))[0]; const $isTemporarilyShowingSelectedImage = useState(() => atom(false))[0]; const shouldClearProgressImageOnLoadRef = useRef(false); @@ -56,6 +79,12 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { $progressEvent.set(data); if (data.image) { $progressImage.set(data.image); + // Track per-session so the viewer can tile concurrent sessions (multi-GPU). + $progressData.setKey(data.item_id, { + itemId: data.item_id, + progressEvent: data, + progressImage: data.image, + }); } }; @@ -64,7 +93,7 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { return () => { socket.off('invocation_progress', onInvocationProgress); }; - }, [$isProgressImageResolving, $progressEvent, $progressImage, finishedQueueItemIds, socket]); + }, [$isProgressImageResolving, $progressData, $progressEvent, $progressImage, finishedQueueItemIds, socket]); useEffect(() => { if (!socket) { @@ -81,6 +110,9 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { } if (data.status === 'completed' || data.status === 'canceled' || data.status === 'failed') { finishedQueueItemIds.set(data.item_id, true); + // Remove this session's tile from the multi-session preview as soon as it reaches a terminal + // state. The single-image "resolve" illusion below is handled separately via onLoadImage. + $progressData.setKey(data.item_id, undefined); // Completed queue items have the progress event cleared by the onLoadImage callback. This allows the viewer to // create the illusion of the progress image "resolving" into the final image. If we cleared the progress image // now, there would be a flicker where the progress image disappears before the final image appears, and the @@ -115,7 +147,15 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { return () => { socket.off('queue_item_status_changed', onQueueItemStatusChanged); }; - }, [$isProgressImageResolving, $progressEvent, $progressImage, autoSwitch, finishedQueueItemIds, socket]); + }, [ + $isProgressImageResolving, + $progressData, + $progressEvent, + $progressImage, + autoSwitch, + finishedQueueItemIds, + socket, + ]); const onLoadImage = useCallback(() => { if (!shouldClearProgressImageOnLoadRef.current) { @@ -133,12 +173,16 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { $progressEvent, $progressImage, $hasProgressImage, + $progressData, + $activeProgressData, $isProgressImageResolving, $isTemporarilyShowingSelectedImage, onLoadImage, }), [ $hasProgressImage, + $progressData, + $activeProgressData, $isProgressImageResolving, $isTemporarilyShowingSelectedImage, $progressEvent, diff --git a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx index e1c5f4ec973..6d3f773a2e9 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx @@ -1,6 +1,7 @@ import type { ChakraProps, CollapseProps, FlexProps } from '@invoke-ai/ui-library'; import { ButtonGroup, Collapse, Flex, IconButton, Text } from '@invoke-ai/ui-library'; import { useAppSelector } from 'app/store/storeHooks'; +import { getCudaDeviceIndex } from 'common/util/getCudaDeviceIndex'; import { selectCurrentUser } from 'features/auth/store/authSlice'; import QueueStatusBadge from 'features/queue/components/common/QueueStatusBadge'; import { useDestinationText } from 'features/queue/components/QueueList/useDestinationText'; @@ -95,6 +96,8 @@ const QueueItemComponent = ({ index, item }: InnerItemProps) => { return `${seconds}s`; }, [item]); + const gpuIndex = useMemo(() => getCudaDeviceIndex(item.device), [item.device]); + const isCanceled = useMemo(() => ['canceled', 'completed', 'failed'].includes(item.status), [item.status]); const isFailed = useMemo(() => ['canceled', 'failed'].includes(item.status), [item.status]); const originText = useOriginText(item.origin); @@ -140,6 +143,9 @@ const QueueItemComponent = ({ index, item }: InnerItemProps) => { + + {gpuIndex !== null ? gpuIndex : '-'} + {executionTime || '-'} diff --git a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueListHeader.tsx b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueListHeader.tsx index 4cd3397d217..9f6e2fa5458 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueListHeader.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueListHeader.tsx @@ -33,6 +33,7 @@ const QueueListHeader = () => { w={COLUMN_WIDTHS.statusBadge} alignItems="center" /> + { +// In "fit" mode (e.g. the strip below a dockview tab label) the stack is constrained to a fixed height. +// Bars stay at FIT_BAR_HEIGHT_PX while they fit, then shrink to share the available space so they never +// overlap the label, no matter how many sessions are running. +const FIT_BAR_HEIGHT_PX = 4; +const FIT_BAR_GAP_PX = 1; + +type ProgressBarProps = ProgressProps & { + /** Applied to the Flex that stacks the per-session bars. Use for positioning (e.g. absolute). */ + containerProps?: FlexProps; + /** + * When set, the stacked bars are constrained to this total height (in px) and shrink to share it, so + * they never grow past the available space (e.g. the strip below a dockview tab label). + */ + fitHeightPx?: number; +}; + +type BarDescriptor = { + key: number | string; + value: number; + isIndeterminate: boolean; +}; + +const ProgressBar = ({ containerProps, fitHeightPx, ...props }: ProgressBarProps) => { const { t } = useTranslation(); const { data: queueStatus } = useGetQueueStatusQuery(); const isConnected = useStore($isConnected); - const lastProgressEvent = useStore($lastProgressEvent); + const activeProgressEvents = useStore($activeProgressEvents); const loadingModelsCount = useStore($loadingModelsCount); - const value = useMemo(() => { - if (!lastProgressEvent) { - return 0; - } - return (lastProgressEvent.percentage ?? 0) * 100; - }, [lastProgressEvent]); - - const isIndeterminate = useMemo(() => { - if (!isConnected) { - return false; - } - - if (loadingModelsCount > 0) { - return true; - } - - if (!queueStatus?.queue.in_progress) { - return false; - } - if (!lastProgressEvent) { - return true; + const bars = useMemo(() => { + // One bar per in-flight session (multi-GPU). Each session's progress is tracked independently, so + // the bars no longer jump back and forth when several sessions render simultaneously. + if (activeProgressEvents.length > 0) { + return activeProgressEvents.map((event) => ({ + key: event.item_id, + value: (event.percentage ?? 0) * 100, + isIndeterminate: isConnected && (loadingModelsCount > 0 || event.percentage === null || event.percentage === 0), + })); } - if (lastProgressEvent.percentage === null) { - return true; + // Fallback single bar: idle, or generation has started but no progress event has arrived yet (e.g. + // while models are loading). Mirrors the previous single-bar indeterminate behavior. + let isIndeterminate = false; + if (isConnected && (loadingModelsCount > 0 || Boolean(queueStatus?.queue.in_progress))) { + isIndeterminate = true; } + return [{ key: 'idle', value: 0, isIndeterminate }]; + }, [activeProgressEvents, isConnected, loadingModelsCount, queueStatus?.queue.in_progress]); - if (lastProgressEvent.percentage === 0) { - return true; + // In fit mode, cap the whole stack to the available strip and let the bars flex to share it. When the + // bars fit at their natural height the stack is shorter than the cap; once they don't, they shrink. + const isFit = fitHeightPx !== undefined; + const fitContainerProps = useMemo(() => { + if (!isFit) { + return undefined; } + const naturalHeight = bars.length * FIT_BAR_HEIGHT_PX + Math.max(0, bars.length - 1) * FIT_BAR_GAP_PX; + return { h: `${Math.min(naturalHeight, fitHeightPx)}px`, gap: `${FIT_BAR_GAP_PX}px` }; + }, [bars.length, fitHeightPx, isFit]); - return false; - }, [isConnected, lastProgressEvent, queueStatus?.queue.in_progress, loadingModelsCount]); + const fitBarProps: ProgressProps | undefined = isFit ? { flex: '1 1 0', minH: 0, h: 'auto' } : undefined; return ( - + + {bars.map((bar) => ( + + ))} + ); }; diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx new file mode 100644 index 00000000000..2980fb85c73 --- /dev/null +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx @@ -0,0 +1,249 @@ +import { + Flex, + FormControl, + FormHelperText, + FormLabel, + Tag, + TagCloseButton, + Text, + Tooltip, +} from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { selectCurrentUser } from 'features/auth/store/authSlice'; +import { toast } from 'features/toast/toast'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + useGetGenerationDeviceOptionsQuery, + useGetRuntimeConfigQuery, + useUpdateRuntimeConfigMutation, +} from 'services/api/endpoints/appInfo'; + +const AUTO = 'auto'; + +type GenerationDevicesValue = 'auto' | string[]; + +/** Drop the verbose vendor prefix so e.g. "NVIDIA GeForce RTX 3090" reads as "RTX 3090". */ +const shortenDeviceName = (name: string): string => name.replace(/^NVIDIA GeForce /, '').replace(/^NVIDIA /, ''); + +type DeviceBadge = { + /** The device identifier, or 'auto' for the special "use all GPUs" badge. */ + device: string; + /** The label shown on the badge. */ + label: string; + /** A human-readable description shown on hover (e.g. the GPU model name). */ + tooltip?: string; +}; + +export const SettingsGenerationDevices = memo(() => { + const { t } = useTranslation(); + const currentUser = useAppSelector(selectCurrentUser); + const { data: runtimeConfig } = useGetRuntimeConfigQuery(); + const { data: deviceOptions } = useGetGenerationDeviceOptionsQuery(); + const [updateRuntimeConfig, { isLoading }] = useUpdateRuntimeConfigMutation(); + + const generationDevices: GenerationDevicesValue = runtimeConfig?.config.generation_devices ?? AUTO; + const isAuto = generationDevices === AUTO; + const selectedDevices = useMemo(() => (isAuto ? [] : [...generationDevices]), [generationDevices, isAuto]); + + const canEditRuntimeConfig = runtimeConfig ? !runtimeConfig.config.multiuser || currentUser?.is_admin : false; + const isDisabled = !runtimeConfig || !canEditRuntimeConfig || isLoading; + + const save = useCallback( + async (value: GenerationDevicesValue) => { + try { + await updateRuntimeConfig({ generation_devices: value }).unwrap(); + } catch { + toast({ + id: 'SETTINGS_GENERATION_DEVICES_SAVE_FAILED', + title: t('settings.generationDevicesSaveFailed'), + status: 'error', + }); + } + }, + [t, updateRuntimeConfig] + ); + + const autoBadge = useMemo(() => ({ device: AUTO, label: t('settings.generationDevicesAuto') }), [t]); + + // Build a per-device badge (label + tooltip) keyed by device id, e.g. "cuda:0 (RTX 3090 #1)". + // Cards sharing a name get a 1-based "#N" suffix so identical GPUs can be told apart. + const deviceBadges = useMemo>(() => { + const options = deviceOptions ?? []; + const nameCounts = new Map(); + for (const option of options) { + const name = shortenDeviceName(option.name); + nameCounts.set(name, (nameCounts.get(name) ?? 0) + 1); + } + const ordinals = new Map(); + const badges: Record = {}; + for (const option of options) { + const name = shortenDeviceName(option.name); + const ordinal = (ordinals.get(name) ?? 0) + 1; + ordinals.set(name, ordinal); + const namePart = (nameCounts.get(name) ?? 0) > 1 ? `${name} #${ordinal}` : name; + badges[option.device] = { device: option.device, label: `${option.device} (${namePart})`, tooltip: option.name }; + } + return badges; + }, [deviceOptions]); + + // Fall back to a bare device id when a configured device isn't in the current options (e.g. a + // GPU that's no longer present). + const getDeviceBadge = useCallback( + (device: string): DeviceBadge => deviceBadges[device] ?? { device, label: device }, + [deviceBadges] + ); + + // The active badges: the `auto` pseudo-device, or the explicitly-selected devices in config order. + const activeBadges = useMemo(() => { + if (isAuto) { + return [autoBadge]; + } + return selectedDevices.map(getDeviceBadge); + }, [autoBadge, getDeviceBadge, isAuto, selectedDevices]); + + // The inactive badges: `auto` (when an explicit list is active) plus any unselected devices. + const inactiveBadges = useMemo(() => { + const badges: DeviceBadge[] = []; + if (!isAuto) { + badges.push(autoBadge); + } + for (const option of deviceOptions ?? []) { + if (!selectedDevices.includes(option.device)) { + badges.push(getDeviceBadge(option.device)); + } + } + return badges; + }, [autoBadge, deviceOptions, getDeviceBadge, isAuto, selectedDevices]); + + const onActivate = useCallback( + (device: string) => { + if (isDisabled) { + return; + } + if (device === AUTO) { + save(AUTO); + return; + } + // Switching from `auto` starts a fresh explicit list; otherwise append to the current selection. + const next = isAuto ? [device] : Array.from(new Set([...selectedDevices, device])); + save(next); + }, + [isAuto, isDisabled, save, selectedDevices] + ); + + const onDeactivate = useCallback( + (device: string) => { + if (isDisabled) { + return; + } + const next = selectedDevices.filter((d) => d !== device); + // Never leave an empty selection — fall back to `auto`, which is always meaningful. + save(next.length > 0 ? next : AUTO); + }, + [isDisabled, save, selectedDevices] + ); + + return ( + + {t('settings.generationDevices')} + + {activeBadges.map((badge) => ( + + ))} + + {inactiveBadges.length > 0 && ( + + {inactiveBadges.map((badge) => ( + + ))} + + )} + + {t('settings.generationDevicesHelp')}{' '} + + {t('settings.generationDevicesRestart')} + + + + ); +}); + +SettingsGenerationDevices.displayName = 'SettingsGenerationDevices'; + +type DeviceTagProps = { + badge: DeviceBadge; + isActive: boolean; + isClosable: boolean; + isDisabled: boolean; + onActivate: (device: string) => void; + onDeactivate: (device: string) => void; +}; + +const DeviceTag = memo(({ badge, isActive, isClosable, isDisabled, onActivate, onDeactivate }: DeviceTagProps) => { + const onClick = useCallback(() => { + if (isDisabled) { + return; + } + if (isActive) { + // An active, non-closable badge (the exclusive `auto`) is a no-op when clicked. + if (isClosable) { + onDeactivate(badge.device); + } + } else { + onActivate(badge.device); + } + }, [badge.device, isActive, isClosable, isDisabled, onActivate, onDeactivate]); + + const isInteractive = !isDisabled && (!isActive || isClosable); + + const tag = ( + + + {badge.label} + + {isActive && isClosable && } + + ); + + if (!badge.tooltip) { + return tag; + } + + return ( + + {tag} + + ); +}); + +DeviceTag.displayName = 'DeviceTag'; diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx index 8e331645a9b..725eede4725 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx @@ -28,6 +28,7 @@ import { useRefreshAfterResetModal } from 'features/system/components/SettingsMo import { SettingsDeveloperLogIsEnabled } from 'features/system/components/SettingsModal/SettingsDeveloperLogIsEnabled'; import { SettingsDeveloperLogLevel } from 'features/system/components/SettingsModal/SettingsDeveloperLogLevel'; import { SettingsDeveloperLogNamespaces } from 'features/system/components/SettingsModal/SettingsDeveloperLogNamespaces'; +import { SettingsGenerationDevices } from 'features/system/components/SettingsModal/SettingsGenerationDevices'; import { SettingsImageSubfolderStrategySelect } from 'features/system/components/SettingsModal/SettingsImageSubfolderStrategySelect'; import { useClearIntermediates } from 'features/system/components/SettingsModal/useClearIntermediates'; import { StickyScrollable } from 'features/system/components/StickyScrollable'; @@ -327,6 +328,7 @@ const SettingsModal = (props: { children: ReactElement<{ onClick?: () => void }> + diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx index 80f851ab7af..62246faa0f8 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx @@ -34,7 +34,11 @@ export const DockviewTabCanvasViewer = memo((props: IDockviewPanelHeaderProps {currentQueueItemDestination === 'canvas' && isGenerationInProgress && ( - + )} ); diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx index 440847d7451..285afa3a1b6 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx @@ -37,7 +37,11 @@ export const DockviewTabCanvasWorkspace = memo((props: IDockviewPanelHeaderProps {t(props.params.i18nKey)} {currentQueueItemDestination === canvasSessionId && isGenerationInProgress && ( - + )} ); diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx index 1d997caaf78..c89f682e66a 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx @@ -32,7 +32,11 @@ export const DockviewTabProgress = memo((props: IDockviewPanelHeaderProps {isGenerationInProgress && ( - + )} ); diff --git a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts index 653f458dde8..d8801fe9845 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts @@ -58,6 +58,16 @@ export const appInfoApi = api.injectEndpoints({ }), providesTags: ['AppConfig'], }), + getGenerationDeviceOptions: build.query< + paths['/api/v1/app/generation_device_options']['get']['responses']['200']['content']['application/json'], + void + >({ + query: () => ({ + url: buildAppInfoUrl('generation_device_options'), + method: 'GET', + }), + providesTags: ['FetchOnReconnect'], + }), updateRuntimeConfig: build.mutation< paths['/api/v1/app/runtime_config']['patch']['responses']['200']['content']['application/json'], paths['/api/v1/app/runtime_config']['patch']['requestBody']['content']['application/json'] @@ -149,6 +159,7 @@ export const { useGetAppDepsQuery, useGetPatchmatchStatusQuery, useGetRuntimeConfigQuery, + useGetGenerationDeviceOptionsQuery, useGetExternalProviderStatusesQuery, useGetExternalProviderConfigsQuery, useSetExternalProviderConfigMutation, diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 7864579706a..75dafa37f34 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1652,6 +1652,26 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v1/app/generation_device_options": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * Get Generation Device Options + * @description List the devices available for generation, for use with the `generation_devices` setting. + */ + get: operations["get_generation_device_options"]; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/v1/app/runtime_config": { parameters: { query?: never; @@ -12192,6 +12212,22 @@ export type components = { */ password: string; }; + /** + * GenerationDeviceOption + * @description A device that may be selected for generation. + */ + GenerationDeviceOption: { + /** + * Device + * @description The device identifier, e.g. 'cuda:0', 'mps', or 'cpu' + */ + device: string; + /** + * Name + * @description Human-readable device name + */ + name: string; + }; /** * Get Image Mask Bounding Box * @description Gets the bounding box of the given mask image. @@ -16095,6 +16131,12 @@ export type components = { * @default null */ image: components["schemas"]["ProgressImage"] | null; + /** + * Device + * @description The device processing this session, e.g. 'cuda:1' (set only when running on a CUDA GPU) + * @default null + */ + device: string | null; }; /** * InvocationStartedEvent @@ -16508,6 +16550,12 @@ export type components = { * @default auto */ device?: string; + /** + * Generation Devices + * @description Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number) + * @default auto + */ + generation_devices?: "auto" | string[]; /** * Precision * @description Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system. @@ -28271,6 +28319,11 @@ export type components = { * @description The item_id of the queue item that this item was retried from */ retried_from_item_id?: number | null; + /** + * Device + * @description The device that processed this queue item, e.g. 'cuda:1' (set only when running on a CUDA GPU) + */ + device?: string | null; /** @description The fully-populated session to be executed */ session: components["schemas"]["GraphExecutionState"]; /** @description The workflow associated with this queue item */ @@ -30995,6 +31048,11 @@ export type components = { * @description Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items. */ max_queue_history?: number | null; + /** + * Generation Devices + * @description Devices to use for parallel generation. `auto` uses every available GPU; provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices. Takes effect after restarting InvokeAI. + */ + generation_devices?: unknown; }; /** * UserDTO @@ -36457,6 +36515,26 @@ export interface operations { }; }; }; + get_generation_device_options: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["GenerationDeviceOption"][]; + }; + }; + }; + }; get_runtime_config: { parameters: { query?: never; diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index e6010ce4ca1..da8f114a0e2 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -43,7 +43,13 @@ import { createWorkflowExecutionCoordinator } from 'services/events/workflowExec import type { Socket } from 'socket.io-client'; import type { JsonObject } from 'type-fest'; -import { $lastProgressEvent, $loadingModelsCount } from './stores'; +import { + $lastProgressEvent, + $loadingModelsCount, + clearAllProgressEvents, + clearProgressEvent, + setProgressEvent, +} from './stores'; const log = logger('events'); @@ -86,6 +92,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis socket.emit('subscribe_queue', { queue_id: 'default' }); socket.emit('subscribe_bulk_download', { bulk_download_id: 'default' }); $lastProgressEvent.set(null); + clearAllProgressEvents(); $loadingModelsCount.set(0); }); @@ -93,6 +100,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.debug('Connect error'); setIsConnected(false); $lastProgressEvent.set(null); + clearAllProgressEvents(); $loadingModelsCount.set(0); if (error && error.message) { const data: string | undefined = (error as unknown as { data: string | undefined }).data; @@ -111,6 +119,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.debug('Disconnected'); workflowExecutionCoordinator.cancelPendingWorkflowReconciliations(); $lastProgressEvent.set(null); + clearAllProgressEvents(); $loadingModelsCount.set(0); setIsConnected(false); }); @@ -140,6 +149,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.trace({ data } as JsonObject, _message); $lastProgressEvent.set(data); + setProgressEvent(data); }); socket.on('invocation_error', (data) => { @@ -448,11 +458,14 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis } // If the queue item is completed, failed, or cancelled, we want to clear the last progress event $lastProgressEvent.set(null); + // Also remove this session's per-item progress so its stacked progress bar disappears. + clearProgressEvent(item_id); } }); socket.on('queue_cleared', (data) => { log.debug({ data }, 'Queue cleared'); + clearAllProgressEvents(); dispatch( queueApi.util.invalidateTags([ 'SessionQueueStatus', diff --git a/invokeai/frontend/web/src/services/events/stores.ts b/invokeai/frontend/web/src/services/events/stores.ts index 180f4a3a636..7c7630e2019 100644 --- a/invokeai/frontend/web/src/services/events/stores.ts +++ b/invokeai/frontend/web/src/services/events/stores.ts @@ -1,5 +1,5 @@ import { round } from 'es-toolkit/compat'; -import { atom, computed } from 'nanostores'; +import { atom, computed, map } from 'nanostores'; import type { S } from 'services/api/types'; import type { AppSocket } from 'services/events/types'; @@ -8,6 +8,33 @@ export const $isConnected = atom(false); export const $lastProgressEvent = atom(null); export const $loadingModelsCount = atom(0); +/** + * Live progress events keyed by queue item id. Unlike `$lastProgressEvent` (a single global value that + * is overwritten by whichever session reported last), this tracks each in-flight session separately so + * the UI can render one progress bar per concurrent session (multi-GPU). Entries are added as progress + * events arrive and removed when the session reaches a terminal state. + */ +const $progressEvents = map>({}); + +/** In-flight sessions sorted by queue item id, for a stable top-to-bottom bar order. */ +export const $activeProgressEvents = computed($progressEvents, (events) => + Object.values(events) + .filter((event): event is S['InvocationProgressEvent'] => event !== undefined) + .sort((a, b) => a.item_id - b.item_id) +); + +export const setProgressEvent = (event: S['InvocationProgressEvent']) => { + $progressEvents.setKey(event.item_id, event); +}; + +export const clearProgressEvent = (itemId: number) => { + $progressEvents.setKey(itemId, undefined); +}; + +export const clearAllProgressEvents = () => { + $progressEvents.set({}); +}; + export const $lastProgressMessage = computed($lastProgressEvent, (val) => { if (!val) { return null; diff --git a/scripts/multigpu_ram_driver.py b/scripts/multigpu_ram_driver.py new file mode 100755 index 00000000000..9d985b4d98c --- /dev/null +++ b/scripts/multigpu_ram_driver.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python +"""Driver to exercise the multi-GPU shared-RAM model cache under real, concurrent generations. + +It repeatedly enqueues N batches at once (so the multi-GPU session processor runs them in parallel +across devices), polls the queue until each round drains, and samples the InvokeAI server process's +RAM (RSS) the whole time. It then reports: + + - baseline (idle) RSS, + - peak RSS during generation (this is the text/reference-encode spike you care about), and + - idle RSS after each round -> a leak verdict (does RAM return to baseline, or creep up?). + +This automates the two manual checks from the test plan: + #1 "dual concurrent encode RAM" -> run with --rounds 1 --pairs <#gpus> and read the peak. + #5 "leak check over many gens" -> run with --rounds 25+ and read the idle drift. + +------------------------------------------------------------------------------------------------ +Getting a batch file +------------------------------------------------------------------------------------------------ +The script needs the exact body InvokeAI's UI sends to enqueue a generation. Easiest way to capture +it: + 1. Open InvokeAI in the browser with devtools -> Network open. + 2. Click Invoke once. + 3. Find the POST to `.../queue/default/enqueue_batch`, copy its JSON request body, save to a file + (e.g. batch.json). It looks like {"prepend": false, "batch": {"graph": {...}, "runs": 1}}. + +The script bust the node cache by default (sets use_cache=false on every node and randomizes any +"seed" fields) so every submission actually runs the model instead of returning a cached result. + +------------------------------------------------------------------------------------------------ +Examples +------------------------------------------------------------------------------------------------ + # Headline dual-GPU encode RAM (2 GPUs -> 2 concurrent jobs), one round: + python scripts/multigpu_ram_driver.py --graph batch.json --pairs 2 --rounds 1 + + # Leak soak: 30 rounds of 2 concurrent jobs, save timeline for plotting: + python scripts/multigpu_ram_driver.py --graph batch.json --pairs 2 --rounds 30 --csv ram.csv + + # If PID auto-detection fails, point it at the server explicitly: + python scripts/multigpu_ram_driver.py --graph batch.json --pid 12345 +""" + +from __future__ import annotations + +import argparse +import copy +import json +import random +import sys +import threading +import time +import urllib.error +import urllib.parse +import urllib.request +from dataclasses import dataclass, field + +import psutil + + +# -------------------------------------------------------------------------------------------------- +# HTTP helpers (stdlib only) +# -------------------------------------------------------------------------------------------------- +def _request(method: str, url: str, body: dict | None = None, timeout: float = 60.0) -> dict: + data = json.dumps(body).encode() if body is not None else None + req = urllib.request.Request(url, data=data, headers={"Content-Type": "application/json"}, method=method) + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + raw = resp.read() + return json.loads(raw) if raw else {} + except urllib.error.HTTPError as e: + detail = e.read().decode(errors="replace") + raise SystemExit(f"HTTP {e.code} on {method} {url}\n{detail}") from e + except urllib.error.URLError as e: + raise SystemExit(f"Could not reach {url}: {e.reason}. Is the server running?") from e + + +def enqueue(base: str, queue_id: str, body: dict) -> dict: + return _request("POST", f"{base}/api/v1/queue/{queue_id}/enqueue_batch", body) + + +def queue_counts(base: str, queue_id: str) -> tuple[int, int]: + """Return (pending, in_progress), searching the response defensively for those keys.""" + resp = _request("GET", f"{base}/api/v1/queue/{queue_id}/status") + # The status payload nests the queue counts under "queue"; fall back to top-level. + node = resp.get("queue", resp) if isinstance(resp, dict) else {} + return int(node.get("pending", 0)), int(node.get("in_progress", 0)) + + +# -------------------------------------------------------------------------------------------------- +# Batch preparation +# -------------------------------------------------------------------------------------------------- +def normalize_body(loaded: dict) -> dict: + """Accept either the full {"prepend":..., "batch": {...}} body or a bare Batch ({"graph":...}).""" + if "batch" in loaded: + return copy.deepcopy(loaded) + if "graph" in loaded: + return {"prepend": False, "batch": copy.deepcopy(loaded)} + raise SystemExit("Batch file must contain either a top-level 'batch' or 'graph' key.") + + +def bust_cache(body: dict, mutate_seed: bool, disable_cache: bool) -> dict: + """Return a copy of the body with the node cache busted so the submission really computes.""" + body = copy.deepcopy(body) + nodes = body.get("batch", {}).get("graph", {}).get("nodes", {}) + if not isinstance(nodes, dict): + return body + for node in nodes.values(): + if not isinstance(node, dict): + continue + if disable_cache: + node["use_cache"] = False + if mutate_seed and "seed" in node: + node["seed"] = random.randint(0, 2**31 - 1) + return body + + +# -------------------------------------------------------------------------------------------------- +# Process discovery + RSS sampling +# -------------------------------------------------------------------------------------------------- +def find_server_pid(port: int) -> int: + """Best-effort: find the PID listening on `port`, else a process whose cmdline looks like the server.""" + for conn in psutil.net_connections(kind="inet"): + if conn.laddr and conn.laddr.port == port and conn.pid: + return conn.pid + needles = ("invokeai-web", "invokeai.app.run_app", "invokeai_web", "uvicorn") + for proc in psutil.process_iter(["pid", "cmdline"]): + cmd = " ".join(proc.info.get("cmdline") or []) + if any(n in cmd for n in needles): + return proc.info["pid"] + raise SystemExit(f"Could not auto-detect the InvokeAI server PID on port {port}. Pass --pid explicitly.") + + +def tree_rss(proc: psutil.Process, use_uss: bool) -> int: + """RSS (or USS) of the process and its children, in bytes.""" + procs = [proc] + proc.children(recursive=True) + total = 0 + for p in procs: + try: + if use_uss: + total += p.memory_full_info().uss + else: + total += p.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + return total + + +@dataclass +class Sampler: + proc: psutil.Process + hz: float + use_uss: bool + samples: list[tuple[float, int]] = field(default_factory=list) + _stop: threading.Event = field(default_factory=threading.Event) + _thread: threading.Thread | None = None + + def start(self) -> None: + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def _run(self) -> None: + period = 1.0 / self.hz + while not self._stop.is_set(): + self.samples.append((time.monotonic(), tree_rss(self.proc, self.use_uss))) + time.sleep(period) + + def stop(self) -> None: + self._stop.set() + if self._thread: + self._thread.join(timeout=2.0) + + def current(self) -> int: + return self.samples[-1][1] if self.samples else tree_rss(self.proc, self.use_uss) + + def peak_between(self, t0: float, t1: float) -> int: + vals = [rss for t, rss in self.samples if t0 <= t <= t1] + return max(vals) if vals else 0 + + +# -------------------------------------------------------------------------------------------------- +# Round loop +# -------------------------------------------------------------------------------------------------- +GB = 1024**3 + + +def gb(n: int) -> float: + return n / GB + + +def wait_drained(base: str, queue_id: str, timeout: float) -> None: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + pending, in_progress = queue_counts(base, queue_id) + if pending == 0 and in_progress == 0: + return + time.sleep(0.5) + raise SystemExit(f"Queue did not drain within {timeout}s. Aborting.") + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--graph", required=True, help="Path to a captured enqueue_batch body (JSON).") + ap.add_argument("--url", default="http://127.0.0.1:9090", help="Server base URL.") + ap.add_argument("--queue-id", default="default") + ap.add_argument("--pairs", type=int, default=2, help="Concurrent batches per round (>= #GPUs).") + ap.add_argument("--rounds", type=int, default=1, help="Number of rounds (use 25+ for leak soak).") + ap.add_argument("--pid", type=int, default=None, help="Server PID (auto-detected if omitted).") + ap.add_argument("--hz", type=float, default=10.0, help="RSS sampling rate.") + ap.add_argument("--uss", action="store_true", help="Sample USS instead of RSS (more accurate, slower).") + ap.add_argument("--settle", type=float, default=4.0, help="Seconds to wait after each round for RAM to release.") + ap.add_argument("--timeout", type=float, default=1800.0, help="Per-round drain timeout (s).") + ap.add_argument("--warmup", action="store_true", help="Run one un-measured round first (loads models from disk).") + ap.add_argument("--keep-cache", action="store_true", help="Do NOT set use_cache=false on nodes.") + ap.add_argument("--no-seed-mutate", action="store_true", help="Do NOT randomize node 'seed' fields.") + ap.add_argument("--csv", default=None, help="Write the full (t, rss_gb) timeline here.") + args = ap.parse_args() + + with open(args.graph) as f: + body = normalize_body(json.load(f)) + + base = args.url.rstrip("/") + port = urllib.parse.urlparse(base).port or 9090 + pid = args.pid or find_server_pid(port) + proc = psutil.Process(pid) + print(f"Server PID {pid}: {' '.join(proc.cmdline()[:3])} ...") + print(f"Metric: {'USS' if args.uss else 'RSS'} (process tree) | pairs/round={args.pairs} rounds={args.rounds}") + + def submit_round() -> tuple[float, float]: + t0 = time.monotonic() + for _ in range(args.pairs): + prepared = bust_cache(body, mutate_seed=not args.no_seed_mutate, disable_cache=not args.keep_cache) + res = enqueue(base, args.queue_id, prepared) + if res.get("enqueued", 0) < 1: + raise SystemExit(f"Enqueue returned nothing useful: {res}") + wait_drained(base, args.queue_id, args.timeout) + return t0, time.monotonic() + + sampler = Sampler(proc=proc, hz=args.hz, use_uss=args.uss) + sampler.start() + try: + if args.warmup: + print("Warmup round (not measured)...") + submit_round() + time.sleep(args.settle) + + time.sleep(2.0) # settle before baseline + baseline = sampler.current() + print(f"\nBaseline idle {('USS' if args.uss else 'RSS')}: {gb(baseline):.2f} GB\n") + print(f"{'round':>5} {'peak_GB':>9} {'Δpeak_GB':>9} {'idle_after_GB':>14} {'Δidle_GB':>9}") + + idle_after_first = None + overall_peak = baseline + for r in range(1, args.rounds + 1): + t0, t1 = submit_round() + peak = sampler.peak_between(t0, t1) + overall_peak = max(overall_peak, peak) + time.sleep(args.settle) + idle_after = sampler.current() + if idle_after_first is None: + idle_after_first = idle_after + print( + f"{r:>5} {gb(peak):>9.2f} {gb(peak - baseline):>9.2f} " + f"{gb(idle_after):>14.2f} {gb(idle_after - baseline):>9.2f}" + ) + finally: + sampler.stop() + + # Summary + idle_drift = sampler.current() - (idle_after_first or baseline) + print("\n--- Summary ---") + print(f"Baseline idle: {gb(baseline):.2f} GB") + print(f"Overall peak: {gb(overall_peak):.2f} GB (Δ {gb(overall_peak - baseline):+.2f} GB over baseline)") + print(f"Idle drift (leak): {gb(idle_drift):+.2f} GB across {args.rounds} rounds") + verdict = "LIKELY LEAK" if idle_drift > 0.5 * GB else "no leak detected" + print(f"Leak verdict: {verdict} (threshold 0.50 GB)") + print("Interpretation: peak Δ should be ~1x the encoder size (not Nx). Idle drift should be ~0.") + + if args.csv: + t_start = sampler.samples[0][0] if sampler.samples else 0.0 + with open(args.csv, "w") as f: + f.write("t_seconds,rss_gb\n") + for t, rss in sampler.samples: + f.write(f"{t - t_start:.3f},{gb(rss):.4f}\n") + print(f"\nTimeline written to {args.csv} ({len(sampler.samples)} samples).") + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nInterrupted.", file=sys.stderr) + sys.exit(130) diff --git a/tests/app/routers/test_app_info.py b/tests/app/routers/test_app_info.py index da493cee457..96eb23f1342 100644 --- a/tests/app/routers/test_app_info.py +++ b/tests/app/routers/test_app_info.py @@ -225,6 +225,64 @@ def test_update_runtime_config_image_subfolder_strategy_schema() -> None: } +def test_update_runtime_config_persists_generation_devices( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + + response = client.patch("/api/v1/app/runtime_config", json={"generation_devices": ["cuda:0", "cuda:1"]}) + + assert response.status_code == 200 + assert response.json()["config"]["generation_devices"] == ["cuda:0", "cuda:1"] + + config_path = get_config().config_file_path + file_config = load_and_migrate_config(config_path) + assert file_config.generation_devices == ["cuda:0", "cuda:1"] + assert get_config().generation_devices == ["cuda:0", "cuda:1"] + + # "auto" round-trips back to the default. + response = client.patch("/api/v1/app/runtime_config", json={"generation_devices": "auto"}) + assert response.status_code == 200 + assert response.json()["config"]["generation_devices"] == "auto" + assert get_config().generation_devices == "auto" + + +def test_update_runtime_config_rejects_invalid_generation_device( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + + response = client.patch("/api/v1/app/runtime_config", json={"generation_devices": ["gpu0"]}) + + assert response.status_code == 422 + + +def test_update_runtime_config_rejects_null_generation_devices( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + + response = client.patch("/api/v1/app/runtime_config", json={"generation_devices": None}) + + assert response.status_code == 422 + + +def test_get_generation_device_options_lists_devices( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: + monkeypatch.setattr(app_info.torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(app_info.torch.cuda, "device_count", lambda: 2) + monkeypatch.setattr(app_info.torch.cuda, "get_device_name", lambda index: f"GPU {index}") + + response = client.get("/api/v1/app/generation_device_options") + + assert response.status_code == 200 + assert response.json() == [ + {"device": "cuda:0", "name": "GPU 0"}, + {"device": "cuda:1", "name": "GPU 1"}, + ] + + def test_update_runtime_config_reads_and_writes_yaml_under_config_lock( monkeypatch: Any, mock_invoker: Invoker, client: TestClient ) -> None: diff --git a/tests/app/services/config/test_config_generation_devices.py b/tests/app/services/config/test_config_generation_devices.py new file mode 100644 index 00000000000..e589b35dd3d --- /dev/null +++ b/tests/app/services/config/test_config_generation_devices.py @@ -0,0 +1,38 @@ +"""Validation tests for the multi-GPU `generation_devices` config field.""" + +import pytest +from pydantic import ValidationError + +from invokeai.app.services.config.config_default import InvokeAIAppConfig + + +@pytest.mark.parametrize( + "value", + [ + "auto", + ["cuda:0"], + ["cuda:0", "cuda:1"], + ["cpu"], + ["mps"], + ["cuda"], + ], +) +def test_valid_generation_devices(value): + cfg = InvokeAIAppConfig(generation_devices=value) + assert cfg.generation_devices == value + + +def test_non_auto_string_is_rejected(): + # A bare string (other than "auto") would otherwise be iterated character-by-character. + with pytest.raises(ValidationError): + InvokeAIAppConfig(generation_devices="cuda:0") + + +def test_empty_list_is_rejected(): + with pytest.raises(ValidationError): + InvokeAIAppConfig(generation_devices=[]) + + +def test_invalid_device_name_is_rejected(): + with pytest.raises(ValidationError): + InvokeAIAppConfig(generation_devices=["gpu0"]) diff --git a/tests/app/services/model_load/test_model_load_device_routing.py b/tests/app/services/model_load/test_model_load_device_routing.py new file mode 100644 index 00000000000..85b3868b92f --- /dev/null +++ b/tests/app/services/model_load/test_model_load_device_routing.py @@ -0,0 +1,96 @@ +"""Tests that ModelLoadService routes to the per-device cache for the calling thread (multi-GPU).""" + +import threading +from collections.abc import Iterator + +import pytest +import torch + +from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config +from invokeai.app.services.model_load.model_load_default import ModelLoadService +from invokeai.backend.util.devices import TorchDevice + + +@pytest.fixture(autouse=True) +def restore_global_device() -> Iterator[None]: + """`get_config()` is a process-wide singleton; restore `device` so we don't leak a CUDA device + into later CPU-only tests (e.g. the model-loading suite on the CUDA-less CI runner).""" + config = get_config() + original_device = config.device + try: + yield + finally: + config.device = original_device + TorchDevice.clear_session_device() + + +class _FakeCache: + """Stand-in for ModelCache; ModelLoadService only needs `.execution_device` for keying.""" + + def __init__(self, device: str): + self.execution_device = torch.device(device) + + +def _build_service() -> tuple[ModelLoadService, _FakeCache, _FakeCache]: + cache0 = _FakeCache("cuda:0") + cache1 = _FakeCache("cuda:1") + service = ModelLoadService( + app_config=InvokeAIAppConfig(), + ram_cache=cache0, # type: ignore[arg-type] + ram_caches={"cuda:0": cache0, "cuda:1": cache1}, # type: ignore[arg-type] + ) + return service, cache0, cache1 + + +def test_ram_cache_routes_to_pinned_device(): + """A thread pinned to cuda:1 resolves to that device's cache; the default thread to cuda:0.""" + service, cache0, cache1 = _build_service() + + # The default thread has no session device; point config.device at cuda:0 so it resolves there. + get_config().device = "cuda:0" + assert service.ram_cache is cache0 + + results: dict[str, object] = {} + + def worker(): + TorchDevice.set_session_device("cuda:1") + try: + results["cache"] = service.ram_cache + finally: + TorchDevice.clear_session_device() + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert results["cache"] is cache1 + # Main thread is unaffected by the worker's pinning. + assert service.ram_cache is cache0 + + +def test_ram_caches_exposes_all_devices(): + service, cache0, cache1 = _build_service() + caches = service.ram_caches + assert set(caches.keys()) == {"cuda:0", "cuda:1"} + assert caches["cuda:0"] is cache0 + assert caches["cuda:1"] is cache1 + + +def test_unknown_device_falls_back_to_default(): + """A thread pinned to a device with no cache falls back to the default cache.""" + service, cache0, _ = _build_service() + + results: dict[str, object] = {} + + def worker(): + TorchDevice.set_session_device("cuda:7") + try: + results["cache"] = service.ram_cache + finally: + TorchDevice.clear_session_device() + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert results["cache"] is cache0 diff --git a/tests/app/services/session_processor/test_session_processor_cancel_guard.py b/tests/app/services/session_processor/test_session_processor_cancel_guard.py new file mode 100644 index 00000000000..b99f19a3068 --- /dev/null +++ b/tests/app/services/session_processor/test_session_processor_cancel_guard.py @@ -0,0 +1,51 @@ +"""Tests for the post-dequeue cancellation guard that closes the multi-GPU cancel-loss race. + +A cancellation can mark a queue item terminal in the window between dequeue claiming it and the +worker recording `queue_item` (so the status-changed handler can't set the worker's cancel_event). +`_is_queue_item_terminal` is the fresh DB re-check the worker uses to skip running such an item. +""" + +from types import SimpleNamespace + +import pytest + +from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor +from invokeai.app.services.session_queue.session_queue_common import SessionQueueItemNotFoundError + + +class _Queue: + def __init__(self, status: str | None = None, raise_not_found: bool = False): + self._status = status + self._raise = raise_not_found + + def get_queue_item(self, item_id: int): + if self._raise: + raise SessionQueueItemNotFoundError("gone") + return SimpleNamespace(item_id=item_id, status=self._status) + + +def _processor_with_queue(queue: _Queue) -> DefaultSessionProcessor: + processor = DefaultSessionProcessor() + processor._invoker = SimpleNamespace(services=SimpleNamespace(session_queue=queue)) # type: ignore[attr-defined] + return processor + + +@pytest.mark.parametrize( + ("status", "expected"), + [ + ("in_progress", False), + ("pending", False), + ("canceled", True), + ("failed", True), + ("completed", True), + ], +) +def test_is_queue_item_terminal_status(status: str, expected: bool): + processor = _processor_with_queue(_Queue(status=status)) + assert processor._is_queue_item_terminal(1) is expected + + +def test_is_queue_item_terminal_treats_missing_as_terminal(): + # A deleted row (e.g. queue cleared during the race) should be treated as terminal, not run. + processor = _processor_with_queue(_Queue(raise_not_found=True)) + assert processor._is_queue_item_terminal(1) is True diff --git a/tests/app/services/session_queue/test_session_queue_dequeue_concurrency.py b/tests/app/services/session_queue/test_session_queue_dequeue_concurrency.py new file mode 100644 index 00000000000..2d61bf0f37a --- /dev/null +++ b/tests/app/services/session_queue/test_session_queue_dequeue_concurrency.py @@ -0,0 +1,90 @@ +"""Tests that concurrent dequeue() calls (multi-GPU session workers) never claim the same item twice.""" + +import threading +import uuid + +import pytest + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue +from invokeai.app.services.shared.graph import Graph, GraphExecutionState +from tests.test_nodes import PromptTestInvocation + + +@pytest.fixture +def session_queue(mock_invoker: Invoker) -> SqliteSessionQueue: + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + return queue + + +def _insert_queue_item(session_queue: SqliteSessionQueue, user_id: str = "system") -> int: + graph = Graph() + graph.add_node(PromptTestInvocation(id="prompt", prompt="test")) + session = GraphExecutionState(graph=graph) + session_json = session.model_dump_json(warnings=False, exclude_none=True) + batch_id = str(uuid.uuid4()) + with session_queue._db.transaction() as cursor: + cursor.execute( + """--sql + INSERT INTO session_queue ( + queue_id, session, session_id, batch_id, field_values, priority, + workflow, origin, destination, retried_from_item_id, user_id + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ("default", session_json, session.id, batch_id, None, 0, None, None, None, None, user_id), + ) + return cursor.lastrowid + + +def test_concurrent_dequeue_never_claims_same_item_twice(session_queue: SqliteSessionQueue) -> None: + item_count = 50 + worker_count = 8 + for _ in range(item_count): + _insert_queue_item(session_queue) + + claimed_ids: list[int] = [] + claimed_lock = threading.Lock() + start_barrier = threading.Barrier(worker_count) + + def worker() -> None: + # Release all workers at once to maximize contention on the dequeue path. + start_barrier.wait() + while True: + item = session_queue.dequeue() + if item is None: + break + with claimed_lock: + claimed_ids.append(item.item_id) + + threads = [threading.Thread(target=worker) for _ in range(worker_count)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Every item is claimed exactly once: no duplicates, none lost. + assert len(claimed_ids) == item_count + assert len(set(claimed_ids)) == item_count + + +def test_dequeue_records_processing_device(session_queue: SqliteSessionQueue) -> None: + _insert_queue_item(session_queue) + + item = session_queue.dequeue(device="cuda:1") + assert item is not None + assert item.device == "cuda:1" + + # The device persists across later status transitions (which pass device=None). + completed = session_queue._set_queue_item_status(item.item_id, "completed") + assert completed.device == "cuda:1" + + +def test_dequeue_without_device_leaves_device_unset(session_queue: SqliteSessionQueue) -> None: + _insert_queue_item(session_queue) + + item = session_queue.dequeue() + assert item is not None + assert item.device is None diff --git a/tests/app/services/session_queue/test_session_queue_multigpu_cancel.py b/tests/app/services/session_queue/test_session_queue_multigpu_cancel.py new file mode 100644 index 00000000000..0d97ad6ab03 --- /dev/null +++ b/tests/app/services/session_queue/test_session_queue_multigpu_cancel.py @@ -0,0 +1,142 @@ +"""Regression tests for multi-GPU bulk cancellation. + +With one session-processor worker per device, several queue items can be `in_progress` at the same +time. The bulk-cancel APIs must cancel ALL matching in-progress items (each emitting a cancel event +so its worker stops), not just the single `get_current()` item. See JPPhoto's review on PR #9263. +""" + +import uuid + +import pytest + +from invokeai.app.services.events.events_common import QueueItemStatusChangedEvent +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.session_queue.session_queue_common import SessionQueueItemNotFoundError +from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue +from invokeai.app.services.shared.graph import Graph, GraphExecutionState +from tests.test_nodes import PromptTestInvocation, TestEventService + + +@pytest.fixture +def session_queue(mock_invoker: Invoker) -> SqliteSessionQueue: + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + return queue + + +def _insert( + session_queue: SqliteSessionQueue, + batch_id: str, + destination: str | None = None, + user_id: str = "system", + queue_id: str = "default", +) -> int: + graph = Graph() + graph.add_node(PromptTestInvocation(id="prompt", prompt="test")) + session = GraphExecutionState(graph=graph) + session_json = session.model_dump_json(warnings=False, exclude_none=True) + with session_queue._db.transaction() as cursor: + cursor.execute( + """--sql + INSERT INTO session_queue ( + queue_id, session, session_id, batch_id, field_values, priority, + workflow, origin, destination, retried_from_item_id, user_id + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + (queue_id, session_json, session.id, batch_id, None, 0, None, None, destination, None, user_id), + ) + return cursor.lastrowid + + +def _canceled_event_item_ids(mock_invoker: Invoker) -> set[int]: + event_bus: TestEventService = mock_invoker.services.events + return { + e.item_id for e in event_bus.events if isinstance(e, QueueItemStatusChangedEvent) and e.status == "canceled" + } + + +def _dequeue_two_on_separate_devices(session_queue: SqliteSessionQueue) -> tuple[int, int]: + a = session_queue.dequeue(device="cuda:0") + b = session_queue.dequeue(device="cuda:1") + assert a is not None and b is not None + assert a.status == "in_progress" and b.status == "in_progress" + return a.item_id, b.item_id + + +def test_cancel_by_batch_ids_cancels_all_in_progress(session_queue: SqliteSessionQueue, mock_invoker: Invoker): + batch_id = str(uuid.uuid4()) + _insert(session_queue, batch_id=batch_id) + _insert(session_queue, batch_id=batch_id) + id_a, id_b = _dequeue_two_on_separate_devices(session_queue) + + result = session_queue.cancel_by_batch_ids("default", [batch_id]) + + assert result.canceled == 2 + assert session_queue.get_queue_item(id_a).status == "canceled" + assert session_queue.get_queue_item(id_b).status == "canceled" + # Each worker must have received a cancel event for its item. + assert {id_a, id_b} <= _canceled_event_item_ids(mock_invoker) + + +def test_cancel_by_destination_cancels_all_in_progress(session_queue: SqliteSessionQueue, mock_invoker: Invoker): + _insert(session_queue, batch_id=str(uuid.uuid4()), destination="canvas") + _insert(session_queue, batch_id=str(uuid.uuid4()), destination="canvas") + id_a, id_b = _dequeue_two_on_separate_devices(session_queue) + + result = session_queue.cancel_by_destination("default", "canvas") + + assert result.canceled == 2 + assert session_queue.get_queue_item(id_a).status == "canceled" + assert session_queue.get_queue_item(id_b).status == "canceled" + assert {id_a, id_b} <= _canceled_event_item_ids(mock_invoker) + + +def test_cancel_by_queue_id_cancels_all_in_progress(session_queue: SqliteSessionQueue, mock_invoker: Invoker): + _insert(session_queue, batch_id=str(uuid.uuid4())) + _insert(session_queue, batch_id=str(uuid.uuid4())) + id_a, id_b = _dequeue_two_on_separate_devices(session_queue) + + result = session_queue.cancel_by_queue_id("default") + + assert result.canceled == 2 + assert session_queue.get_queue_item(id_a).status == "canceled" + assert session_queue.get_queue_item(id_b).status == "canceled" + assert {id_a, id_b} <= _canceled_event_item_ids(mock_invoker) + + +def test_delete_by_destination_cancels_all_in_progress(session_queue: SqliteSessionQueue, mock_invoker: Invoker): + """delete_by_destination must signal every running worker (not just get_current()) before + deleting their rows, or the un-canceled workers keep running and then fail to update a deleted + row.""" + _insert(session_queue, batch_id=str(uuid.uuid4()), destination="canvas") + _insert(session_queue, batch_id=str(uuid.uuid4()), destination="canvas") + id_a, id_b = _dequeue_two_on_separate_devices(session_queue) + + result = session_queue.delete_by_destination("default", "canvas") + + assert result.deleted == 2 + # Both in-progress workers were signaled to cancel before deletion. + assert {id_a, id_b} <= _canceled_event_item_ids(mock_invoker) + # Rows are gone. + for item_id in (id_a, id_b): + with pytest.raises(SessionQueueItemNotFoundError): + session_queue.get_queue_item(item_id) + + +def test_cancel_by_batch_ids_respects_user_scope(session_queue: SqliteSessionQueue, mock_invoker: Invoker): + """A user-scoped cancel must not cancel another user's in-progress item in the same batch.""" + batch_id = str(uuid.uuid4()) + _insert(session_queue, batch_id=batch_id, user_id="alice") + _insert(session_queue, batch_id=batch_id, user_id="bob") + alice_item = session_queue.dequeue(device="cuda:0") + bob_item = session_queue.dequeue(device="cuda:1") + assert alice_item is not None and bob_item is not None + + result = session_queue.cancel_by_batch_ids("default", [batch_id], user_id="alice") + + assert result.canceled == 1 + assert session_queue.get_queue_item(alice_item.item_id).status == "canceled" + assert session_queue.get_queue_item(bob_item.item_id).status == "in_progress" + assert _canceled_event_item_ids(mock_invoker) == {alice_item.item_id} diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py new file mode 100644 index 00000000000..79a34ff0d96 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py @@ -0,0 +1,126 @@ +"""Tests for sharing a single canonical CPU copy of model weights across per-device cached models. + +These exercise the multi-GPU RAM-dedup path: two cached models built for the same cache key (as +would happen on two GPUs) must end up aliasing one set of CPU tensors instead of holding two +copies. They run on CPU — the wrapper constructors never touch VRAM, so no GPU is required. +""" + +import pytest +import torch + +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import ( + CachedModelOnlyFullLoad, +) +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( + CachedModelWithPartialLoad, +) +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore +from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule + +CPU = torch.device("cpu") + + +def _data_ptrs(state_dict: dict[str, torch.Tensor]) -> dict[str, int]: + return {k: v.data_ptr() for k, v in state_dict.items()} + + +def test_partial_load_shares_cpu_weights_across_devices(): + store = SharedCpuWeightsStore() + # Two independently-initialised modules (distinct weights), as two devices would build. + model_a = DummyModule() + model_b = DummyModule() + a_ptrs = _data_ptrs(model_a.state_dict()) + + cached_a = CachedModelWithPartialLoad(model_a, CPU, keep_ram_copy=True, shared_store=store, cache_key="m") + cached_b = CachedModelWithPartialLoad(model_b, CPU, keep_ram_copy=True, shared_store=store, cache_key="m") + + # Both cached models expose the SAME canonical CPU tensors. + assert cached_a.get_cpu_state_dict() is cached_b.get_cpu_state_dict() + assert _data_ptrs(cached_b.get_cpu_state_dict()) == a_ptrs + + # model_b's own parameters were re-pointed at the canonical tensors (b's originals are gone). + assert _data_ptrs(model_b.state_dict()) == a_ptrs + + assert store.refcount("m") == 2 + # Counted once despite two devices holding it. + assert store.total_bytes_in_use() == cached_a.total_bytes() + + +def test_full_load_shares_cpu_weights_across_devices(): + store = SharedCpuWeightsStore() + model_a = DummyModule() + model_b = DummyModule() + a_ptrs = _data_ptrs(model_a.state_dict()) + + cached_a = CachedModelOnlyFullLoad( + model_a, CPU, total_bytes=100, keep_ram_copy=True, shared_store=store, cache_key="m" + ) + cached_b = CachedModelOnlyFullLoad( + model_b, CPU, total_bytes=100, keep_ram_copy=True, shared_store=store, cache_key="m" + ) + + assert cached_a.get_cpu_state_dict() is cached_b.get_cpu_state_dict() + assert _data_ptrs(model_b.state_dict()) == a_ptrs + assert store.refcount("m") == 2 + + +def test_release_shared_weights_frees_at_last_reference(): + store = SharedCpuWeightsStore() + cached_a = CachedModelWithPartialLoad(DummyModule(), CPU, keep_ram_copy=True, shared_store=store, cache_key="m") + cached_b = CachedModelWithPartialLoad(DummyModule(), CPU, keep_ram_copy=True, shared_store=store, cache_key="m") + assert store.refcount("m") == 2 + + cached_a.release_shared_weights() + assert store.refcount("m") == 1 + assert "m" in store + + cached_b.release_shared_weights() + assert "m" not in store + assert store.total_bytes_in_use() == 0 + + +def test_release_shared_weights_is_idempotent(): + store = SharedCpuWeightsStore() + cached = CachedModelWithPartialLoad(DummyModule(), CPU, keep_ram_copy=True, shared_store=store, cache_key="m") + cached.release_shared_weights() + cached.release_shared_weights() # second call must not double-decrement + assert store.refcount("m") == 0 + assert "m" not in store + + +def test_no_store_means_no_sharing_and_no_release_error(): + # Without a shared store, behaviour is unchanged: each model keeps its own CPU state dict. + model = DummyModule() + cached = CachedModelWithPartialLoad(model, CPU, keep_ram_copy=True) + assert cached.get_cpu_state_dict() is not None + # release is a safe no-op when nothing was shared. + cached.release_shared_weights() + + +def test_keep_ram_copy_false_does_not_touch_store(): + store = SharedCpuWeightsStore() + cached = CachedModelWithPartialLoad(DummyModule(), CPU, keep_ram_copy=False, shared_store=store, cache_key="m") + assert cached.get_cpu_state_dict() is None + assert "m" not in store + assert store.refcount("m") == 0 + + +class _RepointFailsModule(DummyModule): + """A model whose load_state_dict raises, to simulate a re-point failure during construction.""" + + def load_state_dict(self, *args, **kwargs): # type: ignore[override] + raise RuntimeError("simulated re-point failure") + + +def test_acquire_is_released_if_repoint_fails(): + # First device registers the canonical weights (refcount 1). + store = SharedCpuWeightsStore() + CachedModelWithPartialLoad(DummyModule(), CPU, keep_ram_copy=True, shared_store=store, cache_key="m") + assert store.refcount("m") == 1 + + # Second device adopts the canonical copy, but its re-point throws. The just-acquired reference + # must be released so the store's refcount is not leaked (the wrapper never enters the cache). + with pytest.raises(RuntimeError, match="simulated re-point failure"): + CachedModelWithPartialLoad(_RepointFailsModule(), CPU, keep_ram_copy=True, shared_store=store, cache_key="m") + + assert store.refcount("m") == 1 # back to just the first device, not leaked at 2 diff --git a/tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py b/tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py new file mode 100644 index 00000000000..3ceb36c8dd6 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py @@ -0,0 +1,193 @@ +"""End-to-end tests of the global RamBudget driving eviction across per-device caches. + +Validates that the budget counts a shared model once (not once-per-GPU), counts non-deduplicated +models per-instance, and that eviction is made against the global deduplicated total — including the +case where a cache cannot free RAM because another device still holds the model. Runs on CPU. +""" + +import logging +from unittest.mock import MagicMock, patch + +import pytest + +from invokeai.backend.model_manager.load.model_cache.model_cache import ( + GB, + MIN_RAM_CACHE_BYTES, + RAM_CACHE_BASELINE_BYTES, + RAM_CACHE_SYSTEM_FRACTION, + ModelCache, +) +from invokeai.backend.model_manager.load.model_cache.ram_budget import RamBudget +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore +from invokeai.backend.util.calc_tensor_size import calc_tensor_size +from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule + +# Persistent state-dict bytes of one DummyModule (what the shared store accounts for a shared model). +S = sum(calc_tensor_size(v) for v in DummyModule().state_dict().values()) + + +@pytest.fixture +def mock_logger(): + logger = MagicMock() + logger.getEffectiveLevel.return_value = logging.INFO + return logger + + +def _make_cache(store, budget, logger, keep_ram_copy=True) -> ModelCache: + return ModelCache( + execution_device_working_mem_gb=1.0, + enable_partial_loading=False, + keep_ram_copy_of_weights=keep_ram_copy, + execution_device="cpu", + storage_device="cpu", + logger=logger, + shared_cpu_weights=store, + ram_budget=budget, + ) + + +def test_shared_model_counts_once_in_global_budget(mock_logger): + store = SharedCpuWeightsStore() + budget = RamBudget(max_bytes=10**12, shared_store=store) + cache_a = _make_cache(store, budget, mock_logger) + cache_b = _make_cache(store, budget, mock_logger) + try: + cache_a.put("m", DummyModule()) + one_device = budget.total_in_use() + assert one_device == S + + cache_b.put("m", DummyModule()) + # Second device shares the weights -> the global budget total does NOT grow. + assert budget.total_in_use() == one_device + finally: + cache_a.shutdown() + cache_b.shutdown() + + +def test_non_shared_model_counts_per_device(mock_logger): + store = SharedCpuWeightsStore() + budget = RamBudget(max_bytes=10**12, shared_store=store) + # keep_ram_copy=False -> not deduplicated, so each device's copy is real RAM. + cache_a = _make_cache(store, budget, mock_logger, keep_ram_copy=False) + cache_b = _make_cache(store, budget, mock_logger, keep_ram_copy=False) + try: + cache_a.put("m", DummyModule()) + one = budget.total_in_use() + assert one > 0 + cache_b.put("m", DummyModule()) + # Two independent copies -> counted twice. + assert budget.total_in_use() == 2 * one + finally: + cache_a.shutdown() + cache_b.shutdown() + + +def test_global_budget_evicts_lru_in_single_cache(mock_logger): + # Budget fits one model but not two -> putting the second evicts the first. + store = SharedCpuWeightsStore() + budget = RamBudget(max_bytes=int(S * 1.4), shared_store=store) + cache = _make_cache(store, budget, mock_logger) + try: + cache.put("a", DummyModule()) + cache.put("b", DummyModule()) + assert "a" not in cache._cached_models # evicted to make room for b + assert "b" in cache._cached_models + assert "a" not in store and store.refcount("b") == 1 + assert budget.total_in_use() == S + finally: + cache.shutdown() + + +def test_get_vram_in_use_queries_this_caches_execution_device(mock_logger): + """Regression: _get_vram_in_use must query its own execution device, not the process-current one. + + In multi-GPU mode each worker calls torch.cuda.set_device for its GPU, so a no-argument + memory_allocated() can read a different device. That breaks the cancellation in + _get_vram_available and inflates "available" VRAM, so the cache never offloads and OOMs while + ignoring device_working_mem_gb. + """ + import torch + + mc = "invokeai.backend.model_manager.load.model_cache.model_cache" + with ( + patch(f"{mc}.torch.cuda.mem_get_info", return_value=(10 * GB, 48 * GB)), + patch(f"{mc}.torch.cuda.memory_allocated", return_value=42) as mock_alloc, + ): + cache = ModelCache( + execution_device_working_mem_gb=3.0, + enable_partial_loading=True, + keep_ram_copy_of_weights=True, + execution_device="cuda:1", + storage_device="cpu", + logger=mock_logger, + ) + try: + assert cache._get_vram_in_use() == 42 + mock_alloc.assert_called_with(torch.device("cuda:1")) + finally: + cache.shutdown() + + +def _mock_total_ram(total_bytes: int): + """Patch psutil.virtual_memory().total as seen by model_cache.""" + vm = MagicMock() + vm.total = total_bytes + return patch( + "invokeai.backend.model_manager.load.model_cache.model_cache.psutil.virtual_memory", + return_value=vm, + ) + + +def test_system_ram_headroom_is_fraction_minus_baseline(): + # On a 96 GB box, the default cap is 50% - 2 GB = 46 GB, leaving real headroom for the OS. + with _mock_total_ram(96 * GB): + headroom = ModelCache.calc_system_ram_headroom_bytes() + assert headroom == int(96 * GB * RAM_CACHE_SYSTEM_FRACTION) - RAM_CACHE_BASELINE_BYTES + assert headroom == 46 * GB + # And it must leave at least half the system for everything else. + assert headroom <= 96 * GB * 0.5 + + +def test_system_ram_headroom_respects_floor_on_tiny_systems(): + # A machine with almost no RAM still gets the absolute minimum, never a negative/zero budget. + with _mock_total_ram(2 * GB): + headroom = ModelCache.calc_system_ram_headroom_bytes() + assert headroom == MIN_RAM_CACHE_BYTES + + +def test_headroom_clamps_summed_multi_gpu_budget(): + # Reproduces the multi-GPU blowup: two 45 GB per-device caches sum to 90 GB, which would leave + # only ~6 GB on a 96 GB machine. The headroom cap must clamp the budget below that sum. + per_device_cache_bytes = 45 * GB + summed = 2 * per_device_cache_bytes # 90 GB, as the old code used verbatim + with _mock_total_ram(96 * GB): + headroom = ModelCache.calc_system_ram_headroom_bytes() + clamped = min(summed, headroom) + assert clamped == headroom < summed + assert clamped == 46 * GB + + +def test_eviction_cannot_free_ram_held_by_another_device(mock_logger): + """If a cache's only droppable model is still held by another device, eviction frees nothing + globally (the shared weights stay live) and the new model is still admitted -> transiently over + budget until the other device releases. The eviction loop must handle this without spinning.""" + store = SharedCpuWeightsStore() + budget = RamBudget(max_bytes=int(S * 1.4), shared_store=store) + cache_a = _make_cache(store, budget, mock_logger) + cache_b = _make_cache(store, budget, mock_logger) + try: + cache_a.put("shared", DummyModule()) + cache_b.put("shared", DummyModule()) # both devices hold "shared" (refcount 2, counted once) + assert budget.total_in_use() == S + + cache_a.put("new", DummyModule()) # triggers make_room; "shared" is a's only droppable entry + # a dropped its ref to "shared", but b still holds it, so the shared weights weren't freed. + assert "shared" not in cache_a._cached_models + assert "shared" in cache_b._cached_models + assert store.refcount("shared") == 1 + assert "new" in cache_a._cached_models + # "shared" (still alive via b) + "new" -> over the 1.4*S cap, as expected. + assert budget.total_in_use() == 2 * S + finally: + cache_a.shutdown() + cache_b.shutdown() diff --git a/tests/backend/model_manager/load/model_cache/test_model_cache_shared_weights.py b/tests/backend/model_manager/load/model_cache/test_model_cache_shared_weights.py new file mode 100644 index 00000000000..1b6eee525ab --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/test_model_cache_shared_weights.py @@ -0,0 +1,86 @@ +"""End-to-end test of CPU-weight sharing through ModelCache.put()/eviction. + +Simulates the multi-GPU topology — one ModelCache per device, all sharing a single +SharedCpuWeightsStore — and asserts that the same model loaded into both caches keeps exactly one +CPU copy, with RAM freed only when the last device evicts it. Runs on CPU (no VRAM moves). +""" + +import logging +from unittest.mock import MagicMock + +import pytest + +from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore +from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule + + +@pytest.fixture +def mock_logger(): + logger = MagicMock() + logger.getEffectiveLevel.return_value = logging.INFO + return logger + + +def _make_cache(store: SharedCpuWeightsStore, logger: MagicMock) -> ModelCache: + return ModelCache( + execution_device_working_mem_gb=1.0, + enable_partial_loading=False, + keep_ram_copy_of_weights=True, + execution_device="cpu", + storage_device="cpu", + logger=logger, + shared_cpu_weights=store, + ) + + +def test_two_device_caches_share_one_cpu_copy(mock_logger: MagicMock): + store = SharedCpuWeightsStore() + cache_a = _make_cache(store, mock_logger) + cache_b = _make_cache(store, mock_logger) + try: + cache_a.put("m", DummyModule()) + ram_one_device = store.total_bytes_in_use() + assert ram_one_device > 0 + + cache_b.put("m", DummyModule()) + + # One canonical CPU copy shared by both "devices": the second device's put adds NO RAM. + assert store.refcount("m") == 2 + assert store.total_bytes_in_use() == ram_one_device + sd_a = cache_a.get("m").cached_model.get_cpu_state_dict() + sd_b = cache_b.get("m").cached_model.get_cpu_state_dict() + assert sd_a is sd_b + + # Evicting from one device drops only its reference; the weights stay for the other. + cache_a.make_room(10**12) + assert "m" not in cache_a._cached_models + assert store.refcount("m") == 1 + assert "m" in store + + # Evicting from the last device frees the shared RAM. + cache_b.make_room(10**12) + assert store.refcount("m") == 0 + assert "m" not in store + assert store.total_bytes_in_use() == 0 + finally: + cache_a.shutdown() + cache_b.shutdown() + + +def test_drop_model_releases_shared_weights(mock_logger: MagicMock): + store = SharedCpuWeightsStore() + cache_a = _make_cache(store, mock_logger) + cache_b = _make_cache(store, mock_logger) + try: + cache_a.put("m", DummyModule()) + cache_b.put("m", DummyModule()) + assert store.refcount("m") == 2 + + assert cache_a.drop_model("m") == 1 + assert store.refcount("m") == 1 + assert cache_b.drop_model("m") == 1 + assert "m" not in store + finally: + cache_a.shutdown() + cache_b.shutdown() diff --git a/tests/backend/model_manager/load/model_cache/test_ram_budget.py b/tests/backend/model_manager/load/model_cache/test_ram_budget.py new file mode 100644 index 00000000000..d8704fffad5 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/test_ram_budget.py @@ -0,0 +1,48 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.ram_budget import RamBudget +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore + + +def test_total_in_use_sums_store_and_non_shared(): + store = SharedCpuWeightsStore() + store.acquire("k", {"a": torch.ones(100, dtype=torch.float32)}) # 400 bytes + budget = RamBudget(max_bytes=10_000, shared_store=store) + + assert budget.total_in_use() == 400 # store only + budget.add_non_shared(600) + assert budget.total_in_use() == 1000 + assert budget.available() == 9000 + budget.remove_non_shared(600) + assert budget.total_in_use() == 400 + + +def test_shared_weights_counted_once_regardless_of_refcount(): + store = SharedCpuWeightsStore() + sd = {"a": torch.ones(100, dtype=torch.float32)} # 400 bytes + store.acquire("k", sd) + store.acquire("k", sd) # second device acquires the same key + budget = RamBudget(max_bytes=10_000, shared_store=store) + # Two references, one physical copy -> counted once. + assert budget.total_in_use() == 400 + + +def test_remove_non_shared_floors_at_zero(): + budget = RamBudget(max_bytes=10_000, shared_store=None) + budget.add_non_shared(100) + budget.remove_non_shared(500) + assert budget.total_in_use() == 0 + + +def test_available_can_go_negative_when_over_budget(): + budget = RamBudget(max_bytes=100, shared_store=None) + budget.add_non_shared(250) + assert budget.available() == -150 + + +def test_no_store_tracks_only_non_shared(): + budget = RamBudget(max_bytes=1000, shared_store=None) + assert budget.total_in_use() == 0 + budget.add_non_shared(300) + assert budget.total_in_use() == 300 + assert budget.max_bytes == 1000 diff --git a/tests/backend/model_manager/load/model_cache/test_shared_cpu_weights.py b/tests/backend/model_manager/load/model_cache/test_shared_cpu_weights.py new file mode 100644 index 00000000000..23d8fe875ef --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/test_shared_cpu_weights.py @@ -0,0 +1,106 @@ +import threading + +import torch + +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore + + +def _state_dict() -> dict[str, torch.Tensor]: + return { + "a": torch.ones(10, 10, dtype=torch.float32), # 400 bytes + "b": torch.ones(5, dtype=torch.float32), # 20 bytes + } + + +def test_first_acquire_registers_and_returns_same_object(): + store = SharedCpuWeightsStore() + sd = _state_dict() + canonical = store.acquire("k", sd) + # The first acquire keeps the caller's own dict as canonical. + assert canonical is sd + assert store.refcount("k") == 1 + assert "k" in store + + +def test_second_acquire_returns_canonical_not_the_new_dict(): + store = SharedCpuWeightsStore() + first = _state_dict() + second = _state_dict() # distinct tensors, same shapes + canonical_first = store.acquire("k", first) + canonical_second = store.acquire("k", second) + + # The second caller gets the originally-registered tensors, not its own. + assert canonical_second is canonical_first + assert canonical_second["a"].data_ptr() == first["a"].data_ptr() + assert canonical_second["a"].data_ptr() != second["a"].data_ptr() + assert store.refcount("k") == 2 + + +def test_total_bytes_counts_each_key_once(): + store = SharedCpuWeightsStore() + # Two devices acquire the same key -> counted once. + store.acquire("k", _state_dict()) + store.acquire("k", _state_dict()) + assert store.total_bytes_in_use() == 420 + # A different key adds its own bytes. + store.acquire("k2", {"x": torch.ones(100, dtype=torch.float32)}) # 400 bytes + assert store.total_bytes_in_use() == 820 + + +def test_release_frees_only_at_zero(): + store = SharedCpuWeightsStore() + store.acquire("k", _state_dict()) + store.acquire("k", _state_dict()) + assert store.refcount("k") == 2 + + store.release("k") + assert store.refcount("k") == 1 + assert "k" in store + assert store.total_bytes_in_use() == 420 + + store.release("k") + assert store.refcount("k") == 0 + assert "k" not in store + assert store.total_bytes_in_use() == 0 + + +def test_release_unknown_key_is_noop(): + store = SharedCpuWeightsStore() + store.release("missing") # must not raise + assert store.total_bytes_in_use() == 0 + + +def test_reacquire_after_full_release_registers_fresh(): + store = SharedCpuWeightsStore() + first = _state_dict() + store.acquire("k", first) + store.release("k") + assert "k" not in store + + second = _state_dict() + canonical = store.acquire("k", second) + # After a full release the next caller becomes the new canonical. + assert canonical is second + assert store.refcount("k") == 1 + + +def test_concurrent_acquire_release_is_consistent(): + store = SharedCpuWeightsStore() + sd = _state_dict() + # Pre-register so the key exists for the whole run and the count never hits zero. + store.acquire("k", sd) + + def worker(): + for _ in range(200): + store.acquire("k", _state_dict()) + store.release("k") + + threads = [threading.Thread(target=worker) for _ in range(8)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Every acquire was paired with a release, so only the pre-registration reference remains. + assert store.refcount("k") == 1 + assert store.total_bytes_in_use() == 420 diff --git a/tests/backend/model_manager/load/model_cache/test_shared_weights_gpu.py b/tests/backend/model_manager/load/model_cache/test_shared_weights_gpu.py new file mode 100644 index 00000000000..4f5b700cbb3 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/test_shared_weights_gpu.py @@ -0,0 +1,234 @@ +"""Real-GPU validation of cross-device CPU-weight sharing. + +These require two CUDA (incl. ROCm/HIP) devices. They prove the properties the CPU-only unit tests +cannot: that a module re-pointed at shared canonical CPU weights (a) loads onto its GPU and produces +correct inference output, and (b) survives two GPUs loading/unloading from the *same* shared CPU +state dict concurrently without corrupting each other's results. +""" + +import copy +import logging +import threading +from unittest.mock import MagicMock + +import gguf +import pytest +import torch + +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import ( + CachedModelOnlyFullLoad, +) +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( + CachedModelWithPartialLoad, +) +from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + apply_custom_layers_to_model, +) +from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule +from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor + +requires_two_gpus = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA devices.") + +DEVICE_A = "cuda:0" +DEVICE_B = "cuda:1" + + +def _mock_logger() -> MagicMock: + logger = MagicMock() + logger.getEffectiveLevel.return_value = logging.INFO + return logger + + +def _make_cache(store: SharedCpuWeightsStore, device: str, partial: bool) -> ModelCache: + return ModelCache( + execution_device_working_mem_gb=1.0, + enable_partial_loading=partial, + keep_ram_copy_of_weights=True, + execution_device=device, + storage_device="cpu", + logger=_mock_logger(), + shared_cpu_weights=store, + ) + + +@requires_two_gpus +@pytest.mark.parametrize("partial", [False, True]) +def test_shared_weights_produce_correct_output_on_both_gpus(partial: bool): + """A model loaded on two GPUs from one shared CPU copy must compute correct results on both.""" + torch.manual_seed(0) + model_a = DummyModule() + # model_b starts with DIFFERENT weights; sharing must overwrite them with model_a's canonical + # weights (both keys map to the same logical model). + torch.manual_seed(1) + model_b = DummyModule() + + x = torch.randn(4, 10) + # Reference output from model_a's original weights, computed before any cache/device mutation. + reference = copy.deepcopy(model_a)(x) + + store = SharedCpuWeightsStore() + cache_a = _make_cache(store, DEVICE_A, partial) + cache_b = _make_cache(store, DEVICE_B, partial) + try: + cache_a.put("m", model_a) + cache_b.put("m", model_b) + + # Single shared CPU copy across both devices. + assert store.refcount("m") == 2 + assert cache_a.get("m").cached_model.get_cpu_state_dict() is cache_b.get("m").cached_model.get_cpu_state_dict() + + rec_a = cache_a.get("m") + rec_b = cache_b.get("m") + cache_a.lock(rec_a, None) + cache_b.lock(rec_b, None) + try: + out_a = rec_a.cached_model.model(x.to(DEVICE_A)) + out_b = rec_b.cached_model.model(x.to(DEVICE_B)) + finally: + cache_a.unlock(rec_a) + cache_b.unlock(rec_b) + + # Both devices reproduce model_a's output (so model_b really adopted the shared weights). + assert torch.allclose(out_a.cpu(), reference, atol=1e-5) + assert torch.allclose(out_b.cpu(), reference, atol=1e-5) + finally: + cache_a.shutdown() + cache_b.shutdown() + + +@requires_two_gpus +@pytest.mark.parametrize("wrapper_cls", [CachedModelOnlyFullLoad, CachedModelWithPartialLoad]) +def test_concurrent_load_unload_from_shared_state_dict(wrapper_cls): + """Two GPUs repeatedly loading/unloading from one shared CPU state dict must not corrupt each + other. Each thread drives its own device's wrapper; the canonical CPU tensors are read-only and + must stay intact across concurrent .to(device) reads and load_state_dict(assign=True) restores. + """ + torch.manual_seed(0) + model_a = DummyModule() + torch.manual_seed(1) + model_b = DummyModule() + + x = torch.randn(4, 10) + reference = copy.deepcopy(model_a)(x) + + store = SharedCpuWeightsStore() + + def build(model, device): + if wrapper_cls is CachedModelWithPartialLoad: + return CachedModelWithPartialLoad( + model, torch.device(device), keep_ram_copy=True, shared_store=store, cache_key="m" + ) + return CachedModelOnlyFullLoad( + model, torch.device(device), total_bytes=1000, keep_ram_copy=True, shared_store=store, cache_key="m" + ) + + cached_a = build(model_a, DEVICE_A) + cached_b = build(model_b, DEVICE_B) + + errors: list[Exception] = [] + barrier = threading.Barrier(2) + + def run(cached, device): + try: + xd = x.to(device) + for _ in range(20): + barrier.wait() # maximise overlap of the two devices' loads + cached.full_load_to_vram() + out = cached.model(xd) + assert torch.allclose(out.cpu(), reference, atol=1e-5) + cached.full_unload_from_vram() + except Exception as e: # noqa: BLE001 - surface to main thread + errors.append(e) + try: + barrier.abort() + except Exception: + pass + + t_a = threading.Thread(target=run, args=(cached_a, DEVICE_A)) + t_b = threading.Thread(target=run, args=(cached_b, DEVICE_B)) + t_a.start() + t_b.start() + t_a.join() + t_b.join() + + assert not errors, f"Concurrent load/unload corrupted results: {errors[0]!r}" + # Canonical CPU weights survived and are still shared. + assert store.refcount("m") == 2 + cached_a.release_shared_weights() + cached_b.release_shared_weights() + assert "m" not in store + + +class _GGUFModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(32, 64) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +def _build_gguf_model(seed: int) -> _GGUFModel: + """A small model whose linear weight is a Q8_0 GGML-quantized (CPU-resident) tensor. + + This mirrors how large quantized transformers/encoders are stored: the weights live on the CPU + as GGMLTensors and are dequantized on the fly during the forward pass. It is the path that goes + through the shared-CPU-weights mechanism, so it validates that re-pointing a quantized state + dict across devices preserves correct dequantized inference. + """ + torch.manual_seed(seed) + model = _GGUFModel() + model.linear.weight = torch.nn.Parameter(quantize_tensor(model.linear.weight, gguf.GGMLQuantizationType.Q8_0)) + return model + + +@requires_two_gpus +def test_shared_gguf_quantized_weights_correct_on_both_gpus(): + """A GGUF-quantized model loaded on two GPUs from one shared CPU copy must dequantize and + compute correct results on both devices.""" + x = torch.randn(1, 32, dtype=torch.float32) + + # Reference: a standalone copy of the same (seed-0) quantized weights, run via the autocast + # custom layers. Weights stay on CPU; compute happens on the device. + reference_model = _build_gguf_model(0) + apply_custom_layers_to_model(reference_model, device_autocasting_enabled=True) + reference = reference_model(x.to(DEVICE_A)).cpu() + + model_a = _build_gguf_model(0) + model_b = _build_gguf_model(1) # different weights; sharing must overwrite with canonical + + store = SharedCpuWeightsStore() + # enable_partial_loading=True routes quantized nn.Modules through CachedModelWithPartialLoad. + cache_a = _make_cache(store, DEVICE_A, partial=True) + cache_b = _make_cache(store, DEVICE_B, partial=True) + try: + cache_a.put("m", model_a) + ram_one_device = store.total_bytes_in_use() + cache_b.put("m", model_b) + + # One shared CPU copy of the quantized weights; second device adds no RAM. + assert store.refcount("m") == 2 + assert store.total_bytes_in_use() == ram_one_device + rec_a = cache_a.get("m") + rec_b = cache_b.get("m") + assert rec_a.cached_model.get_cpu_state_dict() is rec_b.cached_model.get_cpu_state_dict() + # model_b's quantized weight was re-pointed at model_a's canonical tensor. + assert rec_b.cached_model.model.linear.weight.data_ptr() == rec_a.cached_model.model.linear.weight.data_ptr() + + cache_a.lock(rec_a, None) + cache_b.lock(rec_b, None) + try: + out_a = rec_a.cached_model.model(x.to(DEVICE_A)) + out_b = rec_b.cached_model.model(x.to(DEVICE_B)) + finally: + cache_a.unlock(rec_a) + cache_b.unlock(rec_b) + + # Both GPUs reproduce the reference dequantized output. + assert torch.allclose(out_a.cpu(), reference, atol=1e-5) + assert torch.allclose(out_b.cpu(), reference, atol=1e-5) + finally: + cache_a.shutdown() + cache_b.shutdown() diff --git a/tests/backend/model_manager/load/test_shared_weight_adoption.py b/tests/backend/model_manager/load/test_shared_weight_adoption.py new file mode 100644 index 00000000000..3393e2af91b --- /dev/null +++ b/tests/backend/model_manager/load/test_shared_weight_adoption.py @@ -0,0 +1,140 @@ +"""Tests for load-time adoption of shared CPU weights (multi-GPU RAM-spike fix). + +When a second device loads a model that another device already holds, the loader deep-copies the +empty (meta-weight) structural shell the first device registered and assigns the canonical CPU +weights into it — instead of re-reading the model from disk and materializing a full transient +second copy. This is loader-agnostic (no per-model-family code): it works by cloning a built module, +so it covers diffusers, single-file checkpoints, GGUF and transformers models alike, and preserves +any registered hooks (e.g. fp8 layerwise-cast hooks). +""" + +from unittest.mock import MagicMock + +import torch + +from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore + + +class _TinyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin = torch.nn.Linear(4, 4) + # A non-persistent buffer: not in the state dict, so adoption must carry it over with data. + self.register_buffer("scale", torch.tensor([2.0]), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin(x) * self.scale + + +def _loader_with_store(store: SharedCpuWeightsStore | None) -> ModelLoader: + loader = ModelLoader.__new__(ModelLoader) # bypass __init__ (needs app deps we don't use here) + loader._logger = MagicMock() + loader._ram_cache = MagicMock() + loader._ram_cache.shared_cpu_weights = store + return loader + + +def _populate(store: SharedCpuWeightsStore, key: str, model: torch.nn.Module) -> None: + """Mimic the first device's load: register canonical weights + a meta shell for `model`.""" + store.acquire(key, model.state_dict()) + shell = ModelLoader._build_meta_shell(model) + assert shell is not None + store.set_shell(key, shell) + + +def test_meta_shell_has_no_real_weight_storage(): + model = _TinyModel() + shell = ModelLoader._build_meta_shell(model) + assert shell is not None + # Parameters are on meta (0 bytes); the non-persistent buffer keeps real data. + assert all(p.is_meta for p in shell.parameters()) + assert not shell.scale.is_meta + assert torch.equal(shell.scale, model.scale) + + +def test_build_meta_shell_returns_none_for_non_module(): + assert ModelLoader._build_meta_shell({"not": "a module"}) is None # type: ignore[arg-type] + + +def test_adopts_canonical_weights_without_copying(): + store = SharedCpuWeightsStore() + source = _TinyModel() + _populate(store, "m", source) + canonical = store.peek("m") + refcount_before = store.refcount("m") + + model = _loader_with_store(store)._try_adopt_shared_weights("m") + + assert model is not None + # The adopted params ARE the canonical tensors (assign=True, no copy) -> no extra RAM. + assert model.lin.weight.data_ptr() == canonical["lin.weight"].data_ptr() + assert model.lin.bias.data_ptr() == canonical["lin.bias"].data_ptr() + assert not any(t.is_meta for t in model.parameters()) + assert not any(t.is_meta for t in model.buffers()) + # peek()/get_shell() must not have taken a reference -- the wrapper's acquire() does that later. + assert store.refcount("m") == refcount_before + + +def test_adopted_model_produces_correct_output(): + store = SharedCpuWeightsStore() + source = _TinyModel() + _populate(store, "m", source) + x = torch.randn(3, 4) + + model = _loader_with_store(store)._try_adopt_shared_weights("m") + + assert torch.allclose(model(x), source(x), atol=1e-6) + + +def test_adoption_preserves_forward_hooks(): + # fp8 layerwise casting is implemented as forward hooks; cloning the built module must keep them. + store = SharedCpuWeightsStore() + source = _TinyModel() + fired: list[str] = [] + source.lin.register_forward_pre_hook(lambda mod, args: fired.append("pre")) + _populate(store, "m", source) + + model = _loader_with_store(store)._try_adopt_shared_weights("m") + model(torch.randn(1, 4)) + + assert fired == ["pre"] # the cloned module's hook fired + + +def test_no_shell_means_no_adoption(): + # Canonical present but no shell registered (e.g. first device couldn't clone) -> fall back. + store = SharedCpuWeightsStore() + store.acquire("m", _TinyModel().state_dict()) + assert _loader_with_store(store)._try_adopt_shared_weights("m") is None + + +def test_absent_key_means_no_adoption(): + assert _loader_with_store(SharedCpuWeightsStore())._try_adopt_shared_weights("missing") is None + + +def test_no_shared_store_means_no_adoption(): + assert _loader_with_store(None)._try_adopt_shared_weights("m") is None + + +def test_mismatched_canonical_falls_back_safely(): + # If the canonical weights don't match the shell's structure, adoption must fail soft (-> None), + # not raise, so the caller can load normally. + store = SharedCpuWeightsStore() + source = _TinyModel() + shell = ModelLoader._build_meta_shell(source) + assert shell is not None + store.acquire("m", {"unexpected.key": torch.zeros(2)}) # wrong state dict + store.set_shell("m", shell) + + loader = _loader_with_store(store) + assert loader._try_adopt_shared_weights("m") is None + loader._logger.warning.assert_called_once() + + +def test_shell_dropped_when_entry_released(): + store = SharedCpuWeightsStore() + _populate(store, "m", _TinyModel()) + assert store.get_shell("m") is not None + store.release("m") # last reference -> entry (and its shell) gone + assert store.get_shell("m") is None + assert "m" not in store diff --git a/tests/backend/patches/test_layer_patcher_shared_weights.py b/tests/backend/patches/test_layer_patcher_shared_weights.py new file mode 100644 index 00000000000..328f93b04cc --- /dev/null +++ b/tests/backend/patches/test_layer_patcher_shared_weights.py @@ -0,0 +1,106 @@ +"""Regression tests: LoRA direct patching must not mutate the model's canonical CPU weights. + +In multi-GPU mode the per-device caches share one canonical CPU state_dict (SharedCpuWeightsStore), +and that same dict is the keep_ram_copy used to restore a model after unpatching. Direct patching +must therefore never mutate those tensors in place — otherwise a LoRA applied on one GPU would +corrupt the weights seen by the other GPU (and taint the cached "clean" copy even with one GPU). + +These run on CPU and force direct patching, which is the path that touches CPU-resident weights. +""" + +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + apply_custom_layers_to_model, +) +from invokeai.backend.patches.layer_patcher import LayerPatcher +from invokeai.backend.patches.layers.lora_layer import LoRALayer +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw +from tests.backend.patches.test_layer_patcher import DummyModuleWithOneLayer + + +def _make_loras(num_loras: int, in_features: int, out_features: int, rank: int): + lora_models: list[tuple[ModelPatchRaw, float]] = [] + for _ in range(num_loras): + layers = { + "linear_layer_1": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((rank, in_features), device="cpu", dtype=torch.float32), + "lora_up.weight": torch.ones((out_features, rank), device="cpu", dtype=torch.float32), + }, + ) + } + lora_models.append((ModelPatchRaw(layers), 0.5)) + return lora_models + + +@torch.no_grad() +def test_force_direct_patch_does_not_mutate_canonical_cpu_weights(): + in_features, out_features, rank = 4, 8, 2 + model = DummyModuleWithOneLayer(in_features, out_features, device="cpu", dtype=torch.float32) + apply_custom_layers_to_model(model) + + # `canonical` holds references to the model's actual parameter tensors — exactly what the shared + # store would hand out as the canonical CPU copy and what model_on_device() passes as + # cached_weights. We snapshot their values to detect any in-place mutation. + canonical = dict(model.state_dict()) + snapshot = {k: v.detach().clone() for k, v in canonical.items()} + + lora_models = _make_loras(num_loras=2, in_features=in_features, out_features=out_features, rank=rank) + x = torch.randn(1, in_features, dtype=torch.float32) + out_before = model(x) + + with LayerPatcher.apply_smart_model_patches( + model=model, + patches=lora_models, + prefix="", + dtype=torch.float32, + cached_weights=canonical, + force_direct_patching=True, + ): + # Sanity: this really is the direct path (no sidecar wrappers), so weights were applied + # directly — and the patch actually changed the output. + assert model.linear_layer_1.get_num_patches() == 0 + out_during = model(x) + assert not torch.allclose(out_before, out_during) + + # The canonical tensors must be untouched even while the patch is active. + for k in canonical: + torch.testing.assert_close(canonical[k], snapshot[k]) + + # ...and after unpatching. + for k in canonical: + torch.testing.assert_close(canonical[k], snapshot[k]) + assert torch.allclose(out_before, model(x)) + + +@torch.no_grad() +def test_two_models_sharing_canonical_are_isolated_under_direct_patch(): + """Patch one model built from the shared canonical weights; a second model built from the same + canonical tensors must be unaffected (no cross-device bleed).""" + in_features, out_features, rank = 4, 8, 2 + model_a = DummyModuleWithOneLayer(in_features, out_features, device="cpu", dtype=torch.float32) + apply_custom_layers_to_model(model_a) + canonical = dict(model_a.state_dict()) + + # model_b shares the canonical tensors (as a second device's cache would via load_state_dict). + model_b = DummyModuleWithOneLayer(in_features, out_features, device="cpu", dtype=torch.float32) + apply_custom_layers_to_model(model_b) + model_b.load_state_dict(canonical, assign=True) + + x = torch.randn(1, in_features, dtype=torch.float32) + out_b_before = model_b(x) + + lora_models = _make_loras(num_loras=1, in_features=in_features, out_features=out_features, rank=rank) + with LayerPatcher.apply_smart_model_patches( + model=model_a, + patches=lora_models, + prefix="", + dtype=torch.float32, + cached_weights=canonical, + force_direct_patching=True, + ): + # model_a is patched; model_b (sharing the canonical weights) must be unchanged. + assert torch.allclose(model_b(x), out_b_before) + + assert torch.allclose(model_b(x), out_b_before) diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py index 3f134e3c3da..1d7dfa75614 100644 --- a/tests/backend/util/test_devices.py +++ b/tests/backend/util/test_devices.py @@ -2,6 +2,7 @@ Test abstract device class. """ +import threading from unittest.mock import patch import pytest @@ -24,6 +25,96 @@ def test_device_choice(device_name): assert torch_device == torch.device(device_name) +# ===== per-thread session device (multi-GPU worker pinning) ================ + + +def test_session_device_overrides_config(): + """A per-thread session device takes precedence over the global config.device.""" + config = get_config() + config.device = "cpu" + try: + TorchDevice.set_session_device("cuda:1") + assert TorchDevice.choose_torch_device() == torch.device("cuda:1") + finally: + TorchDevice.clear_session_device() + # Once cleared, we fall back to the global config. + assert TorchDevice.choose_torch_device() == torch.device("cpu") + + +def test_session_device_is_thread_local(): + """Each thread sees only its own pinned device; the main thread is unaffected.""" + config = get_config() + config.device = "cpu" + results: dict[str, torch.device] = {} + barrier = threading.Barrier(2) + + def worker(name: str, device: str): + TorchDevice.set_session_device(device) + # Wait so both threads have set their device before either reads it, proving isolation. + barrier.wait() + results[name] = TorchDevice.choose_torch_device() + TorchDevice.clear_session_device() + + t0 = threading.Thread(target=worker, args=("a", "cuda:0")) + t1 = threading.Thread(target=worker, args=("b", "cuda:1")) + t0.start() + t1.start() + t0.join() + t1.join() + + assert results["a"] == torch.device("cuda:0") + assert results["b"] == torch.device("cuda:1") + # The main thread never set a session device, so it still uses the global config. + assert TorchDevice.get_session_device() is None + assert TorchDevice.choose_torch_device() == torch.device("cpu") + + +# ===== generation_devices resolution (config -> concrete device list) ======= + + +def test_get_generation_devices_auto_expands_to_all_cuda(): + """`auto` enumerates every visible CUDA device.""" + with ( + patch("invokeai.backend.util.devices.torch.cuda.is_available", return_value=True), + patch("invokeai.backend.util.devices.torch.cuda.device_count", return_value=3), + ): + assert TorchDevice.get_generation_devices("auto") == [ + torch.device("cuda:0"), + torch.device("cuda:1"), + torch.device("cuda:2"), + ] + + +def test_get_generation_devices_auto_without_cuda(): + """`auto` falls back to the single best device when CUDA is unavailable.""" + config = get_config() + config.device = "cpu" + with ( + patch("invokeai.backend.util.devices.torch.cuda.is_available", return_value=False), + patch("invokeai.backend.util.devices.torch.backends.mps.is_available", return_value=False), + ): + assert TorchDevice.get_generation_devices("auto") == [torch.device("cpu")] + + +def test_get_generation_devices_explicit_list_is_deduplicated(): + """An explicit list is normalized and deduplicated, preserving order.""" + # Mock CUDA as present so the device-existence validation passes on CPU-only runners. + with ( + patch("invokeai.backend.util.devices.torch.cuda.is_available", return_value=True), + patch("invokeai.backend.util.devices.torch.cuda.device_count", return_value=2), + ): + assert TorchDevice.get_generation_devices(["cuda:0", "cuda:0", "cuda:1"]) == [ + torch.device("cuda:0"), + torch.device("cuda:1"), + ] + + +@pytest.mark.parametrize("value", [None, []]) +def test_get_generation_devices_empty(value): + """`None` or an empty list resolves to an empty list (caller handles the single-device fallback).""" + assert TorchDevice.get_generation_devices(value) == [] + + @pytest.mark.parametrize("device_dtype_pair", device_types_cpu) def test_device_dtype_cpu(device_dtype_pair): with ( @@ -169,3 +260,23 @@ def test_choose_anima_inference_dtype_auto_delegates_to_safe_dtype(): result = TorchDevice.choose_anima_inference_dtype(device) assert result is sentinel mock_safe.assert_called_once_with(device) + + +@patch("torch.cuda.device_count", return_value=2) +@patch("torch.cuda.is_available", return_value=True) +def test_get_generation_devices_rejects_out_of_range_cuda(mock_avail, mock_count): + # cuda:2 does not exist on a 2-GPU machine — fail fast instead of deferring to first allocation. + with pytest.raises(ValueError, match="only 2 CUDA"): + TorchDevice.get_generation_devices(["cuda:2"]) + + +@patch("torch.cuda.device_count", return_value=2) +@patch("torch.cuda.is_available", return_value=True) +def test_get_generation_devices_accepts_in_range_cuda(mock_avail, mock_count): + assert [str(d) for d in TorchDevice.get_generation_devices(["cuda:1"])] == ["cuda:1"] + + +@patch("torch.cuda.is_available", return_value=False) +def test_get_generation_devices_rejects_cuda_when_unavailable(mock_avail): + with pytest.raises(ValueError, match="no CUDA"): + TorchDevice.get_generation_devices(["cuda:0"])