From 1f9e7df52ac07ee6956500d02ba85ab09082e9f9 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 4 Jun 2026 18:24:22 +0300 Subject: [PATCH 01/45] [Partner Nodes] feat: add Krea 2 Medium Turbo model (#14280) --- comfy_api_nodes/nodes_krea.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/comfy_api_nodes/nodes_krea.py b/comfy_api_nodes/nodes_krea.py index 34369f05f20c..b9e6268f2bd4 100644 --- a/comfy_api_nodes/nodes_krea.py +++ b/comfy_api_nodes/nodes_krea.py @@ -42,9 +42,11 @@ async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Ima _MODEL_MEDIUM = "Krea 2 Medium" +_MODEL_MEDIUM_TURBO = "Krea 2 Medium Turbo" _MODEL_LARGE = "Krea 2 Large" _MODEL_ENDPOINTS: dict[str, str] = { _MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium", + _MODEL_MEDIUM_TURBO: "/proxy/krea/generate/image/krea/krea-2/medium-turbo", _MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large", } @@ -57,7 +59,7 @@ async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Ima def _krea_model_inputs() -> list: - """Nested inputs shared by both Krea 2 Medium and Large under the DynamicCombo.""" + """Nested inputs shared by Krea 2 Medium, Medium Turbo and Large under the DynamicCombo.""" return [ IO.Combo.Input( "aspect_ratio", @@ -123,6 +125,7 @@ def define_schema(cls) -> IO.Schema: "model", options=[ IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()), + IO.DynamicCombo.Option(_MODEL_MEDIUM_TURBO, _krea_model_inputs()), IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()), ], tooltip="Krea 2 Medium is best for expressive illustrations; " @@ -151,14 +154,15 @@ def define_schema(cls) -> IO.Schema: ), expr=""" ( - $isLarge := widgets.model = "krea 2 large"; + $rates := { + "krea 2 medium turbo": {"text": 0.015, "style": 0.0175, "moodboard": 0.02}, + "krea 2 medium": {"text": 0.03, "style": 0.035, "moodboard": 0.04}, + "krea 2 large": {"text": 0.06, "style": 0.065, "moodboard": 0.07} + }; + $r := $lookup($rates, widgets.model); $hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0; $hasStyle := $lookup(inputs, "model.style_reference").connected; - $usd := $hasMoodboard - ? ($isLarge ? 0.07 : 0.04) - : ($hasStyle - ? ($isLarge ? 0.065 : 0.035) - : ($isLarge ? 0.06 : 0.03)); + $usd := $hasMoodboard ? $r.moodboard : ($hasStyle ? $r.style : $r.text); {"type":"usd","usd": $usd} ) """, From 27b5c423a6ff44a0b90f879b05d58b9a17d86528 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 4 Jun 2026 19:32:15 +0300 Subject: [PATCH 02/45] [Partner Nodes] feat: add seed input to Flux Erase node (#14283) Signed-off-by: bigcat88 --- comfy_api_nodes/apis/bfl.py | 1 + comfy_api_nodes/nodes_bfl.py | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/comfy_api_nodes/apis/bfl.py b/comfy_api_nodes/apis/bfl.py index 2ad651122e05..4c950da84574 100644 --- a/comfy_api_nodes/apis/bfl.py +++ b/comfy_api_nodes/apis/bfl.py @@ -43,6 +43,7 @@ class BFLFluxEraseRequest(BaseModel): "white (255) marks areas to remove, black (0) marks areas to preserve.", ) dilate_pixels: int = Field(10) + seed: int | None = Field(None) output_format: str = Field("png") diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 79961ff9df74..259c54ef97e2 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -534,6 +534,15 @@ def define_schema(cls) -> IO.Schema: max=25, tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.", ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, + ), ], outputs=[IO.Image.Output()], hidden=[ @@ -553,6 +562,7 @@ async def execute( image: Input.Image, mask: Input.Image, dilate_pixels: int = 10, + seed: int = 0, ) -> IO.NodeOutput: validate_image_dimensions(image, min_width=256, min_height=256) mask = resize_mask_to_image(mask, image) @@ -565,6 +575,7 @@ async def execute( image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed mask=mask, dilate_pixels=dilate_pixels, + seed=seed, ), ) From 6ecca5f468ac5a24c450d601867de4d366f701fc Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Fri, 5 Jun 2026 00:40:44 +0800 Subject: [PATCH 03/45] chore: update workflow templates to v0.9.98 (#14284) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 79d38fc066fd..6f88daafabb0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.44.19 -comfyui-workflow-templates==0.9.94 +comfyui-workflow-templates==0.9.98 comfyui-embedded-docs==0.5.2 torch torchsde From 4e1f7cb1db1c26bb9ee61cf1875776517e2abae8 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Fri, 5 Jun 2026 03:41:33 +0900 Subject: [PATCH 04/45] Bump comfyui-frontend-package to 1.45.15 (#14265) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6f88daafabb0..8b64c60a9ece 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.44.19 +comfyui-frontend-package==1.45.15 comfyui-workflow-templates==0.9.98 comfyui-embedded-docs==0.5.2 torch From 514bb8ba21626da3126847321513f9863c88ce2c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 4 Jun 2026 19:20:22 -0700 Subject: [PATCH 05/45] Fix ideogram if model dtype gets set to fp8. (#14291) --- comfy/ldm/ideogram4/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/ideogram4/model.py b/comfy/ldm/ideogram4/model.py index 3b02a243a966..b86c65bf0ae7 100644 --- a/comfy/ldm/ideogram4/model.py +++ b/comfy/ldm/ideogram4/model.py @@ -174,7 +174,7 @@ def _backbone(self, llm_features, x, t, position_ids, attn_mask, indicator, tran llm = self.llm_cond_proj(llm) * text_mask h[:, :L_text] = h[:, :L_text] + llm - h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long)) + h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long), out_dtype=h.dtype) # Qwen3-VL interleaved MRoPE; position_ids (B, L, 3) -> (3, L) (same across batch). freqs_cis = precompute_freqs_cis( @@ -235,7 +235,7 @@ def _image_position_ids(self, gh, gw, device): def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options): B = x_chunk.shape[0] device = x_chunk.device - img_tokens = self._img_to_tokens(x_chunk).to(self.dtype) + img_tokens = self._img_to_tokens(x_chunk) L_img = img_tokens.shape[1] L_text = context_chunk.shape[1] L = L_text + L_img @@ -268,7 +268,7 @@ def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options): B = x_chunk.shape[0] device = x_chunk.device - img_tokens = self._img_to_tokens(x_chunk).to(self.dtype) + img_tokens = self._img_to_tokens(x_chunk) L_img = img_tokens.shape[1] position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3) From ab0d8a9203fbad76b0ccca723bbf9ba0c257ddfe Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Thu, 4 Jun 2026 19:29:41 -0700 Subject: [PATCH 06/45] Consolidate audio nodes into SaveAudioAdvanced node (CORE-202) (#13871) --- comfy_api/latest/_ui.py | 2 +- comfy_extras/nodes_audio.py | 58 +++++++++++++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 6592f6b1d566..b48713d41086 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -285,7 +285,7 @@ def save_audio( results = [] for batch_number, waveform in enumerate(audio["waveform"].cpu()): filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) - file = f"{filename_with_batch_num}_{counter:05}_.{format}" + file = f"{filename_with_batch_num}_{counter:05}.{format}" output_path = os.path.join(full_output_folder, file) # Use original sample rate initially diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index ff078f74cf50..1dc97ecd762e 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -158,7 +158,7 @@ def define_schema(cls): return IO.Schema( node_id="SaveAudio", search_aliases=["export flac"], - display_name="Save Audio (FLAC)", + display_name="Save Audio (FLAC) (Deprecated)", category="audio", essentials_category="Audio", inputs=[ @@ -167,6 +167,7 @@ def define_schema(cls): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + is_deprecated=True, ) @classmethod @@ -186,7 +187,7 @@ def define_schema(cls): return IO.Schema( node_id="SaveAudioMP3", search_aliases=["export mp3"], - display_name="Save Audio (MP3)", + display_name="Save Audio (MP3) (Deprecated)", category="audio", essentials_category="Audio", inputs=[ @@ -196,6 +197,7 @@ def define_schema(cls): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + is_deprecated=True, ) @classmethod @@ -217,7 +219,7 @@ def define_schema(cls): return IO.Schema( node_id="SaveAudioOpus", search_aliases=["export opus"], - display_name="Save Audio (Opus)", + display_name="Save Audio (Opus) (Deprecated)", category="audio", inputs=[ IO.Audio.Input("audio"), @@ -226,6 +228,7 @@ def define_schema(cls): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + is_deprecated=True, ) @classmethod @@ -241,6 +244,54 @@ def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") save_opus = execute # TODO: remove +class SaveAudioAdvanced(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioAdvanced", + search_aliases=["save audio", "export audio", "output audio", "write audio", "flac", "mp3", "opus"], + display_name="Save Audio (Advanced)", + description="Saves the input audio to your ComfyUI output directory.", + category="audio", + inputs=[ + IO.Audio.Input("audio", tooltip="The audio to save."), + IO.String.Input( + "filename_prefix", + default="audio/ComfyUI", + tooltip=( + "The prefix for the file to save. May include formatting tokens " + "such as %date:yyyy-MM-dd%." + ), + ), + IO.DynamicCombo.Input( + "format", + options=[ + IO.DynamicCombo.Option("flac", []), + IO.DynamicCombo.Option("mp3", [ + IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"), + ]), + IO.DynamicCombo.Option("opus", [ + IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"), + ]), + ], + tooltip="The file format in which to save the audio.", + ), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, audio, filename_prefix: str, format: dict) -> IO.NodeOutput: + file_format = format.get("format", None) + quality = format.get("quality", None) + if quality: + ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format, quality=quality) + else: + ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format) + return IO.NodeOutput(ui=ui) + + class PreviewAudio(IO.ComfyNode): @classmethod def define_schema(cls): @@ -822,6 +873,7 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: SaveAudio, SaveAudioMP3, SaveAudioOpus, + SaveAudioAdvanced, LoadAudio, PreviewAudio, ConditioningStableAudio, From 5aa71b9bc28809a16596bb9fa3d0a6300d8e3f0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 5 Jun 2026 10:04:10 +0300 Subject: [PATCH 07/45] Enable cfg1 optimization for DualModelGuider with CFGGuider (#14290) * Enable cfg1 optimization for DualModelGuider * Fix CFG Override tooltip --- comfy_extras/nodes_custom_sampler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 2f4ff1f70708..3e97084a45e8 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -933,9 +933,10 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas def predict_noise(self, x, timestep, model_options={}, seed=None): positive = self.conds.get("positive", None) - if self.uncond_inner is None: # cfg == 1 or no negative -> single model, cond only - return comfy.samplers.calc_cond_batch(self.inner_model, [positive], x, timestep, model_options)[0] cond = comfy.samplers.calc_cond_batch(self.inner_model, [positive], x, timestep, model_options)[0] + # uncond model not loaded (base cfg==1/no negative), or cfg driven to 1.0 this step -> single model, cond only + if self.uncond_inner is None or (math.isclose(self.cfg, 1.0) and not model_options.get("disable_cfg1_optimization", False)): + return cond uncond_model_options = model_options if "multigpu_clones" in model_options: # TODO: support multigpu instead of just running uncond on a single GPU @@ -1140,7 +1141,7 @@ def define_schema(cls) -> io.Schema: return io.Schema( node_id="CFGOverride", display_name="CFG Override", - description="Override cfg to a fixed value over a [start, end] percent slice of the steps. " + description="Override cfg to a fixed value over a [start, end] percent (sigma) range. " "With multiple overrides, the one nearest the sampler wins on overlap.", category="sampling/custom_sampling", inputs=[ From 410df2725336dc0e34214f6a127c761a1879cca8 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 6 Jun 2026 01:39:35 +1000 Subject: [PATCH 08/45] Fix interoperation with external source of pinned memory pressure (#14252) * mm: split off registration helper to doer and headroom calc * pinned_memory: implement registration comfy side Move away from Aimdo buffer registrations which seem fraught with danger and do it comfy side. Just start with the basic move. * pinned_memory: do registrations as portable memory * pinned_memory: discard async errors on registration fail Like the good ol days. * pinned_memory: implement abs shortfall retry If pinned registration happens to fail despite the previous budget ensures, consider the allocation shortfall, ensure it again, and try again. This allows comfy pins to interoperate with other software that might be doing substantive pinning. --- comfy/model_management.py | 6 ++++-- comfy/pinned_memory.py | 19 ++++++++++++++++--- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index dfd58bf1be0d..8e786c0a507b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -651,8 +651,7 @@ def ensure_pin_budget(size, evict_active=False): to_free = shortfall + PIN_PRESSURE_HYSTERESIS return free_pins(to_free, evict_active=evict_active) >= shortfall -def ensure_pin_registerable(size, evict_active=True): - shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY +def free_registrations(shortfall, evict_active=True): if MAX_PINNED_MEMORY <= 0: return False if shortfall <= 0: @@ -674,6 +673,9 @@ def ensure_pin_registerable(size, evict_active=True): return True return shortfall <= REGISTERABLE_PIN_HYSTERESIS +def ensure_pin_registerable(size, evict_active=True): + return free_registrations(TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY, evict_active=evict_active) + class LoadedModel: def __init__(self, model: ModelPatcher): self._set_model(model) diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index ffe12e0dc0e9..cb77c517a283 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -89,13 +89,26 @@ def pin_memory(module, subset="weights", size=None): not comfy.model_management.ensure_pin_registerable(registerable_size)): return _steal_pin(module, stack, buckets, size, priority) + extended = False try: - hostbuf.extend(size=size) + hostbuf.extend(size=size, register=False) + extended = True + pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] + pin.untyped_storage()._comfy_hostbuf = hostbuf + if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0: + comfy.model_management.discard_cuda_async_error() + comfy.model_management.free_registrations(size) + if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0: + comfy.model_management.discard_cuda_async_error() + del pin + hostbuf.truncate(offset, do_unregister=False) + return _steal_pin(module, stack, buckets, size, priority) except RuntimeError: + if extended: + hostbuf.truncate(offset, do_unregister=False) return _steal_pin(module, stack, buckets, size, priority) - module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] - module._pin.untyped_storage()._comfy_hostbuf = hostbuf + module._pin = pin stack.append((module, offset)) module._pin_registered = True module._pin_stack_index = len(stack) - 1 From ec6aa979a627ecbfb7847dad85a4a26a0a3a424a Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 6 Jun 2026 01:40:03 +1000 Subject: [PATCH 09/45] aimdo 049 (#14300) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8b64c60a9ece..613553d8f614 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0 filelock av>=16.0.0 comfy-kitchen==0.2.10 -comfy-aimdo==0.4.8 +comfy-aimdo==0.4.9 requests simpleeval>=1.0.0 blake3 From 4a00126e9cd3ea59e33ff0c75e628cb7585ab164 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 5 Jun 2026 20:31:55 +0300 Subject: [PATCH 10/45] [Partner Nodes] feat: add new Gemini text node (#14299) --- comfy_api_nodes/apis/gemini.py | 13 +- comfy_api_nodes/nodes_gemini.py | 402 +++++++++++++++++++++++++++----- 2 files changed, 356 insertions(+), 59 deletions(-) diff --git a/comfy_api_nodes/apis/gemini.py b/comfy_api_nodes/apis/gemini.py index 22879fe181c7..caaba8f36fef 100644 --- a/comfy_api_nodes/apis/gemini.py +++ b/comfy_api_nodes/apis/gemini.py @@ -108,13 +108,19 @@ class GeminiVideoMetadata(BaseModel): startOffset: GeminiOffset | None = Field(None) +class GeminiThinkingConfig(BaseModel): + includeThoughts: bool | None = Field(None) + thinkingLevel: str = Field(...) + + class GeminiGenerationConfig(BaseModel): - maxOutputTokens: int | None = Field(None, ge=16, le=8192) + maxOutputTokens: int | None = Field(None, ge=16, le=65536) seed: int | None = Field(None) stopSequences: list[str] | None = Field(None) temperature: float | None = Field(None, ge=0.0, le=2.0) topK: int | None = Field(None, ge=1) topP: float | None = Field(None, ge=0.0, le=1.0) + thinkingConfig: GeminiThinkingConfig | None = Field(None) class GeminiImageOutputOptions(BaseModel): @@ -128,11 +134,6 @@ class GeminiImageConfig(BaseModel): imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions) -class GeminiThinkingConfig(BaseModel): - includeThoughts: bool | None = Field(None) - thinkingLevel: str = Field(...) - - class GeminiImageGenerationConfig(GeminiGenerationConfig): responseModalities: list[str] | None = Field(None) imageConfig: GeminiImageConfig | None = Field(None) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index e75ef3835a12..2699d2792184 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -8,7 +8,7 @@ from enum import Enum from fnmatch import fnmatch from io import BytesIO -from typing import Literal +from typing import Any, Literal import torch from typing_extensions import override @@ -19,6 +19,7 @@ GeminiContent, GeminiFileData, GeminiGenerateContentRequest, + GeminiGenerationConfig, GeminiGenerateContentResponse, GeminiImageConfig, GeminiImageGenerateContentRequest, @@ -40,13 +41,18 @@ get_number_of_images, sync_op, tensor_to_base64_string, + upload_audio_to_comfyapi, + upload_image_to_comfyapi, upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_string, video_to_base64_string, ) GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB +GEMINI_URL_INPUT_BUDGET = 10 +GEMINI_MAX_INLINE_BYTES = 18 * 1024 * 1024 GEMINI_IMAGE_SYS_PROMPT = ( "You are an expert image-generation engine. You must ALWAYS produce an image.\n" "Interpret all user input—regardless of " @@ -285,6 +291,140 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N return final_price / 1_000_000.0 +def create_video_parts(video_input: Input.Video) -> list[GeminiPart]: + """Convert a single video input to Gemini API compatible parts (inline MP4/H.264).""" + base_64_string = video_to_base64_string( + video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264 + ) + return [ + GeminiPart( + inlineData=GeminiInlineData( + mimeType=GeminiMimeType.video_mp4, + data=base_64_string, + ) + ) + ] + + +def create_audio_parts(audio_input: Input.Audio) -> list[GeminiPart]: + """Convert an audio input to Gemini API compatible parts (one inline MP3 part per batch item).""" + audio_parts: list[GeminiPart] = [] + for batch_index in range(audio_input["waveform"].shape[0]): + # Recreate an IO.AUDIO object for the given batch dimension index + audio_at_index = Input.Audio( + waveform=audio_input["waveform"][batch_index].unsqueeze(0), + sample_rate=audio_input["sample_rate"], + ) + # Convert to MP3 format for compatibility with Gemini API + audio_bytes = audio_to_base64_string( + audio_at_index, + container_format="mp3", + codec_name="libmp3lame", + ) + audio_parts.append( + GeminiPart( + inlineData=GeminiInlineData( + mimeType=GeminiMimeType.audio_mp3, + data=audio_bytes, + ) + ) + ) + return audio_parts + + +def _flatten_images(images: list[Input.Image]) -> list[torch.Tensor]: + """Expand any batched image tensors into individual (H, W, C) frames, preserving order.""" + frames: list[torch.Tensor] = [] + for img in images: + if len(img.shape) == 4: + frames.extend(img[i] for i in range(img.shape[0])) + else: + frames.append(img) + return frames + + +def _flatten_audio(audios: list[Input.Audio]) -> list[Input.Audio]: + """Expand any batched audio inputs into individual single-clip audio inputs, preserving order.""" + clips: list[Input.Audio] = [] + for audio in audios: + waveform = audio["waveform"] + for i in range(waveform.shape[0]): + clips.append(Input.Audio(waveform=waveform[i].unsqueeze(0), sample_rate=audio["sample_rate"])) + return clips + + +async def _media_url_part(cls: type[IO.ComfyNode], kind: str, payload: Any) -> GeminiPart: + """Upload a single media unit to ComfyAPI storage and return a fileData (URL) part.""" + if kind == "image": + url = await upload_image_to_comfyapi(cls, payload, mime_type="image/png", wait_label="Uploading image") + return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.image_png, fileUri=url)) + if kind == "audio": + url = await upload_audio_to_comfyapi( + cls, payload, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mp3" + ) + return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.audio_mp3, fileUri=url)) + url = await upload_video_to_comfyapi(cls, payload, wait_label="Uploading video") + return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.video_mp4, fileUri=url)) + + +def _media_inline_part(kind: str, payload: Any) -> tuple[GeminiPart, int]: + """Encode a single media unit as an inline base64 part; returns (part, base64_length).""" + if kind == "image": + data = tensor_to_base64_string(payload, mime_type="image/webp") + mime = GeminiMimeType.image_webp + elif kind == "audio": + data = audio_to_base64_string(payload, container_format="mp3", codec_name="libmp3lame") + mime = GeminiMimeType.audio_mp3 + else: + data = video_to_base64_string( + payload, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264 + ) + mime = GeminiMimeType.video_mp4 + return GeminiPart(inlineData=GeminiInlineData(mimeType=mime, data=data)), len(data) + + +async def build_gemini_media_parts( + cls: type[IO.ComfyNode], + images: list[Input.Image], + audios: list[Input.Audio], + videos: list[Input.Video], + *, + url_budget: int = GEMINI_URL_INPUT_BUDGET, + max_inline_bytes: int = GEMINI_MAX_INLINE_BYTES, +) -> list[GeminiPart]: + """Build Gemini parts for multimodal inputs (images, audio, video). + + fileData URLs are preferred for every media type: the upload is fetched directly by the + model, keeping the request body tiny regardless of media size. The URL budget is shared + across all media and assigned largest-first (video, then audio, then images), so that if it + is ever exhausted the inline-base64 overflow is limited to the smallest items. Total inline + payload is capped by `max_inline_bytes`. + """ + units: list[tuple[str, Any]] = ( + [("video", v) for v in videos] + + [("audio", a) for a in _flatten_audio(audios)] + + [("image", f) for f in _flatten_images(images)] + ) + + parts: list[GeminiPart] = [] + url_used = 0 + inline_bytes = 0 + for kind, payload in units: + if url_used < url_budget: + parts.append(await _media_url_part(cls, kind, payload)) + url_used += 1 + continue + part, nbytes = _media_inline_part(kind, payload) + inline_bytes += nbytes + if inline_bytes > max_inline_bytes: + raise ValueError( + f"Too much media to send inline (over {max_inline_bytes // (1024 * 1024)}MB after the first " + f"{url_budget} inputs are uploaded as URLs). Reduce the number or size of attached media." + ) + parts.append(part) + return parts + + class GeminiNode(IO.ComfyNode): """ Node to generate text responses from a Gemini model. @@ -407,58 +547,9 @@ def define_schema(cls): ) """, ), + is_deprecated=True, ) - @classmethod - def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]: - """Convert video input to Gemini API compatible parts.""" - - base_64_string = video_to_base64_string( - video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264 - ) - return [ - GeminiPart( - inlineData=GeminiInlineData( - mimeType=GeminiMimeType.video_mp4, - data=base_64_string, - ) - ) - ] - - @classmethod - def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]: - """ - Convert audio input to Gemini API compatible parts. - - Args: - audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate. - - Returns: - List of GeminiPart objects containing the encoded audio. - """ - audio_parts: list[GeminiPart] = [] - for batch_index in range(audio_input["waveform"].shape[0]): - # Recreate an IO.AUDIO object for the given batch dimension index - audio_at_index = Input.Audio( - waveform=audio_input["waveform"][batch_index].unsqueeze(0), - sample_rate=audio_input["sample_rate"], - ) - # Convert to MP3 format for compatibility with Gemini API - audio_bytes = audio_to_base64_string( - audio_at_index, - container_format="mp3", - codec_name="libmp3lame", - ) - audio_parts.append( - GeminiPart( - inlineData=GeminiInlineData( - mimeType=GeminiMimeType.audio_mp3, - data=audio_bytes, - ) - ) - ) - return audio_parts - @classmethod async def execute( cls, @@ -482,9 +573,9 @@ async def execute( if images is not None: parts.extend(await create_image_parts(cls, images)) if audio is not None: - parts.extend(cls.create_audio_parts(audio)) + parts.extend(create_audio_parts(audio)) if video is not None: - parts.extend(cls.create_video_parts(video)) + parts.extend(create_video_parts(video)) if files is not None: parts.extend(files) @@ -512,6 +603,210 @@ async def execute( return IO.NodeOutput(output_text or "Empty response from Gemini model...") +GEMINI_V2_MODELS: dict[str, str] = { + "Gemini 3.1 Pro": "gemini-3.1-pro-preview", + "Gemini 3.1 Flash-Lite": "gemini-3.1-flash-lite-preview", +} + + +def _gemini_text_model_inputs(thinking_default: str) -> list[Input]: + """Per-model inputs revealed by the model DynamicCombo (shared media + sampling controls).""" + return [ + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("image"), + names=[f"image_{i}" for i in range(1, 17)], + min=0, + ), + tooltip="Optional image(s) to use as context for the model. Up to 16 images.", + ), + IO.Autogrow.Input( + "audio", + template=IO.Autogrow.TemplateNames( + IO.Audio.Input("audio"), + names=["audio_1"], + min=0, + ), + tooltip="Optional audio clip to use as context for the model.", + ), + IO.Autogrow.Input( + "video", + template=IO.Autogrow.TemplateNames( + IO.Video.Input("video"), + names=["video_1"], + min=0, + ), + tooltip="Optional video clip to use as context for the model.", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Input Files node.", + ), + IO.Combo.Input( + "thinking_level", + options=["LOW", "HIGH"], + default=thinking_default, + tooltip="How hard the model reasons internally before answering. " + "HIGH improves quality on difficult tasks but costs more (thinking) tokens and is slower.", + ), + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.01, + tooltip="Controls randomness. Lower is more focused/deterministic, higher is more creative.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=0.95, + min=0.0, + max=1.0, + step=0.01, + tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.", + advanced=True, + ), + IO.Int.Input( + "max_output_tokens", + default=32768, + min=16, + max=65536, + tooltip="Maximum tokens to generate, including the model's internal thinking. " + "With thinking_level HIGH, a low value can leave no room for the answer; raise this if " + "responses come back empty or truncated. The model stops early when finished, so a higher " + "cap costs nothing extra for short replies.", + advanced=True, + ), + ] + + +class GeminiNodeV2(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiNodeV2", + display_name="Google Gemini", + category="partner/text/Gemini", + essentials_category="Text Generation", + description="Generate text responses with Google's Gemini models. Provide a text prompt and, " + "optionally, one or more images, audio clips, videos, or files as multimodal context.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text input to the model. Include detailed instructions, questions, or context.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option("Gemini 3.1 Pro", _gemini_text_model_inputs("HIGH")), + IO.DynamicCombo.Option("Gemini 3.1 Flash-Lite", _gemini_text_model_inputs("LOW")), + ], + tooltip="The Gemini model used to generate the response.", + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed for sampling. Set to 0 for a random seed. Deterministic output isn't guaranteed.", + ), + IO.String.Input( + "system_prompt", + multiline=True, + default="", + optional=True, + advanced=True, + tooltip="Foundational instructions that dictate the model's behavior.", + ), + ], + outputs=[ + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m, "lite") ? { + "type": "list_usd", + "usd": [0.00025, 0.0015], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } : { + "type": "list_usd", + "usd": [0.002, 0.012], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + system_prompt: str = "", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + model_id = GEMINI_V2_MODELS[model["model"]] + + parts: list[GeminiPart] = [GeminiPart(text=prompt)] + images = [t for t in (model.get("images") or {}).values() if t is not None] + audios = [a for a in (model.get("audio") or {}).values() if a is not None] + videos = [v for v in (model.get("video") or {}).values() if v is not None] + if images or audios or videos: + parts.extend(await build_gemini_media_parts(cls, images, audios, videos)) + files = model.get("files") + if files is not None: + parts.extend(files) + + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"), + data=GeminiGenerateContentRequest( + contents=[ + GeminiContent( + role=GeminiRole.user, + parts=parts, + ) + ], + generationConfig=GeminiGenerationConfig( + temperature=model["temperature"], + topP=model["top_p"], + maxOutputTokens=model["max_output_tokens"], + seed=seed if seed > 0 else None, + thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]), + ), + systemInstruction=gemini_system_prompt, + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + + output_text = get_text_from_response(response) + return IO.NodeOutput(output_text or "Empty response from Gemini model...") + + class GeminiInputFiles(IO.ComfyNode): """ Loads and formats input files for use with the Gemini API. @@ -1222,6 +1517,7 @@ class GeminiExtension(ComfyExtension): async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ GeminiNode, + GeminiNodeV2, GeminiImage, GeminiImage2, GeminiNanoBanana2, From aeee53ff6a66191a93c3d6f0dff51a59fda202f5 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 5 Jun 2026 21:52:15 +0300 Subject: [PATCH 11/45] [Partner Nodes] feat: add temperature and top_p to NanoBanan node (#14305) --- comfy_api_nodes/nodes_gemini.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 2699d2792184..3d4be60653ef 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -1424,6 +1424,26 @@ def define_schema(cls): tooltip="Foundational instructions that dictate an AI's behavior.", advanced=True, ), + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.01, + optional=True, + tooltip="Controls randomness in generation. Lower is more focused/deterministic.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=0.95, + min=0.0, + max=1.0, + step=0.01, + optional=True, + tooltip="Nucleus sampling threshold. Lower is more focused, higher more diverse.", + advanced=True, + ), ], outputs=[ IO.Image.Output(), @@ -1460,6 +1480,8 @@ async def execute( seed: int, response_modalities: str, system_prompt: str = "", + temperature: float = 1.0, + top_p: float = 0.95, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) model_choice = model["model"] @@ -1499,6 +1521,8 @@ async def execute( responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), imageConfig=image_config, thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]), + temperature=temperature, + topP=top_p, ), systemInstruction=gemini_system_prompt, ), From 2ef2cf1a7cfc2dd1df9ae4f73a7f9cc279af136f Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Fri, 5 Jun 2026 15:30:58 -0400 Subject: [PATCH 12/45] feat: add PreviewGaussianSplat + PreviewPointCloud nodes (#14194) --- comfy_api/latest/_io.py | 14 +++ comfy_extras/nodes_gaussian_splat.py | 4 +- comfy_extras/nodes_load_3d.py | 141 +++++++++++++++++++++++++-- comfy_extras/nodes_save_3d.py | 6 ++ 4 files changed, 157 insertions(+), 8 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index a3aa508ce005..37614a4c3325 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -755,6 +755,18 @@ class File3DKSPLAT(ComfyTypeIO): Type = File3D +@comfytype(io_type="FILE_3D_SPLAT_ANY") +class File3DSplatAny(ComfyTypeIO): + """General 3D Gaussian splat file type - accepts any supported splat container (.ply / .spz / .splat / .ksplat).""" + Type = File3D + + +@comfytype(io_type="FILE_3D_POINT_CLOUD_ANY") +class File3DPointCloudAny(ComfyTypeIO): + """General point cloud file type - accepts any supported point cloud container (currently .ply).""" + Type = File3D + + @comfytype(io_type="HOOKS") class Hooks(ComfyTypeIO): if TYPE_CHECKING: @@ -2336,6 +2348,8 @@ def as_dict(self): "File3DSPLAT", "File3DSPZ", "File3DKSPLAT", + "File3DSplatAny", + "File3DPointCloudAny", "Hooks", "HookKeyframes", "TimestepsRange", diff --git a/comfy_extras/nodes_gaussian_splat.py b/comfy_extras/nodes_gaussian_splat.py index 2ba3a38202ed..116c14fde480 100644 --- a/comfy_extras/nodes_gaussian_splat.py +++ b/comfy_extras/nodes_gaussian_splat.py @@ -488,7 +488,7 @@ def define_schema(cls): "spz: Niantic gzip-compressed (~10x smaller), base color only " ), ], - outputs=[IO.File3DAny.Output(display_name="model_3d")], + outputs=[IO.File3DSplatAny.Output(display_name="model_3d")], ) @classmethod @@ -516,7 +516,7 @@ def define_schema(cls): inputs=[ IO.MultiType.Input( IO.File3DAny.Input("model_3d"), - types=[IO.File3DPLY, IO.File3DSPLAT, IO.File3DKSPLAT, IO.File3DSPZ], + types=[IO.File3DSplatAny, IO.File3DPLY, IO.File3DSPLAT, IO.File3DKSPLAT, IO.File3DSPZ], tooltip="A gaussian splat 3D file", ), ], diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index b339dc4fffd3..77dd1173b9e5 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -136,7 +136,7 @@ def define_schema(cls): is_output_node=True, inputs=[ IO.MultiType.Input( - "model_file", + "model_3d", types=[ IO.File3DGLB, IO.File3DGLTF, @@ -155,7 +155,134 @@ def define_schema(cls): IO.Int.Input("height", default=1024, min=1, max=4096, step=1), ], outputs=[ - IO.File3DAny.Output(display_name="model_file"), + IO.File3DAny.Output(display_name="model_3d"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + ], + ) + + @classmethod + def execute(cls, model_3d: Types.File3D, image, width: int, height: int, **kwargs) -> IO.NodeOutput: + filename = f"preview3d_advanced_{uuid.uuid4().hex}.{model_3d.format}" + model_3d.save_to(os.path.join(folder_paths.get_output_directory(), filename)) + + camera_info_input = kwargs.get("camera_info", None) + camera_info = camera_info_input if camera_info_input is not None else image['camera_info'] + model_3d_info_input = kwargs.get("model_3d_info", None) + model_3d_info = model_3d_info_input if model_3d_info_input is not None else image.get('model_3d_info', []) + return IO.NodeOutput( + model_3d, + camera_info, + model_3d_info, + width, + height, + ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info), + ) + + +class PreviewGaussianSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PreviewGaussianSplat", + display_name="Preview Splat", + category="3d", + is_experimental=True, + is_output_node=True, + search_aliases=[ + "view splat", + "view gaussian", + "view gaussian splat", + "preview gaussian", + "preview gaussian splat", + "view 3dgs", + "preview 3dgs", + "preview ply", + "preview spz", + "preview splat", + "preview ksplat", + ], + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[ + IO.File3DSplatAny, + IO.File3DPLY, + IO.File3DSPLAT, + IO.File3DSPZ, + IO.File3DKSPLAT, + ], + tooltip="A gaussian splat 3D file.", + ), + IO.Load3D.Input("image"), + IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), + IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.File3DSplatAny.Output(display_name="model_3d"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + ], + ) + + @classmethod + def execute(cls, model_3d: Types.File3D, image, width: int, height: int, **kwargs) -> IO.NodeOutput: + filename = f"preview_splat_{uuid.uuid4().hex}.{model_3d.format}" + model_3d.save_to(os.path.join(folder_paths.get_output_directory(), filename)) + + camera_info_input = kwargs.get("camera_info", None) + camera_info = camera_info_input if camera_info_input is not None else image['camera_info'] + model_3d_info_input = kwargs.get("model_3d_info", None) + model_3d_info = model_3d_info_input if model_3d_info_input is not None else image.get('model_3d_info', []) + return IO.NodeOutput( + model_3d, + camera_info, + model_3d_info, + width, + height, + ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info), + ) + + +class PreviewPointCloud(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PreviewPointCloud", + display_name="Preview Point Cloud", + category="3d", + is_experimental=True, + is_output_node=True, + search_aliases=[ + "view point cloud", + "view pointcloud", + "preview point cloud", + "preview pointcloud", + "preview ply", + ], + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[ + IO.File3DPointCloudAny, + IO.File3DPLY, + ], + tooltip="Point cloud file (.ply)", + ), + IO.Load3D.Input("image"), + IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), + IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.File3DPointCloudAny.Output(display_name="model_3d"), IO.Load3DCamera.Output(display_name="camera_info"), IO.Load3DModelInfo.Output(display_name="model_3d_info"), IO.Int.Output(display_name="width"), @@ -164,16 +291,16 @@ def define_schema(cls): ) @classmethod - def execute(cls, model_file: Types.File3D, image, width: int, height: int, **kwargs) -> IO.NodeOutput: - filename = f"preview3d_advanced_{uuid.uuid4().hex}.{model_file.format}" - model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename)) + def execute(cls, model_3d: Types.File3D, image, width: int, height: int, **kwargs) -> IO.NodeOutput: + filename = f"preview_pointcloud_{uuid.uuid4().hex}.{model_3d.format}" + model_3d.save_to(os.path.join(folder_paths.get_output_directory(), filename)) camera_info_input = kwargs.get("camera_info", None) camera_info = camera_info_input if camera_info_input is not None else image['camera_info'] model_3d_info_input = kwargs.get("model_3d_info", None) model_3d_info = model_3d_info_input if model_3d_info_input is not None else image.get('model_3d_info', []) return IO.NodeOutput( - model_file, + model_3d, camera_info, model_3d_info, width, @@ -189,6 +316,8 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: Load3D, Preview3D, Preview3DAdvanced, + PreviewGaussianSplat, + PreviewPointCloud, ] diff --git a/comfy_extras/nodes_save_3d.py b/comfy_extras/nodes_save_3d.py index a91549e7f6d8..1b6592bb25e9 100644 --- a/comfy_extras/nodes_save_3d.py +++ b/comfy_extras/nodes_save_3d.py @@ -337,6 +337,12 @@ def define_schema(cls): IO.File3DFBX, IO.File3DSTL, IO.File3DUSDZ, + IO.File3DPLY, + IO.File3DSPLAT, + IO.File3DSPZ, + IO.File3DKSPLAT, + IO.File3DSplatAny, + IO.File3DPointCloudAny, IO.File3DAny, ], tooltip="Mesh or 3D file to save", From 986ce5b4f0935035bbee63b628d01e5dcd67a5a9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 5 Jun 2026 12:41:44 -0700 Subject: [PATCH 13/45] Update AMD portable readme. (#14303) --- .../README_VERY_IMPORTANT.txt | 55 +++++++++---------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt index 2cbb00d99195..26aeeee52b42 100755 --- a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt +++ b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt @@ -1,28 +1,27 @@ -As of the time of writing this you need this driver for best results: -https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html - -HOW TO RUN: - -If you have a AMD gpu: - -run_amd_gpu.bat - -If you have memory issues you can try disabling the smart memory management by running comfyui with: - -run_amd_gpu_disable_smart_memory.bat - -IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints - -You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors - - -RECOMMENDED WAY TO UPDATE: -To update the ComfyUI code: update\update_comfyui.bat - - -TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI: -In the ComfyUI directory you will find a file: extra_model_paths.yaml.example -Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor. - - - +As of the time of writing this you need a recent driver. Updating to the latest driver is recommended. + +HOW TO RUN: + +If you have a AMD gpu: + +run_amd_gpu.bat + +If you have memory issues you can try enabling the new dynamic memory management by running comfyui with: + +run_amd_gpu_enable_dynamic_vram.bat + +IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints + +You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors + + +RECOMMENDED WAY TO UPDATE: +To update the ComfyUI code: update\update_comfyui.bat + + +TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI: +In the ComfyUI directory you will find a file: extra_model_paths.yaml.example +Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor. + + + From a65a5464c731932c1565bca95b729ecaf055162d Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Fri, 5 Jun 2026 17:18:41 -0400 Subject: [PATCH 14/45] BE-1172 fix(3d): save Preview3DAdvanced / PreviewGaussianSplat / PreviewPointCloud to temp/, rename viewport input (#14294) --- comfy_extras/nodes_load_3d.py | 38 +++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 77dd1173b9e5..b5f24707616c 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -51,6 +51,14 @@ def define_schema(cls): ], ) + @classmethod + def validate_inputs(cls, model_file, **kwargs) -> bool | str: + if not model_file or model_file == "none": + return True + if not folder_paths.exists_annotated_filepath(model_file): + return f"Invalid 3D model file: {model_file}" + return True + @classmethod def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput: image_path = folder_paths.get_annotated_filepath(image['image']) @@ -148,7 +156,7 @@ def define_schema(cls): ], tooltip="3D model file from an upstream 3D node.", ), - IO.Load3D.Input("image"), + IO.Load3D.Input("viewport_state"), IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), IO.Int.Input("width", default=1024, min=1, max=4096, step=1), @@ -164,14 +172,14 @@ def define_schema(cls): ) @classmethod - def execute(cls, model_3d: Types.File3D, image, width: int, height: int, **kwargs) -> IO.NodeOutput: + def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput: filename = f"preview3d_advanced_{uuid.uuid4().hex}.{model_3d.format}" - model_3d.save_to(os.path.join(folder_paths.get_output_directory(), filename)) + model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename)) camera_info_input = kwargs.get("camera_info", None) - camera_info = camera_info_input if camera_info_input is not None else image['camera_info'] + camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info'] model_3d_info_input = kwargs.get("model_3d_info", None) - model_3d_info = model_3d_info_input if model_3d_info_input is not None else image.get('model_3d_info', []) + model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', []) return IO.NodeOutput( model_3d, camera_info, @@ -216,7 +224,7 @@ def define_schema(cls): ], tooltip="A gaussian splat 3D file.", ), - IO.Load3D.Input("image"), + IO.Load3D.Input("viewport_state"), IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), IO.Int.Input("width", default=1024, min=1, max=4096, step=1), @@ -232,14 +240,14 @@ def define_schema(cls): ) @classmethod - def execute(cls, model_3d: Types.File3D, image, width: int, height: int, **kwargs) -> IO.NodeOutput: + def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput: filename = f"preview_splat_{uuid.uuid4().hex}.{model_3d.format}" - model_3d.save_to(os.path.join(folder_paths.get_output_directory(), filename)) + model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename)) camera_info_input = kwargs.get("camera_info", None) - camera_info = camera_info_input if camera_info_input is not None else image['camera_info'] + camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info'] model_3d_info_input = kwargs.get("model_3d_info", None) - model_3d_info = model_3d_info_input if model_3d_info_input is not None else image.get('model_3d_info', []) + model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', []) return IO.NodeOutput( model_3d, camera_info, @@ -275,7 +283,7 @@ def define_schema(cls): ], tooltip="Point cloud file (.ply)", ), - IO.Load3D.Input("image"), + IO.Load3D.Input("viewport_state"), IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), IO.Int.Input("width", default=1024, min=1, max=4096, step=1), @@ -291,14 +299,14 @@ def define_schema(cls): ) @classmethod - def execute(cls, model_3d: Types.File3D, image, width: int, height: int, **kwargs) -> IO.NodeOutput: + def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput: filename = f"preview_pointcloud_{uuid.uuid4().hex}.{model_3d.format}" - model_3d.save_to(os.path.join(folder_paths.get_output_directory(), filename)) + model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename)) camera_info_input = kwargs.get("camera_info", None) - camera_info = camera_info_input if camera_info_input is not None else image['camera_info'] + camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info'] model_3d_info_input = kwargs.get("model_3d_info", None) - model_3d_info = model_3d_info_input if model_3d_info_input is not None else image.get('model_3d_info', []) + model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', []) return IO.NodeOutput( model_3d, camera_info, From ea36cb16d62db2029309c983f77a6361534932af Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Fri, 5 Jun 2026 22:01:57 -0400 Subject: [PATCH 15/45] feat(3d): reorder Preview3DAdvanced / PreviewGaussianSplat / PreviewPointCloud inputs and outputs (#14308) --- comfy_extras/nodes_load_3d.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index b5f24707616c..455897859ec4 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -156,16 +156,16 @@ def define_schema(cls): ], tooltip="3D model file from an upstream 3D node.", ), + IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), IO.Load3D.Input("viewport_state"), IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), - IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), IO.Int.Input("width", default=1024, min=1, max=4096, step=1), IO.Int.Input("height", default=1024, min=1, max=4096, step=1), ], outputs=[ IO.File3DAny.Output(display_name="model_3d"), - IO.Load3DCamera.Output(display_name="camera_info"), IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Load3DCamera.Output(display_name="camera_info"), IO.Int.Output(display_name="width"), IO.Int.Output(display_name="height"), ], @@ -182,8 +182,8 @@ def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', []) return IO.NodeOutput( model_3d, - camera_info, model_3d_info, + camera_info, width, height, ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info), @@ -224,16 +224,16 @@ def define_schema(cls): ], tooltip="A gaussian splat 3D file.", ), + IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), IO.Load3D.Input("viewport_state"), IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), - IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), IO.Int.Input("width", default=1024, min=1, max=4096, step=1), IO.Int.Input("height", default=1024, min=1, max=4096, step=1), ], outputs=[ IO.File3DSplatAny.Output(display_name="model_3d"), - IO.Load3DCamera.Output(display_name="camera_info"), IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Load3DCamera.Output(display_name="camera_info"), IO.Int.Output(display_name="width"), IO.Int.Output(display_name="height"), ], @@ -250,8 +250,8 @@ def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', []) return IO.NodeOutput( model_3d, - camera_info, model_3d_info, + camera_info, width, height, ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info), @@ -283,16 +283,16 @@ def define_schema(cls): ], tooltip="Point cloud file (.ply)", ), + IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), IO.Load3D.Input("viewport_state"), IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), - IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), IO.Int.Input("width", default=1024, min=1, max=4096, step=1), IO.Int.Input("height", default=1024, min=1, max=4096, step=1), ], outputs=[ IO.File3DPointCloudAny.Output(display_name="model_3d"), - IO.Load3DCamera.Output(display_name="camera_info"), IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Load3DCamera.Output(display_name="camera_info"), IO.Int.Output(display_name="width"), IO.Int.Output(display_name="height"), ], @@ -309,8 +309,8 @@ def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', []) return IO.NodeOutput( model_3d, - camera_info, model_3d_info, + camera_info, width, height, ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info), From 2cdaaf4a25fd5771da451e906a98e096397312a9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 6 Jun 2026 19:33:03 -0700 Subject: [PATCH 16/45] Update line endings check to ignore .ci files. (#14319) --- .github/workflows/check-line-endings.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/check-line-endings.yml b/.github/workflows/check-line-endings.yml index eeb594d6cd39..a69a24a872f5 100644 --- a/.github/workflows/check-line-endings.yml +++ b/.github/workflows/check-line-endings.yml @@ -17,7 +17,7 @@ jobs: - name: Check for Windows line endings (CRLF) run: | # Get the list of changed files in the PR - CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }}) + CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} -- ':!.ci') # Flag to track if CRLF is found CRLF_FOUND=false From 739061dd4cbb7434a19de8b39cd08cb232b38a7c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 7 Jun 2026 20:56:53 -0700 Subject: [PATCH 17/45] Use windows line endings for windows portable readmes. (#14334) --- .../README_VERY_IMPORTANT.txt | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt index 26aeeee52b42..2c72c8a1384c 100755 --- a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt +++ b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt @@ -1,27 +1,27 @@ -As of the time of writing this you need a recent driver. Updating to the latest driver is recommended. - -HOW TO RUN: - -If you have a AMD gpu: - -run_amd_gpu.bat - -If you have memory issues you can try enabling the new dynamic memory management by running comfyui with: - -run_amd_gpu_enable_dynamic_vram.bat - -IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints - -You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors - - -RECOMMENDED WAY TO UPDATE: -To update the ComfyUI code: update\update_comfyui.bat - - -TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI: -In the ComfyUI directory you will find a file: extra_model_paths.yaml.example -Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor. - - - +As of the time of writing this you need a recent driver. Updating to the latest driver is recommended. + +HOW TO RUN: + +If you have a AMD gpu: + +run_amd_gpu.bat + +If you have memory issues you can try enabling the new dynamic memory management by running comfyui with: + +run_amd_gpu_enable_dynamic_vram.bat + +IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints + +You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors + + +RECOMMENDED WAY TO UPDATE: +To update the ComfyUI code: update\update_comfyui.bat + + +TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI: +In the ComfyUI directory you will find a file: extra_model_paths.yaml.example +Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor. + + + From 7863cf0e53ca599a84b3ec5bcda122e4ecc3765c Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 8 Jun 2026 05:15:05 -0500 Subject: [PATCH 18/45] Add SeedVR2 support (CORE-6) (#14110) --- comfy/latent_formats.py | 5 + comfy/ldm/modules/attention.py | 84 +- comfy/ldm/modules/diffusionmodules/model.py | 8 +- comfy/ldm/seedvr/color_fix.py | 340 +++ comfy/ldm/seedvr/constants.py | 79 + comfy/ldm/seedvr/model.py | 1665 +++++++++++++ comfy/ldm/seedvr/vae.py | 2110 +++++++++++++++++ comfy/model_base.py | 12 + comfy/model_detection.py | 50 + comfy/sample.py | 8 +- comfy/sd.py | 237 +- comfy/supported_models.py | 31 +- comfy/supported_models_base.py | 2 +- comfy_extras/nodes_seedvr.py | 1015 ++++++++ nodes.py | 42 +- .../test_seedvr2_conditioning.py | 213 ++ .../comfy_extras_test/test_seedvr2_nodes.py | 55 + .../test_seedvr2_post_processing.py | 57 + tests-unit/comfy_test/model_detection_test.py | 60 + .../comfy_test/seedvr_vae_forward_test.py | 90 + tests-unit/comfy_test/test_seedvr2_dtype.py | 47 + .../comfy_test/test_seedvr2_internals.py | 341 +++ tests-unit/comfy_test/test_seedvr2_model.py | 308 +++ .../comfy_test/test_seedvr2_vae_decode.py | 91 + .../comfy_test/test_seedvr2_vae_tiled.py | 347 +++ .../test_seedvr_progressive_sampler.py | 126 + 26 files changed, 7383 insertions(+), 40 deletions(-) create mode 100644 comfy/ldm/seedvr/color_fix.py create mode 100644 comfy/ldm/seedvr/constants.py create mode 100644 comfy/ldm/seedvr/model.py create mode 100644 comfy/ldm/seedvr/vae.py create mode 100644 comfy_extras/nodes_seedvr.py create mode 100644 tests-unit/comfy_extras_test/test_seedvr2_conditioning.py create mode 100644 tests-unit/comfy_extras_test/test_seedvr2_nodes.py create mode 100644 tests-unit/comfy_extras_test/test_seedvr2_post_processing.py create mode 100644 tests-unit/comfy_test/seedvr_vae_forward_test.py create mode 100644 tests-unit/comfy_test/test_seedvr2_dtype.py create mode 100644 tests-unit/comfy_test/test_seedvr2_internals.py create mode 100644 tests-unit/comfy_test/test_seedvr2_model.py create mode 100644 tests-unit/comfy_test/test_seedvr2_vae_decode.py create mode 100644 tests-unit/comfy_test/test_seedvr2_vae_tiled.py create mode 100644 tests-unit/comfy_test/test_seedvr_progressive_sampler.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index bbdfd4bc2fac..fcbd97c5971a 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -4,6 +4,7 @@ class LatentFormat: scale_factor = 1.0 latent_channels = 4 latent_dimensions = 2 + preserve_empty_channel_multiples = False latent_rgb_factors = None latent_rgb_factors_bias = None latent_rgb_factors_reshape = None @@ -779,6 +780,10 @@ class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 +class SeedVR2(LatentFormat): + latent_channels = 16 + preserve_empty_channel_multiples = True + class ACEAudio15(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 55360535af3b..b78e764c71ba 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -735,7 +735,86 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out +def _var_attention_qkv(q, k, v, heads, skip_reshape): + if skip_reshape: + return q, k, v, q.shape[-1] + total_tokens, embed_dim = q.shape + head_dim = embed_dim // heads + return ( + q.view(total_tokens, heads, head_dim), + k.view(k.shape[0], heads, head_dim), + v.view(v.shape[0], heads, head_dim), + head_dim, + ) + +def _var_attention_output(out, heads, head_dim, skip_output_reshape): + if skip_output_reshape: + return out + return out.reshape(-1, heads * head_dim) + + +def _use_blackwell_attention(): + device = model_management.get_torch_device() + if device.type != "cuda": + return False + major, minor = torch.cuda.get_device_capability(device) + return (major, minor) >= (12, 0) + + +def _validate_split_cu_seqlens(name, cu_seqlens, token_count): + if cu_seqlens.dtype not in (torch.int32, torch.int64): + raise ValueError(f"{name} must use an integer dtype") + if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2: + raise ValueError(f"{name} must be a 1D tensor with at least two offsets") + if cu_seqlens[0].item() != 0: + raise ValueError(f"{name} must start at 0") + if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item(): + raise ValueError(f"{name} must be strictly increasing") + if cu_seqlens[-1].item() != token_count: + raise ValueError(f"{name} does not match token count") + + +def _split_indices(cu_seqlens): + return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long) + + +def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): + q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) + + _validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0]) + _validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0]) + if cu_seqlens_k[-1].item() != v.shape[0]: + raise ValueError("cu_seqlens_k does not match v token count") + + q_split_indices = _split_indices(cu_seqlens_q) + k_split_indices = _split_indices(cu_seqlens_k) + q_splits = torch.tensor_split(q, q_split_indices, dim=0) + k_splits = torch.tensor_split(k, k_split_indices, dim=0) + v_splits = torch.tensor_split(v, k_split_indices, dim=0) + if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits): + raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count") + + out = [] + for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits): + q_i = q_i.permute(1, 0, 2).unsqueeze(0) + k_i = k_i.permute(1, 0, 2).unsqueeze(0) + v_i = v_i.permute(1, 0, 2).unsqueeze(0) + out_dtype = q_i.dtype + if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16): + q_i = q_i.to(torch.bfloat16) + k_i = k_i.to(torch.bfloat16) + v_i = v_i.to(torch.bfloat16) + out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True) + if out_i.dtype != out_dtype: + out_i = out_i.to(out_dtype) + out.append(out_i.squeeze(0).permute(1, 0, 2)) + + out = torch.cat(out, dim=0) + return _var_attention_output(out, heads, head_dim, skip_output_reshape) + + +optimized_var_attention = var_attention_optimized_split optimized_attention = attention_basic if model_management.sage_attention_enabled(): @@ -758,6 +837,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad +logging.info("Using optimized_attention split-loop for variable-length attention") + optimized_attention_masked = optimized_attention @@ -773,6 +854,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape register_attention_function("pytorch", attention_pytorch) register_attention_function("sub_quad", attention_sub_quad) register_attention_function("split", attention_split) +register_attention_function("var_attention_optimized_split", var_attention_optimized_split) def optimized_attention_for_device(device, mask=False, small_input=False): @@ -1209,5 +1291,3 @@ def forward( x = self.proj_out(x) out = x + x_in return out - - diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index fcbaa074fd84..235df0b835bb 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,6 +13,7 @@ import xformers import xformers.ops + def torch_cat_if_needed(xl, dim): xl = [x for x in xl if x is not None and x.shape[dim] > 0] if len(xl) > 1: @@ -22,7 +23,8 @@ def torch_cat_if_needed(xl, dim): else: return None -def get_timestep_embedding(timesteps, embedding_dim): + +def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. @@ -33,11 +35,13 @@ def get_timestep_embedding(timesteps, embedding_dim): assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) + emb = math.log(10000) / (half_dim - downscale_freq_shift) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0,1,0,0)) return emb diff --git a/comfy/ldm/seedvr/color_fix.py b/comfy/ldm/seedvr/color_fix.py new file mode 100644 index 000000000000..7ddfc03af370 --- /dev/null +++ b/comfy/ldm/seedvr/color_fix.py @@ -0,0 +1,340 @@ +import torch +import torch.nn.functional as F +from torch import Tensor + +from comfy.ldm.seedvr.model import safe_pad_operation +from comfy.ldm.seedvr.vae import safe_interpolate_operation +from comfy.ldm.seedvr.constants import ( + CIELAB_DELTA, + CIELAB_KAPPA, + D65_WHITE_X, + D65_WHITE_Z, + WAVELET_DECOMP_LEVELS, +) + + +def wavelet_blur(image: Tensor, radius): + max_safe_radius = max(1, min(image.shape[-2:]) // 8) + if radius > max_safe_radius: + radius = max_safe_radius + + num_channels = image.shape[1] + + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + kernel = kernel[None, None].repeat(num_channels, 1, 1, 1) + + image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate') + output = F.conv2d(image, kernel, groups=num_channels, dilation=radius) + + return output + +def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS): + high_freq = torch.zeros_like(image) + + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq.add_(image).sub_(low_freq) + image = low_freq + + return high_freq, low_freq + +def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: + + if content_feat.shape != style_feat.shape: + # Resize style to match content spatial dimensions + if len(content_feat.shape) >= 3: + # safe_interpolate_operation handles FP16 conversion automatically + style_feat = safe_interpolate_operation( + style_feat, + size=content_feat.shape[-2:], + mode='bilinear', + align_corners=False + ) + + # Decompose both features into frequency components + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq # Free memory immediately + + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq # Free memory immediately + + if content_high_freq.shape != style_low_freq.shape: + style_low_freq = safe_interpolate_operation( + style_low_freq, + size=content_high_freq.shape[-2:], + mode='bilinear', + align_corners=False + ) + + content_high_freq.add_(style_low_freq) + + return content_high_freq.clamp_(-1.0, 1.0) + +def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor: + original_shape = source.shape + + # Flatten + source_flat = source.flatten() + reference_flat = reference.flatten() + + # Sort both arrays + source_sorted, source_indices = torch.sort(source_flat) + reference_sorted, _ = torch.sort(reference_flat) + del reference_flat + + # Quantile mapping + n_source = len(source_sorted) + n_reference = len(reference_sorted) + + if n_source == n_reference: + matched_sorted = reference_sorted + else: + # Interpolate reference to match source quantiles + source_quantiles = torch.linspace(0, 1, n_source, device=device) + ref_indices = (source_quantiles * (n_reference - 1)).long() + ref_indices.clamp_(0, n_reference - 1) + matched_sorted = reference_sorted[ref_indices] + del source_quantiles, ref_indices, reference_sorted + + del source_sorted, source_flat + + # Reconstruct using argsort (portable across CUDA/ROCm/MPS) + inverse_indices = torch.argsort(source_indices) + del source_indices + matched_flat = matched_sorted[inverse_indices] + del matched_sorted, inverse_indices + + return matched_flat.reshape(original_shape) + +def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor: + """Convert batch of CIELAB images to RGB color space.""" + L, a, b = lab[:, 0], lab[:, 1], lab[:, 2] + + # LAB to XYZ + fy = (L + 16.0) / 116.0 + fx = a.div(500.0).add_(fy) + fz = fy - b / 200.0 + del L, a, b + + # XYZ transformation + x = torch.where( + fx > epsilon, + torch.pow(fx, 3.0), + fx.mul(116.0).sub_(16.0).div_(kappa) + ) + y = torch.where( + fy > epsilon, + torch.pow(fy, 3.0), + fy.mul(116.0).sub_(16.0).div_(kappa) + ) + z = torch.where( + fz > epsilon, + torch.pow(fz, 3.0), + fz.mul(116.0).sub_(16.0).div_(kappa) + ) + del fx, fy, fz + + # Apply D65 white point (in-place) + x.mul_(D65_WHITE_X) + # y *= 1.00000 # (no-op, skip) + z.mul_(D65_WHITE_Z) + + xyz = torch.stack([x, y, z], dim=1) + del x, y, z + + # Matrix multiplication: XYZ -> RGB + B, C, H, W = xyz.shape + xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3) + del xyz + + # Ensure dtype consistency for matrix multiplication + xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype) + rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T) + del xyz_flat + + rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) + del rgb_linear_flat + + # Apply inverse gamma correction (delinearize) + mask = rgb_linear > 0.0031308 + rgb = torch.where( + mask, + torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055), + rgb_linear * 12.92 + ) + del mask, rgb_linear + + return torch.clamp(rgb, 0.0, 1.0) + +def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor: + """Convert batch of RGB images to CIELAB color space using D65 illuminant.""" + # Apply sRGB gamma correction (linearize) + mask = rgb > 0.04045 + rgb_linear = torch.where( + mask, + torch.pow((rgb + 0.055) / 1.055, 2.4), + rgb / 12.92 + ) + del mask + + # Matrix multiplication: RGB -> XYZ + B, C, H, W = rgb_linear.shape + rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3) + del rgb_linear + + # Ensure dtype consistency for matrix multiplication + rgb_flat = rgb_flat.to(dtype=matrix.dtype) + xyz_flat = torch.matmul(rgb_flat, matrix.T) + del rgb_flat + + xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) + del xyz_flat + + # Normalize by D65 white point (in-place) + xyz[:, 0].div_(D65_WHITE_X) # X + # xyz[:, 1] /= 1.00000 # Y (no-op, skip) + xyz[:, 2].div_(D65_WHITE_Z) # Z + + # XYZ to LAB transformation + epsilon_cubed = epsilon ** 3 + mask = xyz > epsilon_cubed + f_xyz = torch.where( + mask, + torch.pow(xyz, 1.0 / 3.0), + xyz.mul(kappa).add_(16.0).div_(116.0) + ) + del xyz, mask + + # Extract channels and compute LAB + L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100] + a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127] + b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127] + del f_xyz + + return torch.stack([L, a, b], dim=1) + +def lab_color_transfer( + content_feat: Tensor, + style_feat: Tensor, + luminance_weight: float = 0.8 +) -> Tensor: + content_feat = wavelet_reconstruction(content_feat, style_feat) + + if content_feat.shape != style_feat.shape: + style_feat = safe_interpolate_operation( + style_feat, + size=content_feat.shape[-2:], + mode='bilinear', + align_corners=False + ) + + device = content_feat.device + + def ensure_float32_precision(c): + orig_dtype = c.dtype + c = c.float() + return c, orig_dtype + content_feat, original_dtype = ensure_float32_precision(content_feat) + style_feat, _ = ensure_float32_precision(style_feat) + + rgb_to_xyz_matrix = torch.tensor([ + [0.4124564, 0.3575761, 0.1804375], + [0.2126729, 0.7151522, 0.0721750], + [0.0193339, 0.1191920, 0.9503041] + ], dtype=torch.float32, device=device) + + xyz_to_rgb_matrix = torch.tensor([ + [ 3.2404542, -1.5371385, -0.4985314], + [-0.9692660, 1.8760108, 0.0415560], + [ 0.0556434, -0.2040259, 1.0572252] + ], dtype=torch.float32, device=device) + + epsilon = CIELAB_DELTA + kappa = CIELAB_KAPPA + + content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) + style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) + + # Convert to LAB color space + content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa) + del content_feat + + style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa) + del style_feat, rgb_to_xyz_matrix + + # Match chrominance channels (a*, b*) for accurate color transfer + matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device) + matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device) + + # Handle luminance with weighted blending + if luminance_weight < 1.0: + # Partially match luminance for better overall color accuracy + matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device) + # Blend: preserve some content L* for detail, adopt some style L* for color + result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight)) + del matched_L + else: + # Fully preserve content luminance + result_L = content_lab[:, 0] + + del content_lab, style_lab + + # Reconstruct LAB with corrected channels + result_lab = torch.stack([result_L, matched_a, matched_b], dim=1) + del result_L, matched_a, matched_b + + # Convert back to RGB + result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa) + del result_lab, xyz_to_rgb_matrix + + # Convert back to [-1, 1] range (in-place) + result = result_rgb.mul_(2.0).sub_(1.0) + del result_rgb + + result = result.to(original_dtype) + + return result + + +def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor: + return wavelet_reconstruction(content_feat, style_feat) + + +def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor: + if content_feat.shape != style_feat.shape: + style_feat = safe_interpolate_operation( + style_feat, + size=content_feat.shape[-2:], + mode='bilinear', + align_corners=False, + ) + + original_dtype = content_feat.dtype + content_feat = content_feat.float() + style_feat = style_feat.float() + + b, c = content_feat.shape[:2] + content_flat = content_feat.reshape(b, c, -1) + style_flat = style_feat.reshape(b, c, -1) + + content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1) + content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) + style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1) + style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) + del content_flat, style_flat + + normalized = (content_feat - content_mean) / content_std + del content_mean, content_std + result = normalized * style_std + style_mean + del normalized, style_mean, style_std + + result = result.clamp_(-1.0, 1.0) + if result.dtype != original_dtype: + result = result.to(original_dtype) + return result diff --git a/comfy/ldm/seedvr/constants.py b/comfy/ldm/seedvr/constants.py new file mode 100644 index 000000000000..95838d1dd7f0 --- /dev/null +++ b/comfy/ldm/seedvr/constants.py @@ -0,0 +1,79 @@ +"""Named constants for the SeedVR2 integration, grouped by provenance. + +Provenance prefixes: +- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline. +- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites + the upstream config/source path it was lifted from. +- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature / + ISO / CIE values; cite the standard. +""" + +# -------------------------------------------------------------------------------------- +# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment) +# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN) +# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070 +# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT). +# -------------------------------------------------------------------------------------- +SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB) +SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB + +# -------------------------------------------------------------------------------------- +# B. Fork heuristics (SEEDVR2 - this integration) +# -------------------------------------------------------------------------------------- +SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim. + # (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.) +SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry. +SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case). +SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM. +SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk. +SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels). +SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16). +SEEDVR2_DEFAULT_TEMPORAL_SIZE = 16 # default VAE temporal tile when unset. + +# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing) +SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk. +SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path. +SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path. +SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path. + +# -------------------------------------------------------------------------------------- +# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR) +# -------------------------------------------------------------------------------------- +BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm. +BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift. +BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem). +BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem). +BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28. +BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28. +BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16). +BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32). +BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD = 64 # attn_video_vae.py:308 (force .contiguous() above this b*t). +BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11. +BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size). +BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor. +BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor). +BYTEDANCE_SCHEDULE_T = 1000.0 # configs_3b/main.yaml:65 (schedule.T); timestep range. +BYTEDANCE_SPATIAL_DIVISOR = 16 # inference_seedvr2_3b.py:241 (DivisibleCrop((16,16))). +BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling). +BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames). +BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency). +BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim). +# Resolution-dependent timestep-shift linear fits: (x1, y1, x2, y2) for get_lin_function. +BYTEDANCE_IMG_SHIFT_FIT = (256 * 256, 1.0, 1024 * 1024, 3.2) # infer.py:242. +BYTEDANCE_VID_SHIFT_FIT = (256 * 256 * 37, 1.0, 1280 * 720 * 145, 5.0) # infer.py:243. + +# -------------------------------------------------------------------------------------- +# D. Published standards (cite the literature) +# -------------------------------------------------------------------------------------- +ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864. + +# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65). +CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta). +CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa). +D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1). +D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn. +WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR). + +# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and +# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the +# exact existing coefficients move verbatim rather than being retyped here. diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py new file mode 100644 index 000000000000..3fa9fe07e870 --- /dev/null +++ b/comfy/ldm/seedvr/model.py @@ -0,0 +1,1665 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union, List, Dict, Any, Callable +import einops +from einops import rearrange +import torch.nn.functional as F +from math import ceil, pi +import torch +from itertools import chain +from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding +from comfy.ldm.modules.attention import optimized_var_attention +from torch.nn.modules.utils import _triple +from torch import nn +import math +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.seedvr.constants import ( + BYTEDANCE_720P_REF_AREA, + BYTEDANCE_MAX_TEMPORAL_WINDOW, + BYTEDANCE_ROPE_MAX_FREQ, + BYTEDANCE_SINUSOIDAL_DIM, + ROPE_THETA, + SEEDVR2_7B_MLP_CHUNK, + SEEDVR2_7B_VID_DIM, + SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, +) +import comfy.model_management +import numbers + +def _torch_float8_types(): + return tuple( + getattr(torch, name) + for name in ( + "float8_e4m3fn", + "float8_e4m3fnuz", + "float8_e5m2", + "float8_e5m2fnuz", + "float8_e8m0fnu", + ) + if hasattr(torch, name) + ) + +class CustomRMSNorm(nn.Module): + + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None): + super(CustomRMSNorm, self).__init__() + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(*normalized_shape, device=device, dtype=dtype)) + else: + self.register_parameter('weight', None) + + def forward(self, input): + + dims = tuple(range(-len(self.normalized_shape), 0)) + + normalized = input.float() + variance = normalized.pow(2).mean(dim=dims, keepdim=True) + rms = torch.sqrt(variance + self.eps) + + normalized = normalized / rms + + if self.elementwise_affine: + return normalized * self.weight.to(input.dtype) + return normalized + +class Cache: + def __init__(self, disable=False, prefix="", cache=None): + self.cache = cache if cache is not None else {} + self.disable = disable + self.prefix = prefix + + def __call__(self, key: str, fn: Callable): + if self.disable: + return fn() + + key = self.prefix + key + try: + result = self.cache[key] + except KeyError: + result = fn() + self.cache[key] = result + return result + + def namespace(self, namespace: str): + return Cache( + disable=self.disable, + prefix=self.prefix + namespace + ".", + cache=self.cache, + ) + + def get(self, key: str): + key = self.prefix + key + return self.cache[key] + +def repeat_concat( + vid: torch.FloatTensor, # (VL ... c) + txt: torch.FloatTensor, # (TL ... c) + vid_len: torch.LongTensor, # (n*b) + txt_len: torch.LongTensor, # (b) + txt_repeat: List, # (n) +) -> torch.FloatTensor: # (L ... c) + vid = torch.split(vid, vid_len.tolist()) + txt = torch.split(txt, txt_len.tolist()) + txt = [[x] * n for x, n in zip(txt, txt_repeat)] + txt = list(chain(*txt)) + return torch.cat(list(chain(*zip(vid, txt)))) + +def concat( + vid: torch.FloatTensor, # (VL ... c) + txt: torch.FloatTensor, # (TL ... c) + vid_len: torch.LongTensor, # (b) + txt_len: torch.LongTensor, # (b) +) -> torch.FloatTensor: # (L ... c) + vid = torch.split(vid, vid_len.tolist()) + txt = torch.split(txt, txt_len.tolist()) + return torch.cat(list(chain(*zip(vid, txt)))) + +def concat_idx( + vid_len: torch.LongTensor, # (b) + txt_len: torch.LongTensor, # (b) +) -> Tuple[ + Callable, + Callable, +]: + device = vid_len.device + vid_idx = torch.arange(vid_len.sum(), device=device) + txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) + tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) + src_idx = torch.argsort(tgt_idx) + return ( + lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), + lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), + ) + + +def repeat_concat_idx( + vid_len: torch.LongTensor, # (n*b) + txt_len: torch.LongTensor, # (b) + txt_repeat: torch.LongTensor, # (n) +) -> Tuple[ + Callable, + Callable, +]: + device = vid_len.device + vid_idx = torch.arange(vid_len.sum(), device=device) + txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) + txt_repeat_list = txt_repeat.tolist() + tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) + src_idx = torch.argsort(tgt_idx) + txt_idx_len = len(tgt_idx) - len(vid_idx) + repeat_txt_len = (txt_len * txt_repeat).tolist() + + def unconcat_coalesce(all): + vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) + txt_out_coalesced = [] + for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): + txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) + txt_out_coalesced.append(txt) + return vid_out, torch.cat(txt_out_coalesced) + + return ( + lambda vid, txt: torch.cat([vid, txt])[tgt_idx], + lambda all: unconcat_coalesce(all), + ) + + +@dataclass +class MMArg: + vid: Any + txt: Any + +def safe_pad_operation(x, padding, mode='constant', value=0.0): + """Safe padding operation that handles Half precision only for problematic modes""" + # Modes qui nécessitent le fix Half precision + problematic_modes = ['replicate', 'reflect', 'circular'] + + if mode in problematic_modes: + try: + return F.pad(x, padding, mode=mode, value=value) + except RuntimeError as e: + if "not implemented for 'Half'" in str(e): + original_dtype = x.dtype + return F.pad(x.float(), padding, mode=mode, value=value).to(original_dtype) + else: + raise e + else: + # Pour 'constant' et autres modes compatibles, pas de fix nécessaire + return F.pad(x, padding, mode=mode, value=value) + + +def get_args(key: str, args: List[Any]) -> List[Any]: + return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] + + +def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} + + +def get_window_op(name: str): + if name == "720pwin_by_size_bysize": + return make_720Pwindows_bysize + if name == "720pswin_by_size_bysize": + return make_shifted_720Pwindows_bysize + raise ValueError(f"Unknown windowing method: {name}") + + +# -------------------------------- Windowing -------------------------------- # +def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): + t, h, w = size + resized_nt, resized_nh, resized_nw = num_windows + #cal windows under 720p + scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) + resized_h, resized_w = round(h * scale), round(w * scale) + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. + wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. + nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. + return [ + ( + slice(it * wt, min((it + 1) * wt, t)), + slice(ih * wh, min((ih + 1) * wh, h)), + slice(iw * ww, min((iw + 1) * ww, w)), + ) + for iw in range(nw) + if min((iw + 1) * ww, w) > iw * ww + for ih in range(nh) + if min((ih + 1) * wh, h) > ih * wh + for it in range(nt) + if min((it + 1) * wt, t) > it * wt + ] + +def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): + t, h, w = size + resized_nt, resized_nh, resized_nw = num_windows + #cal windows under 720p + scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) + resized_h, resized_w = round(h * scale), round(w * scale) + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. + wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. + + st, sh, sw = ( # shift size. + 0.5 if wt < t else 0, + 0.5 if wh < h else 0, + 0.5 if ww < w else 0, + ) + nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. + nt, nh, nw = ( # number of window. + nt + 1 if st > 0 else 1, + nh + 1 if sh > 0 else 1, + nw + 1 if sw > 0 else 1, + ) + return [ + ( + slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)), + slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)), + slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)), + ) + for iw in range(nw) + if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0) + for ih in range(nh) + if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0) + for it in range(nt) + if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0) + ] + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + custom_freqs = None, + freqs_for = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + learned_freq = False, + use_xpos = False, + xpos_scale_base = 512, + interpolate_factor = 1., + theta_rescale_factor = 1., + seq_before_head_dim = False, + cache_if_possible = True, + cache_max_seq_len = 8192 + ): + super().__init__() + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.freqs_for = freqs_for + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + + self.cache_if_possible = cache_if_possible + self.cache_max_seq_len = cache_max_seq_len + + self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) + self.cached_freqs_seq_len = 0 + + self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) + + self.learned_freq = learned_freq + + # dummy for device + + self.register_buffer('dummy', torch.tensor(0), persistent = False) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + # interpolation factors + + assert interpolate_factor >= 1. + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + + if not use_xpos: + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + + self.register_buffer('scale', scale, persistent = False) + self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False) + self.cached_scales_seq_len = 0 + + # add apply_rotary_emb as static method + + self.apply_rotary_emb = staticmethod(apply_rotary_emb) + + @property + def device(self): + return self.dummy.device + + def get_axial_freqs( + self, + *dims, + offsets = None + ): + Colon = slice(None) + all_freqs = [] + + # handle offset + + if exists(offsets): + assert len(offsets) == len(dims) + + for ind, dim in enumerate(dims): + + offset = 0 + if exists(offsets): + offset = offsets[ind] + + if self.freqs_for == 'pixel': + pos = torch.linspace(-1, 1, steps = dim, device = self.device) + else: + pos = torch.arange(dim, device = self.device) + + pos = pos + offset + + freqs = self.forward(pos, seq_len = dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + # concat all freqs + + all_freqs = torch.broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim = -1) + + def forward( + self, + t, + seq_len: int | None = None, + offset = 0 + ): + should_cache = ( + self.cache_if_possible and + not self.learned_freq and + exists(seq_len) and + self.freqs_for != 'pixel' and + (offset + seq_len) <= self.cache_max_seq_len + ) + + if ( + should_cache and \ + exists(self.cached_freqs) and \ + (offset + seq_len) <= self.cached_freqs_seq_len + ): + return self.cached_freqs[offset:(offset + seq_len)].detach() + + freqs = self.freqs + + freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2) + + if should_cache and offset == 0: + self.cached_freqs[:seq_len] = freqs.detach() + self.cached_freqs_seq_len = seq_len + + return freqs + +class RotaryEmbeddingBase(nn.Module): + def __init__(self, dim: int, rope_dim: int): + super().__init__() + self.rope = RotaryEmbedding( + dim=dim // rope_dim, + freqs_for="pixel", + max_freq=BYTEDANCE_ROPE_MAX_FREQ, + ) + freqs = self.rope.freqs + del self.rope.freqs + self.rope.register_buffer("freqs", freqs.data) + + def get_axial_freqs(self, *dims): + return self.rope.get_axial_freqs(*dims) + + +class RotaryEmbedding3d(RotaryEmbeddingBase): + def __init__(self, dim: int): + super().__init__(dim, rope_dim=3) + self.mm = False + + def forward( + self, + q: torch.FloatTensor, # b h l d + k: torch.FloatTensor, # b h l d + size: Tuple[int, int, int], + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + T, H, W = size + freqs = self.get_axial_freqs(T, H, W) + q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) + k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) + q = apply_rotary_emb(freqs, q.float()).to(q.dtype) + k = apply_rotary_emb(freqs, k.float()).to(k.dtype) + q = rearrange(q, "b h T H W d -> b h (T H W) d") + k = rearrange(k, "b h T H W d -> b h (T H W) d") + return q, k + + +class NaRotaryEmbedding3d(RotaryEmbedding3d): + def forward( + self, + q: torch.FloatTensor, + k: torch.FloatTensor, + shape: torch.LongTensor, + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape)) + freqs = freqs.to(device=q.device) + q = rearrange(q, "L h d -> h L d") + k = rearrange(k, "L h d -> h L d") + q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype) + k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype) + q = rearrange(q, "h L d -> L h d") + k = rearrange(k, "h L d -> L h d") + return q, k + + @torch._dynamo.disable + def get_freqs( + self, + shape: torch.LongTensor, + ) -> torch.Tensor: + # Primary provenance: ByteDance-Seed/SeedVR models/dit/rope.py builds + # 7B pixel RoPE with the interleaved-angle convention, not Comfy's + # Flux freqs_cis matrix. + plain_rope = RotaryEmbedding( + dim=self.rope.freqs.numel() * 2, + freqs_for="pixel", + max_freq=BYTEDANCE_ROPE_MAX_FREQ, + ) + plain_rope = plain_rope.to(self.rope.dummy.device) + freq_list = [] + for f, h, w in shape.tolist(): + freqs = plain_rope.get_axial_freqs(f, h, w) + freq_list.append(freqs.view(-1, freqs.size(-1))) + return torch.cat(freq_list, dim=0) + + +class MMRotaryEmbeddingBase(RotaryEmbeddingBase): + def __init__(self, dim: int, rope_dim: int): + super().__init__(dim, rope_dim) + self.rope = RotaryEmbedding( + dim=dim // rope_dim, + freqs_for="lang", + theta=ROPE_THETA, + cache_if_possible=False, + ) + freqs = self.rope.freqs + del self.rope.freqs + self.rope.register_buffer("freqs", freqs.data) + self.mm = True + +def slice_at_dim(t, dim_slice: slice, *, dim): + dim += (t.ndim if dim < 0 else 0) + colons = [slice(None)] * t.ndim + colons[dim] = dim_slice + return t[tuple(colons)] + +# rotary embedding helper functions + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') +def exists(val): + return val is not None + +def apply_rotary_emb( + freqs, + t, + start_index = 0, + scale = 1., + seq_dim = -2, + freqs_seq_dim = None +): + dtype = t.dtype + if not exists(freqs_seq_dim): + if freqs.ndim == 2 or t.ndim == 3: + freqs_seq_dim = 0 + + if t.ndim == 3 or exists(freqs_seq_dim): + seq_len = t.shape[seq_dim] + freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) + + rot_feats = freqs.shape[-1] + end_index = start_index + rot_feats + + t_left = t[..., :start_index] + t_middle = t[..., start_index:end_index] + t_right = t[..., end_index:] + + angles = freqs.to(t_middle.device)[..., ::2] + cos = torch.cos(angles) * scale + sin = torch.sin(angles) * scale + + col0 = torch.stack([cos, sin], dim=-1) + col1 = torch.stack([-sin, cos], dim=-1) + freqs_mat = torch.stack([col0, col1], dim=-1) + + t_middle_out = apply_rope1(t_middle, freqs_mat) + out = torch.cat((t_left, t_middle_out, t_right), dim=-1) + return out.type(dtype) + + +def _apply_seedvr2_rotary_emb( + freqs: torch.Tensor, + t: torch.Tensor, + start_index: int = 0, + scale: float = 1.0, + seq_dim: int = -2, + freqs_seq_dim: int | None = None, +) -> torch.Tensor: + dtype = t.dtype + if freqs_seq_dim is None and (freqs.ndim == 2 or t.ndim == 3): + freqs_seq_dim = 0 + + if t.ndim == 3 or freqs_seq_dim is not None: + seq_len = t.shape[seq_dim] + freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim) + + rot_feats = freqs.shape[-1] + end_index = start_index + rot_feats + + t_left = t[..., :start_index] + t_middle = t[..., start_index:end_index] + t_right = t[..., end_index:] + + freqs = freqs.to(device=t_middle.device, dtype=t_middle.dtype) + cos = freqs.cos() * scale + sin = freqs.sin() * scale + t_middle = (t_middle * cos) + (rotate_half(t_middle) * sin) + return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype) + +def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: + """Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`.""" + angles = freqs_interleaved[..., ::2].float() + cos = torch.cos(angles) + sin = torch.sin(angles) + out = torch.stack([cos, -sin, sin, cos], dim=-1) + return rearrange(out, "... d (i j) -> ... d i j", i=2, j=2) + + +def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """Rotate the leading ``rot_d = 2 * freqs_cis.shape[-3]`` dims of ``t`` and pass the rest + through; in-place for inference, cloned for training (autograd). Mirrors the legacy + ``apply_rotary_emb`` ``t_left``/``t_middle``/``t_right`` split: 3B ``rope_dim=128`` gives + ``42*3 = 126`` rotated of head_dim 128 (trailing 2 unrotated). Fast path skips the cat when + ``rot_d == t.shape[-1]``. + """ + out = t.clone() if t.requires_grad or comfy.model_management.in_training else t + rot_d = 2 * freqs_cis.shape[-3] + seq_len = out.shape[-2] + for start in range(0, seq_len, SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS): + end = min(start + SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, seq_len) + freqs_chunk = freqs_cis[start:end] + if rot_d == out.shape[-1]: + out[..., start:end, :] = apply_rope1(out[..., start:end, :], freqs_chunk).to(out.dtype) + else: + out[..., start:end, :rot_d] = apply_rope1(out[..., start:end, :rot_d], freqs_chunk).to(out.dtype) + return out + + +class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): + def __init__(self, dim: int): + super().__init__(dim, rope_dim=3) + + def forward( + self, + vid_q: torch.FloatTensor, # L h d + vid_k: torch.FloatTensor, # L h d + vid_shape: torch.LongTensor, # B 3 + txt_q: torch.FloatTensor, # L h d + txt_k: torch.FloatTensor, # L h d + txt_shape: torch.LongTensor, # B 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_freqs, txt_freqs = cache( + "mmrope_freqs_3d", + lambda: self.get_freqs(vid_shape, txt_shape), + ) + target_device = vid_q.device + if vid_freqs.device != target_device: + vid_freqs = vid_freqs.to(target_device) + if txt_freqs.device != target_device: + txt_freqs = txt_freqs.to(target_device) + vid_q = rearrange(vid_q, "L h d -> h L d") + vid_k = rearrange(vid_k, "L h d -> h L d") + vid_q = _apply_rope1_partial(vid_q, vid_freqs) + vid_k = _apply_rope1_partial(vid_k, vid_freqs) + vid_q = rearrange(vid_q, "h L d -> L h d") + vid_k = rearrange(vid_k, "h L d -> L h d") + + txt_q = rearrange(txt_q, "L h d -> h L d") + txt_k = rearrange(txt_k, "L h d -> h L d") + txt_q = _apply_rope1_partial(txt_q, txt_freqs) + txt_k = _apply_rope1_partial(txt_k, txt_freqs) + txt_q = rearrange(txt_q, "h L d -> L h d") + txt_k = rearrange(txt_k, "h L d -> L h d") + return vid_q, vid_k, txt_q, txt_k + + @torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks + def get_freqs( + self, + vid_shape: torch.LongTensor, + txt_shape: torch.LongTensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + ]: + + # Calculate actual max dimensions needed for this batch + max_temporal = 0 + max_height = 0 + max_width = 0 + max_txt_len = 0 + + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal + max_height = max(max_height, h) + max_width = max(max_width, w) + max_txt_len = max(max_txt_len, l) + + autocast_device = "cuda" if torch.cuda.is_available() else "cpu" + with torch.amp.autocast(autocast_device, enabled=False): + vid_freqs = self.get_axial_freqs( + max_temporal + 16, + max_height + 4, + max_width + 4, + ).float() + txt_freqs = self.get_axial_freqs(max_txt_len + 16) + + # Now slice as before + vid_freq_list, txt_freq_list = [], [] + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) + txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1)) + vid_freq_list.append(vid_freq) + txt_freq_list.append(txt_freq) + vid_freqs_interleaved = torch.cat(vid_freq_list, dim=0) + txt_freqs_interleaved = torch.cat(txt_freq_list, dim=0) + + # Convert from lucidrains-interleaved layout `[θ0, θ0, θ1, θ1, ...]` + # (produced by `repeat(freqs, '... n -> ... (n r)', r=2)` in the + # upstream `RotaryEmbedding.forward`) to flux-canonical `freqs_cis` + # in shape `[..., d/2, 2, 2]` with `cos/-sin/sin/cos` baked in. + # Mirrors `comfy/ldm/flux/math.py:rope` (line 27) so the trailing + # 2x2 is the per-frequency rotation matrix that + # `comfy.ldm.flux.math.apply_rope1` expects. + return _to_flux_freqs_cis(vid_freqs_interleaved), _to_flux_freqs_cis(txt_freqs_interleaved) + +class MMModule(nn.Module): + def __init__( + self, + module: Callable[..., nn.Module], + *args, + shared_weights: bool = False, + vid_only: bool = False, + **kwargs, + ): + super().__init__() + self.shared_weights = shared_weights + self.vid_only = vid_only + if self.shared_weights: + assert get_args("vid", args) == get_args("txt", args) + assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) + self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) + else: + self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) + self.txt = ( + module(*get_args("txt", args), **get_kwargs("txt", kwargs)) + if not vid_only + else None + ) + + def forward( + self, + vid: torch.FloatTensor, + txt: torch.FloatTensor, + *args, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_module = self.vid if not self.shared_weights else self.all + vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) + if not self.vid_only: + txt_module = self.txt if not self.shared_weights else self.all + txt = txt.to(device=vid.device, dtype=vid.dtype) + txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) + return vid, txt + +def get_na_rope(rope_type: Optional[str], dim: int): + if rope_type is None: + return None + if rope_type == "rope3d": + return NaRotaryEmbedding3d(dim=dim) + if rope_type == "mmrope3d": + return NaMMRotaryEmbedding3d(dim=dim) + +class NaMMAttention(nn.Module): + def __init__( + self, + vid_dim: int, + txt_dim: int, + heads: int, + head_dim: int, + qk_bias: bool, + qk_norm, + qk_norm_eps: float, + rope_type: Optional[str], + rope_dim: int, + shared_weights: bool, + device, dtype, operations, + **kwargs, + ): + super().__init__() + dim = MMArg(vid_dim, txt_dim) + self.heads = heads + inner_dim = heads * head_dim + qkv_dim = inner_dim * 3 + self.head_dim = head_dim + self.proj_qkv = MMModule( + operations.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights, device=device, dtype=dtype + ) + self.proj_out = MMModule(operations.Linear, inner_dim, dim, shared_weights=shared_weights, device=device, dtype=dtype) + self.norm_q = MMModule( + qk_norm, + normalized_shape=head_dim, + eps=qk_norm_eps, + elementwise_affine=True, + shared_weights=shared_weights, + device=device, dtype=dtype + ) + self.norm_k = MMModule( + qk_norm, + normalized_shape=head_dim, + eps=qk_norm_eps, + elementwise_affine=True, + shared_weights=shared_weights, + device=device, dtype=dtype + ) + + + self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) + + def forward(self): + pass + +def window( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + window_fn: Callable[[torch.Tensor], List[torch.Tensor]], +): + hid = unflatten(hid, hid_shape) + hid = list(map(window_fn, hid)) + hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) + hid, hid_shape = flatten(list(chain(*hid))) + return hid, hid_shape, hid_windows + +def window_idx( + hid_shape: torch.LongTensor, # (b n) + window_fn: Callable[[torch.Tensor], List[torch.Tensor]], +): + hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) + tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) + tgt_idx = tgt_idx.squeeze(-1) + src_idx = torch.argsort(tgt_idx) + return ( + lambda hid: torch.index_select(hid, 0, tgt_idx), + lambda hid: torch.index_select(hid, 0, src_idx), + tgt_shape, + tgt_windows, + ) + +class NaSwinAttention(NaMMAttention): + def __init__( + self, + *args, + window: Union[int, Tuple[int, int, int]], + window_method: bool, # shifted or not + **kwargs, + ): + super().__init__(*args, **kwargs) + self.version_7b = kwargs.get("version", False) + self.window = _triple(window) + self.window_method = window_method + assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) + + self.window_op = get_window_op(window_method) + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + + vid_qkv, txt_qkv = self.proj_qkv(vid, txt) + + # re-org the input seq for window attn + cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") + + def make_window(x: torch.Tensor): + t, h, w, _ = x.shape + window_slices = self.window_op((t, h, w), self.window) + return [x[st, sh, sw] for (st, sh, sw) in window_slices] + + window_partition, window_reverse, window_shape, window_count = cache_win( + "win_transform", + lambda: window_idx(vid_shape, make_window), + ) + vid_qkv_win = window_partition(vid_qkv) + + vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) + txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + + vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) + txt_q, txt_k, txt_v = txt_qkv.unbind(1) + + vid_q, txt_q = self.norm_q(vid_q, txt_q) + vid_k, txt_k = self.norm_k(vid_k, txt_k) + + txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) + + vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) + txt_len = txt_len.to(window_count.device) + + # window rope + if self.rope: + if self.version_7b: + vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) + elif self.rope.mm: + # repeat text q and k for window mmrope + _, num_h, _ = txt_q.shape + txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") + txt_q_repeat = unflatten(txt_q_repeat, txt_shape) + txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] + txt_q_repeat = list(chain(*txt_q_repeat)) + txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) + txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) + + txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") + txt_k_repeat = unflatten(txt_k_repeat, txt_shape) + txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] + txt_k_repeat = list(chain(*txt_k_repeat)) + txt_k_repeat, _ = flatten(txt_k_repeat) + txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) + + vid_q, vid_k, txt_q, txt_k = self.rope( + vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win + ) + else: + vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) + + txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) + all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) + concat_win, unconcat_win = cache_win( + "mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count) + ) + out = optimized_var_attention( + q=concat_win(vid_q, txt_q), + k=concat_win(vid_k, txt_k), + v=concat_win(vid_v, txt_v), + heads=self.heads, skip_reshape=True, skip_output_reshape=True, + cu_seqlens_q=cache_win( + "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + ), + cu_seqlens_k=cache_win( + "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + ), + ) + vid_out, txt_out = unconcat_win(out) + + vid_out = rearrange(vid_out, "l h d -> l (h d)") + txt_out = rearrange(txt_out, "l h d -> l (h d)") + vid_out = window_reverse(vid_out) + + vid_out, txt_out = self.proj_out(vid_out, txt_out) + + return vid_out, txt_out + +class MLP(nn.Module): + def __init__( + self, + dim: int, + expand_ratio: int, + device, dtype, operations + ): + super().__init__() + self.proj_in = operations.Linear(dim, dim * expand_ratio, device=device, dtype=dtype) + self.act = nn.GELU("tanh") + self.proj_out = operations.Linear(dim * expand_ratio, dim, device=device, dtype=dtype) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + x = self.proj_in(x) + x = self.act(x) + x = self.proj_out(x) + return x + + +class SwiGLUMLP(nn.Module): + def __init__( + self, + dim: int, + expand_ratio: int, + multiple_of: int = 256, + device=None, dtype=None, operations=None + ): + super().__init__() + hidden_dim = int(2 * dim * expand_ratio / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.proj_in_gate = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) + self.proj_out = operations.Linear(hidden_dim, dim, bias=False, device=device, dtype=dtype) + self.proj_in = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) + +def get_mlp(mlp_type: Optional[str] = "normal"): + # 3b and 7b uses different mlp types + if mlp_type == "normal": + return MLP + elif mlp_type == "swiglu": + return SwiGLUMLP + +class NaMMSRTransformerBlock(nn.Module): + def __init__( + self, + *, + vid_dim: int, + txt_dim: int, + emb_dim: int, + heads: int, + head_dim: int, + expand_ratio: int, + norm, + norm_eps: float, + ada, + qk_bias: bool, + qk_norm, + mlp_type: str, + shared_weights: bool, + rope_type: str, + rope_dim: int, + is_last_layer: bool, + device, dtype, operations, + **kwargs, + ): + super().__init__() + version = kwargs.get("version", False) + dim = MMArg(vid_dim, txt_dim) + self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype) + + self.attn = NaSwinAttention( + vid_dim=vid_dim, + txt_dim=txt_dim, + heads=heads, + head_dim=head_dim, + qk_bias=qk_bias, + qk_norm=qk_norm, + qk_norm_eps=norm_eps, + rope_type=rope_type, + rope_dim=rope_dim, + shared_weights=shared_weights, + window=kwargs.pop("window", None), + window_method=kwargs.pop("window_method", None), + version=version, + device=device, dtype=dtype, operations=operations + ) + + self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) + self.mlp = MMModule( + get_mlp(mlp_type), + dim=dim, + expand_ratio=expand_ratio, + shared_weights=shared_weights, + vid_only=is_last_layer, + device=device, dtype=dtype, operations=operations + ) + self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) + self.is_last_layer = is_last_layer + self.version = version + + def _seedvr2_7b_mlp( + self, + vid: torch.FloatTensor, + txt: torch.FloatTensor, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_module = self.mlp.vid if not self.mlp.shared_weights else self.mlp.all + if comfy.model_management.in_training or vid.requires_grad: + vid = torch.cat([vid_module(chunk) for chunk in vid.split(SEEDVR2_7B_MLP_CHUNK, dim=0)], dim=0) + else: + vid_out = None + offset = 0 + for chunk in vid.split(SEEDVR2_7B_MLP_CHUNK, dim=0): + chunk_out = vid_module(chunk) + if vid_out is None: + vid_out = chunk_out.new_empty((vid.shape[0], *chunk_out.shape[1:])) + vid_out[offset:offset + chunk_out.shape[0]] = chunk_out + offset += chunk_out.shape[0] + vid = vid_out + if not self.mlp.vid_only: + txt_module = self.mlp.txt if not self.mlp.shared_weights else self.mlp.all + txt = txt.to(device=vid.device, dtype=vid.dtype) + txt = txt_module(txt) + return vid, txt + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + emb: torch.FloatTensor, + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.LongTensor, + torch.LongTensor, + ]: + hid_len = MMArg( + cache("vid_len", lambda: vid_shape.prod(-1)), + cache("txt_len", lambda: txt_shape.prod(-1)), + ) + ada_kwargs = { + "emb": emb, + "hid_len": hid_len, + "cache": cache, + "branch_tag": MMArg("vid", "txt"), + } + + vid_attn, txt_attn = self.attn_norm(vid, txt) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) + vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) + vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) + + vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) + if self.version: + vid_mlp, txt_mlp = self._seedvr2_7b_mlp(vid_mlp, txt_mlp) + else: + vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) + vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) + + return vid_mlp, txt_mlp, vid_shape, txt_shape + +class PatchOut(nn.Module): + def __init__( + self, + out_channels: int, + patch_size: Union[int, Tuple[int, int, int]], + dim: int, + device, dtype, operations + ): + super().__init__() + t, h, w = _triple(patch_size) + self.patch_size = t, h, w + self.proj = operations.Linear(dim, out_channels * t * h * w, device=device, dtype=dtype) + + def forward( + self, + vid: torch.Tensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + vid = self.proj(vid) + vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) + if t > 1: + vid = vid[:, :, (t - 1) :] + return vid + +class NaPatchOut(PatchOut): + def forward( + self, + vid: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, + cache: Cache = Cache(disable=True), # for test + vid_shape_before_patchify = None + ) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, + ]: + + t, h, w = self.patch_size + vid = self.proj(vid) + + if not (t == h == w == 1): + vid = unflatten(vid, vid_shape) + for i in range(len(vid)): + vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) + if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: + vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] + vid, vid_shape = flatten(vid) + + return vid, vid_shape + +class PatchIn(nn.Module): + def __init__( + self, + in_channels: int, + patch_size: Union[int, Tuple[int, int, int]], + dim: int, + device, dtype, operations + ): + super().__init__() + t, h, w = _triple(patch_size) + self.patch_size = t, h, w + self.proj = operations.Linear(in_channels * t * h * w, dim, device=device, dtype=dtype) + + def forward( + self, + vid: torch.Tensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + if t > 1: + assert vid.size(2) % t == 1 + vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) + vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) + vid = self.proj(vid) + return vid + +class NaPatchIn(PatchIn): + def forward( + self, + vid: torch.Tensor, # l c + vid_shape: torch.LongTensor, + cache: Cache = Cache(disable=True), # for test + ) -> torch.Tensor: + cache = cache.namespace("patch") + vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) + t, h, w = self.patch_size + if not (t == h == w == 1): + vid = unflatten(vid, vid_shape) + for i in range(len(vid)): + if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: + vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) + vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) + vid, vid_shape = flatten(vid) + + vid = self.proj(vid) + return vid, vid_shape + +def expand_dims(x: torch.Tensor, dim: int, ndim: int): + shape = x.shape + shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] + return x.reshape(shape) + + +class AdaSingle(nn.Module): + def __init__( + self, + dim: int, + emb_dim: int, + layers: List[str], + modes: List[str] = ["in", "out"], + device = None, dtype = None, + ): + assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" + super().__init__() + self.dim = dim + self.emb_dim = emb_dim + self.layers = layers + + randn_kwargs = {"device": device} + fp8_types = _torch_float8_types() + if dtype is not None and dtype not in fp8_types: + randn_kwargs["dtype"] = dtype + + for l in layers: + if "in" in modes: + # Passing fp8 ``dtype=`` here would break CPU weight + # loads: CPU has no ``normal_kernel_cpu`` for fp8. + self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5)) + self.register_parameter( + f"{l}_scale", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5 + 1) + ) + if "out" in modes: + self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5)) + + def forward( + self, + hid: torch.FloatTensor, # b ... c + emb: torch.FloatTensor, # b d + layer: str, + mode: str, + cache: Cache = Cache(disable=True), + branch_tag: str = "", + hid_len: Optional[torch.LongTensor] = None, # b + ) -> torch.FloatTensor: + idx = self.layers.index(layer) + emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] + emb = expand_dims(emb, 1, hid.ndim + 1) + + if hid_len is not None: + slice_inputs = lambda x, dim: x + emb = cache( + f"emb_repeat_{idx}_{branch_tag}", + lambda: slice_inputs( + torch.repeat_interleave(emb, hid_len, dim=0), + dim=0, + ), + ) + + shiftA, scaleA, gateA = emb.unbind(-1) + shiftB, scaleB, gateB = ( + getattr(self, f"{layer}_shift", None), + getattr(self, f"{layer}_scale", None), + getattr(self, f"{layer}_gate", None), + ) + + fp8_types = _torch_float8_types() + if fp8_types: + target_dtype = hid.dtype + + if shiftB is not None and shiftB.dtype in fp8_types: + shiftB = shiftB.to(target_dtype) + if scaleB is not None and scaleB.dtype in fp8_types: + scaleB = scaleB.to(target_dtype) + if gateB is not None and gateB.dtype in fp8_types: + gateB = gateB.to(target_dtype) + + if mode == "in": + return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) + if mode == "out": + if gateB is not None: + return hid.mul_(gateA + gateB) + else: + return hid.mul_(gateA) + + raise NotImplementedError + + +def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): + return emb1 if emb2 is None else emb1 + emb2 + + +class TimeEmbedding(nn.Module): + def __init__( + self, + sinusoidal_dim: int, + hidden_dim: int, + output_dim: int, + device, dtype, operations + ): + super().__init__() + self.sinusoidal_dim = sinusoidal_dim + self.proj_in = operations.Linear(sinusoidal_dim, hidden_dim, device=device, dtype=dtype) + self.proj_hid = operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype) + self.proj_out = operations.Linear(hidden_dim, output_dim, device=device, dtype=dtype) + self.act = nn.SiLU() + + def forward( + self, + timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], + device: torch.device, + dtype: torch.dtype, + ) -> torch.FloatTensor: + if not torch.is_tensor(timestep): + timestep = torch.tensor([timestep], device=device, dtype=dtype) + if timestep.ndim == 0: + timestep = timestep[None] + + emb = get_timestep_embedding( + timesteps=timestep, + embedding_dim=self.sinusoidal_dim, + flip_sin_to_cos=False, + downscale_freq_shift=0, + ).to(dtype) + emb = self.proj_in(emb) + emb = self.act(emb) + emb = self.proj_hid(emb) + emb = self.act(emb) + emb = self.proj_out(emb) + return emb + +def flatten( + hid: List[torch.FloatTensor], # List of (*** c) +) -> Tuple[ + torch.FloatTensor, # (L c) + torch.LongTensor, # (b n) +]: + assert len(hid) > 0 + shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) + hid = torch.cat([x.flatten(0, -2) for x in hid]) + return hid, shape + + +def unflatten( + hid: torch.FloatTensor, # (L c) or (L ... c) + hid_shape: torch.LongTensor, # (b n) +) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) + hid_len = hid_shape.prod(-1) + hid = hid.split(hid_len.tolist()) + hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] + return hid + +def repeat( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + pattern: str, + **kwargs: Dict[str, torch.LongTensor], # (b) +) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, +]: + hid = unflatten(hid, hid_shape) + kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] + return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) + +class NaDiT(nn.Module): + + def __init__( + self, + norm_eps, + qk_rope, + num_layers, + mlp_type, + vid_in_channels = 33, + vid_out_channels = 16, + vid_dim = 2560, + txt_in_dim = 5120, + heads = 20, + head_dim = 128, + mm_layers = 10, + expand_ratio = 4, + qk_bias = False, + patch_size = [ 1,2,2 ], + shared_qkv: bool = False, + shared_mlp: bool = False, + window_method: Optional[Tuple[str]] = None, + temporal_window_size: int = None, + temporal_shifted: bool = False, + rope_dim = 128, + rope_type = "mmrope3d", + vid_out_norm: Optional[str] = None, + device = None, + dtype = None, + operations = None, + **kwargs, + ): + self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM + if self._7b_version: + rope_type = "rope3d" + self.dtype = dtype + factory_kwargs = {"device": device, "dtype": dtype} + window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] + txt_dim = vid_dim + emb_dim = vid_dim * 6 + block_type = ["mmdit_sr"] * num_layers + window = num_layers * [(4,3,3)] + ada = AdaSingle + norm = CustomRMSNorm + qk_norm = CustomRMSNorm + if isinstance(block_type, str): + block_type = [block_type] * num_layers + elif len(block_type) != num_layers: + raise ValueError("The ``block_type`` list should equal to ``num_layers``.") + super().__init__() + # ``torch.empty`` returns uninitialized memory, not zeros. The + # SeedVR2Conditioning fail-loud guard at + # ``comfy_extras/nodes_seedvr.py`` distinguishes "buffer was loaded" + # from "buffer was never populated by the file" by checking + # ``positive_conditioning.abs().sum() == 0``. That sentinel is only + # reliable if the post-construction buffer state is deterministically + # zero, so explicitly zero-fill here rather than relying on the + # allocator's zero-on-alloc behavior (allocator-dependent and not + # contractual). When ``load_state_dict`` populates these buffers + # from a properly-baked SeedVR2 .safetensors, the in-place copy + # overwrites the zeros with the universal SeedVR2 conditioning + # tensors (shape (58, 5120) and (64, 5120) bf16). + self.register_buffer("positive_conditioning", torch.zeros((58, 5120), device=device, dtype=dtype)) + self.register_buffer("negative_conditioning", torch.zeros((64, 5120), device=device, dtype=dtype)) + self.vid_in = NaPatchIn( + in_channels=vid_in_channels, + patch_size=patch_size, + dim=vid_dim, + device=device, dtype=dtype, operations=operations + ) + self.txt_in = ( + operations.Linear(txt_in_dim, txt_dim, **factory_kwargs) + if txt_in_dim and txt_in_dim != txt_dim + else nn.Identity() + ) + self.emb_in = TimeEmbedding( + sinusoidal_dim=BYTEDANCE_SINUSOIDAL_DIM, + hidden_dim=max(vid_dim, txt_dim), + output_dim=emb_dim, + device=device, dtype=dtype, operations=operations + ) + + if window is None or isinstance(window[0], int): + window = [window] * num_layers + if window_method is None or isinstance(window_method, str): + window_method = [window_method] * num_layers + if temporal_window_size is None or isinstance(temporal_window_size, int): + temporal_window_size = [temporal_window_size] * num_layers + if temporal_shifted is None or isinstance(temporal_shifted, bool): + temporal_shifted = [temporal_shifted] * num_layers + + rope_dim = rope_dim if rope_dim is not None else head_dim // 2 + self.blocks = nn.ModuleList( + [ + NaMMSRTransformerBlock( + vid_dim=vid_dim, + txt_dim=txt_dim, + emb_dim=emb_dim, + heads=heads, + head_dim=head_dim, + expand_ratio=expand_ratio, + norm=norm, + norm_eps=norm_eps, + ada=ada, + qk_bias=qk_bias, + qk_rope=qk_rope, + qk_norm=qk_norm, + shared_qkv=shared_qkv, + shared_mlp=shared_mlp, + mlp_type=mlp_type, + rope_dim = rope_dim, + window=window[i], + window_method=window_method[i], + temporal_window_size=temporal_window_size[i], + temporal_shifted=temporal_shifted[i], + is_last_layer=(i == num_layers - 1) and not self._7b_version, + rope_type = rope_type, + shared_weights=not ( + (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] + ), + version = self._7b_version, + operations = operations, + **kwargs, + **factory_kwargs + ) + for i in range(num_layers) + ] + ) + self.vid_out = NaPatchOut( + out_channels=vid_out_channels, + patch_size=patch_size, + dim=vid_dim, + device=device, dtype=dtype, operations=operations + ) + + self.need_txt_repeat = block_type[0] in [ + "mmdit_stwin", + "mmdit_stwin_spatial", + "mmdit_stwin_3d_spatial", + ] + + self.vid_out_norm = None + if vid_out_norm is not None: + self.vid_out_norm = CustomRMSNorm( + normalized_shape=vid_dim, + eps=norm_eps, + elementwise_affine=True, + device=device, dtype=dtype + ) + self.vid_out_ada = ada( + dim=vid_dim, + emb_dim=emb_dim, + layers=["out"], + modes=["in"], + device=device, dtype=dtype + ) + + def _resolve_text_conditioning(self, context, cond_or_uncond=None): + if context is None or getattr(context, "numel", lambda: None)() == 0: + context = self.positive_conditioning + return flatten([context]) + if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): + if context.shape[0] == 1: + context = context.squeeze(0) + return flatten([context]) + return flatten(context.unbind(0)) + if context.shape[0] % 2 != 0: + raise ValueError(f"SeedVR2 expected an even text-conditioning batch, got shape {tuple(context.shape)}") + neg_cond, pos_cond = context.chunk(2, dim=0) + if pos_cond.shape[0] == 1: + pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) + return flatten([pos_cond, neg_cond]) + return flatten((*pos_cond.unbind(0), *neg_cond.unbind(0))) + + @staticmethod + def _seedvr2_is_single_conditioning_branch(cond_or_uncond): + if cond_or_uncond is None or len(cond_or_uncond) == 0: + return False + first = cond_or_uncond[0] + return all(entry == first for entry in cond_or_uncond) + + def _swap_pos_neg_halves(self, out, cond_or_uncond=None): + if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): + return out + # ``dim=0`` is explicit on both calls. The contract is "split + # the batch axis into two halves and swap them"; making the + # axis load-bearing in source guards against silent drift if a + # future refactor reorders tensor axes. + pos, neg = out.chunk(2, dim=0) + return torch.cat([neg, pos], dim=0) + + def forward( + self, + x, + timestep, + context, # l c + disable_cache: bool = False, # for test # TODO ? // gives an error when set to True + **kwargs + ): + transformer_options = kwargs.get("transformer_options", {}) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + conditions = kwargs.get("condition") + b, tc, h, w = x.shape + x = x.view(b, 16, -1, h, w) + conditions = conditions.view(b, 17, -1, h, w) + x = x.movedim(1, -1) + conditions = conditions.movedim(1, -1) + cache = Cache(disable=disable_cache) + + txt, txt_shape = self._resolve_text_conditioning(context, transformer_options.get("cond_or_uncond")) + + vid, vid_shape = flatten(x) + cond_latent, _ = flatten(conditions) + + vid = torch.cat([vid, cond_latent], dim=-1) + if txt_shape.size(-1) == 1 and self.need_txt_repeat: + txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) + + txt = self.txt_in(txt) + + vid_shape_before_patchify = vid_shape + vid, vid_shape = self.vid_in(vid, vid_shape, cache=cache) + + emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) + + for i, block in enumerate(self.blocks): + if ("block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] = block( + vid=args["vid"], + txt=args["txt"], + vid_shape=args["vid_shape"], + txt_shape=args["txt_shape"], + emb=args["emb"], + cache=args["cache"], + ) + return out + out = blocks_replace[("block", i)]({ + "vid":vid, + "txt":txt, + "vid_shape":vid_shape, + "txt_shape":txt_shape, + "emb":emb, + "cache":cache, + }, {"original_block": block_wrap}) + vid, txt, vid_shape, txt_shape = out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] + else: + vid, txt, vid_shape, txt_shape = block( + vid=vid, + txt=txt, + vid_shape=vid_shape, + txt_shape=txt_shape, + emb=emb, + cache=cache, + ) + + if self.vid_out_norm: + vid = self.vid_out_norm(vid) + vid = self.vid_out_ada( + vid, + emb=emb, + layer="out", + mode="in", + hid_len=cache("vid_len", lambda: vid_shape.prod(-1)), + cache=cache, + branch_tag="vid", + ) + + vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) + vid = unflatten(vid, vid_shape) + out = torch.stack(vid) + out = out.movedim(-1, 1) + out = rearrange(out, "b c t h w -> b (c t) h w") + return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond")) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py new file mode 100644 index 000000000000..68b11c0ff813 --- /dev/null +++ b/comfy/ldm/seedvr/vae.py @@ -0,0 +1,2110 @@ +from contextlib import nullcontext +from typing import Literal, Optional, Tuple +import gc +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor +from contextlib import contextmanager +from comfy.utils import ProgressBar + +from comfy.ldm.seedvr.model import safe_pad_operation +from comfy.ldm.seedvr.constants import ( + BYTEDANCE_BLOCK_OUT_CHANNELS, + BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD, + BYTEDANCE_GN_CHUNKS_FP16, + BYTEDANCE_GN_CHUNKS_FP32, + BYTEDANCE_LOGVAR_CLAMP_MAX, + BYTEDANCE_LOGVAR_CLAMP_MIN, + BYTEDANCE_SLICING_SAMPLE_MIN, + BYTEDANCE_VAE_CONV_MEM_GIB, + BYTEDANCE_VAE_NORM_MEM_GIB, + BYTEDANCE_VAE_SCALING_FACTOR, + BYTEDANCE_VAE_SHIFTING_FACTOR, + BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE, + BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE, + SEEDVR2_LATENT_CHANNELS, +) +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.diffusionmodules.model import vae_attention + +import math +from enum import Enum +from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND + +import logging +import comfy.model_management +import comfy.ops +ops = comfy.ops.disable_weight_init + + +def _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap, temporal_scale=1): + if temporal_size is None: + return None + + temporal_size = int(temporal_size) + if temporal_size <= 0: + return 0 + + temporal_overlap = max(0, int(temporal_overlap or 0)) + temporal_overlap = min(temporal_overlap, temporal_size - 1) + temporal_step = temporal_size - temporal_overlap + temporal_scale = max(1, int(temporal_scale)) + return max(1, math.ceil(temporal_step / temporal_scale)) + + +def _seedvr2_clamped_spatial_overlap(overlap, tile_size): + overlap = max(0, int(overlap)) + tile_size = max(1, int(tile_size)) + return min(overlap, tile_size - 1) + + +def _seedvr2_clear_temporal_memory(model): + for module in model.modules(): + if hasattr(module, "memory"): + module.memory = None + + +@torch.inference_mode() +def tiled_vae( + x, + vae_model, + tile_size=(512, 512), + tile_overlap=(64, 64), + temporal_size=16, + temporal_overlap=0, + encode=True, + **kwargs, +): + gc.collect() + comfy.model_management.soft_empty_cache() + + x = x.to(next(vae_model.parameters()).dtype) + if x.ndim != 5: + x = x.unsqueeze(2) + + _, _, d, h, w = x.shape + + sf_s = getattr(vae_model, "spatial_downsample_factor", BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE) + sf_t = getattr(vae_model, "temporal_downsample_factor", BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE) + if encode: + slicing_attr = "slicing_sample_min_size" + slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap) + else: + slicing_attr = "slicing_latent_min_size" + slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap, sf_t) + if encode: + ti_h, ti_w = tile_size + ov_h = _seedvr2_clamped_spatial_overlap(tile_overlap[0], ti_h) + ov_w = _seedvr2_clamped_spatial_overlap(tile_overlap[1], ti_w) + blend_ov_h = max(0, ov_h // sf_s) + blend_ov_w = max(0, ov_w // sf_s) + target_d = (d + sf_t - 1) // sf_t + target_h = (h + sf_s - 1) // sf_s + target_w = (w + sf_s - 1) // sf_s + else: + ti_h = max(1, tile_size[0] // sf_s) + ti_w = max(1, tile_size[1] // sf_s) + ov_h = _seedvr2_clamped_spatial_overlap(tile_overlap[0] // sf_s, ti_h) + ov_w = _seedvr2_clamped_spatial_overlap(tile_overlap[1] // sf_s, ti_w) + blend_ov_h = ov_h * sf_s + blend_ov_w = ov_w * sf_s + + target_d = max(1, d * sf_t - (sf_t - 1)) + target_h = h * sf_s + target_w = w * sf_s + + stride_h = max(1, ti_h - ov_h) + stride_w = max(1, ti_w - ov_w) + + storage_device = vae_model.device + result = None + count = None + def run_temporal_chunks(spatial_tile, model=vae_model, device=storage_device): + device = torch.device(device) + _seedvr2_clear_temporal_memory(model) + t_chunk = spatial_tile.to(device=device, dtype=next(model.parameters()).dtype, non_blocking=True).contiguous() + old_device = getattr(model, "device", None) + model.device = device + old_slicing_min_size = getattr(model, slicing_attr, None) + if old_slicing_min_size is not None and slicing_min_size is not None: + if slicing_min_size <= 0: + setattr(model, slicing_attr, t_chunk.shape[2]) + else: + setattr(model, slicing_attr, slicing_min_size) + try: + if encode: + out = model.encode(t_chunk)[0] + else: + out = model.decode_(t_chunk) + finally: + if old_slicing_min_size is not None and slicing_min_size is not None: + setattr(model, slicing_attr, old_slicing_min_size) + if old_device is not None: + model.device = old_device + if isinstance(out, (tuple, list)): + out = out[0] + if out.ndim == 4: + out = out.unsqueeze(2) + return out.to(storage_device) + + ramp_cache = {} + def get_ramp(steps): + if steps not in ramp_cache: + t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32) + ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi) + return ramp_cache[steps] + + tile_ranges = [] + for y_idx in range(0, h, stride_h): + y_end = min(y_idx + ti_h, h) + if y_idx > 0 and (y_end - y_idx) <= ov_h: + continue + for x_idx in range(0, w, stride_w): + x_end = min(x_idx + ti_w, w) + if x_idx > 0 and (x_end - x_idx) <= ov_w: + continue + tile_ranges.append((y_idx, y_end, x_idx, x_end)) + + total_tiles = len(tile_ranges) + bar = ProgressBar(total_tiles) + single_spatial_tile = h <= ti_h and w <= ti_w + + _seedvr2_clear_temporal_memory(vae_model) + + def run_tile(tile_index, tile_range): + y_idx, y_end, x_idx, x_end = tile_range + tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end] + tile_out = run_temporal_chunks(tile_x) + return tile_index, y_idx, y_end, x_idx, x_end, tile_out + + ordered_tile_outputs = ( + run_tile(tile_index, tile_range) + for tile_index, tile_range in enumerate(tile_ranges) + ) + + for _, y_idx, y_end, x_idx, x_end, tile_out in ordered_tile_outputs: + + if single_spatial_tile: + result = tile_out[:, :, :target_d, :target_h, :target_w] + if result.device != x.device: + result = result.to(x.device).to(x.dtype) + if x.shape[2] == 1 and sf_t == 1: + result = result.squeeze(2) + bar.update(1) + return result + + if result is None: + b_out, c_out = tile_out.shape[0], tile_out.shape[1] + result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) + count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32) + + if encode: + ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] + xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] + cur_ov_h = max(0, min(blend_ov_h, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(blend_ov_w, tile_out.shape[4] // 2)) + else: + ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3] + xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4] + cur_ov_h = max(0, min(blend_ov_h, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(blend_ov_w, tile_out.shape[4] // 2)) + + w_h = torch.ones((tile_out.shape[3],), device=storage_device) + w_w = torch.ones((tile_out.shape[4],), device=storage_device) + + if cur_ov_h > 0: + r = get_ramp(cur_ov_h) + if y_idx > 0: + w_h[:cur_ov_h] = r + if y_end < h: + w_h[-cur_ov_h:] = 1.0 - r + + if cur_ov_w > 0: + r = get_ramp(cur_ov_w) + if x_idx > 0: + w_w[:cur_ov_w] = r + if x_end < w: + w_w[-cur_ov_w:] = 1.0 - r + + final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) + + valid_d = min(tile_out.shape[2], result.shape[2]) + tile_out = tile_out[:, :, :valid_d, :, :] + + tile_out.mul_(final_weight) + + result[:, :, :valid_d, ys:ye, xs:xe] += tile_out + count[:, :, :, ys:ye, xs:xe] += final_weight + + del tile_out, final_weight, w_h, w_w + bar.update(1) + + result.div_(count.clamp(min=1e-6)) + _seedvr2_clear_temporal_memory(vae_model) + + if result.device != x.device: + result = result.to(x.device).to(x.dtype) + + if x.shape[2] == 1 and sf_t == 1: + result = result.squeeze(2) + + return result + +_NORM_LIMIT = float("inf") +def get_norm_limit(): + return _NORM_LIMIT + + +def set_norm_limit(value: Optional[float] = None): + global _NORM_LIMIT + if value is None: + value = float("inf") + _NORM_LIMIT = value + +@contextmanager +def ignore_padding(model): + orig_padding = model.padding + model.padding = (0, 0, 0) + try: + yield + finally: + model.padding = orig_padding + +class MemoryState(Enum): + DISABLED = 0 + INITIALIZING = 1 + ACTIVE = 2 + UNSET = 3 + +def get_cache_size(conv_module, input_len, pad_len, dim=0): + dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 + output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 + remain_len = ( + input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) + ) + overlap_len = dilated_kernerl_size - conv_module.stride[dim] + cache_len = overlap_len + remain_len # >= 0 + + assert output_len > 0 + return cache_len + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, BYTEDANCE_LOGVAR_CLAMP_MIN, BYTEDANCE_LOGVAR_CLAMP_MAX) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: + sample = torch.randn( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def mode(self): + return self.mean + +class SpatialNorm(nn.Module): + def __init__( + self, + f_channels: int, + zq_channels: int, + ): + super().__init__() + self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv_y = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + +# partial implementation of diffusers's Attention for comfyui +class Attention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + out_dim: int = None, + out_context_dim: int = None, + context_pre_only=None, + pre_only=False, + is_causal: bool = False, + ): + super().__init__() + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.is_causal = is_causal + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if norm_num_groups is not None: + self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + self.norm_q = None + self.norm_k = None + + self.norm_cross = None + self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = ops.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = ops.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None + + self.norm_added_q = None + self.norm_added_k = None + self.optimized_vae_attention = vae_attention() + + def __call__( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + + residual = hidden_states + if self.spatial_norm is not None: + hidden_states = self.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1]) + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + if input_ndim == 4 and encoder_hidden_states is hidden_states and attention_mask is None and self.heads == 1: + query = query.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) + key = key.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) + value = value.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) + hidden_states = self.optimized_vae_attention(query, key, value).reshape(batch_size, self.heads, head_dim, height * width).transpose(2, 3) + else: + hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if self.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / self.rescale_output_factor + + return hidden_states + + +def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor): + with torch.no_grad(): + depth = weight_3d.size(2) + weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) + return weight_3d + +def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor): + with torch.no_grad(): + bias_3d.copy_(bias_2d) + return bias_3d + + +def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): + weight_name = prefix + "weight" + bias_name = prefix + "bias" + if weight_name in state_dict: + weight_2d = state_dict[weight_name] + if weight_2d.dim() == 4: + weight_3d = inflate_weight_fn( + weight_2d=weight_2d, + weight_3d=layer.weight, + ) + state_dict[weight_name] = weight_3d + else: + return state_dict + if bias_name in state_dict: + bias_2d = state_dict[bias_name] + if bias_2d.dim() == 1: + bias_3d = inflate_bias_fn( + bias_2d=bias_2d, + bias_3d=layer.bias, + ) + state_dict[bias_name] = bias_3d + return state_dict + +def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: + input_dtype = x.dtype + if isinstance(norm_layer, (ops.LayerNorm, ops.RMSNorm)): + if x.ndim == 4: + x = rearrange(x, "b c h w -> b h w c") + x = norm_layer(x) + x = rearrange(x, "b h w c -> b c h w") + return x.to(input_dtype) + if x.ndim == 5: + x = rearrange(x, "b c t h w -> b t h w c") + x = norm_layer(x) + x = rearrange(x, "b t h w c -> b c t h w") + return x.to(input_dtype) + if isinstance(norm_layer, (ops.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): + if x.ndim <= 4: + return norm_layer(x).to(input_dtype) + if x.ndim == 5: + t = x.size(2) + x = rearrange(x, "b c t h w -> (b t) c h w") + memory_occupy = x.numel() * x.element_size() / 1024**3 + if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit(): + num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups) + assert norm_layer.num_groups % num_chunks == 0 + num_groups_per_chunk = norm_layer.num_groups // num_chunks + + x = list(x.chunk(num_chunks, dim=1)) + weights = norm_layer.weight.chunk(num_chunks, dim=0) + biases = norm_layer.bias.chunk(num_chunks, dim=0) + for i, (w, b) in enumerate(zip(weights, biases)): + x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps) + x[i] = x[i].to(input_dtype) + x = torch.cat(x, dim=1) + else: + x = norm_layer(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x.to(input_dtype) + raise NotImplementedError + +def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): + problematic_modes = ['bilinear', 'bicubic', 'trilinear'] + + if mode in problematic_modes: + try: + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ) + except RuntimeError as e: + if ("not implemented for 'Half'" in str(e) or + "compute_indices_weights" in str(e)): + original_dtype = x.dtype + return F.interpolate( + x.float(), + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ).to(original_dtype) + else: + raise e + else: + # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ) + +_receptive_field_t = Literal["half", "full"] + +def extend_head(tensor, times: int = 2, memory = None): + if memory is not None: + return torch.cat((memory.to(tensor), tensor), dim=2) + assert times >= 0, "Invalid input for function 'extend_head'!" + if times == 0: + return tensor + else: + tile_repeat = [1] * tensor.ndim + tile_repeat[2] = times + return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2) + +def cache_send_recv(tensor, cache_size, times, memory=None): + recv_buffer = None + + if memory is not None: + recv_buffer = memory.to(tensor[0]) + elif times > 0: + tile_repeat = [1] * tensor[0].ndim + tile_repeat[2] = times + recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat) + + return recv_buffer + +class InflatedCausalConv3d(ops.Conv3d): + def __init__( + self, + *args, + inflation_mode, + memory_device = "same", + **kwargs, + ): + self.inflation_mode = inflation_mode + self.memory = None + super().__init__(*args, **kwargs) + self.temporal_padding = self.padding[0] + self.memory_device = memory_device + self.padding = (0, *self.padding[1:]) + self.memory_limit = float("inf") + self.logged_once = False + + def set_memory_limit(self, value: float): + self.memory_limit = value + + def set_memory_device(self, memory_device): + self.memory_device = memory_device + + def _conv_forward(self, input, weight, bias, *args, **kwargs): + if (NVIDIA_MEMORY_CONV_BUG_WORKAROUND and + weight.dtype in (torch.float16, torch.bfloat16) and + hasattr(torch.backends.cudnn, 'is_available') and + torch.backends.cudnn.is_available() and + getattr(torch.backends.cudnn, 'enabled', True)): + try: + out = torch.cudnn_convolution( + input, weight, self.padding, self.stride, self.dilation, self.groups, + benchmark=False, deterministic=False, allow_tf32=True + ) + if bias is not None: + out += bias.reshape((1, -1) + (1,) * (out.ndim - 2)) + return out + except RuntimeError: + pass + except NotImplementedError: + pass + try: + return super()._conv_forward(input, weight, bias, *args, **kwargs) + except NotImplementedError: + # for: Could not run 'aten::cudnn_convolution' with arguments from the 'CPU' backend + if not self.logged_once: + logging.warning("VAE is on CPU for decoding. This is most likely due to not enough memory") + self.logged_once = True + return F.conv3d(input, weight, bias, *args, **kwargs) + + def memory_limit_conv( + self, + x, + *, + split_dim=3, + padding=(0, 0, 0, 0, 0, 0), + prev_cache=None, + ): + # Compatible with no limit. + if math.isinf(self.memory_limit): + if prev_cache is not None: + x = torch.cat([prev_cache, x], dim=split_dim - 1) + return super().forward(x) + + # Compute tensor shape after concat & padding. + shape = torch.tensor(x.size()) + if prev_cache is not None: + shape[split_dim - 1] += prev_cache.size(split_dim - 1) + shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0) + memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB + if memory_occupy < self.memory_limit or split_dim == x.ndim: + x_concat = x + if prev_cache is not None: + x_concat = torch.cat([prev_cache, x], dim=split_dim - 1) + + def pad_and_forward(): + padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0) + if not padded.is_contiguous(): + padded = padded.contiguous() + with ignore_padding(self): + return torch.nn.Conv3d.forward(self, padded) + + return pad_and_forward() + + num_splits = math.ceil(memory_occupy / self.memory_limit) + size_per_split = x.size(split_dim) // num_splits + split_sizes = [size_per_split] * (num_splits - 1) + split_sizes += [x.size(split_dim) - sum(split_sizes)] + + x = list(x.split(split_sizes, dim=split_dim)) + if prev_cache is not None: + prev_cache = list(prev_cache.split(split_sizes, dim=split_dim)) + cache = None + for idx in range(len(x)): + if prev_cache is not None: + x[idx] = torch.cat([prev_cache[idx], x[idx]], dim=split_dim - 1) + + lpad_dim = (x[idx].ndim - split_dim - 1) * 2 + rpad_dim = lpad_dim + 1 + padding = list(padding) + padding[lpad_dim] = self.padding[split_dim - 2] if idx == 0 else 0 + padding[rpad_dim] = self.padding[split_dim - 2] if idx == len(x) - 1 else 0 + pad_len = padding[lpad_dim] + padding[rpad_dim] + padding = tuple(padding) + + next_cache = None + cache_len = cache.size(split_dim) if cache is not None else 0 + next_catch_size = get_cache_size( + conv_module=self, + input_len=x[idx].size(split_dim) + cache_len, + pad_len=pad_len, + dim=split_dim - 2, + ) + if next_catch_size != 0: + assert next_catch_size <= x[idx].size(split_dim) + next_cache = ( + x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim) + ) + + x[idx] = self.memory_limit_conv( + x[idx], + split_dim=split_dim + 1, + padding=padding, + prev_cache=cache + ) + + cache = next_cache + + output = torch.cat(x, dim=split_dim) + return output + + def forward( + self, + input, + memory_state: MemoryState = MemoryState.UNSET + ) -> Tensor: + assert memory_state != MemoryState.UNSET + if memory_state != MemoryState.ACTIVE: + self.memory = None + if ( + math.isinf(self.memory_limit) + and torch.is_tensor(input) + ): + return self.basic_forward(input, memory_state) + return self.slicing_forward(input, memory_state) + + def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET): + mem_size = self.stride[0] - self.kernel_size[0] + if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): + input = extend_head(input, memory=self.memory, times=-1) + else: + input = extend_head(input, times=self.temporal_padding * 2) + memory = ( + input[:, :, mem_size:].detach() + if (mem_size != 0 and memory_state != MemoryState.DISABLED) + else None + ) + if ( + memory_state != MemoryState.DISABLED + and not self.training + and (self.memory_device is not None) + ): + self.memory = memory + if self.memory_device == "cpu" and self.memory is not None: + self.memory = self.memory.to("cpu") + return super().forward(input) + + def slicing_forward( + self, + input, + memory_state: MemoryState = MemoryState.UNSET, + ) -> Tensor: + squeeze_out = False + if torch.is_tensor(input): + input = [input] + squeeze_out = True + + cache_size = self.kernel_size[0] - self.stride[0] + cache = cache_send_recv( + input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2 + ) + + # Single GPU inference - simplified memory management + if ( + memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing + and not self.training + and (self.memory_device is not None) + and cache_size != 0 + ): + if cache_size > input[-1].size(2) and cache is not None and len(input) == 1: + input[0] = torch.cat([cache, input[0]], dim=2) + cache = None + if cache_size <= input[-1].size(2): + self.memory = input[-1][:, :, -cache_size:].detach().contiguous() + if self.memory_device == "cpu" and self.memory is not None: + self.memory = self.memory.to("cpu") + + padding = tuple(x for x in reversed(self.padding) for _ in range(2)) + for i in range(len(input)): + # Prepare cache for next input slice. + next_cache = None + cache_size = 0 + if i < len(input) - 1: + cache_len = cache.size(2) if cache is not None else 0 + cache_size = get_cache_size(self, input[i].size(2) + cache_len, pad_len=0) + if cache_size != 0: + if cache_size > input[i].size(2) and cache is not None: + input[i] = torch.cat([cache, input[i]], dim=2) + cache = None + assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}" + next_cache = input[i][:, :, -cache_size:] + + # Conv forward for this input slice. + input[i] = self.memory_limit_conv( + input[i], + padding=padding, + prev_cache=cache + ) + + # Update cache. + cache = next_cache + + return input[0] if squeeze_out else input + +def remove_head(tensor: Tensor, times: int = 1) -> Tensor: + if times == 0: + return tensor + return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) + +class Upsample3D(nn.Module): + + def __init__( + self, + channels, + out_channels = None, + inflation_mode = "tail", + temporal_up: bool = False, + spatial_up: bool = True, + slicing: bool = False, + interpolate = True, + name: str = "conv", + use_conv_transpose = False, + use_conv: bool = False, + padding = 1, + bias = True, + kernel_size = None, + **kwargs, + ): + super().__init__() + self.interpolate = interpolate + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv_transpose = use_conv_transpose + self.use_conv = use_conv + self.name = name + + self.conv = None + if use_conv_transpose: + if kernel_size is None: + kernel_size = 4 + self.conv = ops.ConvTranspose2d( + channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias + ) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + self.conv = ops.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + + conv = self.conv if self.name == "conv" else self.Conv2d_0 + + # Note: lora_layer is not passed into constructor in the original implementation. + # So we make a simplification. + conv = InflatedCausalConv3d( + self.channels, + self.out_channels, + 3, + padding=1, + inflation_mode=inflation_mode, + ) + + self.temporal_up = temporal_up + self.spatial_up = spatial_up + self.temporal_ratio = 2 if temporal_up else 1 + self.spatial_ratio = 2 if spatial_up else 1 + self.slicing = slicing + + assert not self.interpolate + # [Override] MAGViT v2 implementation + if not self.interpolate: + upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio + self.upscale_conv = ops.Conv3d( + self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 + ) + identity = ( + torch.eye(self.channels) + .repeat(upscale_ratio, 1) + .reshape_as(self.upscale_conv.weight) + ) + self.upscale_conv.weight.data.copy_(identity) + + if self.name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + self.norm = None + + def forward( + self, + hidden_states: torch.FloatTensor, + memory_state=None, + **kwargs, + ) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if hasattr(self, "norm") and self.norm is not None: + # [Overridden] change to causal norm. + hidden_states = causal_norm_wrapper(self.norm, hidden_states) + + if self.use_conv_transpose: + return self.conv(hidden_states) + + if self.slicing: + split_size = hidden_states.size(2) // 2 + hidden_states = list( + hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) + ) + else: + hidden_states = [hidden_states] + + for i in range(len(hidden_states)): + hidden_states[i] = self.upscale_conv(hidden_states[i]) + hidden_states[i] = rearrange( + hidden_states[i], + "b (x y z c) f h w -> b c (f z) (h x) (w y)", + x=self.spatial_ratio, + y=self.spatial_ratio, + z=self.temporal_ratio, + ) + + if self.temporal_up and memory_state != MemoryState.ACTIVE: + hidden_states[0] = remove_head(hidden_states[0]) + + if not self.slicing: + hidden_states = hidden_states[0] + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states, memory_state=memory_state) + else: + hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state) + + if not self.slicing: + return hidden_states + else: + return torch.cat(hidden_states, dim=2) + + +class Downsample3D(nn.Module): + """A 3D downsampling layer with an optional convolution.""" + + def __init__( + self, + channels, + out_channels = None, + inflation_mode = "tail", + spatial_down: bool = False, + temporal_down: bool = False, + name: str = "conv", + kernel_size=3, + use_conv: bool = False, + padding = 1, + bias=True, + **kwargs, + ): + super().__init__() + self.padding = padding + self.name = name + self.channels = channels + self.out_channels = out_channels or channels + self.temporal_down = temporal_down + self.spatial_down = spatial_down + self.use_conv = use_conv + self.padding = padding + + self.temporal_ratio = 2 if temporal_down else 1 + self.spatial_ratio = 2 if spatial_down else 1 + + self.temporal_kernel = 3 if temporal_down else 1 + self.spatial_kernel = 3 if spatial_down else 1 + + if use_conv: + conv = InflatedCausalConv3d( + self.channels, + self.out_channels, + kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), + stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + padding=( + 1 if self.temporal_down else 0, + self.padding if self.spatial_down else 0, + self.padding if self.spatial_down else 0, + ), + inflation_mode=inflation_mode, + ) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool3d( + kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + ) + + self.conv = conv + + + def forward( + self, + hidden_states: torch.FloatTensor, + memory_state = None, + **kwargs, + ) -> torch.FloatTensor: + + assert hidden_states.shape[1] == self.channels + + if hasattr(self, "norm") and self.norm is not None: + # [Overridden] change to causal norm. + hidden_states = causal_norm_wrapper(self.norm, hidden_states) + + if self.use_conv and self.padding == 0 and self.spatial_down: + pad = (0, 1, 0, 1) + hidden_states = safe_pad_operation(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + + hidden_states = self.conv(hidden_states, memory_state=memory_state) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + eps: float = 1e-6, + non_linearity: str = "swish", + time_embedding_norm: str = "default", + output_scale_factor: float = 1.0, + skip_time_act: bool = False, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + slicing: bool = False, + **kwargs, + ): + super().__init__() + self.up = up + self.down = down + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.use_in_shortcut = use_in_shortcut + self.output_scale_factor = output_scale_factor + self.skip_time_act = skip_time_act + self.nonlinearity = nn.SiLU() + if temb_channels is not None: + self.time_emb_proj = ops.Linear(temb_channels, out_channels) + else: + self.time_emb_proj = None + self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + if groups_out is None: + groups_out = groups + self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.use_in_shortcut = self.in_channels != out_channels + self.dropout = torch.nn.Dropout(dropout) + self.conv1 = InflatedCausalConv3d( + self.in_channels, + self.out_channels, + kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), + stride=1, + padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), + inflation_mode=inflation_mode, + ) + + self.conv2 = InflatedCausalConv3d( + self.out_channels, + conv_2d_out_channels, + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = Upsample3D( + self.in_channels, + use_conv=False, + inflation_mode=inflation_mode, + slicing=slicing, + ) + elif self.down: + self.downsample = Downsample3D( + self.in_channels, + use_conv=False, + padding=1, + name="op", + inflation_mode=inflation_mode, + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedCausalConv3d( + self.in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=True, + inflation_mode=inflation_mode, + ) + + def forward( + self, input_tensor, temb, memory_state = None, **kwargs + ): + hidden_states = input_tensor + + hidden_states = causal_norm_wrapper(self.norm1, hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + if hidden_states.shape[0] >= BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor, memory_state=memory_state) + hidden_states = self.upsample(hidden_states, memory_state=memory_state) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor, memory_state=memory_state) + hidden_states = self.downsample(hidden_states, memory_state=memory_state) + + hidden_states = self.conv1(hidden_states, memory_state=memory_state) + + if self.time_emb_proj is not None: + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + + if temb is not None: + hidden_states = hidden_states + temb + + hidden_states = causal_norm_wrapper(self.norm2, hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, memory_state=memory_state) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_down: bool = True, + spatial_down: bool = True, + ): + super().__init__() + resnets = [] + temporal_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + # [Override] Replace module. + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + temporal_modules.append(nn.Identity()) + + self.resnets = nn.ModuleList(resnets) + self.temporal_modules = nn.ModuleList(temporal_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + temporal_down=temporal_down, + spatial_down=spatial_down, + inflation_mode=inflation_mode, + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + memory_state = None, + **kwargs, + ) -> torch.FloatTensor: + for resnet, temporal in zip(self.resnets, self.temporal_modules): + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) + hidden_states = temporal(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, memory_state=memory_state) + + return hidden_states + + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_up: bool = True, + spatial_up: bool = True, + slicing: bool = False, + ): + super().__init__() + resnets = [] + temporal_modules = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + # [Override] Replace module. + ResnetBlock3D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + slicing=slicing, + ) + ) + + temporal_modules.append(nn.Identity()) + + self.resnets = nn.ModuleList(resnets) + self.temporal_modules = nn.ModuleList(temporal_modules) + + if add_upsample: + # [Override] Replace module & use learnable upsample + self.upsamplers = nn.ModuleList( + [ + Upsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + temporal_up=temporal_up, + spatial_up=spatial_up, + interpolate=False, + inflation_mode=inflation_mode, + slicing=slicing, + ) + ] + ) + else: + self.upsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + memory_state=None + ) -> torch.FloatTensor: + for resnet, temporal in zip(self.resnets, self.temporal_modules): + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) + hidden_states = temporal(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, memory_state=memory_state) + + return hidden_states + + +class UNetMidBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + # there is always at least one resnet + resnets = [ + # [Override] Replace module. + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ] + attentions = [] + + if attention_head_dim is None: + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=( + resnet_groups if resnet_time_scale_shift == "default" else None + ), + spatial_norm_dim=( + temb_channels if resnet_time_scale_shift == "spatial" else None + ), + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, memory_state=None): + video_length, frame_height, frame_width = hidden_states.size()[-3:] + hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = attn(hidden_states, temb=temb) + hidden_states = rearrange( + hidden_states, "(b f) c h w -> b c f h w", f=video_length + ) + hidden_states = resnet(hidden_states, temb, memory_state=memory_state) + + return hidden_states + + +class Encoder3D(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + # [Override] add extra_cond_dim, temporal down num + temporal_down_num: int = 2, + extra_cond_dim: int = None, + gradient_checkpoint: bool = False, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + ): + super().__init__() + self.layers_per_block = layers_per_block + self.temporal_down_num = temporal_down_num + + self.conv_in = InflatedCausalConv3d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + self.extra_cond_dim = extra_cond_dim + + self.conv_extra_cond = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + # [Override] to support temporal down block design + is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 + # Note: take the last ones + + assert down_block_type == "DownEncoderBlock3D" + + down_block = DownEncoderBlock3D( + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + # Note: Don't know why set it as 0 + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + temporal_down=is_temporal_down_block, + spatial_down=True, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + self.down_blocks.append(down_block) + + def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + self.conv_extra_cond.append( + zero_module( + ops.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) + ) + if self.extra_cond_dim is not None and self.extra_cond_dim > 0 + else None + ) + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # out + self.conv_norm_out = ops.GroupNorm( + num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = InflatedCausalConv3d( + block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode + ) + + self.gradient_checkpointing = gradient_checkpoint + + def forward( + self, + sample: torch.FloatTensor, + extra_cond=None, + memory_state = None + ) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + sample = sample.to(next(self.parameters()).device) + sample = self.conv_in(sample, memory_state = memory_state) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # down + # [Override] add extra block and extra cond + for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, use_reentrant=False + ) + if extra_block is not None: + sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) + + # middle + sample = self.mid_block(sample) + + else: + # down + # [Override] add extra block and extra cond + for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): + sample = down_block(sample, memory_state=memory_state) + if extra_block is not None: + sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) + + # middle + sample = self.mid_block(sample, memory_state=memory_state) + + # post-process + sample = causal_norm_wrapper(self.conv_norm_out, sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample, memory_state = memory_state) + + return sample + + +class Decoder3D(nn.Module): + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + # [Override] add temporal up block + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_up_num: int = 2, + slicing_up_num: int = 0, + gradient_checkpoint: bool = False, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.temporal_up_num = temporal_up_num + + self.conv_in = InflatedCausalConv3d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + is_temporal_up_block = i < self.temporal_up_num + is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num + # Note: Keep symmetric + + assert up_block_type == "UpDecoderBlock3D" + up_block = UpDecoderBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=norm_type, + temb_channels=temb_channels, + temporal_up=is_temporal_up_block, + slicing=is_slicing_up_block, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = ops.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + self.conv_out = InflatedCausalConv3d( + block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode + ) + + self.gradient_checkpointing = gradient_checkpoint + + # Note: Just copy from Decoder. + def forward( + self, + sample: torch.FloatTensor, + latent_embeds: Optional[torch.FloatTensor] = None, + memory_state = None, + ) -> torch.FloatTensor: + + sample = sample.to(next(self.parameters()).device) + sample = self.conv_in(sample, memory_state=memory_state) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + # middle + sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds, memory_state=memory_state) + + # post-process + sample = causal_norm_wrapper(self.conv_norm_out, sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample, memory_state=memory_state) + + return sample + +class VideoAutoencoderKL(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + layers_per_block: int = 2, + act_fn: str = "silu", + latent_channels: int = SEEDVR2_LATENT_CHANNELS, + norm_num_groups: int = 32, + attention: bool = True, + temporal_scale_num: int = 2, + slicing_up_num: int = 0, + gradient_checkpoint: bool = False, + inflation_mode = "pad", + time_receptive_field: _receptive_field_t = "full", + use_quant_conv: bool = False, + use_post_quant_conv: bool = False, + slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN, + *args, + **kwargs, + ): + self.slicing_sample_min_size = slicing_sample_min_size + self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) + extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None + block_out_channels = BYTEDANCE_BLOCK_OUT_CHANNELS + down_block_types = ("DownEncoderBlock3D",) * 4 + up_block_types = ("UpDecoderBlock3D",) * 4 + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + extra_cond_dim=extra_cond_dim, + # [Override] add temporal_down_num parameter + temporal_down_num=temporal_scale_num, + gradient_checkpoint=gradient_checkpoint, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # pass init params to Decoder + self.decoder = Decoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + # [Override] add temporal_up_num parameter + temporal_up_num=temporal_scale_num, + slicing_up_num=slicing_up_num, + gradient_checkpoint=gradient_checkpoint, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + self.quant_conv = ( + InflatedCausalConv3d( + in_channels=2 * latent_channels, + out_channels=2 * latent_channels, + kernel_size=1, + inflation_mode=inflation_mode, + ) + if use_quant_conv + else None + ) + self.post_quant_conv = ( + InflatedCausalConv3d( + in_channels=latent_channels, + out_channels=latent_channels, + kernel_size=1, + inflation_mode=inflation_mode, + ) + if use_post_quant_conv + else None + ) + + # A hacky way to remove attention. + if not attention: + self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) + self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) + + self.use_slicing = True + + def encode(self, x: torch.FloatTensor, return_dict: bool = True): + h = self.slicing_encode(x) + posterior = DiagonalGaussianDistribution(h).mode() + + if not return_dict: + return (posterior,) + + return posterior + + def decode_( + self, z: torch.Tensor, return_dict: bool = True + ): + decoded = self.slicing_decode(z) + + if not return_dict: + return (decoded,) + + return decoded + + def _encode( + self, x, memory_state = MemoryState.DISABLED + ) -> torch.Tensor: + _x = x.to(self.device) + h = self.encoder(_x, memory_state=memory_state) + if self.quant_conv is not None: + output = self.quant_conv(h, memory_state=memory_state) + else: + output = h + return output.to(x.device) + + def _decode( + self, z, memory_state = MemoryState.DISABLED + ) -> torch.Tensor: + _z = z.to(self.device) + + if self.post_quant_conv is not None: + _z = self.post_quant_conv(_z, memory_state=memory_state) + + output = self.decoder(_z, memory_state=memory_state) + return output.to(z.device) + + def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: + sp_size =1 + if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: + split_size = max( + self.slicing_sample_min_size * sp_size, + getattr(self, "temporal_downsample_factor", 1), + ) + x_slices = list(x[:, :, 1:].split(split_size=split_size, dim=2)) + min_active_len = getattr(self, "temporal_downsample_factor", 1) + if len(x_slices) > 1 and x_slices[-1].shape[2] < min_active_len: + x_slices[-2] = torch.cat((x_slices[-2], x_slices[-1]), dim=2) + x_slices.pop() + encoded_slices = [ + self._encode( + torch.cat((x[:, :, :1], x_slices[0]), dim=2), + memory_state=MemoryState.INITIALIZING, + ) + ] + for x_idx in range(1, len(x_slices)): + encoded_slices.append( + self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) + ) + out = torch.cat(encoded_slices, dim=2) + modules_with_memory = [m for m in self.modules() + if isinstance(m, InflatedCausalConv3d) and m.memory is not None] + for m in modules_with_memory: + m.memory = None + return out + else: + return self._encode(x) + + def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: + sp_size = 1 + if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: + z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) + decoded_slices = [ + self._decode( + torch.cat((z[:, :, :1], z_slices[0]), dim=2), + memory_state=MemoryState.INITIALIZING + ) + ] + for z_idx in range(1, len(z_slices)): + decoded_slices.append( + self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) + ) + out = torch.cat(decoded_slices, dim=2) + modules_with_memory = [m for m in self.modules() + if isinstance(m, InflatedCausalConv3d) and m.memory is not None] + for m in modules_with_memory: + m.memory = None + return out + else: + return self._decode(z) + + def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + raise NotImplementedError + + def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + raise NotImplementedError + + def forward( + self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs + ): + # x: [b c t h w] + def _unwrap(value): + return value[0] if isinstance(value, tuple) else value + + if mode == "encode": + return _unwrap(self.encode(x)) + elif mode == "decode": + return _unwrap(self.decode_(x)) + else: + latent = _unwrap(self.encode(x)) + return _unwrap(self.decode_(latent)) + +class VideoAutoencoderKLWrapper(VideoAutoencoderKL): + def __init__( + self, + *args, + spatial_downsample_factor = 8, + temporal_downsample_factor = 4, + freeze_encoder = True, + **kwargs, + ): + self.spatial_downsample_factor = spatial_downsample_factor + self.temporal_downsample_factor = temporal_downsample_factor + self.freeze_encoder = freeze_encoder + self.enable_tiling = False + super().__init__(*args, **kwargs) + self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB) + + def forward(self, x: torch.FloatTensor): + with torch.no_grad() if self.freeze_encoder else nullcontext(): + z, p = self.encode(x) + x = self.decode(z) + return x, z, p + + def encode(self, x, orig_dims=None): + if x.ndim == 4: + x = x.unsqueeze(2) + x = x.to(dtype=next(self.parameters()).dtype) + self.device = x.device + p = super().encode(x) + z = p.squeeze(2) + return z, p + + def decode(self, z, seedvr2_tiling=None): + seedvr2_tiling = {} if seedvr2_tiling is None else seedvr2_tiling + if not isinstance(seedvr2_tiling, dict): + raise RuntimeError( + "SeedVR2 VideoAutoencoderKLWrapper.decode: `seedvr2_tiling` must be a dict; " + f"got {type(seedvr2_tiling).__name__} with value {seedvr2_tiling!r}." + ) + + if z.ndim == 5: + b, c, t_latent, h, w = z.shape + if c != 16: + raise RuntimeError( + "SeedVR2 VideoAutoencoderKLWrapper.decode: 5-D latent input must " + f"have 16 channels; got shape {tuple(z.shape)}." + ) + latent = z + elif z.ndim == 4: + b, tc, h, w = z.shape + if tc % 16 != 0: + raise RuntimeError( + "SeedVR2 VideoAutoencoderKLWrapper.decode: 4-D latent input must " + "use collapsed channel layout (B, 16*T, H, W); " + f"got shape {tuple(z.shape)}." + ) + latent = z.reshape(b, 16, -1, h, w) + else: + raise RuntimeError( + "SeedVR2 VideoAutoencoderKLWrapper.decode: latent input must be " + "4-D collapsed (B, 16*T, H, W) or 5-D (B, 16, T, H, W); " + f"got shape {tuple(z.shape)}." + ) + scale = BYTEDANCE_VAE_SCALING_FACTOR + shift = BYTEDANCE_VAE_SHIFTING_FACTOR + latent = latent / scale + shift + + self.device = latent.device + self.enable_tiling = seedvr2_tiling.get("enable_tiling", False) + + if self.enable_tiling: + decode_seedvr2_args = dict(seedvr2_tiling) + tile_h, tile_w = decode_seedvr2_args.get("tile_size", (512, 512)) + ov_h, ov_w = decode_seedvr2_args.get("tile_overlap", (64, 64)) + decode_seedvr2_args["tile_overlap"] = ( + min(ov_h, max(0, tile_h - 8)), + min(ov_w, max(0, tile_w - 8)), + ) + x = tiled_vae(latent, self, **decode_seedvr2_args, encode=False) + if x.ndim == 4: + # tiled_vae squeezes the temporal axis when + # temporal_downsample_factor == 1 AND latent T == 1 + # (see tiled_vae line 179-180); re-add it so the post-decode + # pipeline can keep batch and time distinct on the tiled path. + x = x.unsqueeze(2) + else: + x = super().decode_(latent) + + # ensure even dims for save video + h, w = x.shape[-2:] + w2 = w - (w % 2) + h2 = h - (h % 2) + x = x[..., :h2, :w2] + + return x + + def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float], memory_device = "same"): + set_norm_limit(norm_max_mem) + for m in self.modules(): + if isinstance(m, InflatedCausalConv3d): + m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) + + for module in self.modules(): + if isinstance(module, InflatedCausalConv3d): + module.set_memory_device(memory_device) diff --git a/comfy/model_base.py b/comfy/model_base.py index 042804771890..c084e23bb19f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -54,6 +54,8 @@ import comfy.ldm.pixeldit.pid import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 +import comfy.ldm.seedvr.model + import comfy.ldm.qwen_image.model import comfy.ldm.ideogram4.model import comfy.ldm.kandinsky5.model @@ -928,6 +930,16 @@ def extra_conds(self, **kwargs): out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) return out +class SeedVR2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + condition = kwargs.get("condition", None) + if condition is not None: + out["condition"] = comfy.conds.CONDRegular(condition) + return out + class PixArt(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 74c838d13338..9555810065c4 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -598,6 +598,56 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config + if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b + dit_config = {} + dit_config["image_model"] = "seedvr2" + dit_config["vid_dim"] = 3072 + dit_config["heads"] = 24 + dit_config["num_layers"] = 36 + # 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.`` + # submodules) at EVERY block — verified by inspecting the 7B + # state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means + # ``MMModule.shared_weights=False``). Native NaDiT computes + # per-block ``shared_weights = not (i < mm_layers)``, so to keep + # every block non-shared we set ``mm_layers = num_layers``. + # Without this, blocks at index >= mm_layers (default 10) try to + # load ``blocks.N.*.all.*`` keys that don't exist in the file, + # silently miss-load → all-black output. + dit_config["mm_layers"] = 36 + dit_config["norm_eps"] = 1e-5 + dit_config["qk_rope"] = True + dit_config["rope_type"] = "rope3d" + dit_config["rope_dim"] = 64 + dit_config["mlp_type"] = "normal" + return dit_config + elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b + dit_config = {} + dit_config["image_model"] = "seedvr2" + dit_config["vid_dim"] = 3072 + dit_config["heads"] = 24 + dit_config["num_layers"] = 36 + # This checkpoint layout carries shared ``all.`` MMModule keys. + # Preserve the historical split: the initial blocks use separate + # vid/txt modules, later blocks use shared modules. + dit_config["mm_layers"] = 10 + dit_config["norm_eps"] = 1e-5 + dit_config["qk_rope"] = True + dit_config["rope_type"] = "rope3d" + dit_config["rope_dim"] = 64 + dit_config["mlp_type"] = "swiglu" + return dit_config + elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b + dit_config = {} + dit_config["image_model"] = "seedvr2" + dit_config["vid_dim"] = 2560 + dit_config["heads"] = 20 + dit_config["num_layers"] = 32 + dit_config["norm_eps"] = 1.0e-05 + dit_config["qk_rope"] = None + dit_config["mlp_type"] = "swiglu" + dit_config["vid_out_norm"] = True + return dit_config + if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} dit_config["image_model"] = "wan2.1" diff --git a/comfy/sample.py b/comfy/sample.py index 2be0cae5f872..de71596b3e1d 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -44,7 +44,13 @@ def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None, is_empty = torch.count_nonzero(latent_image) == 0 if is_empty: if latent_format.latent_channels != latent_image.shape[1]: - latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) + preserves_collapsed_channels = ( + getattr(latent_format, "preserve_empty_channel_multiples", False) + and latent_image.ndim == 4 + and latent_image.shape[1] % latent_format.latent_channels == 0 + ) + if not preserves_collapsed_channels: + latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) if downscale_ratio_spacial is not None: if downscale_ratio_spacial != latent_format.spacial_downscale_ratio: ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio diff --git a/comfy/sd.py b/comfy/sd.py index a66ba1bfb76e..8ac08ac42d86 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,3 +1,4 @@ +import inspect import json import torch from enum import Enum @@ -16,6 +17,7 @@ import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae +import comfy.ldm.seedvr.vae import comfy.ldm.triposplat.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.cogvideo.vae @@ -84,6 +86,36 @@ import comfy.ldm.flux.redux +SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL = 160 + + +def _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w): + output_t = max(1, (latent_t - 1) * 4 + 1) + return output_t * latent_h * 8 * latent_w * 8 + + +def _seedvr2_vae_decode_memory_used(shape): + if len(shape) == 5: + candidates = [] + if shape[1] == 16: + candidates.append((shape[2], shape[3], shape[4])) + if shape[-1] == 16: + candidates.append((shape[1], shape[2], shape[3])) + if len(candidates) == 0: + candidates.append((shape[2], shape[3], shape[4])) + output_pixels = max(_seedvr2_vae_decode_output_pixels(*candidate) for candidate in candidates) + elif len(shape) == 4: + latent_t = max(1, (shape[1] + 15) // 16) + latent_h, latent_w = shape[2], shape[3] + output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w) + else: + latent_t, latent_h, latent_w = 1, shape[-2], shape[-1] + output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w) + # SeedVR2 decode performs full-frame LAB histogram matching: fp32 channels + # plus int64 sort indices dominate peak memory, not the VAE weight dtype. + return output_pixels * SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL + + def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None): key_map = {} if model is not None: @@ -467,8 +499,10 @@ def decode(self, token_ids, skip_special_tokens=True): class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): - if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format - sd = diffusers_convert.convert_vae_state_dict(sd) + is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd + if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + if metadata is None or metadata.get("keep_diffusers_format") != "true": + sd = diffusers_convert.convert_vae_state_dict(sd) if model_management.is_amd(): VAE_KL_MEM_RATIO = 2.73 @@ -540,6 +574,20 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None) self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 + elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 + self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() + self.latent_channels = 16 + self.latent_dim = 3 + self.disable_offload = True + self.memory_used_decode = lambda shape, dtype: _seedvr2_vae_decode_memory_used(shape) + self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype) + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.process_input = lambda image: image * 2.0 - 1.0 + self.crop_input = False elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} @@ -667,6 +715,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None) self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] + elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] @@ -1006,6 +1055,40 @@ def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) + def decode_tiled_seedvr2(self, samples, tile_x=32, tile_y=32, overlap=8, tile_t=16, overlap_t=4): + sf_s = getattr(self.first_stage_model, "spatial_downsample_factor", 8) + sf_t = getattr(self.first_stage_model, "temporal_downsample_factor", 4) + if tile_t is None: + tile_t = 16 + if overlap_t is None: + overlap_t = 4 + if tile_t > 0: + temporal_size = tile_t * sf_t + temporal_overlap = max(0, overlap_t) * sf_t + else: + temporal_size = 0 + temporal_overlap = 0 + args = { + "enable_tiling": True, + "tile_size": (tile_y * sf_s, tile_x * sf_s), + "tile_overlap": (overlap * sf_s, overlap * sf_s), + "temporal_size": temporal_size, + "temporal_overlap": temporal_overlap, + } + output = self.first_stage_model.decode( + samples.to(self.vae_dtype).to(self.device), + seedvr2_tiling=args, + ) + return self.process_output(output.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)) + + def _format_seedvr2_encoded_samples(self, samples): + if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): + if samples.ndim == 4: + samples = samples.unsqueeze(2) + samples = samples.contiguous() + samples = samples * 0.9152 + return samples + def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) @@ -1042,6 +1125,36 @@ def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap= encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) + def encode_tiled_seedvr2(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + if tile_y is None: + tile_y = 512 + if tile_x is None: + tile_x = 512 + if overlap is None: + overlap_y = 64 + overlap_x = 64 + else: + overlap_y = overlap + overlap_x = overlap + if tile_t is None: + tile_t = 9999 + if overlap_t is None: + overlap_t = 0 + overlap_y = min(overlap_y, max(0, tile_y - 8)) + overlap_x = min(overlap_x, max(0, tile_x - 8)) + self.first_stage_model.device = self.device + x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device) + output = comfy.ldm.seedvr.vae.tiled_vae( + x, + self.first_stage_model, + tile_size=(tile_y, tile_x), + tile_overlap=(overlap_y, overlap_x), + temporal_size=tile_t, + temporal_overlap=overlap_t, + encode=True, + ) + return output.to(device=self.output_device, dtype=self.vae_output_dtype()) + def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() pixel_samples = None @@ -1089,16 +1202,40 @@ def decode(self, samples_in, vae_options={}): if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) elif dims == 2: - pixel_samples = self.decode_tiled_(samples_in) + # SeedVR2 latents arrive in 4D collapsed form ``(B, 16*T, H, W)`` + # downstream of ``SeedVR2Conditioning`` (which performs the + # ``rearrange(b c t h w -> b (c t) h w)`` collapse). The + # generic ``decode_tiled_`` would treat the channel dim as + # spatial-only and crash on the collapsed (16, T) layout + # under ``tiled_scale``'s mask broadcast; route SeedVR2 4D + # latents to ``decode_tiled_seedvr2`` instead, whose wrapper + # dispatch handles both 4D and 5D inputs. + if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): + tile = 256 // self.spacial_compression_decode() + overlap = tile // 4 + pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) + else: + pixel_samples = self.decode_tiled_(samples_in) elif dims == 3: tile = 256 // self.spacial_compression_decode() overlap = tile // 4 - pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): + pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) + else: + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples - def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + def decode_tiled( + self, + samples, + tile_x=None, + tile_y=None, + overlap=None, + tile_t=None, + overlap_t=None, + ): self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -1112,7 +1249,20 @@ def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=N args["overlap"] = overlap with model_management.cuda_device_context(self.device): - if dims == 1 or self.extra_1d_channel is not None: + if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper) and dims in (2, 3): + seedvr2_args = {} + if tile_x is not None: + seedvr2_args["tile_x"] = tile_x + if tile_y is not None: + seedvr2_args["tile_y"] = tile_y + if overlap is not None: + seedvr2_args["overlap"] = overlap + if tile_t is not None: + seedvr2_args["tile_t"] = tile_t + if overlap_t is not None: + seedvr2_args["overlap_t"] = overlap_t + output = self.decode_tiled_seedvr2(samples, **seedvr2_args) + elif dims == 1 or self.extra_1d_channel is not None: args.pop("tile_y") output = self.decode_tiled_1d(samples, **args) elif dims == 2: @@ -1154,6 +1304,8 @@ def encode(self, pixel_samples): else: pixels_in = pixels_in.to(self.device) out = self.first_stage_model.encode(pixels_in) + if isinstance(out, tuple): + out = out[0] out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) @@ -1173,20 +1325,23 @@ def encode(self, pixel_samples): if self.latent_dim == 3: tile = 256 overlap = tile // 4 - samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): + samples = self.encode_tiled_seedvr2(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap) + else: + samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) elif self.latent_dim == 1 or self.extra_1d_channel is not None: samples = self.encode_tiled_1d(pixel_samples) else: samples = self.encode_tiled_(pixel_samples) - return samples + return self._format_seedvr2_encoded_samples(samples) def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) - if dims == 3: + if dims == 3 and pixel_samples.ndim < 5: if not self.not_video: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) else: @@ -1210,22 +1365,47 @@ def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, ti elif dims == 2: samples = self.encode_tiled_(pixel_samples, **args) elif dims == 3: - if tile_t is not None: - tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) + if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): + seedvr2_args = {} + if tile_x is not None: + seedvr2_args["tile_x"] = tile_x + else: + seedvr2_args["tile_x"] = 512 + if tile_y is not None: + seedvr2_args["tile_y"] = tile_y + else: + seedvr2_args["tile_y"] = 512 + if overlap is not None: + seedvr2_args["overlap"] = overlap + else: + seedvr2_args["overlap"] = 64 + if tile_t is not None: + seedvr2_args["tile_t"] = tile_t + else: + seedvr2_args["tile_t"] = 9999 + if overlap_t is not None: + seedvr2_args["overlap_t"] = overlap_t + else: + seedvr2_args["overlap_t"] = 0 + samples = self.encode_tiled_seedvr2(pixel_samples, **seedvr2_args) else: - tile_t_latent = 9999 - args["tile_t"] = self.upscale_ratio[0](tile_t_latent) + if tile_t is not None: + tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) + else: + tile_t_latent = 9999 + args["tile_t"] = self.upscale_ratio[0](tile_t_latent) - if overlap_t is None: - args["overlap"] = (1, overlap, overlap) - else: - args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) - maximum = pixel_samples.shape[2] - maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) + spatial_overlap = overlap if overlap is not None else 64 + if overlap_t is None: + args["overlap"] = (1, spatial_overlap, spatial_overlap) + else: + args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap) + maximum = pixel_samples.shape[2] + maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) - samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) + samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) - return samples + return self._format_seedvr2_encoded_samples(samples) def get_sd(self): return self.first_stage_model.state_dict() @@ -1752,6 +1932,17 @@ class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, comfy.mo return (model, clip, vae) + +def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device): + set_dtype = model_config.set_inference_dtype + parameters = inspect.signature(set_dtype).parameters + supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values()) + if supports_device: + set_dtype(dtype, manual_cast_dtype, device=device) + else: + set_dtype(dtype, manual_cast_dtype) + + def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) @@ -1859,7 +2050,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) + _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) if model_config.clip_vision_prefix is not None: if output_clipvision: @@ -2000,7 +2191,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) + _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) if custom_operations is not None: model_config.custom_operations = custom_operations diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 7cf9c133b9cb..fa95003cc237 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1672,6 +1672,35 @@ def clip_target(self, state_dict={}): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) +class SeedVR2(supported_models_base.BASE): + unet_config = { + "image_model": "seedvr2" + } + latent_format = comfy.latent_formats.SeedVR2 + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + sampling_settings = { + "shift": 1.0, + } + + def set_inference_dtype(self, dtype, manual_cast_dtype, device=None): + if ( + dtype == torch.float16 + and manual_cast_dtype is None + and comfy.model_management.should_use_bf16(device) + ): + manual_cast_dtype = torch.bfloat16 + super().set_inference_dtype(dtype, manual_cast_dtype, device=device) + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SeedVR2(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None + class ChromaRadiance(Chroma): unet_config = { "image_model": "chroma_radiance", @@ -2029,7 +2058,6 @@ def clip_target(self, state_dict={}): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) - class RT_DETR_v4(supported_models_base.BASE): unet_config = { "image_model": "RT_DETR_v4", @@ -2267,6 +2295,7 @@ def get_model(self, state_dict, prefix="", device=None): HiDream, HiDreamO1, Chroma, + SeedVR2, ChromaRadiance, ACEStep, ACEStep15, diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 0e7a829ba13b..572f9984e9e6 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -115,7 +115,7 @@ def process_vae_state_dict_for_saving(self, state_dict): replace_prefix = {"": self.vae_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) - def set_inference_dtype(self, dtype, manual_cast_dtype): + def set_inference_dtype(self, dtype, manual_cast_dtype, device=None): self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py new file mode 100644 index 000000000000..d5cd029bacc8 --- /dev/null +++ b/comfy_extras/nodes_seedvr.py @@ -0,0 +1,1015 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +import torch +import math +import logging +from einops import rearrange + +import gc +import comfy.model_management +import comfy.sample +import comfy.samplers +from comfy.ldm.seedvr.color_fix import ( + adain_color_transfer, + lab_color_transfer, + wavelet_color_transfer, +) +from comfy.ldm.seedvr.constants import ( + BYTEDANCE_IMG_SHIFT_FIT, + BYTEDANCE_SCHEDULE_T, + BYTEDANCE_VID_SHIFT_FIT, + SEEDVR2_ADAIN_SCALE_MULTIPLIER, + SEEDVR2_COLOR_MEM_HEADROOM, + SEEDVR2_COND_CHANNELS, + SEEDVR2_DTYPE_BYTES_FLOOR, + SEEDVR2_LAB_SCALE_MULTIPLIER, + SEEDVR2_LATENT_CHANNELS, + SEEDVR2_OOM_BACKOFF_DIVISOR, + SEEDVR2_WAVELET_SCALE_MULTIPLIER, +) + +from torchvision.transforms import functional as TVF +from torchvision.transforms import Lambda +from torchvision.transforms.functional import InterpolationMode + + +_SEEDVR2_INVALID_MODEL_MSG_PREFIX = ( + "SeedVR2Conditioning: model object does not match expected SeedVR2 structure" +) + +# Private sentinel for getattr default: distinguishes "attribute missing" +# from "attribute present but None" so the failure message is accurate. +_ATTR_MISSING = object() + + +def _seedvr2_auto_chunk_attempts(t_latent, t_pixel, frames_per_chunk): + """Return stricter 4n+1 frame chunk sizes for auto OOM retries.""" + attempts = [frames_per_chunk] + current_chunk_latent = ( + t_latent if t_pixel <= frames_per_chunk + else (frames_per_chunk - 1) // 4 + 1 + ) + current_chunk_count = max(1, math.ceil(t_latent / current_chunk_latent)) + seen = {frames_per_chunk} + + for target_chunks in range(max(2, current_chunk_count + 1), t_latent + 1): + chunk_latent = max(1, math.ceil(t_latent / target_chunks)) + candidate = 4 * (chunk_latent - 1) + 1 + if candidate in seen: + continue + if candidate >= attempts[-1]: + continue + attempts.append(candidate) + seen.add(candidate) + + return attempts + + +def _resolve_seedvr2_diffusion_model(model): + """Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message.""" + inner = getattr(model, "model", _ATTR_MISSING) + if inner is _ATTR_MISSING: + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input has no 'model' attribute " + f"(got type {type(model).__name__})." + ) + if inner is None: + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input.model is None " + f"(input type {type(model).__name__})." + ) + diffusion_model = getattr(inner, "diffusion_model", _ATTR_MISSING) + if diffusion_model is _ATTR_MISSING: + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model' has no " + f"'diffusion_model' attribute (got type {type(inner).__name__})." + ) + if diffusion_model is None: + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model.diffusion_model' " + f"is None (model.model type {type(inner).__name__})." + ) + return diffusion_model + + +def _apply_rope_freqs_float32_cast(diffusion_model): + """Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype.""" + for module in diffusion_model.modules(): + if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'): + if module.rope.freqs.data.dtype != torch.float32: + module.rope.freqs.data = module.rope.freqs.data.to(torch.float32) + + +def clear_vae_memory(vae_model): + for module in vae_model.modules(): + if hasattr(module, "memory"): + module.memory = None + gc.collect() + comfy.model_management.soft_empty_cache() + +def expand_dims(tensor, ndim): + shape = tensor.shape + (1,) * (ndim - tensor.ndim) + return tensor.reshape(shape) + +def get_conditions(latent, latent_blur): + t, h, w, c = latent.shape + cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) + cond[:, ..., :-1] = latent_blur[:] + cond[:, ..., -1:] = 1.0 + return cond + +def timestep_transform(timesteps, latents_shapes): + vt = 4 + vs = 8 + frames = (latents_shapes[:, 0] - 1) * vt + 1 + heights = latents_shapes[:, 1] * vs + widths = latents_shapes[:, 2] * vs + + # Compute shift factor. + def get_lin_function(x1, y1, x2, y2): + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + img_shift_fn = get_lin_function(*BYTEDANCE_IMG_SHIFT_FIT) + vid_shift_fn = get_lin_function(*BYTEDANCE_VID_SHIFT_FIT) + shift = torch.where( + frames > 1, + vid_shift_fn(heights * widths * frames), + img_shift_fn(heights * widths), + ).to(timesteps.device) + + # Shift timesteps. + T = BYTEDANCE_SCHEDULE_T + timesteps = timesteps / T + timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) + timesteps = timesteps * T + return timesteps + +def inter(x_0, x_T, t): + t = expand_dims(t, x_0.ndim) + T = BYTEDANCE_SCHEDULE_T + B = lambda t: t / T + A = lambda t: 1 - (t / T) + return A(t) * x_0 + B(t) * x_T + +def div_pad(image, factor): + + height_factor, width_factor = factor + height, width = image.shape[-2:] + + pad_height = (height_factor - (height % height_factor)) % height_factor + pad_width = (width_factor - (width % width_factor)) % width_factor + + if pad_height == 0 and pad_width == 0: + return image + + if isinstance(image, torch.Tensor): + padding = (0, pad_width, 0, pad_height) + image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0) + + return image + +def cut_videos(videos): + t = videos.size(1) + if t == 1: + return videos + if t <= 4 : + padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + return videos + if (t - 1) % (4) == 0: + return videos + else: + padding = [videos[:, -1].unsqueeze(1)] * ( + 4 - ((t - 1) % (4)) + ) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + assert (videos.size(1) - 1) % (4) == 0 + return videos + +def _seedvr2_input_shorter_edge(images, node_name): + if images.dim() == 4: + return min(images.shape[1], images.shape[2]) + if images.dim() == 5: + return min(images.shape[2], images.shape[3]) + raise ValueError( + f"{node_name}: expected 4-D or 5-D IMAGE tensor, " + f"got shape {tuple(images.shape)}" + ) + + +def _seedvr2_pad(images, upscaled_shorter_edge, node_name): + if upscaled_shorter_edge < 2: + raise ValueError( + f"{node_name}: input shorter edge must be at least 2 pixels; " + f"got {upscaled_shorter_edge}." + ) + if images.shape[-1] > 3: + images = images[..., :3] + if images.dim() == 4: + # Comfy video components arrive as a 4-D IMAGE frame sequence: + # (frames, H, W, C). SeedVR2 consumes that as one video. + images = images.unsqueeze(0) + elif images.dim() != 5: + raise ValueError( + f"{node_name}: expected 4-D or 5-D IMAGE tensor, " + f"got shape {tuple(images.shape)}" + ) + images = images.permute(0, 1, 4, 2, 3) + + b, t, c, h, w = images.shape + images = images.reshape(b * t, c, h, w) + + clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) + images = clip(images) + images = div_pad(images, (16, 16)) + _, _, new_h, new_w = images.shape + + images = images.reshape(b, t, c, new_h, new_w) + images = cut_videos(images) + images_bthwc = rearrange(images, "b t c h w -> b t h w c") + + return io.NodeOutput(images_bthwc) + + +class SeedVR2Preprocess(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedVR2Preprocess", + display_name="Pre-Process SeedVR2 Input", + category="image/upscaling", + description="Pad a resized image for SeedVR2 model. Alpha channel is dropped. The node Post-Process SeedVR2 Output re-applies it from the original resized image.", + inputs=[ + io.Image.Input("resized_images", tooltip="The resized image to process."), + ], + outputs=[ + io.Image.Output("images"), + ] + ) + + @classmethod + def execute(cls, resized_images): + upscaled_shorter_edge = _seedvr2_input_shorter_edge(resized_images, "SeedVR2Preprocess") + return _seedvr2_pad( + resized_images, upscaled_shorter_edge, "SeedVR2Preprocess", + ) + + +class SeedVR2PostProcessing(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedVR2PostProcessing", + display_name="Post-Process SeedVR2 Output", + category="image/upscaling", + description="Align the generated image with the original resized image and apply color correction.", + inputs=[ + io.Image.Input("images", tooltip="The generated image to process."), + io.Image.Input("original_resized_images", tooltip="The original resized image before pre-processing, used as reference."), + io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="Method to match the generated image colors to the original image. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."), + ], + outputs=[io.Image.Output(display_name="images")], + ) + + @classmethod + def execute(cls, images, original_resized_images, color_correction_method): + alpha_input = None + if original_resized_images.shape[-1] == 4: + alpha_input = original_resized_images[..., 3:4] + original_resized_images = original_resized_images[..., :3] + decoded_5d, decoded_was_4d = cls._as_bthwc(images) + reference_full, _ = cls._as_bthwc(original_resized_images) + decoded_5d = cls._restore_reference_batch_time(decoded_5d, reference_full) + + b = min(decoded_5d.shape[0], reference_full.shape[0]) + t = min(decoded_5d.shape[1], reference_full.shape[1]) + reference_h = reference_full.shape[2] + reference_w = reference_full.shape[3] + + decoded_5d = decoded_5d[:b, :t, :, :, :] + target_h = min(decoded_5d.shape[2], reference_h) + target_w = min(decoded_5d.shape[3], reference_w) + decoded_5d = decoded_5d[:, :, :target_h, :target_w, :] + if color_correction_method in ("lab", "wavelet", "adain"): + reference_5d = reference_full[:b, :t, :, :, :] + reference_5d = cls._resize_reference(reference_5d, target_h, target_w) + output_device = decoded_5d.device + decoded_raw = cls._to_seedvr2_raw(decoded_5d) + reference_raw = cls._to_seedvr2_raw(reference_5d) + decoded_flat = rearrange(decoded_raw, "b t h w c -> (b t) c h w") + reference_flat = rearrange(reference_raw, "b t h w c -> (b t) c h w") + output = cls._color_transfer_chunked( + decoded_flat, reference_flat, output_device, color_correction_method, + ) + output = rearrange(output, "(b t) c h w -> b t h w c", b=b, t=t) + output = output.add(1.0).div(2.0).clamp(0.0, 1.0) + elif color_correction_method == "none": + output = decoded_5d + else: + raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") + + if alpha_input is not None: + alpha_5d, _ = cls._as_bthwc(alpha_input) + alpha_5d = alpha_5d[:output.shape[0], :output.shape[1], :output.shape[2], :output.shape[3], :] + output = torch.cat([output, alpha_5d.to(dtype=output.dtype, device=output.device)], dim=-1) + h2 = output.shape[-3] - (output.shape[-3] % 2) + w2 = output.shape[-2] - (output.shape[-2] % 2) + output = output[:, :, :h2, :w2, :] + if decoded_was_4d: + output = output.reshape(-1, output.shape[-3], output.shape[-2], output.shape[-1]) + return io.NodeOutput(output) + + @staticmethod + def _as_bthwc(images): + if images.ndim == 4: + return images.unsqueeze(0), True + if images.ndim == 5: + return images, False + raise ValueError( + f"SeedVR2PostProcessing: expected 4-D or 5-D IMAGE tensor, got shape {tuple(images.shape)}" + ) + + @staticmethod + def _restore_reference_batch_time(decoded, reference): + if decoded.shape[0] != 1: + return decoded + ref_b, ref_t = reference.shape[:2] + if ref_b < 1 or decoded.shape[1] % ref_b != 0: + return decoded + decoded_t = decoded.shape[1] // ref_b + if decoded_t < ref_t: + return decoded + return decoded.reshape(ref_b, decoded_t, decoded.shape[2], decoded.shape[3], decoded.shape[4]) + + @staticmethod + def _to_seedvr2_raw(images): + return images.mul(2.0).sub(1.0) + + @staticmethod + def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn): + color_device = comfy.model_management.vae_device() + decoded_flat = decoded_flat.to(device=color_device) + reference_flat = reference_flat.to(device=color_device) + output = transfer_fn(decoded_flat, reference_flat) + return output.to(device=output_device) + + @staticmethod + def _lab_color_transfer_on_vae_device(decoded_flat, reference_flat, output_device): + color_device = comfy.model_management.vae_device() + result = None + for start in range(decoded_flat.shape[0]): + decoded_frame = decoded_flat[start:start + 1].to(device=color_device).clone() + reference_frame = reference_flat[start:start + 1].to(device=color_device).clone() + output = lab_color_transfer(decoded_frame, reference_frame).to(device=output_device) + if result is None: + result = torch.empty( + (decoded_flat.shape[0],) + tuple(output.shape[1:]), + device=output_device, + dtype=output.dtype, + ) + result[start:start + 1].copy_(output) + if result is None: + raise ValueError("SeedVR2PostProcessing: LAB color correction requires at least one frame.") + return result + + @classmethod + def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method): + chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method) + while True: + next_chunk_size = None + try: + return cls._run_color_transfer_chunks( + decoded_flat, reference_flat, output_device, color_correction_method, chunk_size, + ) + except Exception as e: + comfy.model_management.raise_non_oom(e) + if chunk_size <= 1: + raise RuntimeError( + "SeedVR2PostProcessing: color correction OOM at one frame; " + f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}." + ) from e + next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR) + + comfy.model_management.soft_empty_cache() + chunk_size = next_chunk_size + + @classmethod + def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size): + result = None + for start in range(0, decoded_flat.shape[0], chunk_size): + end = min(start + chunk_size, decoded_flat.shape[0]) + decoded_chunk = decoded_flat[start:end] + reference_chunk = reference_flat[start:end] + if color_correction_method == "lab": + output = cls._lab_color_transfer_on_vae_device(decoded_chunk, reference_chunk, output_device) + elif color_correction_method == "wavelet": + output = cls._color_transfer_on_vae_device( + decoded_chunk, reference_chunk, output_device, wavelet_color_transfer, + ) + else: + output = cls._color_transfer_on_vae_device( + decoded_chunk, reference_chunk, output_device, adain_color_transfer, + ) + if result is None: + result = torch.empty( + (decoded_flat.shape[0],) + tuple(output.shape[1:]), + device=output_device, + dtype=output.dtype, + ) + result[start:end].copy_(output) + if result is None: + raise ValueError("SeedVR2PostProcessing: color correction requires at least one frame.") + return result + + @classmethod + def _estimate_color_correction_chunk_size(cls, decoded_flat, color_correction_method): + multiplier = cls._color_correction_memory_multiplier(color_correction_method) + frames = decoded_flat.shape[0] + _, channels, height, width = decoded_flat.shape + dtype_bytes = max(decoded_flat.element_size(), SEEDVR2_DTYPE_BYTES_FLOOR) + bytes_per_frame = height * width * channels * dtype_bytes * multiplier + if bytes_per_frame <= 0: + return frames + color_device = comfy.model_management.vae_device() + free_memory = comfy.model_management.get_free_memory(color_device) + chunk_size = int((free_memory * SEEDVR2_COLOR_MEM_HEADROOM) // bytes_per_frame) + return max(1, min(frames, chunk_size)) + + @staticmethod + def _color_correction_memory_multiplier(color_correction_method): + if color_correction_method == "lab": + return SEEDVR2_LAB_SCALE_MULTIPLIER + if color_correction_method == "wavelet": + return SEEDVR2_WAVELET_SCALE_MULTIPLIER + if color_correction_method == "adain": + return SEEDVR2_ADAIN_SCALE_MULTIPLIER + raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") + + @staticmethod + def _resize_reference(reference, height, width): + if reference.shape[2] == height and reference.shape[3] == width: + return reference + b, t = reference.shape[:2] + reference_flat = rearrange(reference, "b t h w c -> (b t) c h w") + resized = TVF.resize( + reference_flat, + size=(height, width), + interpolation=InterpolationMode.BICUBIC, + antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"), + ) + return rearrange(resized, "(b t) c h w -> b t h w c", b=b, t=t) + + +class SeedVR2Conditioning(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedVR2Conditioning", + display_name="Apply SeedVR2 Conditioning", + category="conditioning", + description="Build SeedVR2 positive/negative conditioning from a VAE latent.", + inputs=[ + io.Model.Input("model", tooltip="The SeedVR2 model."), + io.Latent.Input("vae_conditioning", display_name="latent"), + ], + outputs=[ + io.Model.Output(display_name = "model"), + io.Conditioning.Output(display_name = "positive"), + io.Conditioning.Output(display_name = "negative"), + io.Latent.Output(display_name = "latent"), + ], + ) + + @classmethod + def execute(cls, model, vae_conditioning) -> io.NodeOutput: + + vae_conditioning = vae_conditioning["samples"] + if vae_conditioning.ndim != 5: + raise ValueError( + "SeedVR2Conditioning expects a 5-D VAE latent in Comfy " + f"channel-first layout; got shape {tuple(vae_conditioning.shape)}." + ) + if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS: + raise ValueError( + "SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy " + f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); " + f"got channel-last shape {tuple(vae_conditioning.shape)}." + ) + vae_conditioning = vae_conditioning.movedim(1, -1).contiguous() + model_patcher = model + model = _resolve_seedvr2_diffusion_model(model_patcher) + pos_cond = model.positive_conditioning + neg_cond = model.negative_conditioning + + # Fail-loud guard against silently-wrong output when a + # DiT-only ``.safetensors`` (no ``positive_conditioning`` / + # ``negative_conditioning`` keys) is loaded via ``UNETLoader``. + # ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see + # ``comfy/ldm/seedvr/model.py``); ``load_state_dict(strict=False)`` + # leaves them at zero when the keys are absent. Detect that state + # here rather than at ``BaseModel.extra_conds`` (per sampling step, + # wasteful) or at the resolver helper (mixes structural shape with + # semantic content). Both buffers must be checked together — partial + # bake regressions could populate one but not the other. + if ( + pos_cond.float().abs().sum().item() == 0 + and neg_cond.float().abs().sum().item() == 0 + ): + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning " + f"and negative_conditioning buffers are zero-valued — model " + f"file appears to be a DiT-only export missing " + f"the SeedVR2 conditioning tensors. " + f"Re-bake the file with ``positive_conditioning`` (58, 5120) " + f"and ``negative_conditioning`` (64, 5120) keys at top level, " + f"or load via CheckpointLoaderSimple from a bundled " + f"checkpoint." + ) + + _apply_rope_freqs_float32_cast(model) + + condition = torch.stack([get_conditions(c, c) for c in vae_conditioning]) + condition = condition.movedim(-1, 1) + latent = vae_conditioning.movedim(-1, 1) + + latent = rearrange(latent, "b c t h w -> b (c t) h w") + condition = rearrange(condition, "b c t h w -> b (c t) h w") + + negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] + positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] + + return io.NodeOutput(model_patcher, positive, negative, {"samples": latent}) + +def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int, + t_end: int, channels: int) -> torch.Tensor: + """Slice collapsed ``(B, channels*T, H, W)`` along latent T: reshape (accepts non-contiguous inputs), slice, ``.contiguous()`` (T-slice of 5D is a non-contiguous view; re-collapse needs contiguous), re-collapse.""" + B, CT, H, W = tensor_4d.shape + if CT % channels != 0: + raise ValueError( + f"_slice_collapsed_4d_along_t: collapsed channel dim {CT} is not " + f"divisible by channels={channels}; tensor shape {tuple(tensor_4d.shape)}." + ) + T = CT // channels + if not (0 <= t_start < t_end <= T): + raise ValueError( + f"_slice_collapsed_4d_along_t: slice [{t_start}:{t_end}] out of " + f"range for T={T}." + ) + new_T = t_end - t_start + sliced = tensor_4d.reshape(B, channels, T, H, W)[:, :, t_start:t_end, :, :].contiguous() + return sliced.reshape(B, channels * new_T, H, W) + + +def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int): + """Return a new conditioning list with each entry's ``options["condition"]`` (collapsed ``(B, 17*T, H, W)``) sliced along latent T; text tensors, other option keys, and condition-less entries pass through unchanged and inputs are not mutated.""" + new_list = [] + for entry in cond_list: + text_cond, options = entry[0], entry[1] + if "condition" not in options: + new_list.append(entry) + continue + new_options = options.copy() + new_options["condition"] = _slice_collapsed_4d_along_t( + new_options["condition"], t_start, t_end, + SEEDVR2_COND_CHANNELS, + ) + new_list.append([text_cond, new_options]) + return new_list + + +def _slice_seedvr2_noise_mask_along_t(noise_mask: torch.Tensor, + samples_4d: torch.Tensor, + t_start: int, + t_end: int): + """Slice only masks already expanded to collapsed ``(B, 16*T, H, W)``; pass standard ``(B, 1, H, W)`` ``SetLatentNoiseMask`` outputs through for KSampler to expand.""" + if noise_mask.ndim == samples_4d.ndim and noise_mask.shape[1] == samples_4d.shape[1]: + return _slice_collapsed_4d_along_t( + noise_mask, t_start, t_end, SEEDVR2_LATENT_CHANNELS, + ) + return noise_mask + + +def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor: + """Concatenate collapsed ``(B, channels*T_i, H, W)`` chunks along latent T: un-collapse to 5D, cat on ``dim=2``, re-collapse to 4D.""" + if len(chunks_4d) == 0: + raise ValueError("_concat_chunks_along_t: empty chunk list.") + fives = [] + for ch in chunks_4d: + B, CT, H, W = ch.shape + if CT % channels != 0: + raise ValueError( + f"_concat_chunks_along_t: chunk shape {tuple(ch.shape)} " + f"channel dim {CT} not divisible by channels={channels}." + ) + T = CT // channels + fives.append(ch.reshape(B, channels, T, H, W)) + cat = torch.cat(fives, dim=2).contiguous() + B, C, T_total, H, W = cat.shape + return cat.reshape(B, C * T_total, H, W) + + +def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor: + """1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``): + Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3`` + (dead-band would collapse a tiny transition). Window shape matched to the reference + overlapping-frame blend for parity; caller broadcasts across ``(B, C, T_overlap, H, W)``. + """ + if overlap < 1: + raise ValueError( + f"_hann_blend_weights_1d: overlap must be >= 1; got {overlap}." + ) + if overlap >= 3: + t = torch.linspace(0.0, 1.0, steps=overlap, device=device, dtype=dtype) + blend_start = 1.0 / 3.0 + blend_end = 2.0 / 3.0 + u = ((t - blend_start) / (blend_end - blend_start)).clamp(0.0, 1.0) + return 0.5 + 0.5 * torch.cos(torch.pi * u) + return torch.linspace(1.0, 0.0, steps=overlap, device=device, dtype=dtype) + + +def _blend_overlap_region(prev_tail_5d: torch.Tensor, + cur_head_5d: torch.Tensor) -> torch.Tensor: + """Blend two equal-shape 5D ``(B, C, T_overlap, H, W)`` tensors with a 1D Hann/linear T-ramp: ``prev_tail_5d`` takes the descending weight, ``cur_head_5d`` takes ``1 - w_prev`` (caller ensures matching shape/dtype/device).""" + if prev_tail_5d.shape != cur_head_5d.shape: + raise ValueError( + f"_blend_overlap_region: shape mismatch " + f"prev {tuple(prev_tail_5d.shape)} vs " + f"cur {tuple(cur_head_5d.shape)}." + ) + overlap = int(prev_tail_5d.shape[2]) + w_prev_1d = _hann_blend_weights_1d( + overlap, prev_tail_5d.device, prev_tail_5d.dtype, + ) + # Reshape to (1, 1, overlap, 1, 1) for broadcast across B, C, H, W. + w_prev = w_prev_1d.view(1, 1, overlap, 1, 1) + w_cur = 1.0 - w_prev + return prev_tail_5d * w_prev + cur_head_5d * w_cur + + +def _concat_chunks_with_overlap_blend(chunk_specs, channels: int, + overlap_latent: int) -> torch.Tensor: + """Concatenate overlapping ``(t_start, t_end, chunk_4d)`` specs (source-latent T coords) into one collapsed 4D tensor, Hann/linear-blending overlaps; ``overlap_latent == 0`` fast-paths to plain concat (bit-identical to ``_concat_chunks_along_t``). Each blend uses the actual width ``min(prev_end - cur_start, chunk length)``, smaller than ``overlap_latent`` for a runt final chunk.""" + if len(chunk_specs) == 0: + raise ValueError("_concat_chunks_with_overlap_blend: empty chunk list.") + if overlap_latent < 0: + raise ValueError( + f"_concat_chunks_with_overlap_blend: overlap_latent must be " + f">= 0; got {overlap_latent}." + ) + + # Validate channel divisibility once and capture per-chunk T. + chunk_5d = [] + for t_start, t_end, ch in chunk_specs: + B, CT, H, W = ch.shape + if CT % channels != 0: + raise ValueError( + f"_concat_chunks_with_overlap_blend: chunk shape " + f"{tuple(ch.shape)} channel dim {CT} not divisible " + f"by channels={channels}." + ) + T = CT // channels + if t_end - t_start != T: + raise ValueError( + f"_concat_chunks_with_overlap_blend: chunk T={T} mismatches " + f"declared range [{t_start}:{t_end}]." + ) + chunk_5d.append((t_start, t_end, ch.reshape(B, channels, T, H, W))) + + if overlap_latent == 0: + # Fast path: pure concat in the caller-provided chunk order. + return _concat_chunks_along_t( + [c.reshape(c.shape[0], channels * c.shape[2], c.shape[3], c.shape[4]) + for _, _, c in chunk_5d], + channels, + ) + + T_total = max(t_end for _, t_end, _ in chunk_5d) + first_5d = chunk_5d[0][2] + B = first_5d.shape[0] + H = first_5d.shape[3] + W = first_5d.shape[4] + result = torch.empty( + (B, channels, T_total, H, W), + device=first_5d.device, dtype=first_5d.dtype, + ) + filled_until = 0 + for i, (cs, ce, ct_5d) in enumerate(chunk_5d): + chunk_T = int(ct_5d.shape[2]) + if i == 0: + result[:, :, cs:ce, :, :] = ct_5d + filled_until = ce + continue + # Overlap region width is bounded by both the previous fill + # frontier and the current chunk's actual length (for runt + # final chunks shorter than the configured overlap). + overlap_len = min(filled_until - cs, chunk_T) + if overlap_len > 0: + prev_tail = result[:, :, cs:cs + overlap_len, :, :].contiguous() + cur_head = ct_5d[:, :, :overlap_len, :, :].contiguous() + blended = _blend_overlap_region(prev_tail, cur_head) + result[:, :, cs:cs + overlap_len, :, :] = blended + tail_start = cs + overlap_len + tail_end = ce + if tail_end > tail_start: + result[:, :, tail_start:tail_end, :, :] = ( + ct_5d[:, :, overlap_len:, :, :] + ) + else: + # Disjoint chunks (overlap_latent set but this pair did not + # actually overlap, e.g. step_latent equal to chunk_latent + # in a degenerate config). Treat as concat. + result[:, :, cs:ce, :, :] = ct_5d + filled_until = ce + + return result.contiguous().reshape(B, channels * T_total, H, W) + + +def _run_standard_sample(model, seed: int, steps: int, cfg: float, + sampler_name: str, scheduler: str, + positive, negative, latent: dict, + denoise: float) -> dict: + """Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk.""" + samples_in = latent["samples"] + samples_in = comfy.sample.fix_empty_latent_channels( + model, samples_in, latent.get("downscale_ratio_spacial", None), + ) + batch_inds = latent.get("batch_index", None) + noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds) + noise_mask = latent.get("noise_mask", None) + samples = comfy.sample.sample( + model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, samples_in, + denoise=denoise, noise_mask=noise_mask, seed=seed, + ) + out = latent.copy() + out.pop("downscale_ratio_spacial", None) + out["samples"] = samples + return out + + +class SeedVR2ProgressiveSampler(io.ComfyNode): + """Sequential temporal chunking sampler for SeedVR2 native. + + Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that + OOM on long sequences. The latent enters the sampler in SeedVR2's + collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning`` + at ``rearrange(b c t h w -> b (c t) h w)``); this node slices that + tensor along the temporal axis, runs the configured inner sampler + sequentially per chunk against the standard ``comfy.sample.sample`` + entry point, and concatenates per-chunk outputs back into a single + ``(B, 16*T_total, H, W)`` latent. + + ``frames_per_chunk`` is expressed in pixel-frame units to match the + SeedVR2 4n+1 constraint enforced upstream by ``cut_videos`` and the + VAE's ``temporal_downsample_factor=4``. A pixel chunk size ``F`` + maps to ``(F - 1) // 4 + 1`` latent-frame chunks. + + Determinism contract: a single noise tensor is generated once from + the user seed and sliced per chunk (rather than re-seeding each + chunk), so a workflow that fits in a single chunk produces output + identical to a workflow that fits in N chunks at the same seed, + modulo the inherent T-axis chunk-boundary independence of the model. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedVR2ProgressiveSampler", + display_name="Sample SeedVR2 (Progressive)", + category="sampling", + description="Sample a SeedVR2 latent in sequential temporal chunks to allow longer videos to fit into VRAM via frame blending the resulting upscaled latents.", + inputs=[ + io.Model.Input("model", tooltip="The model used for denoising the input latent."), + io.Int.Input("seed", default=0, min=0, + max=0xffffffffffffffff, + control_after_generate=True, + tooltip="The random seed used for creating the noise."), + io.Int.Input("steps", default=20, min=1, max=10000, + tooltip="The number of steps used in the denoising process."), + io.Float.Input("cfg", default=1.0, min=0.0, max=100.0, + step=0.1, round=0.01, + tooltip="The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."), + io.Combo.Input("sampler_name", + options=comfy.samplers.SAMPLER_NAMES, + tooltip="The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."), + io.Combo.Input("scheduler", + options=comfy.samplers.SCHEDULER_NAMES, + tooltip="The scheduler controls how noise is gradually removed to form the image."), + io.Conditioning.Input("positive", + tooltip="The conditioning describing the attributes you want to include in the image."), + io.Conditioning.Input("negative", + tooltip="The conditioning describing the attributes you want to exclude from the image."), + io.Latent.Input("latent", + tooltip="The latent image to denoise."), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, + step=0.01, + tooltip="The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."), + io.Int.Input("frames_per_chunk", default=21, min=1, + max=16384, step=4, + tooltip="Pixel frames per temporal chunk (4n+1: 1, 5, 9, 13, ...)."), + io.Int.Input("temporal_overlap", default=0, min=0, + max=16384, + tooltip="Latent frames blended between adjacent chunks to hide the seam; 0 = no blend."), + io.Combo.Input("chunking_mode", + options=["manual", "auto"], + default="manual", + tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."), + ], + outputs=[io.Latent.Output(display_name="latent")], + ) + + @classmethod + def execute(cls, model, seed, steps, cfg, sampler_name, scheduler, + positive, negative, latent, denoise, + frames_per_chunk, temporal_overlap, + chunking_mode="manual") -> io.NodeOutput: + # 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline + # requires pixel-frame counts of the form 4n+1 (1, 5, 9, 13, ...), + # imposed at ``cut_videos`` upstream and propagated through the VAE's + # temporal_downsample_factor=4. Reject violations explicitly before + # any model invocation; a silent rounding would mis-align chunk + # boundaries with the 4n+1 lattice. + if frames_per_chunk < 1 or (frames_per_chunk - 1) % 4 != 0: + raise ValueError( + f"SeedVR2ProgressiveSampler: frames_per_chunk must be a " + f"4n+1 pixel-frame count (1, 5, 9, 13, 17, 21, ...); " + f"got {frames_per_chunk}." + ) + + samples_4d = latent["samples"] + samples_4d = comfy.sample.fix_empty_latent_channels( + model, samples_4d, + latent.get("downscale_ratio_spacial", None), + ) + if samples_4d.ndim != 4: + raise ValueError( + f"SeedVR2ProgressiveSampler: expected 4D collapsed latent " + f"(B, 16*T, H, W); got shape {tuple(samples_4d.shape)}." + ) + B, CT, H, W = samples_4d.shape + if CT % SEEDVR2_LATENT_CHANNELS != 0: + raise ValueError( + f"SeedVR2ProgressiveSampler: collapsed channel dim {CT} is " + f"not divisible by SeedVR2 latent channels " + f"{SEEDVR2_LATENT_CHANNELS}; latent does not appear to be " + f"SeedVR2-shaped." + ) + T_latent = CT // SEEDVR2_LATENT_CHANNELS + T_pixel = 4 * (T_latent - 1) + 1 + + if chunking_mode not in ("manual", "auto"): + raise ValueError( + f"SeedVR2ProgressiveSampler: chunking_mode must be " + f"'manual' or 'auto'; got {chunking_mode!r}." + ) + + if chunking_mode == "auto": + attempts = _seedvr2_auto_chunk_attempts( + T_latent, T_pixel, frames_per_chunk, + ) + for i, attempt_frames_per_chunk in enumerate(attempts): + retry = False + try: + return cls.execute( + model=model, seed=seed, steps=steps, cfg=cfg, + sampler_name=sampler_name, scheduler=scheduler, + positive=positive, negative=negative, + latent=latent, denoise=denoise, + frames_per_chunk=attempt_frames_per_chunk, + temporal_overlap=temporal_overlap, + chunking_mode="manual", + ) + except Exception as e: + comfy.model_management.raise_non_oom(e) + if i == len(attempts) - 1: + raise RuntimeError( + "SeedVR2ProgressiveSampler: exhausted auto " + "chunking attempts after OOM. Tried " + f"frames_per_chunk values {attempts}." + ) from e + retry = True + + if retry: + logging.warning( + "SeedVR2ProgressiveSampler auto chunking OOM at " + "frames_per_chunk=%s; retrying with " + "frames_per_chunk=%s.", + attempt_frames_per_chunk, attempts[i + 1], + ) + comfy.model_management.soft_empty_cache() + + # Short-circuit: total fits in one chunk -> standard path with no + # chunking overhead. Output of this branch is byte-identical to the + # built-in KSampler given the same (model, seed, steps, cfg, + # sampler_name, scheduler, positive, negative, latent, + # denoise) tuple. + if T_pixel <= frames_per_chunk: + return io.NodeOutput(_run_standard_sample( + model, seed, steps, cfg, sampler_name, scheduler, + positive, negative, latent, denoise, + )) + + # Map pixel chunk -> latent chunk. Each chunk's latent length is + # at most ``chunk_latent``; the final chunk may be a runt that + # is automatically 4n+1-aligned in the pixel domain by the + # T_pixel = 4*(T_latent-1) + 1 mapping (every positive integer + # T_latent corresponds to a valid 4n+1 pixel count). + chunk_latent = (frames_per_chunk - 1) // 4 + 1 + + # ``temporal_overlap`` is exposed in latent-frame units, but users + # do not know the derived latent chunk length. Treat oversized + # values as "maximum valid overlap" while preserving a strictly + # positive chunk-loop stride. + if temporal_overlap < 0: + raise ValueError( + f"SeedVR2ProgressiveSampler: temporal_overlap must be >= 0; " + f"got {temporal_overlap}." + ) + temporal_overlap = min(temporal_overlap, chunk_latent - 1) + step_latent = chunk_latent - temporal_overlap + + # Generate full noise once from the user seed, then slice along T + # per chunk. Using one global noise tensor (rather than re-seeding + # per chunk) preserves seed-determinism across chunk-count + # variations: the same (seed, total T_latent) always produces the + # same noise samples regardless of how the work is partitioned. + batch_inds = latent.get("batch_index", None) + noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds) + + noise_mask = latent.get("noise_mask", None) + + # Build the flat list of chunk ranges first so the chunking + # geometry is fully known before any sample call. + chunk_ranges = [] + for chunk_start in range(0, T_latent, step_latent): + chunk_end = min(chunk_start + chunk_latent, T_latent) + if chunk_start >= chunk_end: + # The final iteration of a stride that lands exactly on + # T_latent produces a zero-length chunk; skip it. + break + chunk_ranges.append((chunk_start, chunk_end)) + if chunk_end >= T_latent: + break + + def _sample_one_chunk(chunk_start, chunk_end): + samples_chunk = _slice_collapsed_4d_along_t( + samples_4d, chunk_start, chunk_end, + SEEDVR2_LATENT_CHANNELS, + ) + noise_chunk = _slice_collapsed_4d_along_t( + noise_full, chunk_start, chunk_end, + SEEDVR2_LATENT_CHANNELS, + ) + positive_chunk = _slice_seedvr2_cond_along_t( + positive, chunk_start, chunk_end, + ) + negative_chunk = _slice_seedvr2_cond_along_t( + negative, chunk_start, chunk_end, + ) + + # Per-chunk noise_mask handling: standard masks are passed + # through for KSampler expansion; pre-expanded collapsed + # masks are sliced. + chunk_noise_mask = None + if noise_mask is not None: + chunk_noise_mask = _slice_seedvr2_noise_mask_along_t( + noise_mask, samples_4d, chunk_start, chunk_end, + ) + + return comfy.sample.sample( + model, noise_chunk, steps, cfg, sampler_name, scheduler, + positive_chunk, negative_chunk, samples_chunk, + denoise=denoise, noise_mask=chunk_noise_mask, seed=seed, + ) + + chunk_specs = [] + for chunk_start, chunk_end in chunk_ranges: + chunk_samples = _sample_one_chunk(chunk_start, chunk_end) + chunk_specs.append((chunk_start, chunk_end, chunk_samples)) + + final = _concat_chunks_with_overlap_blend( + chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap, + ) + + out = latent.copy() + out.pop("downscale_ratio_spacial", None) + out["samples"] = final + return io.NodeOutput(out) + + +class SeedVRExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SeedVR2Conditioning, + SeedVR2Preprocess, + SeedVR2PostProcessing, + SeedVR2ProgressiveSampler, + ] + +async def comfy_entrypoint() -> SeedVRExtension: + return SeedVRExtension() diff --git a/nodes.py b/nodes.py index 2f5a478b59e3..d9ac53eded42 100644 --- a/nodes.py +++ b/nodes.py @@ -47,14 +47,18 @@ if args.enable_manager: import comfyui_manager + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() + def interrupt_processing(value=True): comfy.model_management.interrupt_current_processing(value) + MAX_RESOLUTION=16384 + class CLIPTextEncode(ComfyNodeABC): @classmethod def INPUT_TYPES(s) -> InputTypeDict: @@ -323,8 +327,8 @@ def INPUT_TYPES(s): return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32, "advanced": True}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}), - "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time.", "advanced": True}), - "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), + "temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}), + "temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" @@ -334,18 +338,32 @@ def INPUT_TYPES(s): def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8): if tile_size < overlap * 4: overlap = tile_size // 4 - if temporal_size < temporal_overlap * 2: - temporal_overlap = temporal_overlap // 2 temporal_compression = vae.temporal_compression_decode() if temporal_compression is not None: - temporal_size = max(2, temporal_size // temporal_compression) - temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression)) + if temporal_size <= 0: + temporal_size = 0 + temporal_overlap = 0 + else: + requested_temporal_overlap = temporal_overlap + if temporal_size < temporal_overlap * 2: + temporal_overlap = temporal_overlap // 2 + temporal_size = max(2, temporal_size // temporal_compression) + temporal_overlap = min(temporal_size // 2, temporal_overlap // temporal_compression) + if requested_temporal_overlap > 0: + temporal_overlap = max(1, temporal_overlap) else: temporal_size = None temporal_overlap = None compression = vae.spacial_compression_decode() - images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap) + images = vae.decode_tiled( + samples["samples"], + tile_x=tile_size // compression, + tile_y=tile_size // compression, + overlap=overlap // compression, + tile_t=temporal_size, + overlap_t=temporal_overlap, + ) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, ) @@ -362,7 +380,7 @@ def INPUT_TYPES(s): def encode(self, vae, pixels): t = vae.encode(pixels) - return ({"samples":t}, ) + return ({"samples": t}, ) class VAEEncodeTiled: @classmethod @@ -370,8 +388,8 @@ def INPUT_TYPES(s): return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64, "advanced": True}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}), - "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time.", "advanced": True}), - "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), + "temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}), + "temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), }} RETURN_TYPES = ("LATENT",) FUNCTION = "encode" @@ -379,6 +397,9 @@ def INPUT_TYPES(s): CATEGORY = "experimental" def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): + if temporal_size <= 0: + temporal_size = 0 + temporal_overlap = 0 t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) return ({"samples": t}, ) @@ -2418,6 +2439,7 @@ async def init_builtin_extra_nodes(): "nodes_camera_trajectory.py", "nodes_edit_model.py", "nodes_tcfg.py", + "nodes_seedvr.py", "nodes_context_windows.py", "nodes_qwen.py", "nodes_chroma_radiance.py", diff --git a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py new file mode 100644 index 000000000000..2a6e3d43075d --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py @@ -0,0 +1,213 @@ +"""Consolidated SeedVR2 conditioning and refactor regression tests. + +Merges the prior test_seedvr2_refactor_nodes.py and +test_seedvr_conditioning_hardening.py modules. Refactor tests use the +top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests +use _import_nodes_seedvr_isolated() for sys.modules isolation when +mocking comfy.model_management. +""" + +import importlib +import sys +from unittest.mock import MagicMock + +import pytest +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + + +_SENTINEL = object() +_TARGETS = ( + ("comfy.model_management", "comfy"), + ("comfy_extras.nodes_seedvr", "comfy_extras"), +) + + +def _import_nodes_seedvr_isolated(): + """Import comfy_extras.nodes_seedvr with comfy.model_management mocked.""" + priors = [] + for mod_name, parent_name in _TARGETS: + prior_mod = sys.modules.get(mod_name, _SENTINEL) + parent = sys.modules.get(parent_name) + attr = mod_name.split(".")[-1] + prior_attr = ( + getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL + ) + priors.append((mod_name, parent_name, attr, prior_mod, prior_attr)) + + mock_mm = MagicMock() + for fn in ( + "xformers_enabled", "xformers_enabled_vae", + "pytorch_attention_enabled", "pytorch_attention_enabled_vae", + "sage_attention_enabled", "flash_attention_enabled", + "is_intel_xpu", + ): + getattr(mock_mm, fn).return_value = False + tv = torch.version.__version__.split(".") + mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1])) + mock_mm.WINDOWS = False + sys.modules["comfy.model_management"] = mock_mm + if sys.modules.get("comfy") is None: + import comfy as _comfy_pkg # noqa: F401 + comfy_pkg = sys.modules.get("comfy") + if comfy_pkg is not None: + setattr(comfy_pkg, "model_management", mock_mm) + nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or ( + importlib.import_module("comfy_extras.nodes_seedvr") + ) + + def _restore(): + for mod_name, parent_name, attr, prior_mod, prior_attr in priors: + if prior_mod is _SENTINEL: + sys.modules.pop(mod_name, None) + else: + sys.modules[mod_name] = prior_mod + parent = sys.modules.get(parent_name) + if parent is None: + continue + if prior_attr is _SENTINEL: + if hasattr(parent, attr): + delattr(parent, attr) + else: + setattr(parent, attr, prior_attr) + + return nodes_seedvr, _restore + + +class _Rope(nn.Module): + """Minimal RoPE stub exposing a `freqs` parameter.""" + def __init__(self): + super().__init__() + self.freqs = nn.Parameter(torch.zeros(4)) + + +class _Block(nn.Module): + """Minimal transformer block stub holding a `_Rope`.""" + def __init__(self): + super().__init__() + self.rope = _Rope() + + +class _DiffusionModel(nn.Module): + """Stub diffusion model with N blocks and pos/neg conditioning buffers.""" + def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32): + super().__init__() + self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)]) + pos = torch.zeros if zero_conditioning else torch.ones + self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype)) + self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype)) + + +class _ModelInner: + """Inner model wrapper exposing `.diffusion_model`.""" + def __init__(self, diffusion_model): + self.diffusion_model = diffusion_model + + +class _ModelPatcher: + """ModelPatcher stub exposing `.model._ModelInner`.""" + def __init__(self, diffusion_model): + self.model = _ModelInner(diffusion_model) + + +def test_seedvr2_conditioning_schema_exposes_model_passthrough_output(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + schema = nodes_seedvr.SeedVR2Conditioning.define_schema() + assert [input_item.id for input_item in schema.inputs] == [ + "model", + "vae_conditioning", + ] + assert schema.inputs[1].display_name == "latent" + assert [output.display_name for output in schema.outputs] == [ + "model", + "positive", + "negative", + "latent", + ] + finally: + restore() + + +def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel() + patcher = _ModelPatcher(diffusion_model) + samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2) + vae_conditioning = {"samples": samples} + + _, first_positive, first_negative, first_latent = ( + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, + vae_conditioning, + ) + ) + _, second_positive, second_negative, second_latent = ( + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, + vae_conditioning, + ) + ) + + expected_latent = samples.reshape(1, 6, 2, 2) + channel_last = samples.movedim(1, -1).contiguous() + expected_condition = torch.cat( + [ + channel_last, + torch.ones((*channel_last.shape[:-1], 1)), + ], + dim=-1, + ).movedim(-1, 1).reshape(1, 9, 2, 2) + + assert torch.equal(first_latent["samples"], expected_latent) + assert torch.equal(second_latent["samples"], expected_latent) + assert torch.equal( + first_positive[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + second_positive[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + first_negative[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + second_negative[0][1]["condition"], + expected_condition, + ) + finally: + restore() + + +def test_seedvr2_conditioning_fails_loud_on_zero_buffers(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel(zero_conditioning=True) + patcher = _ModelPatcher(diffusion_model) + vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))} + + with pytest.raises(RuntimeError) as excinfo: + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, vae_conditioning, + ) + + message = str(excinfo.value) + assert message.startswith( + nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX + ), ( + "Fail-loud message must use the standard " + "_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers " + f"can match it. Got: {message!r}" + ) + assert "positive_conditioning" in message + assert "negative_conditioning" in message + finally: + restore() diff --git a/tests-unit/comfy_extras_test/test_seedvr2_nodes.py b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py new file mode 100644 index 000000000000..f7d9a4f65ab3 --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py @@ -0,0 +1,55 @@ +import importlib +import inspect +import sys +from unittest.mock import MagicMock, patch + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + + +def test_seedvr_node_signature_matches_schema(): + mock_mm = MagicMock() + mock_mm.xformers_enabled.return_value = False + mock_mm.xformers_enabled_vae.return_value = False + mock_mm.sage_attention_enabled.return_value = False + mock_mm.flash_attention_enabled.return_value = False + + sentinel = object() + prior_cpu = cli_args.cpu + cli_args.cpu = True + prior_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel) + comfy_pkg = sys.modules.get("comfy") + prior_mm_attr = getattr(comfy_pkg, "model_management", sentinel) if comfy_pkg else sentinel + + with patch.dict(sys.modules, {"comfy.model_management": mock_mm}): + if comfy_pkg is not None: + setattr(comfy_pkg, "model_management", mock_mm) + sys.modules.pop("comfy_extras.nodes_seedvr", None) + try: + nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr") + for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler): + schema_ids = [i.id for i in node_cls.define_schema().inputs] + exec_params = [ + p for p in inspect.signature(node_cls.execute).parameters.keys() + if p != "cls" + ] + assert schema_ids == exec_params, ( + f"{node_cls.__name__} schema/execute drift: " + f"schema_ids={schema_ids}, exec_params={exec_params}" + ) + finally: + cli_args.cpu = prior_cpu + if prior_module is sentinel: + sys.modules.pop("comfy_extras.nodes_seedvr", None) + else: + sys.modules["comfy_extras.nodes_seedvr"] = prior_module + if comfy_pkg is not None: + if prior_mm_attr is sentinel: + if hasattr(comfy_pkg, "model_management"): + delattr(comfy_pkg, "model_management") + else: + setattr(comfy_pkg, "model_management", prior_mm_attr) diff --git a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py new file mode 100644 index 000000000000..a27a8f8df24d --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py @@ -0,0 +1,57 @@ +from unittest.mock import patch + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +from comfy_extras import nodes_seedvr # noqa: E402 + + +def _schema_ids(items): + return [item.id for item in items] + + +def test_seedvr2_post_processing_schema(): + schema = nodes_seedvr.SeedVR2PostProcessing.define_schema() + + assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"] + assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"] + assert schema.inputs[2].default == "lab" + assert schema.outputs[0].get_io_type() == "IMAGE" + + +def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch): + decoded = torch.full((1, 3, 4, 4), 0.25) + reference = torch.full((1, 3, 4, 4), 0.75) + + def _lab(content, style): + raise torch.cuda.OutOfMemoryError("CUDA out of memory") + + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000) + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None) + + with patch.object(nodes_seedvr, "lab_color_transfer", _lab): + try: + nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked( + decoded, reference, torch.device("cpu"), "lab", + ) + except RuntimeError as exc: + assert "color_correction_method=lab" in str(exc) + assert " method=lab" not in str(exc) + else: + raise AssertionError("expected RuntimeError for one-frame LAB OOM") + + +def test_seedvr2_post_processing_unknown_color_correction_method_raises(): + decoded = torch.zeros(1, 2, 4, 4, 3) + original = torch.zeros(1, 2, 4, 4, 3) + try: + nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus") + except ValueError as exc: + assert "color_correction_method" in str(exc) + else: + raise AssertionError("expected ValueError for unknown color_correction_method") diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py index 4e9350602d7a..c63f69a0df11 100644 --- a/tests-unit/comfy_test/model_detection_test.py +++ b/tests-unit/comfy_test/model_detection_test.py @@ -73,6 +73,24 @@ def _make_flux_schnell_comfyui_sd(): return sd +def _make_seedvr2_7b_separate_mm_sd(): + return { + "blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072), + } + + +def _make_seedvr2_7b_shared_mm_sd(): + return { + "blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1), + } + + +def _make_seedvr2_3b_shared_mm_sd(): + return { + "blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1), + } + + class TestModelDetection: """Verify that first-match model detection selects the correct model based on list ordering and unet_config specificity.""" @@ -125,6 +143,48 @@ def test_flux_schnell_comfyui_detected_as_flux_schnell(self): assert model_config is not None assert type(model_config).__name__ == "FluxSchnell" + def test_seedvr2_7b_separate_mm_detection_config(self): + sd = _make_seedvr2_7b_separate_mm_sd() + unet_config = detect_unet_config(sd, "") + + assert unet_config is not None + assert unet_config["image_model"] == "seedvr2" + assert unet_config["vid_dim"] == 3072 + assert unet_config["heads"] == 24 + assert unet_config["num_layers"] == 36 + assert unet_config["mm_layers"] == 36 + assert unet_config["mlp_type"] == "normal" + assert unet_config["qk_rope"] is True + assert unet_config["rope_type"] == "rope3d" + assert unet_config["rope_dim"] == 64 + + def test_seedvr2_7b_shared_mm_detection_config(self): + sd = _make_seedvr2_7b_shared_mm_sd() + unet_config = detect_unet_config(sd, "") + + assert unet_config is not None + assert unet_config["image_model"] == "seedvr2" + assert unet_config["vid_dim"] == 3072 + assert unet_config["heads"] == 24 + assert unet_config["num_layers"] == 36 + assert unet_config["mm_layers"] == 10 + assert unet_config["mlp_type"] == "swiglu" + assert unet_config["qk_rope"] is True + assert unet_config["rope_type"] == "rope3d" + assert unet_config["rope_dim"] == 64 + + def test_seedvr2_3b_shared_mm_detection_config(self): + sd = _make_seedvr2_3b_shared_mm_sd() + unet_config = detect_unet_config(sd, "") + + assert unet_config is not None + assert unet_config["image_model"] == "seedvr2" + assert unet_config["vid_dim"] == 2560 + assert unet_config["heads"] == 20 + assert unet_config["num_layers"] == 32 + assert unet_config["mlp_type"] == "swiglu" + assert unet_config["qk_rope"] is None + def test_unet_config_and_required_keys_combination_is_unique(self): """Each model in the registry must have a unique combination of ``unet_config`` and ``required_keys``. If two models share the same diff --git a/tests-unit/comfy_test/seedvr_vae_forward_test.py b/tests-unit/comfy_test/seedvr_vae_forward_test.py new file mode 100644 index 000000000000..f9dbd68906d3 --- /dev/null +++ b/tests-unit/comfy_test/seedvr_vae_forward_test.py @@ -0,0 +1,90 @@ +"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must +honor the actual tensor/tuple return contract of ``encode()`` and +``decode_()`` and must NOT dereference diffusers-style ``.latent_dist`` +or ``.sample`` attributes on those returns. + +The pre-fix body raised ``AttributeError: 'Tensor' object has no +attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and +``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'`` +for ``mode == "decode"`` (the class only defines ``decode_`` with a +trailing underscore). The post-fix body unwraps the optional one-element +tuple shape that ``return_dict=False`` produces and returns the tensor +directly. + +Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses +the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and +overrides ``encode``/``decode_`` with known tensors so the contract can +be probed without loading any real VAE weights. +""" + +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402 + + +_LATENT_SHAPE = (1, 16, 2, 2, 2) +_DECODED_SHAPE = (1, 3, 5, 16, 16) +_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16) +_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2) + + +class _StubVAE(VideoAutoencoderKL): + def __init__(self): + nn.Module.__init__(self) + self._encode_out = torch.zeros(*_LATENT_SHAPE) + self._decode_out = torch.zeros(*_DECODED_SHAPE) + + def encode(self, x, return_dict=True): + return self._encode_out + + def decode_(self, z, return_dict=True): + return self._decode_out + + +def test_forward_encode_returns_tensor(): + vae = _StubVAE() + x = torch.zeros(*_INPUT_ENCODE_SHAPE) + result = vae.forward(x, mode="encode") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_LATENT_SHAPE) + + +def test_forward_decode_returns_tensor(): + vae = _StubVAE() + z = torch.zeros(*_INPUT_DECODE_SHAPE) + result = vae.forward(z, mode="decode") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_DECODED_SHAPE) + + +class _TupleReturningStubVAE(VideoAutoencoderKL): + """Stub variant whose ``encode``/``decode_`` return the + ``(tensor,)`` one-element tuple shape ``return_dict=False`` produces + in the parent class. Exercises the unwrap branch of + ``VideoAutoencoderKL.forward``. + """ + + def __init__(self): + nn.Module.__init__(self) + self._encode_tensor = torch.zeros(*_LATENT_SHAPE) + self._decode_tensor = torch.zeros(*_DECODED_SHAPE) + + def encode(self, x, return_dict=True): + return (self._encode_tensor,) + + def decode_(self, z, return_dict=True): + return (self._decode_tensor,) + + +def test_forward_all_unwraps_one_tuple_at_each_step(): + vae = _TupleReturningStubVAE() + x = torch.zeros(*_INPUT_ENCODE_SHAPE) + result = vae.forward(x, mode="all") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_DECODED_SHAPE) diff --git a/tests-unit/comfy_test/test_seedvr2_dtype.py b/tests-unit/comfy_test/test_seedvr2_dtype.py new file mode 100644 index 000000000000..e5d79a306b72 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_dtype.py @@ -0,0 +1,47 @@ +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.sd +import comfy.supported_models +import comfy.ldm.seedvr.model as seedvr_model + + +def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch): + bf16_device = object() + fp16_device = object() + + monkeypatch.setattr( + comfy.supported_models.comfy.model_management, + "should_use_bf16", + lambda device=None: device is bf16_device, + ) + + bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"}) + bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device) + assert bf16_config.manual_cast_dtype is torch.bfloat16 + + fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"}) + fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device) + assert fp16_config.manual_cast_dtype is None + + +def test_seedvr2_text_conditioning_accepts_cfg1_single_branch(): + context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2) + + txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0]) + + torch.testing.assert_close(txt, context.squeeze(0)) + torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device)) + + +def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer(): + estimate = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160)) + old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2 + + assert estimate == 101 * 960 * 1280 * 160 + assert estimate > 15 * 1024 ** 3 + assert estimate > old_estimate * 100 diff --git a/tests-unit/comfy_test/test_seedvr2_internals.py b/tests-unit/comfy_test/test_seedvr2_internals.py new file mode 100644 index 000000000000..5b008ea6e97b --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_internals.py @@ -0,0 +1,341 @@ +"""Consolidated SeedVR2 internals regression tests. + +Sources (all merged verbatim, helper names disambiguated where colliding): + + * RoPE rewrite — NaMMRotaryEmbedding3d.forward must match the legacy + apply_rotary_emb wrapper oracle at fp32. + * GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare + memory_occupy against get_norm_limit(), not float('inf'). + * SeedVR2 variable-length attention split-loop contract. + +Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and +comfy.ldm.modules.attention transitively pull in comfy.model_management, +which probes torch.cuda.current_device() at import time unless args.cpu is +set first. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest +import torch + +from comfy.cli_args import args + +if not torch.cuda.is_available(): + args.cpu = True + +import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 +import comfy.ldm.modules.attention as attention # noqa: E402 +import comfy.ops as comfy_ops # noqa: E402 +from comfy.ldm.seedvr.model import ( # noqa: E402 + Cache, + NaMMRotaryEmbedding3d, +) +from comfy.ldm.seedvr.vae import ( # noqa: E402 + causal_norm_wrapper, + set_norm_limit, +) +from comfy.ldm.modules.attention import var_attention_optimized_split # noqa: E402 + + +# --------------------------------------------------------------------------- +# RoPE rewrite tests (test_seedvr_rope_rewrite.py) +# --------------------------------------------------------------------------- + +# Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains +# requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8. +_DIM = 192 +_HEADS = 4 +_VID_T, _VID_H, _VID_W = 2, 4, 4 +_TXT_L = 8 +_L_VID = _VID_T * _VID_H * _VID_W +_SEED = 0 + + +def _make_inputs(dtype=torch.float32, device="cpu"): + """Construct the 6 forward inputs + cache. Deterministic via local + Generator so global RNG state is not mutated. + """ + g = torch.Generator(device=device).manual_seed(_SEED) + vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g) + vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g) + txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g) + txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g) + vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device) + txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device) + cache = Cache(disable=True) + return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache + + +def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape): + """Reproduce the pre-rewrite ``get_freqs`` body verbatim against + ``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method, + unchanged by the rewrite). + """ + max_temporal = 0 + max_height = 0 + max_width = 0 + max_txt_len = 0 + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + max_temporal = max(max_temporal, l + f) + max_height = max(max_height, h) + max_width = max(max_width, w) + max_txt_len = max(max_txt_len, l) + with torch.amp.autocast(device_type="cuda", enabled=False): + vid_freqs_full = rope.get_axial_freqs( + min(max_temporal + 16, 1024), + min(max_height + 4, 128), + min(max_width + 4, 128), + ).float() + txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024)) + vid_freq_list, txt_freq_list = [], [] + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1)) + txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1)) + vid_freq_list.append(vid_freq) + txt_freq_list.append(txt_freq) + return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) + + +def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape, + txt_q, txt_k, txt_shape): + """Compute expected forward output via the unchanged + ``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the + oracle. The wrapper itself is out of scope for the rewrite (Shape B). + """ + vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape) + vid_freqs = vid_freqs.to(vid_q.device) + txt_freqs = txt_freqs.to(txt_q.device) + + from einops import rearrange + + vid_q = rearrange(vid_q, "L h d -> h L d") + vid_k = rearrange(vid_k, "L h d -> h L d") + vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) + vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) + vid_q_out = rearrange(vid_q_out, "h L d -> L h d") + vid_k_out = rearrange(vid_k_out, "h L d -> L h d") + + txt_q = rearrange(txt_q, "L h d -> h L d") + txt_k = rearrange(txt_k, "L h d -> h L d") + txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) + txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) + txt_q_out = rearrange(txt_q_out, "h L d -> L h d") + txt_k_out = rearrange(txt_k_out, "h L d -> L h d") + return vid_q_out, vid_k_out, txt_q_out, txt_k_out + + +def test_namm_forward_output_tensor_equal_against_legacy_oracle(): + rope = NaMMRotaryEmbedding3d(dim=_DIM) + vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs() + + expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward( + rope, + vid_q.clone(), vid_k.clone(), vid_shape, + txt_q.clone(), txt_k.clone(), txt_shape, + ) + + actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward( + vid_q.clone(), vid_k.clone(), vid_shape, + txt_q.clone(), txt_k.clone(), txt_shape, cache, + ) + + torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0, + msg="vid_q output diverges from wrapper oracle") + torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0, + msg="vid_k output diverges from wrapper oracle") + torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0, + msg="txt_q output diverges from wrapper oracle") + torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0, + msg="txt_k output diverges from wrapper oracle") + + +# --------------------------------------------------------------------------- +# GroupNorm limit tests (test_seedvr_groupnorm_limit.py) +# --------------------------------------------------------------------------- + +_NUM_CHANNELS = 8 +_NUM_GROUPS = 4 +_TENSOR_SHAPE = (1, 8, 2, 4, 4) + +_GROUPNORM_SUBCLASSES = [ + pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"), + pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"), +] + + +@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES) +def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls): + real_group_norm = vae_mod.F.group_norm + set_norm_limit(1e-9) + try: + gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS) + gn.eval() + + forward_hook_calls = [] + + def _hook(module, inputs, output): + forward_hook_calls.append(tuple(inputs[0].shape)) + + spy_calls = [] + + def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs): + spy_calls.append({"num_groups": int(num_groups_arg)}) + return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs) + + handle = gn.register_forward_hook(_hook) + try: + with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy): + out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE)) + finally: + handle.remove() + + full_calls = len(forward_hook_calls) + chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS) + + assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE + assert full_calls == 0, ( + f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}" + ) + assert chunked_calls > 0, ( + f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}" + ) + finally: + set_norm_limit(None) + + +# --------------------------------------------------------------------------- +# SeedVR2 var_attention split-loop tests +# --------------------------------------------------------------------------- + +def test_var_attention_registry_contains_always_available_entries(): + assert ( + attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_optimized_split"] + is attention.var_attention_optimized_split + ) + + +def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch): + dim = 8 + heads = 2 + head_dim = 4 + attn = seedvr_model.NaSwinAttention( + vid_dim=dim, + txt_dim=dim, + heads=heads, + head_dim=head_dim, + qk_bias=False, + qk_norm=seedvr_model.CustomRMSNorm, + qk_norm_eps=1e-6, + rope_type=None, + rope_dim=head_dim, + shared_weights=False, + window=(2, 1, 1), + window_method="720pwin_by_size_bysize", + version=True, + device="cpu", + dtype=torch.float32, + operations=comfy_ops.disable_weight_init, + ) + generator = torch.Generator(device="cpu").manual_seed(11) + vid = torch.randn(8, dim, generator=generator) + txt = torch.randn(3, dim, generator=generator) + vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long) + txt_shape = torch.tensor([[3]], dtype=torch.long) + calls = [] + + def fake_optimized_var_attention(**kwargs): + calls.append(kwargs) + return kwargs["q"] + + monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention) + + vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True)) + + assert tuple(vid_out.shape) == (8, dim) + assert tuple(txt_out.shape) == (3, dim) + assert len(calls) == 1 + call = calls[0] + assert tuple(call["q"].shape) == (14, heads, head_dim) + assert tuple(call["k"].shape) == (14, heads, head_dim) + assert tuple(call["v"].shape) == (14, heads, head_dim) + assert call["heads"] == heads + assert call["skip_reshape"] is True + assert call["skip_output_reshape"] is True + torch.testing.assert_close( + call["cu_seqlens_q"], + torch.tensor([0, 7, 14], dtype=torch.int32), + rtol=0, + atol=0, + ) + torch.testing.assert_close( + call["cu_seqlens_k"], + torch.tensor([0, 7, 14], dtype=torch.int32), + rtol=0, + atol=0, + ) + + +def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch): + heads = 2 + head_dim = 3 + q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim) + k = q + 100 + v = q + 200 + cu = torch.tensor([0, 2, 5], dtype=torch.int32) + calls = [] + + def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs): + calls.append( + { + "q_shape": tuple(q_arg.shape), + "k_shape": tuple(k_arg.shape), + "v_shape": tuple(v_arg.shape), + "heads": heads_arg, + "kwargs": kwargs, + } + ) + return q_arg + v_arg + + monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention) + + out = var_attention_optimized_split( + q, + k, + v, + heads, + cu, + cu, + skip_reshape=True, + skip_output_reshape=True, + ) + + assert tuple(out.shape) == (5, heads, head_dim) + assert len(calls) == 2 + assert calls[0]["q_shape"] == (1, heads, 2, head_dim) + assert calls[1]["q_shape"] == (1, heads, 3, head_dim) + assert all(call["heads"] == heads for call in calls) + assert all(call["kwargs"]["skip_reshape"] is True for call in calls) + assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls) + torch.testing.assert_close(out, q + v, rtol=0, atol=0) + + +def test_var_attention_optimized_split_rejects_bad_offsets(): + q = torch.randn(5, 2, 3) + cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32) + cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32) + + with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"): + var_attention_optimized_split( + q, + q, + q, + 2, + cu_bad, + cu_ok, + skip_reshape=True, + skip_output_reshape=True, + ) diff --git a/tests-unit/comfy_test/test_seedvr2_model.py b/tests-unit/comfy_test/test_seedvr2_model.py new file mode 100644 index 000000000000..f2b9bcbbec8b --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_model.py @@ -0,0 +1,308 @@ +"""Consolidated SeedVR2 model/graph/forward regression tests. + +Merged from: +- seedvr_model_test.py +- test_seedvr_7b_final_block_text_path.py +- test_seedvr_forward_no_device_cast.py +- test_seedvr_latent_format.py +- test_seedvr2_vae_graph_boundaries.py +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import torch +from torch import nn + +from comfy.cli_args import args + +if not torch.cuda.is_available(): + args.cpu = True + +import comfy # noqa: E402 +import comfy.latent_formats # noqa: E402 +import comfy.ldm.seedvr.model # noqa: E402 +import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 +import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 +import comfy.model_management # noqa: E402 +import comfy.sample # noqa: E402 +import comfy.sd as sd_mod # noqa: E402 +import nodes as nodes_mod # noqa: E402 +from comfy.ldm.seedvr.model import NaDiT # noqa: E402 + + +# --------------------------------------------------------------------------- +# Helpers from seedvr_model_test.py +# --------------------------------------------------------------------------- + + +def _make_standin(positive_conditioning): + class _StandIn(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "positive_conditioning", positive_conditioning + ) + + _resolve_text_conditioning = NaDiT._resolve_text_conditioning + + return _StandIn() + + +# --------------------------------------------------------------------------- +# Helpers from test_seedvr_7b_final_block_text_path.py +# --------------------------------------------------------------------------- + + +class _StubModule(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + +def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]: + flags = [] + + class _Block(_StubModule): + def __init__(self, *args, **kwargs): + flags.append(kwargs["is_last_layer"]) + super().__init__() + + monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule) + monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule) + monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule) + monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block) + + seedvr_model.NaDiT( + norm_eps=1e-5, + qk_rope=None, + num_layers=4, + mlp_type="normal", + vid_dim=vid_dim, + txt_in_dim=txt_in_dim, + heads=24, + mm_layers=3, + ) + + return flags + + +# --------------------------------------------------------------------------- +# Helpers from test_seedvr_latent_format.py +# --------------------------------------------------------------------------- + + +class _Model: + def __init__(self, latent_format): + self._latent_format = latent_format + + def get_model_object(self, name): + assert name == "latent_format" + return self._latent_format + + +# --------------------------------------------------------------------------- +# Helpers from test_seedvr2_vae_graph_boundaries.py +# --------------------------------------------------------------------------- + + +class _Patcher: + def get_free_memory(self, device): + return 1024 * 1024 * 1024 + + +class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): + def __init__(self, encoded): + nn.Module.__init__(self) + self.encoded = encoded + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.seen = [] + + def encode(self, x): + self.seen.append(tuple(x.shape)) + return self.encoded.to(device=x.device, dtype=x.dtype) + + +class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): + def __init__(self): + nn.Module.__init__(self) + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.calls = [] + + def decode(self, z, seedvr2_tiling=None): + self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling}) + if z.ndim == 4: + b, tc, h, w = z.shape + t = tc // 16 + else: + b, _, t, h, w = z.shape + return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device) + + +def _make_vae(wrapper): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = wrapper + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.latent_channels = 16 + vae.latent_dim = 3 + vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8) + vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + vae.output_channels = 3 + vae.disable_offload = True + vae.extra_1d_channel = None + vae.crop_input = False + vae.not_video = False + vae.patcher = _Patcher() + vae.process_input = lambda image: image + vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0) + vae.vae_output_dtype = lambda: torch.float32 + vae.memory_used_encode = lambda shape, dtype: 1 + vae.memory_used_decode = lambda shape, dtype: 1 + vae.throw_exception_if_invalid = lambda: None + vae.vae_encode_crop_pixels = lambda pixels: pixels + vae.spacial_compression_decode = lambda: 8 + vae.temporal_compression_decode = lambda: 4 + return vae + + +# --------------------------------------------------------------------------- +# Tests from seedvr_model_test.py +# --------------------------------------------------------------------------- + + +def test_missing_context_falls_back_to_positive_buffer(): + """AC: ``context is None`` falls back to the registered + ``positive_conditioning`` buffer and runs to completion — no + silent zero substitution, no raised exception. + """ + pos_buffer = torch.full((58, 5120), 7.0) + standin = _make_standin(pos_buffer) + txt, txt_shape = standin._resolve_text_conditioning(None) + assert txt.shape == (58, 5120) + assert (txt == 7.0).all(), ( + "fallback path must use the positive_conditioning buffer " + "verbatim, not a zero tensor" + ) + assert txt_shape.shape == (1, 1) + assert txt_shape[0, 0].item() == 58 + + +# --------------------------------------------------------------------------- +# Tests from test_seedvr_7b_final_block_text_path.py +# --------------------------------------------------------------------------- + + +def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch): + assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [ + False, + False, + False, + False, + ] + + +def test_seedvr2_7b_rope3d_matches_wrapper_oracle(): + rope = seedvr_model.get_na_rope("rope3d", dim=64) + generator = torch.Generator(device="cpu").manual_seed(0) + q = torch.randn(4, 2, 128, generator=generator) + k = torch.randn(4, 2, 128, generator=generator) + shape = torch.tensor([[1, 2, 2]], dtype=torch.long) + freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1) + + expected_q = seedvr_model._apply_seedvr2_rotary_emb( + freqs, + q.permute(1, 0, 2).float(), + ).to(q.dtype).permute(1, 0, 2) + expected_k = seedvr_model._apply_seedvr2_rotary_emb( + freqs, + k.permute(1, 0, 2).float(), + ).to(k.dtype).permute(1, 0, 2) + + actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True)) + + torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0) + torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0) + + +# --------------------------------------------------------------------------- +# Tests from test_seedvr_latent_format.py +# --------------------------------------------------------------------------- + + +def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion(): + latent_format = comfy.latent_formats.SeedVR2() + latent_image = torch.zeros(1, 1, 4, 5) + + fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image) + + assert latent_format.latent_channels == 16 + assert latent_format.latent_dimensions == 2 + assert fixed.shape == (1, 16, 4, 5) + + +# --------------------------------------------------------------------------- +# Tests from test_seedvr2_vae_graph_boundaries.py +# --------------------------------------------------------------------------- + + +def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch): + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + + encoded = torch.full((1, 16, 2, 4, 5), 2.0) + vae = _make_vae(_EncodeWrapper(encoded)) + pixels = torch.zeros(1, 5, 32, 40, 3) + + node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0] + node_latent = node_output["samples"] + assert set(node_output) == {"samples"} + assert tuple(node_latent.shape) == (1, 16, 2, 4, 5) + assert node_latent.dtype == torch.float32 + assert node_latent.stride()[-1] == 1 + assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152)) + + tiled = torch.full((1, 16, 2, 4, 5), 3.0) + monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled)) + tiled_output = nodes_mod.VAEEncodeTiled().encode( + vae, + pixels, + tile_size=512, + overlap=64, + temporal_size=16, + temporal_overlap=4, + )[0] + tiled_latent = tiled_output["samples"] + assert set(tiled_output) == {"samples"} + assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5) + assert tiled_latent.dtype == torch.float32 + assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152)) + + +def test_vaedecode_tiled_visible_inputs_are_seedvr2_decode_tiling_authority(monkeypatch): + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + vae = _make_vae(_DecodeWrapper()) + + nodes_mod.VAEDecodeTiled().decode( + vae, + {"samples": torch.zeros(1, 16, 2, 4, 5)}, + tile_size=512, + overlap=64, + temporal_size=16, + temporal_overlap=4, + ) + + assert vae.first_stage_model.calls == [ + { + "shape": (1, 16, 2, 4, 5), + "seedvr2_tiling": { + "enable_tiling": True, + "tile_size": (512, 512), + "tile_overlap": (64, 64), + "temporal_size": 16, + "temporal_overlap": 4, + }, + } + ] diff --git a/tests-unit/comfy_test/test_seedvr2_vae_decode.py b/tests-unit/comfy_test/test_seedvr2_vae_decode.py new file mode 100644 index 000000000000..ea9f978f38b9 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_vae_decode.py @@ -0,0 +1,91 @@ +from unittest.mock import patch + +import pytest +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 +from comfy_extras import nodes_seedvr # noqa: E402 + + +def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper: + wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( + vae_mod.VideoAutoencoderKLWrapper + ) + nn.Module.__init__(wrapper) + return wrapper + + +def _fingerprint_decode_(self, z, return_dict=True): + b = int(z.shape[0]) + t = int(z.shape[2]) + h = int(z.shape[3]) + w = int(z.shape[4]) + out = torch.empty(b, 3, t, h * 8, w * 8) + for batch_idx in range(b): + out[batch_idx].fill_(float(batch_idx + 1)) + return out + + +def _decode_with_patches(wrapper, z): + with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_): + return wrapper.decode(z) + + +def test_decode_b2_t3_multi_frame_batch_unchanged(): + wrapper = _make_wrapper() + + out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2)) + + assert tuple(out.shape) == (2, 3, 3, 16, 16) + + +class _Wrapper(vae_mod.VideoAutoencoderKLWrapper): + def __init__(self): + nn.Module.__init__(self) + self.calls = [] + + def parameters(self): + return iter([torch.nn.Parameter(torch.zeros(()))]) + +def _decode_stub(self, latent): + self.calls.append(tuple(latent.shape)) + return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8) + + +def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state(): + wrapper = _Wrapper() + + with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub): + out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5)) + + assert tuple(out.shape) == (1, 3, 2, 32, 40) + assert wrapper.calls == [(1, 16, 2, 4, 5)] + + +def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents(): + wrapper = _Wrapper() + + with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"): + wrapper.decode(torch.zeros(1, 16, 4)) + + +def _t_padded(t_in: int) -> int: + if t_in == 1: + return 1 + if t_in <= 4: + return 5 + if (t_in - 1) % 4 == 0: + return t_in + return t_in + (4 - ((t_in - 1) % 4)) + + +@pytest.mark.parametrize("t_in", [1, 5, 9]) +def test_t_padded_matches_cut_videos(t_in): + dummy = torch.zeros(1, t_in, 1, 1, 1) + assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in) diff --git a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py new file mode 100644 index 000000000000..40079bbe2c47 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py @@ -0,0 +1,347 @@ +from contextlib import ExitStack +from unittest.mock import MagicMock, patch + +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 +import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 +import comfy.sd as sd_mod # noqa: E402 +from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402 + + +# --------------------------------------------------------------------------- +# From test_seedvr_vae_tiled_decode_latent_min_size_override.py +# --------------------------------------------------------------------------- + + +def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): + from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae + + class StubVAEModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.slicing_latent_min_size = 2 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self.use_slicing = True + self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.decode_min_sizes = [] + self.memory_states = [] + + def decode_(self, t_chunk): + self.decode_min_sizes.append(self.slicing_latent_min_size) + return VideoAutoencoderKL.slicing_decode(self, t_chunk) + + def _decode(self, z, memory_state=MemoryState.DISABLED): + self.memory_states.append(memory_state) + b, c, d, h, w = z.shape + return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype) + + vae = StubVAEModel() + z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32) + + tiled_vae( + z, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=False, + ) + + assert vae.decode_min_sizes == [5] + assert vae.memory_states == [MemoryState.DISABLED] + assert vae.slicing_latent_min_size == 2 + + +# --------------------------------------------------------------------------- +# From test_seedvr_vae_tiled_encode_runt_slice_override.py +# --------------------------------------------------------------------------- + + +def test_zero_temporal_size_preserves_min_size_when_encode_raises(): + from comfy.ldm.seedvr.vae import tiled_vae + + class RaisingVAEModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.slicing_sample_min_size = 4 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + + def encode(self, t_chunk): + raise RuntimeError("simulated encode failure") + + vae = RaisingVAEModel() + x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32) + + raised = False + try: + tiled_vae( + x, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=True, + ) + except RuntimeError as exc: + if "simulated encode failure" not in str(exc): + raise + raised = True + + assert raised + assert vae.slicing_sample_min_size == 4 + + +# --------------------------------------------------------------------------- +# From test_seedvr_vae_tiled_temporal_slicing.py +# --------------------------------------------------------------------------- + + +class _SlicingDecodeVAE(nn.Module): + def __init__(self, slicing_latent_min_size): + super().__init__() + self.slicing_latent_min_size = slicing_latent_min_size + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self.use_slicing = True + self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.decode_min_sizes = [] + self.memory_states = [] + + def decode_(self, z): + self.decode_min_sizes.append(self.slicing_latent_min_size) + return vae_mod.VideoAutoencoderKL.slicing_decode(self, z) + + def _decode(self, z, memory_state=MemoryState.DISABLED): + self.memory_states.append(memory_state) + x = z[:, :1].repeat( + 1, + 3, + 1, + self.spatial_downsample_factor, + self.spatial_downsample_factor, + ) + return x + + +def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size(): + vae = _SlicingDecodeVAE(slicing_latent_min_size=2) + z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8) + + tiled_vae( + z, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=12, + temporal_overlap=4, + encode=False, + ) + + assert vae.decode_min_sizes == [2] + assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE] + assert vae.slicing_latent_min_size == 2 + + wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( + vae_mod.VideoAutoencoderKLWrapper + ) + nn.Module.__init__(wrapper) + seedvr2_tiling = { + "enable_tiling": True, + "tile_size": (64, 64), + "tile_overlap": (0, 0), + "temporal_size": 8, + "temporal_overlap": 7, + } + + captured = {} + + def _fake_tiled_vae(latent, model, **kwargs): + captured.update(kwargs) + return torch.zeros(1, 3, 1, 16, 16) + + with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae): + wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling) + + assert captured["temporal_overlap"] == 7 + + +# --------------------------------------------------------------------------- +# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py +# --------------------------------------------------------------------------- + + +def _force_oom(*a, **k): + raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") + + +def _make_vae(first_stage_model, latent_channels, latent_dim): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = first_stage_model + vae.patcher = MagicMock() + vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) + vae.device = vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.disable_offload = True + vae.extra_1d_channel = None + vae.upscale_ratio = vae.downscale_ratio = 8 + vae.upscale_index_formula = vae.downscale_index_formula = None + vae.output_channels = 3 + vae.latent_channels = latent_channels + vae.latent_dim = latent_dim + vae.vae_output_dtype = lambda: torch.float32 + vae.spacial_compression_decode = lambda: 8 + vae.process_input = lambda x: x + vae.process_output = lambda x: x + vae.throw_exception_if_invalid = lambda: None + vae.memory_used_decode = lambda *a, **k: 1 + return vae + + +def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode): + mm = sd_mod.model_management + with ExitStack() as stack: + stack.enter_context(patch.object(mm, "raise_non_oom", lambda e: None)) + stack.enter_context(patch.object(mm, "load_models_gpu", lambda *a, **k: None)) + stack.enter_context(patch.object(mm, "soft_empty_cache", lambda: None)) + stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call)) + stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_", generic_call)) + if patch_wrapper_decode: + stack.enter_context(patch.object( + seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode", + side_effect=_force_oom)) + vae.decode(samples) + + +def test_4d_seedvr2_latent_routes_to_decode_tiled_seedvr2(): + wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( + seedvr_vae_mod.VideoAutoencoderKLWrapper) + vae = _make_vae(wrapper, latent_channels=16, latent_dim=3) + seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) + generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) + _dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True) + assert seedvr2_call.call_count == 1 + assert generic_call.call_count == 0 + + +def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled(): + first_stage = MagicMock() + first_stage.decode = MagicMock(side_effect=_force_oom) + vae = _make_vae(first_stage, latent_channels=4, latent_dim=2) + seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) + generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) + _dispatch(vae, torch.zeros(1, 4, 8, 8), seedvr2_call, generic_call, False) + assert generic_call.call_count == 1 + assert seedvr2_call.call_count == 0 + + +# --------------------------------------------------------------------------- +# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py +# --------------------------------------------------------------------------- + + +def _populate_common_vae_attrs_fallback(vae): + vae.patcher = MagicMock() + vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.disable_offload = True + vae.extra_1d_channel = None + vae.upscale_ratio = 8 + vae.upscale_index_formula = None + vae.output_channels = 3 + vae.latent_channels = 16 + vae.latent_dim = 3 + vae.downscale_ratio = 8 + vae.downscale_index_formula = None + vae.not_video = False + vae.crop_input = False + vae.pad_channel_value = None + + vae.vae_output_dtype = lambda: torch.float32 + vae.spacial_compression_encode = lambda: 8 + vae.process_input = lambda x: x + vae.process_output = lambda x: x + vae.throw_exception_if_invalid = lambda: None + vae.memory_used_encode = lambda *a, **k: 1 + + +def _make_seedvr2_vae_fallback(): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( + seedvr_vae_mod.VideoAutoencoderKLWrapper + ) + vae.first_stage_model = wrapper + _populate_common_vae_attrs_fallback(vae) + return vae + + +def _make_non_seedvr2_vae_fallback(): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = MagicMock() + _populate_common_vae_attrs_fallback(vae) + return vae + + +def _force_regular_encode_oom(*args, **kwargs): + raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") + + +def test_seedvr2_3d_routes_to_encode_tiled_seedvr2_on_oom(): + vae = _make_seedvr2_vae_fallback() + pixel_samples = torch.zeros((1, 8, 64, 64, 3)) + + seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + + with patch.object(sd_mod.model_management, "raise_non_oom", + lambda e: None), \ + patch.object(sd_mod.model_management, "load_models_gpu", + lambda *a, **k: None), \ + patch.object(sd_mod.model_management, "soft_empty_cache", + lambda: None), \ + patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode", + side_effect=_force_regular_encode_oom), \ + patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call, + create=True), \ + patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): + vae.encode(pixel_samples) + + assert seedvr2_call.call_count == 1, ( + f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D " + f"input under OOM fallback; got {seedvr2_call.call_count} calls." + ) + assert generic_call.call_count == 0, ( + f"encode_tiled_3d must NOT be called for a SeedVR2 input; got " + f"{generic_call.call_count} calls." + ) + + +def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete(): + vae = _make_non_seedvr2_vae_fallback() + vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8) + vae.upscale_ratio = (lambda a: a * 4, 8, 8) + generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + pixel_samples = torch.zeros((1, 8, 64, 64, 3)) + + with patch.object(sd_mod.model_management, "load_models_gpu", + lambda *a, **k: None), \ + patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): + vae.encode_tiled(pixel_samples) + + assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64) diff --git a/tests-unit/comfy_test/test_seedvr_progressive_sampler.py b/tests-unit/comfy_test/test_seedvr_progressive_sampler.py new file mode 100644 index 000000000000..05291989edfa --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_progressive_sampler.py @@ -0,0 +1,126 @@ +"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``.""" + +from unittest.mock import patch + +import pytest +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.sample # noqa: E402 +import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402 +from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402 + +_LAT_C = 16 +_COND_C = 17 + + +def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8): + """Build minimal SeedVR2-shaped sampling inputs.""" + samples_5d = torch.arange( + B * _LAT_C * T * H * W, dtype=torch.float32 + ).reshape(B, _LAT_C, T, H, W) + samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous() + + cond_5d = torch.arange( + B * _COND_C * T * H * W, dtype=torch.float32 + ).reshape(B, _COND_C, T, H, W) + 10000.0 + cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous() + + text_pos = torch.zeros(1, 4, 32) + text_neg = torch.zeros(1, 4, 32) + positive = [[text_pos, {"condition": cond.clone()}]] + negative = [[text_neg, {"condition": cond.clone()}]] + latent_image = {"samples": samples} + return latent_image, positive, negative, samples_5d, cond_5d + + +def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None): + return latent_image + + +def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None): + """Return a tensor whose values encode ``(seed, position)``.""" + base = torch.arange( + latent_image.numel(), dtype=torch.float32 + ).reshape(latent_image.shape) + return base + float(seed) * 1e6 + + +def test_progressive_sampler_schema_exposes_manual_default_auto_chunking(): + schema = SeedVR2ProgressiveSampler.define_schema() + inputs = {item.id: item for item in schema.inputs} + + assert inputs["chunking_mode"].options == ["manual", "auto"] + assert inputs["chunking_mode"].default == "manual" + + +def test_auto_chunking_walks_two_three_four_chunk_ladder(): + """Auto mode must walk 2-, 3-, then 4-chunk geometries on OOM.""" + latent, pos, neg, _, _ = _make_inputs(T=17) + calls = [] + + def _oom_until_four_chunks(model, noise, steps, cfg, sampler_name, + scheduler, positive, negative, + latent_image, denoise=1.0, + noise_mask=None, seed=None): + calls.append(tuple(latent_image.shape)) + if latent_image.shape[1] > _LAT_C * 5: + raise torch.cuda.OutOfMemoryError("chunk too large") + return latent_image.clone() + + with patch.object(comfy.sample, "sample", + side_effect=_oom_until_four_chunks), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise), \ + patch.object(nodes_seedvr_mod.comfy.model_management, + "soft_empty_cache") as soft_empty: + out = SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent=latent, + denoise=1.0, frames_per_chunk=65, temporal_overlap=0, + chunking_mode="auto", + ) + + assert calls[:4] == [ + (1, _LAT_C * 17, 8, 8), + (1, _LAT_C * 9, 8, 8), + (1, _LAT_C * 6, 8, 8), + (1, _LAT_C * 5, 8, 8), + ] + assert torch.equal(out.result[0]["samples"], latent["samples"]) + assert soft_empty.call_count == 3 + + +@pytest.mark.parametrize("bad_chunk", [0, -1, 2]) +def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk): + """``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation.""" + latent, pos, neg, _, _ = _make_inputs(T=5) + + sampler_called = {"n": 0} + + def _should_not_be_called(*args, **kwargs): + sampler_called["n"] += 1 + return torch.zeros(1) + + with patch.object(comfy.sample, "sample", + side_effect=_should_not_be_called), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + with pytest.raises(ValueError) as excinfo: + SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent=latent, + denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0, + ) + assert str(bad_chunk) in str(excinfo.value) + assert sampler_called["n"] == 0 From 38f750d80e3ded90eb4e2e6a16b48ab9240587c4 Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Mon, 8 Jun 2026 22:58:52 +0800 Subject: [PATCH 19/45] chore: update embedded docs to v0.5.3 (#14350) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 613553d8f614..a49d968af491 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ comfyui-frontend-package==1.45.15 comfyui-workflow-templates==0.9.98 -comfyui-embedded-docs==0.5.2 +comfyui-embedded-docs==0.5.3 torch torchsde torchvision From fc258b10e54d71dca16f2b1b1e127614d2620817 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Mon, 8 Jun 2026 19:30:28 +0300 Subject: [PATCH 20/45] Add Color primitive (#14260) --- comfy_extras/nodes_color.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_color.py b/comfy_extras/nodes_color.py index 01a05035e2a3..688254e4eb22 100644 --- a/comfy_extras/nodes_color.py +++ b/comfy_extras/nodes_color.py @@ -7,29 +7,29 @@ class ColorToRGBInt(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="ColorToRGBInt", - display_name="Color to RGB Int", + display_name="Color Picker", category="utilities", - description="Convert a color to a RGB integer value.", + description="Return a color RGB integer value and hexadecimal representation.", inputs=[ io.Color.Input("color"), ], outputs=[ io.Int.Output(display_name="rgb_int"), + io.Color.Output(display_name="hex") ], ) @classmethod - def execute( - cls, - color: str, - ) -> io.NodeOutput: + def execute(cls, color: str) -> io.NodeOutput: # expect format #RRGGBB if len(color) != 7 or color[0] != "#": raise ValueError("Color must be in format #RRGGBB") r = int(color[1:3], 16) g = int(color[3:5], 16) b = int(color[5:7], 16) - return io.NodeOutput(r * 256 * 256 + g * 256 + b) + + rgb_int = r * 256 * 256 + g * 256 + b + return io.NodeOutput(rgb_int, color) class ColorExtension(ComfyExtension): From a1c434eb65113673c483a922e84cb0493622b3a3 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Mon, 8 Jun 2026 19:05:10 +0200 Subject: [PATCH 21/45] Improve ResolutionSelector (#14309) --- comfy_extras/nodes_resolution.py | 33 ++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/comfy_extras/nodes_resolution.py b/comfy_extras/nodes_resolution.py index dc405291c5cd..083e47ae46df 100644 --- a/comfy_extras/nodes_resolution.py +++ b/comfy_extras/nodes_resolution.py @@ -6,24 +6,24 @@ class AspectRatio(str, Enum): SQUARE = "1:1 (Square)" + PHOTO_V = "2:3 (Portrait Photo)" PHOTO_H = "3:2 (Photo)" + STANDARD_V = "3:4 (Portrait Standard)" STANDARD_H = "4:3 (Standard)" + WIDESCREEN_V = "9:16 (Portrait Widescreen)" WIDESCREEN_H = "16:9 (Widescreen)" ULTRAWIDE_H = "21:9 (Ultrawide)" - PHOTO_V = "2:3 (Portrait Photo)" - STANDARD_V = "3:4 (Portrait Standard)" - WIDESCREEN_V = "9:16 (Portrait Widescreen)" ASPECT_RATIOS: dict[AspectRatio, tuple[int, int]] = { AspectRatio.SQUARE: (1, 1), + AspectRatio.PHOTO_V: (2, 3), AspectRatio.PHOTO_H: (3, 2), + AspectRatio.STANDARD_V: (3, 4), AspectRatio.STANDARD_H: (4, 3), + AspectRatio.WIDESCREEN_V: (9, 16), AspectRatio.WIDESCREEN_H: (16, 9), AspectRatio.ULTRAWIDE_H: (21, 9), - AspectRatio.PHOTO_V: (2, 3), - AspectRatio.STANDARD_V: (3, 4), - AspectRatio.WIDESCREEN_V: (9, 16), } @@ -50,26 +50,35 @@ def define_schema(cls): min=0.1, max=16.0, step=0.1, - tooltip="Target total megapixels. 1.0 MP ≈ 1024×1024 for square.", + tooltip="Target total megapixels. 1.0 MP ≈ 1024x1024 for square.", + ), + io.Int.Input( + id="multiple", + default=8, + min=8, + max=128, + step=4, + tooltip="Nearest multiple of the result to set the selected resolution to.", + advanced=True, ), ], outputs=[ io.Int.Output( - "width", tooltip="Calculated width in pixels (multiple of 8)." + "width", tooltip="Calculated width in pixels multiplied by the selected multiple." ), io.Int.Output( - "height", tooltip="Calculated height in pixels (multiple of 8)." + "height", tooltip="Calculated height in pixels multiplied by the selected multiple." ), ], ) @classmethod - def execute(cls, aspect_ratio: str, megapixels: float) -> io.NodeOutput: + def execute(cls, aspect_ratio: str, megapixels: float, multiple: int) -> io.NodeOutput: w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] total_pixels = megapixels * 1024 * 1024 scale = math.sqrt(total_pixels / (w_ratio * h_ratio)) - width = round(w_ratio * scale / 8) * 8 - height = round(h_ratio * scale / 8) * 8 + width = round(w_ratio * scale / multiple) * multiple + height = round(h_ratio * scale / multiple) * multiple return io.NodeOutput(width, height) From a0a055bc4e4f2878c106bf8cf69c1aaa30f8b840 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Mon, 8 Jun 2026 14:27:50 -0700 Subject: [PATCH 22/45] feat(assets): extract image dimensions at ingest and emit on asset responses (#13991) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(assets): extract image dimensions at ingest and emit on asset responses Image assets now carry width/height under the existing `metadata` field on asset responses, shaped as `{"kind": "image", "width": W, "height": H}`. This lets consumers get original dimensions (e.g. for clients that render server-side thumbnails and can't recover them from naturalWidth/Height) without an extra round-trip. Dimensions are written to AssetReference.system_metadata across three ingest paths: - Direct file ingest (upload, in-place registration): Pillow reads the image header right after hashing, while the file is still in OS page cache. Non-image MIME types are skipped without touching the file. - From-hash registration: this path never reads the file bytes, so dimensions are best-effort copied from any prior sibling reference of the same asset that already carries kind=image metadata. Missing siblings, non-image siblings, or absent dimension keys leave the new reference's metadata unchanged. - Scanner enrichment: extends the existing system_metadata write in enrich_asset so scanner-registered images get the same treatment as uploaded ones. Existing system_metadata keys (e.g. safetensors fields written by the enricher, download provenance) are preserved through merge. Existing assets ingested before this change retain their current metadata — no automatic backfill in this PR. Tests cover image emission, non-image no-op, merge preservation, and the from-hash sibling back-fill (including the no-sibling and non-image-sibling cases). * fix(assets): validate sibling dimensions before backfilling Per CodeRabbit review on #13991: the previous loop accepted any sibling with `kind == "image"` and copied whichever dimension keys happened to be present, then returned. A partial sibling (kind set but missing or invalid width/height) could persist incomplete metadata onto the new reference even when a later sibling had valid dimensions. Now we validate that the sibling has both width and height as positive integers before adopting its dimensions, and continue scanning to the next sibling otherwise. * fix(assets): reject booleans in sibling dimension validation (use type-is) Per CodeRabbit follow-up on #13991: bool is a subclass of int in Python, so isinstance(True, int) is True. The previous strict-int gate would have accepted width=True (truthy + > 0) as a valid dimension. Realistic occurrence is low (extract_image_dimensions returns proper ints, JSON doesn't serialize bools as numbers), but the validation gate exists for defense-in-depth so it should be actually strict. --------- Co-authored-by: guill --- app/assets/scanner.py | 5 + app/assets/services/image_dimensions.py | 63 ++++++ app/assets/services/ingest.py | 99 +++++++++ .../services/test_image_dimensions.py | 86 ++++++++ .../assets_test/services/test_ingest.py | 208 +++++++++++++++++- 5 files changed, 460 insertions(+), 1 deletion(-) create mode 100644 app/assets/services/image_dimensions.py create mode 100644 tests-unit/assets_test/services/test_image_dimensions.py diff --git a/app/assets/scanner.py b/app/assets/scanner.py index ebb6869aff85..495c3044329d 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -33,6 +33,7 @@ verify_file_unchanged, ) from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash +from app.assets.services.image_dimensions import extract_image_dimensions from app.assets.services.metadata_extract import extract_file_metadata from app.assets.services.path_utils import ( compute_relative_filename, @@ -506,6 +507,10 @@ def enrich_asset( if extract_metadata and metadata: system_metadata = metadata.to_user_metadata() + if mime_type and mime_type.startswith("image/"): + dims = extract_image_dimensions(file_path, mime_type=mime_type) + if dims: + system_metadata.update(dims) set_reference_system_metadata(session, reference_id, system_metadata) if full_hash: diff --git a/app/assets/services/image_dimensions.py b/app/assets/services/image_dimensions.py new file mode 100644 index 000000000000..ccd97399ac23 --- /dev/null +++ b/app/assets/services/image_dimensions.py @@ -0,0 +1,63 @@ +"""Image dimension extraction for asset ingest. + +Reads only the image header via Pillow to capture width/height cheaply, +without a full pixel decode. Returns a metadata dict suitable for merging +into ``AssetReference.system_metadata``. +""" +from __future__ import annotations + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def extract_image_dimensions( + file_path: str, mime_type: str | None = None +) -> dict[str, Any] | None: + """Extract image dimensions for the file at ``file_path``. + + Args: + file_path: Absolute path to a file on disk. + mime_type: Optional MIME type hint. When provided and not prefixed + with ``image/``, extraction is skipped without touching the file. + + Returns: + ``{"kind": "image", "width": W, "height": H}`` when the file is a + recognizable image with positive dimensions, otherwise ``None``. + + The dict shape is intended to be merged into ``system_metadata`` so the + asset response surfaces ``metadata.kind`` plus dimension fields for image + assets. Forward-compatible: future media kinds (e.g. ``"video"`` with + duration/fps) can extend this shape without schema changes. + """ + if mime_type is not None and not mime_type.startswith("image/"): + return None + + try: + from PIL import Image, UnidentifiedImageError + except ImportError: + logger.debug( + "Pillow not available; skipping image dimension extraction for %s", + file_path, + ) + return None + + try: + with Image.open(file_path) as img: + width, height = img.size + except (OSError, UnidentifiedImageError, ValueError) as exc: + logger.debug( + "Failed to read image dimensions from %s: %s", file_path, exc + ) + return None + + if ( + not isinstance(width, int) + or not isinstance(height, int) + or width <= 0 + or height <= 0 + ): + return None + + return {"kind": "image", "width": width, "height": height} diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index f0b070517308..3b6dc237c435 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -17,9 +17,11 @@ get_reference_by_file_path, get_reference_tags, get_or_create_reference, + list_references_by_asset_id, reference_exists, remove_missing_tag_for_asset_id, set_reference_metadata, + set_reference_system_metadata, set_reference_tags, update_asset_hash_and_mime, upsert_asset, @@ -29,6 +31,7 @@ from app.assets.helpers import get_utc_now, normalize_tags from app.assets.services.bulk_ingest import batch_insert_seed_assets from app.assets.services.file_utils import get_size_and_mtime_ns +from app.assets.services.image_dimensions import extract_image_dimensions from app.assets.services.path_utils import ( compute_relative_filename, get_name_and_tags_from_asset_path, @@ -118,6 +121,14 @@ def _ingest_file_from_path( user_metadata=user_metadata, ) + _maybe_store_image_dimensions( + session, + reference_id=reference_id, + file_path=locator, + mime_type=mime_type, + current_system_metadata=ref.system_metadata, + ) + try: remove_missing_tag_for_asset_id(session, asset_id=asset.id) except Exception: @@ -288,6 +299,13 @@ def _register_existing_asset( user_metadata=new_meta, ) + _backfill_image_dimensions_from_siblings( + session, + asset_id=asset.id, + new_reference_id=ref.id, + current_system_metadata=ref.system_metadata, + ) + if tags is not None: set_reference_tags( session, @@ -334,6 +352,87 @@ def _update_metadata_with_filename( ) +_IMAGE_DIMENSION_KEYS = ("kind", "width", "height") + + +def _maybe_store_image_dimensions( + session: Session, + reference_id: str, + file_path: str, + mime_type: str | None, + current_system_metadata: dict | None, +) -> None: + """Populate ``kind``/``width``/``height`` on system_metadata for image refs. + + Non-image MIME types are a no-op. Pre-existing keys (e.g. enricher-written + safetensors metadata, download provenance) are preserved by merge. + """ + if not mime_type or not mime_type.startswith("image/"): + return + + dims = extract_image_dimensions(file_path, mime_type=mime_type) + if not dims: + return + + current = current_system_metadata or {} + merged = dict(current) + merged.update(dims) + if merged != current: + set_reference_system_metadata( + session, + reference_id=reference_id, + system_metadata=merged, + ) + + +def _backfill_image_dimensions_from_siblings( + session: Session, + asset_id: str, + new_reference_id: str, + current_system_metadata: dict | None, +) -> None: + """Copy image dimension keys from any sibling reference of the same asset. + + The from-hash path doesn't read the file bytes, so dimensions can't be + extracted there directly. When another reference of the same asset already + carries image dimensions, copy them onto the new reference so consumers + see consistent metadata regardless of how the asset was registered. + + Best-effort: missing siblings, non-image siblings, or absent dimension + keys leave the target reference unchanged. + """ + current = current_system_metadata or {} + if current.get("kind") == "image" and "width" in current and "height" in current: + return + + for sibling in list_references_by_asset_id(session, asset_id): + if sibling.id == new_reference_id: + continue + meta = sibling.system_metadata or {} + if meta.get("kind") != "image": + continue + width = meta.get("width") + height = meta.get("height") + if ( + type(width) is not int + or type(height) is not int + or width <= 0 + or height <= 0 + ): + continue + merged = dict(current) + merged["kind"] = "image" + merged["width"] = width + merged["height"] = height + if merged != current: + set_reference_system_metadata( + session, + reference_id=new_reference_id, + system_metadata=merged, + ) + return + + def _sanitize_filename(name: str | None, fallback: str) -> str: n = os.path.basename((name or "").strip() or fallback) return n if n else fallback diff --git a/tests-unit/assets_test/services/test_image_dimensions.py b/tests-unit/assets_test/services/test_image_dimensions.py new file mode 100644 index 000000000000..ac275eae2752 --- /dev/null +++ b/tests-unit/assets_test/services/test_image_dimensions.py @@ -0,0 +1,86 @@ +"""Tests for the image_dimensions service.""" +from __future__ import annotations + +from pathlib import Path + +import pytest +from PIL import Image + +from app.assets.services.image_dimensions import extract_image_dimensions + + +def _make_png(path: Path, size: tuple[int, int]) -> Path: + img = Image.new("RGB", size, color=(123, 45, 67)) + img.save(path, format="PNG") + return path + + +def _make_jpeg(path: Path, size: tuple[int, int]) -> Path: + img = Image.new("RGB", size, color=(10, 20, 30)) + img.save(path, format="JPEG", quality=80) + return path + + +class TestExtractImageDimensions: + def test_extracts_png_dimensions(self, tmp_path: Path): + f = _make_png(tmp_path / "rect.png", (320, 240)) + + result = extract_image_dimensions(str(f), mime_type="image/png") + + assert result == {"kind": "image", "width": 320, "height": 240} + + def test_extracts_jpeg_dimensions(self, tmp_path: Path): + f = _make_jpeg(tmp_path / "shot.jpg", (1920, 1080)) + + result = extract_image_dimensions(str(f), mime_type="image/jpeg") + + assert result == {"kind": "image", "width": 1920, "height": 1080} + + def test_works_when_mime_type_is_none(self, tmp_path: Path): + f = _make_png(tmp_path / "no_mime.png", (50, 100)) + + result = extract_image_dimensions(str(f), mime_type=None) + + assert result == {"kind": "image", "width": 50, "height": 100} + + def test_skips_non_image_mime_without_touching_file(self, tmp_path: Path): + # Path doesn't need to exist — non-image MIME short-circuits. + result = extract_image_dimensions( + str(tmp_path / "model.safetensors"), + mime_type="application/octet-stream", + ) + + assert result is None + + @pytest.mark.parametrize( + "mime", + ["application/json", "text/plain", "video/mp4", "audio/mpeg"], + ) + def test_skips_all_non_image_mime_types(self, tmp_path: Path, mime: str): + f = tmp_path / "file.bin" + f.write_bytes(b"\x00\x01\x02") + + assert extract_image_dimensions(str(f), mime_type=mime) is None + + def test_returns_none_for_missing_file(self, tmp_path: Path): + result = extract_image_dimensions( + str(tmp_path / "does_not_exist.png"), mime_type="image/png" + ) + + assert result is None + + def test_returns_none_for_corrupt_image(self, tmp_path: Path): + f = tmp_path / "corrupt.png" + f.write_bytes(b"not actually a png file") + + result = extract_image_dimensions(str(f), mime_type="image/png") + + assert result is None + + def test_returns_none_for_empty_file(self, tmp_path: Path): + f = tmp_path / "empty.png" + f.write_bytes(b"") + + result = extract_image_dimensions(str(f), mime_type="image/png") + + assert result is None diff --git a/tests-unit/assets_test/services/test_ingest.py b/tests-unit/assets_test/services/test_ingest.py index b153f9795eef..12a3bdfe62ae 100644 --- a/tests-unit/assets_test/services/test_ingest.py +++ b/tests-unit/assets_test/services/test_ingest.py @@ -4,10 +4,12 @@ from unittest.mock import patch import pytest +from PIL import Image from sqlalchemy.orm import Session as SASession, Session from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag from app.assets.database.queries import get_reference_tags +from app.assets.helpers import get_utc_now from app.assets.services.ingest import ( _ingest_file_from_path, _register_existing_asset, @@ -15,6 +17,11 @@ ) +def _make_png(path: Path, size: tuple[int, int]) -> Path: + Image.new("RGB", size, color=(80, 120, 200)).save(path, format="PNG") + return path + + class TestIngestFileFromPath: def test_creates_asset_and_reference(self, mock_create_session, temp_dir: Path, session: Session): file_path = temp_dir / "test_file.bin" @@ -279,4 +286,203 @@ def _create_session(): ref_tags = sess.query(AssetReferenceTag).all() ref_tag_names = {rt.tag_name for rt in ref_tags} assert "output" in ref_tag_names - assert "my-job" in ref_tag_names + + +class TestIngestImageDimensions: + """system_metadata should carry {kind, width, height} for image assets.""" + + def test_image_asset_emits_dimensions( + self, mock_create_session, temp_dir: Path, session: Session + ): + f = _make_png(temp_dir / "shot.png", (640, 480)) + + result = _ingest_file_from_path( + abs_path=str(f), + asset_hash="blake3:img1", + size_bytes=f.stat().st_size, + mtime_ns=1234567890000000000, + mime_type="image/png", + ) + + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + assert ref.system_metadata == { + "kind": "image", + "width": 640, + "height": 480, + } + + def test_non_image_asset_leaves_system_metadata_empty( + self, mock_create_session, temp_dir: Path, session: Session + ): + f = temp_dir / "model.safetensors" + f.write_bytes(b"not an image") + + result = _ingest_file_from_path( + abs_path=str(f), + asset_hash="blake3:safetensors1", + size_bytes=f.stat().st_size, + mtime_ns=1234567890000000000, + mime_type="application/octet-stream", + ) + + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + assert ref.system_metadata in (None, {}) + + def test_preserves_existing_system_metadata_keys( + self, mock_create_session, temp_dir: Path, session: Session + ): + f = _make_png(temp_dir / "annotated.png", (100, 200)) + + # First pass populates a sentinel system_metadata key (simulating prior + # enricher write). + result = _ingest_file_from_path( + abs_path=str(f), + asset_hash="blake3:img-merge", + size_bytes=f.stat().st_size, + mtime_ns=1234567890000000000, + mime_type="image/png", + ) + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/x.png"} + session.commit() + + # Second pass with the same path triggers the merge code path again. + _ingest_file_from_path( + abs_path=str(f), + asset_hash="blake3:img-merge", + size_bytes=f.stat().st_size, + mtime_ns=1234567890000000001, + mime_type="image/png", + ) + + session.refresh(ref) + assert ref.system_metadata["kind"] == "image" + assert ref.system_metadata["width"] == 100 + assert ref.system_metadata["height"] == 200 + assert ref.system_metadata["source_url"] == "https://example/x.png" + + +class TestRegisterExistingAssetBackfill: + """The from-hash path back-fills dimensions from a sibling reference.""" + + def _add_reference( + self, + session: Session, + asset: Asset, + name: str, + system_metadata: dict | None = None, + ) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + asset_id=asset.id, + name=name, + owner_id="", + created_at=now, + updated_at=now, + last_access_time=now, + system_metadata=system_metadata or {}, + ) + session.add(ref) + session.flush() + return ref + + def test_backfills_dimensions_from_sibling_image_reference( + self, mock_create_session, session: Session + ): + asset = Asset(hash="blake3:shared", size_bytes=2048, mime_type="image/png") + session.add(asset) + session.flush() + self._add_reference( + session, + asset, + name="original.png", + system_metadata={"kind": "image", "width": 800, "height": 600}, + ) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:shared", + name="from_hash.png", + owner_id="user-x", + ) + + ref = session.query(AssetReference).filter_by(id=result.ref.id).first() + assert ref.system_metadata.get("kind") == "image" + assert ref.system_metadata.get("width") == 800 + assert ref.system_metadata.get("height") == 600 + + def test_no_backfill_when_sibling_has_no_image_metadata( + self, mock_create_session, session: Session + ): + asset = Asset(hash="blake3:nodims", size_bytes=2048, mime_type="image/png") + session.add(asset) + session.flush() + self._add_reference( + session, + asset, + name="original.png", + system_metadata={"base_model": "flux"}, # no kind=image + ) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:nodims", + name="from_hash.png", + owner_id="user-x", + ) + + ref = session.query(AssetReference).filter_by(id=result.ref.id).first() + meta = ref.system_metadata or {} + assert "kind" not in meta + assert "width" not in meta + assert "height" not in meta + + def test_no_backfill_when_no_sibling_exists( + self, mock_create_session, session: Session + ): + asset = Asset(hash="blake3:lonely", size_bytes=1024, mime_type="image/png") + session.add(asset) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:lonely", + name="solo.png", + owner_id="user-x", + ) + + ref = session.query(AssetReference).filter_by(id=result.ref.id).first() + assert ref.system_metadata in (None, {}) + + def test_backfill_preserves_caller_supplied_keys( + self, mock_create_session, session: Session + ): + asset = Asset(hash="blake3:preserve", size_bytes=2048, mime_type="image/png") + session.add(asset) + session.flush() + self._add_reference( + session, + asset, + name="original.png", + system_metadata={"kind": "image", "width": 1024, "height": 768}, + ) + session.commit() + + # Simulate a from-hash path where the new reference already carries + # some system_metadata (e.g. a download-provenance source_url written + # by an earlier step). The back-fill must merge dim keys without + # clobbering existing keys. + result = _register_existing_asset( + asset_hash="blake3:preserve", + name="from_hash.png", + owner_id="user-x", + ) + ref = session.query(AssetReference).filter_by(id=result.ref.id).first() + # Seed a sentinel key and re-run back-fill via a second register call + # to exercise the merge path with pre-existing data. + ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/p"} + session.commit() + + assert ref.system_metadata.get("source_url") == "https://example/p" + assert ref.system_metadata.get("kind") == "image" + assert ref.system_metadata.get("width") == 1024 + assert ref.system_metadata.get("height") == 768 From 00b633f368e68ffc229084ed819354c29006f92c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 8 Jun 2026 15:00:20 -0700 Subject: [PATCH 23/45] Revert "Add SeedVR2 support (CORE-6) (#14110)" (#14359) This reverts commit 7863cf0e53ca599a84b3ec5bcda122e4ecc3765c. --- comfy/latent_formats.py | 5 - comfy/ldm/modules/attention.py | 84 +- comfy/ldm/modules/diffusionmodules/model.py | 8 +- comfy/ldm/seedvr/color_fix.py | 340 --- comfy/ldm/seedvr/constants.py | 79 - comfy/ldm/seedvr/model.py | 1665 ------------- comfy/ldm/seedvr/vae.py | 2110 ----------------- comfy/model_base.py | 12 - comfy/model_detection.py | 50 - comfy/sample.py | 8 +- comfy/sd.py | 237 +- comfy/supported_models.py | 31 +- comfy/supported_models_base.py | 2 +- comfy_extras/nodes_seedvr.py | 1015 -------- nodes.py | 42 +- .../test_seedvr2_conditioning.py | 213 -- .../comfy_extras_test/test_seedvr2_nodes.py | 55 - .../test_seedvr2_post_processing.py | 57 - tests-unit/comfy_test/model_detection_test.py | 60 - .../comfy_test/seedvr_vae_forward_test.py | 90 - tests-unit/comfy_test/test_seedvr2_dtype.py | 47 - .../comfy_test/test_seedvr2_internals.py | 341 --- tests-unit/comfy_test/test_seedvr2_model.py | 308 --- .../comfy_test/test_seedvr2_vae_decode.py | 91 - .../comfy_test/test_seedvr2_vae_tiled.py | 347 --- .../test_seedvr_progressive_sampler.py | 126 - 26 files changed, 40 insertions(+), 7383 deletions(-) delete mode 100644 comfy/ldm/seedvr/color_fix.py delete mode 100644 comfy/ldm/seedvr/constants.py delete mode 100644 comfy/ldm/seedvr/model.py delete mode 100644 comfy/ldm/seedvr/vae.py delete mode 100644 comfy_extras/nodes_seedvr.py delete mode 100644 tests-unit/comfy_extras_test/test_seedvr2_conditioning.py delete mode 100644 tests-unit/comfy_extras_test/test_seedvr2_nodes.py delete mode 100644 tests-unit/comfy_extras_test/test_seedvr2_post_processing.py delete mode 100644 tests-unit/comfy_test/seedvr_vae_forward_test.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_dtype.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_internals.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_model.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_vae_decode.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_vae_tiled.py delete mode 100644 tests-unit/comfy_test/test_seedvr_progressive_sampler.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index fcbd97c5971a..bbdfd4bc2fac 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -4,7 +4,6 @@ class LatentFormat: scale_factor = 1.0 latent_channels = 4 latent_dimensions = 2 - preserve_empty_channel_multiples = False latent_rgb_factors = None latent_rgb_factors_bias = None latent_rgb_factors_reshape = None @@ -780,10 +779,6 @@ class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 -class SeedVR2(LatentFormat): - latent_channels = 16 - preserve_empty_channel_multiples = True - class ACEAudio15(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index b78e764c71ba..55360535af3b 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -735,86 +735,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out -def _var_attention_qkv(q, k, v, heads, skip_reshape): - if skip_reshape: - return q, k, v, q.shape[-1] - total_tokens, embed_dim = q.shape - head_dim = embed_dim // heads - return ( - q.view(total_tokens, heads, head_dim), - k.view(k.shape[0], heads, head_dim), - v.view(v.shape[0], heads, head_dim), - head_dim, - ) - -def _var_attention_output(out, heads, head_dim, skip_output_reshape): - if skip_output_reshape: - return out - return out.reshape(-1, heads * head_dim) - - -def _use_blackwell_attention(): - device = model_management.get_torch_device() - if device.type != "cuda": - return False - major, minor = torch.cuda.get_device_capability(device) - return (major, minor) >= (12, 0) - - -def _validate_split_cu_seqlens(name, cu_seqlens, token_count): - if cu_seqlens.dtype not in (torch.int32, torch.int64): - raise ValueError(f"{name} must use an integer dtype") - if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2: - raise ValueError(f"{name} must be a 1D tensor with at least two offsets") - if cu_seqlens[0].item() != 0: - raise ValueError(f"{name} must start at 0") - if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item(): - raise ValueError(f"{name} must be strictly increasing") - if cu_seqlens[-1].item() != token_count: - raise ValueError(f"{name} does not match token count") - - -def _split_indices(cu_seqlens): - return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long) - - -def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): - q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) - - _validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0]) - _validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0]) - if cu_seqlens_k[-1].item() != v.shape[0]: - raise ValueError("cu_seqlens_k does not match v token count") - - q_split_indices = _split_indices(cu_seqlens_q) - k_split_indices = _split_indices(cu_seqlens_k) - q_splits = torch.tensor_split(q, q_split_indices, dim=0) - k_splits = torch.tensor_split(k, k_split_indices, dim=0) - v_splits = torch.tensor_split(v, k_split_indices, dim=0) - if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits): - raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count") - - out = [] - for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits): - q_i = q_i.permute(1, 0, 2).unsqueeze(0) - k_i = k_i.permute(1, 0, 2).unsqueeze(0) - v_i = v_i.permute(1, 0, 2).unsqueeze(0) - out_dtype = q_i.dtype - if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16): - q_i = q_i.to(torch.bfloat16) - k_i = k_i.to(torch.bfloat16) - v_i = v_i.to(torch.bfloat16) - out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True) - if out_i.dtype != out_dtype: - out_i = out_i.to(out_dtype) - out.append(out_i.squeeze(0).permute(1, 0, 2)) - - out = torch.cat(out, dim=0) - return _var_attention_output(out, heads, head_dim, skip_output_reshape) - - -optimized_var_attention = var_attention_optimized_split optimized_attention = attention_basic if model_management.sage_attention_enabled(): @@ -837,8 +758,6 @@ def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *a logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad -logging.info("Using optimized_attention split-loop for variable-length attention") - optimized_attention_masked = optimized_attention @@ -854,7 +773,6 @@ def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *a register_attention_function("pytorch", attention_pytorch) register_attention_function("sub_quad", attention_sub_quad) register_attention_function("split", attention_split) -register_attention_function("var_attention_optimized_split", var_attention_optimized_split) def optimized_attention_for_device(device, mask=False, small_input=False): @@ -1291,3 +1209,5 @@ def forward( x = self.proj_out(x) out = x + x_in return out + + diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 235df0b835bb..fcbaa074fd84 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,7 +13,6 @@ import xformers import xformers.ops - def torch_cat_if_needed(xl, dim): xl = [x for x in xl if x is not None and x.shape[dim] > 0] if len(xl) > 1: @@ -23,8 +22,7 @@ def torch_cat_if_needed(xl, dim): else: return None - -def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1): +def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. @@ -35,13 +33,11 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - downscale_freq_shift) + emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if flip_sin_to_cos: - emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0,1,0,0)) return emb diff --git a/comfy/ldm/seedvr/color_fix.py b/comfy/ldm/seedvr/color_fix.py deleted file mode 100644 index 7ddfc03af370..000000000000 --- a/comfy/ldm/seedvr/color_fix.py +++ /dev/null @@ -1,340 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import Tensor - -from comfy.ldm.seedvr.model import safe_pad_operation -from comfy.ldm.seedvr.vae import safe_interpolate_operation -from comfy.ldm.seedvr.constants import ( - CIELAB_DELTA, - CIELAB_KAPPA, - D65_WHITE_X, - D65_WHITE_Z, - WAVELET_DECOMP_LEVELS, -) - - -def wavelet_blur(image: Tensor, radius): - max_safe_radius = max(1, min(image.shape[-2:]) // 8) - if radius > max_safe_radius: - radius = max_safe_radius - - num_channels = image.shape[1] - - kernel_vals = [ - [0.0625, 0.125, 0.0625], - [0.125, 0.25, 0.125], - [0.0625, 0.125, 0.0625], - ] - kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) - kernel = kernel[None, None].repeat(num_channels, 1, 1, 1) - - image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate') - output = F.conv2d(image, kernel, groups=num_channels, dilation=radius) - - return output - -def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS): - high_freq = torch.zeros_like(image) - - for i in range(levels): - radius = 2 ** i - low_freq = wavelet_blur(image, radius) - high_freq.add_(image).sub_(low_freq) - image = low_freq - - return high_freq, low_freq - -def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: - - if content_feat.shape != style_feat.shape: - # Resize style to match content spatial dimensions - if len(content_feat.shape) >= 3: - # safe_interpolate_operation handles FP16 conversion automatically - style_feat = safe_interpolate_operation( - style_feat, - size=content_feat.shape[-2:], - mode='bilinear', - align_corners=False - ) - - # Decompose both features into frequency components - content_high_freq, content_low_freq = wavelet_decomposition(content_feat) - del content_low_freq # Free memory immediately - - style_high_freq, style_low_freq = wavelet_decomposition(style_feat) - del style_high_freq # Free memory immediately - - if content_high_freq.shape != style_low_freq.shape: - style_low_freq = safe_interpolate_operation( - style_low_freq, - size=content_high_freq.shape[-2:], - mode='bilinear', - align_corners=False - ) - - content_high_freq.add_(style_low_freq) - - return content_high_freq.clamp_(-1.0, 1.0) - -def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor: - original_shape = source.shape - - # Flatten - source_flat = source.flatten() - reference_flat = reference.flatten() - - # Sort both arrays - source_sorted, source_indices = torch.sort(source_flat) - reference_sorted, _ = torch.sort(reference_flat) - del reference_flat - - # Quantile mapping - n_source = len(source_sorted) - n_reference = len(reference_sorted) - - if n_source == n_reference: - matched_sorted = reference_sorted - else: - # Interpolate reference to match source quantiles - source_quantiles = torch.linspace(0, 1, n_source, device=device) - ref_indices = (source_quantiles * (n_reference - 1)).long() - ref_indices.clamp_(0, n_reference - 1) - matched_sorted = reference_sorted[ref_indices] - del source_quantiles, ref_indices, reference_sorted - - del source_sorted, source_flat - - # Reconstruct using argsort (portable across CUDA/ROCm/MPS) - inverse_indices = torch.argsort(source_indices) - del source_indices - matched_flat = matched_sorted[inverse_indices] - del matched_sorted, inverse_indices - - return matched_flat.reshape(original_shape) - -def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor: - """Convert batch of CIELAB images to RGB color space.""" - L, a, b = lab[:, 0], lab[:, 1], lab[:, 2] - - # LAB to XYZ - fy = (L + 16.0) / 116.0 - fx = a.div(500.0).add_(fy) - fz = fy - b / 200.0 - del L, a, b - - # XYZ transformation - x = torch.where( - fx > epsilon, - torch.pow(fx, 3.0), - fx.mul(116.0).sub_(16.0).div_(kappa) - ) - y = torch.where( - fy > epsilon, - torch.pow(fy, 3.0), - fy.mul(116.0).sub_(16.0).div_(kappa) - ) - z = torch.where( - fz > epsilon, - torch.pow(fz, 3.0), - fz.mul(116.0).sub_(16.0).div_(kappa) - ) - del fx, fy, fz - - # Apply D65 white point (in-place) - x.mul_(D65_WHITE_X) - # y *= 1.00000 # (no-op, skip) - z.mul_(D65_WHITE_Z) - - xyz = torch.stack([x, y, z], dim=1) - del x, y, z - - # Matrix multiplication: XYZ -> RGB - B, C, H, W = xyz.shape - xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3) - del xyz - - # Ensure dtype consistency for matrix multiplication - xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype) - rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T) - del xyz_flat - - rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) - del rgb_linear_flat - - # Apply inverse gamma correction (delinearize) - mask = rgb_linear > 0.0031308 - rgb = torch.where( - mask, - torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055), - rgb_linear * 12.92 - ) - del mask, rgb_linear - - return torch.clamp(rgb, 0.0, 1.0) - -def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor: - """Convert batch of RGB images to CIELAB color space using D65 illuminant.""" - # Apply sRGB gamma correction (linearize) - mask = rgb > 0.04045 - rgb_linear = torch.where( - mask, - torch.pow((rgb + 0.055) / 1.055, 2.4), - rgb / 12.92 - ) - del mask - - # Matrix multiplication: RGB -> XYZ - B, C, H, W = rgb_linear.shape - rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3) - del rgb_linear - - # Ensure dtype consistency for matrix multiplication - rgb_flat = rgb_flat.to(dtype=matrix.dtype) - xyz_flat = torch.matmul(rgb_flat, matrix.T) - del rgb_flat - - xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) - del xyz_flat - - # Normalize by D65 white point (in-place) - xyz[:, 0].div_(D65_WHITE_X) # X - # xyz[:, 1] /= 1.00000 # Y (no-op, skip) - xyz[:, 2].div_(D65_WHITE_Z) # Z - - # XYZ to LAB transformation - epsilon_cubed = epsilon ** 3 - mask = xyz > epsilon_cubed - f_xyz = torch.where( - mask, - torch.pow(xyz, 1.0 / 3.0), - xyz.mul(kappa).add_(16.0).div_(116.0) - ) - del xyz, mask - - # Extract channels and compute LAB - L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100] - a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127] - b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127] - del f_xyz - - return torch.stack([L, a, b], dim=1) - -def lab_color_transfer( - content_feat: Tensor, - style_feat: Tensor, - luminance_weight: float = 0.8 -) -> Tensor: - content_feat = wavelet_reconstruction(content_feat, style_feat) - - if content_feat.shape != style_feat.shape: - style_feat = safe_interpolate_operation( - style_feat, - size=content_feat.shape[-2:], - mode='bilinear', - align_corners=False - ) - - device = content_feat.device - - def ensure_float32_precision(c): - orig_dtype = c.dtype - c = c.float() - return c, orig_dtype - content_feat, original_dtype = ensure_float32_precision(content_feat) - style_feat, _ = ensure_float32_precision(style_feat) - - rgb_to_xyz_matrix = torch.tensor([ - [0.4124564, 0.3575761, 0.1804375], - [0.2126729, 0.7151522, 0.0721750], - [0.0193339, 0.1191920, 0.9503041] - ], dtype=torch.float32, device=device) - - xyz_to_rgb_matrix = torch.tensor([ - [ 3.2404542, -1.5371385, -0.4985314], - [-0.9692660, 1.8760108, 0.0415560], - [ 0.0556434, -0.2040259, 1.0572252] - ], dtype=torch.float32, device=device) - - epsilon = CIELAB_DELTA - kappa = CIELAB_KAPPA - - content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) - style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) - - # Convert to LAB color space - content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa) - del content_feat - - style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa) - del style_feat, rgb_to_xyz_matrix - - # Match chrominance channels (a*, b*) for accurate color transfer - matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device) - matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device) - - # Handle luminance with weighted blending - if luminance_weight < 1.0: - # Partially match luminance for better overall color accuracy - matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device) - # Blend: preserve some content L* for detail, adopt some style L* for color - result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight)) - del matched_L - else: - # Fully preserve content luminance - result_L = content_lab[:, 0] - - del content_lab, style_lab - - # Reconstruct LAB with corrected channels - result_lab = torch.stack([result_L, matched_a, matched_b], dim=1) - del result_L, matched_a, matched_b - - # Convert back to RGB - result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa) - del result_lab, xyz_to_rgb_matrix - - # Convert back to [-1, 1] range (in-place) - result = result_rgb.mul_(2.0).sub_(1.0) - del result_rgb - - result = result.to(original_dtype) - - return result - - -def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor: - return wavelet_reconstruction(content_feat, style_feat) - - -def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor: - if content_feat.shape != style_feat.shape: - style_feat = safe_interpolate_operation( - style_feat, - size=content_feat.shape[-2:], - mode='bilinear', - align_corners=False, - ) - - original_dtype = content_feat.dtype - content_feat = content_feat.float() - style_feat = style_feat.float() - - b, c = content_feat.shape[:2] - content_flat = content_feat.reshape(b, c, -1) - style_flat = style_feat.reshape(b, c, -1) - - content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1) - content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) - style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1) - style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) - del content_flat, style_flat - - normalized = (content_feat - content_mean) / content_std - del content_mean, content_std - result = normalized * style_std + style_mean - del normalized, style_mean, style_std - - result = result.clamp_(-1.0, 1.0) - if result.dtype != original_dtype: - result = result.to(original_dtype) - return result diff --git a/comfy/ldm/seedvr/constants.py b/comfy/ldm/seedvr/constants.py deleted file mode 100644 index 95838d1dd7f0..000000000000 --- a/comfy/ldm/seedvr/constants.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Named constants for the SeedVR2 integration, grouped by provenance. - -Provenance prefixes: -- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline. -- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites - the upstream config/source path it was lifted from. -- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature / - ISO / CIE values; cite the standard. -""" - -# -------------------------------------------------------------------------------------- -# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment) -# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN) -# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070 -# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT). -# -------------------------------------------------------------------------------------- -SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB) -SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB - -# -------------------------------------------------------------------------------------- -# B. Fork heuristics (SEEDVR2 - this integration) -# -------------------------------------------------------------------------------------- -SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim. - # (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.) -SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry. -SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case). -SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM. -SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk. -SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels). -SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16). -SEEDVR2_DEFAULT_TEMPORAL_SIZE = 16 # default VAE temporal tile when unset. - -# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing) -SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk. -SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path. -SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path. -SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path. - -# -------------------------------------------------------------------------------------- -# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR) -# -------------------------------------------------------------------------------------- -BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm. -BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift. -BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem). -BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem). -BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28. -BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28. -BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16). -BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32). -BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD = 64 # attn_video_vae.py:308 (force .contiguous() above this b*t). -BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11. -BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size). -BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor. -BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor). -BYTEDANCE_SCHEDULE_T = 1000.0 # configs_3b/main.yaml:65 (schedule.T); timestep range. -BYTEDANCE_SPATIAL_DIVISOR = 16 # inference_seedvr2_3b.py:241 (DivisibleCrop((16,16))). -BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling). -BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames). -BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency). -BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim). -# Resolution-dependent timestep-shift linear fits: (x1, y1, x2, y2) for get_lin_function. -BYTEDANCE_IMG_SHIFT_FIT = (256 * 256, 1.0, 1024 * 1024, 3.2) # infer.py:242. -BYTEDANCE_VID_SHIFT_FIT = (256 * 256 * 37, 1.0, 1280 * 720 * 145, 5.0) # infer.py:243. - -# -------------------------------------------------------------------------------------- -# D. Published standards (cite the literature) -# -------------------------------------------------------------------------------------- -ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864. - -# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65). -CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta). -CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa). -D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1). -D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn. -WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR). - -# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and -# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the -# exact existing coefficients move verbatim rather than being retyped here. diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py deleted file mode 100644 index 3fa9fe07e870..000000000000 --- a/comfy/ldm/seedvr/model.py +++ /dev/null @@ -1,1665 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Tuple, Union, List, Dict, Any, Callable -import einops -from einops import rearrange -import torch.nn.functional as F -from math import ceil, pi -import torch -from itertools import chain -from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding -from comfy.ldm.modules.attention import optimized_var_attention -from torch.nn.modules.utils import _triple -from torch import nn -import math -from comfy.ldm.flux.math import apply_rope1 -from comfy.ldm.seedvr.constants import ( - BYTEDANCE_720P_REF_AREA, - BYTEDANCE_MAX_TEMPORAL_WINDOW, - BYTEDANCE_ROPE_MAX_FREQ, - BYTEDANCE_SINUSOIDAL_DIM, - ROPE_THETA, - SEEDVR2_7B_MLP_CHUNK, - SEEDVR2_7B_VID_DIM, - SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, -) -import comfy.model_management -import numbers - -def _torch_float8_types(): - return tuple( - getattr(torch, name) - for name in ( - "float8_e4m3fn", - "float8_e4m3fnuz", - "float8_e5m2", - "float8_e5m2fnuz", - "float8_e8m0fnu", - ) - if hasattr(torch, name) - ) - -class CustomRMSNorm(nn.Module): - - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None): - super(CustomRMSNorm, self).__init__() - - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = torch.Size(normalized_shape) - self.eps = eps - self.elementwise_affine = elementwise_affine - - if self.elementwise_affine: - self.weight = nn.Parameter(torch.ones(*normalized_shape, device=device, dtype=dtype)) - else: - self.register_parameter('weight', None) - - def forward(self, input): - - dims = tuple(range(-len(self.normalized_shape), 0)) - - normalized = input.float() - variance = normalized.pow(2).mean(dim=dims, keepdim=True) - rms = torch.sqrt(variance + self.eps) - - normalized = normalized / rms - - if self.elementwise_affine: - return normalized * self.weight.to(input.dtype) - return normalized - -class Cache: - def __init__(self, disable=False, prefix="", cache=None): - self.cache = cache if cache is not None else {} - self.disable = disable - self.prefix = prefix - - def __call__(self, key: str, fn: Callable): - if self.disable: - return fn() - - key = self.prefix + key - try: - result = self.cache[key] - except KeyError: - result = fn() - self.cache[key] = result - return result - - def namespace(self, namespace: str): - return Cache( - disable=self.disable, - prefix=self.prefix + namespace + ".", - cache=self.cache, - ) - - def get(self, key: str): - key = self.prefix + key - return self.cache[key] - -def repeat_concat( - vid: torch.FloatTensor, # (VL ... c) - txt: torch.FloatTensor, # (TL ... c) - vid_len: torch.LongTensor, # (n*b) - txt_len: torch.LongTensor, # (b) - txt_repeat: List, # (n) -) -> torch.FloatTensor: # (L ... c) - vid = torch.split(vid, vid_len.tolist()) - txt = torch.split(txt, txt_len.tolist()) - txt = [[x] * n for x, n in zip(txt, txt_repeat)] - txt = list(chain(*txt)) - return torch.cat(list(chain(*zip(vid, txt)))) - -def concat( - vid: torch.FloatTensor, # (VL ... c) - txt: torch.FloatTensor, # (TL ... c) - vid_len: torch.LongTensor, # (b) - txt_len: torch.LongTensor, # (b) -) -> torch.FloatTensor: # (L ... c) - vid = torch.split(vid, vid_len.tolist()) - txt = torch.split(txt, txt_len.tolist()) - return torch.cat(list(chain(*zip(vid, txt)))) - -def concat_idx( - vid_len: torch.LongTensor, # (b) - txt_len: torch.LongTensor, # (b) -) -> Tuple[ - Callable, - Callable, -]: - device = vid_len.device - vid_idx = torch.arange(vid_len.sum(), device=device) - txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) - tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) - src_idx = torch.argsort(tgt_idx) - return ( - lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), - lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), - ) - - -def repeat_concat_idx( - vid_len: torch.LongTensor, # (n*b) - txt_len: torch.LongTensor, # (b) - txt_repeat: torch.LongTensor, # (n) -) -> Tuple[ - Callable, - Callable, -]: - device = vid_len.device - vid_idx = torch.arange(vid_len.sum(), device=device) - txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) - txt_repeat_list = txt_repeat.tolist() - tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) - src_idx = torch.argsort(tgt_idx) - txt_idx_len = len(tgt_idx) - len(vid_idx) - repeat_txt_len = (txt_len * txt_repeat).tolist() - - def unconcat_coalesce(all): - vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) - txt_out_coalesced = [] - for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): - txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) - txt_out_coalesced.append(txt) - return vid_out, torch.cat(txt_out_coalesced) - - return ( - lambda vid, txt: torch.cat([vid, txt])[tgt_idx], - lambda all: unconcat_coalesce(all), - ) - - -@dataclass -class MMArg: - vid: Any - txt: Any - -def safe_pad_operation(x, padding, mode='constant', value=0.0): - """Safe padding operation that handles Half precision only for problematic modes""" - # Modes qui nécessitent le fix Half precision - problematic_modes = ['replicate', 'reflect', 'circular'] - - if mode in problematic_modes: - try: - return F.pad(x, padding, mode=mode, value=value) - except RuntimeError as e: - if "not implemented for 'Half'" in str(e): - original_dtype = x.dtype - return F.pad(x.float(), padding, mode=mode, value=value).to(original_dtype) - else: - raise e - else: - # Pour 'constant' et autres modes compatibles, pas de fix nécessaire - return F.pad(x, padding, mode=mode, value=value) - - -def get_args(key: str, args: List[Any]) -> List[Any]: - return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] - - -def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: - return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} - - -def get_window_op(name: str): - if name == "720pwin_by_size_bysize": - return make_720Pwindows_bysize - if name == "720pswin_by_size_bysize": - return make_shifted_720Pwindows_bysize - raise ValueError(f"Unknown windowing method: {name}") - - -# -------------------------------- Windowing -------------------------------- # -def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): - t, h, w = size - resized_nt, resized_nh, resized_nw = num_windows - #cal windows under 720p - scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) - resized_h, resized_w = round(h * scale), round(w * scale) - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. - nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. - return [ - ( - slice(it * wt, min((it + 1) * wt, t)), - slice(ih * wh, min((ih + 1) * wh, h)), - slice(iw * ww, min((iw + 1) * ww, w)), - ) - for iw in range(nw) - if min((iw + 1) * ww, w) > iw * ww - for ih in range(nh) - if min((ih + 1) * wh, h) > ih * wh - for it in range(nt) - if min((it + 1) * wt, t) > it * wt - ] - -def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): - t, h, w = size - resized_nt, resized_nh, resized_nw = num_windows - #cal windows under 720p - scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) - resized_h, resized_w = round(h * scale), round(w * scale) - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. - - st, sh, sw = ( # shift size. - 0.5 if wt < t else 0, - 0.5 if wh < h else 0, - 0.5 if ww < w else 0, - ) - nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. - nt, nh, nw = ( # number of window. - nt + 1 if st > 0 else 1, - nh + 1 if sh > 0 else 1, - nw + 1 if sw > 0 else 1, - ) - return [ - ( - slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)), - slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)), - slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)), - ) - for iw in range(nw) - if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0) - for ih in range(nh) - if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0) - for it in range(nt) - if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0) - ] - -class RotaryEmbedding(nn.Module): - def __init__( - self, - dim, - custom_freqs = None, - freqs_for = 'lang', - theta = 10000, - max_freq = 10, - num_freqs = 1, - learned_freq = False, - use_xpos = False, - xpos_scale_base = 512, - interpolate_factor = 1., - theta_rescale_factor = 1., - seq_before_head_dim = False, - cache_if_possible = True, - cache_max_seq_len = 8192 - ): - super().__init__() - - theta *= theta_rescale_factor ** (dim / (dim - 2)) - - self.freqs_for = freqs_for - - if exists(custom_freqs): - freqs = custom_freqs - elif freqs_for == 'lang': - freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) - elif freqs_for == 'pixel': - freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi - elif freqs_for == 'constant': - freqs = torch.ones(num_freqs).float() - - self.cache_if_possible = cache_if_possible - self.cache_max_seq_len = cache_max_seq_len - - self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) - self.cached_freqs_seq_len = 0 - - self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) - - self.learned_freq = learned_freq - - # dummy for device - - self.register_buffer('dummy', torch.tensor(0), persistent = False) - - # default sequence dimension - - self.seq_before_head_dim = seq_before_head_dim - self.default_seq_dim = -3 if seq_before_head_dim else -2 - - # interpolation factors - - assert interpolate_factor >= 1. - self.interpolate_factor = interpolate_factor - - # xpos - - self.use_xpos = use_xpos - - if not use_xpos: - return - - scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) - self.scale_base = xpos_scale_base - - self.register_buffer('scale', scale, persistent = False) - self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False) - self.cached_scales_seq_len = 0 - - # add apply_rotary_emb as static method - - self.apply_rotary_emb = staticmethod(apply_rotary_emb) - - @property - def device(self): - return self.dummy.device - - def get_axial_freqs( - self, - *dims, - offsets = None - ): - Colon = slice(None) - all_freqs = [] - - # handle offset - - if exists(offsets): - assert len(offsets) == len(dims) - - for ind, dim in enumerate(dims): - - offset = 0 - if exists(offsets): - offset = offsets[ind] - - if self.freqs_for == 'pixel': - pos = torch.linspace(-1, 1, steps = dim, device = self.device) - else: - pos = torch.arange(dim, device = self.device) - - pos = pos + offset - - freqs = self.forward(pos, seq_len = dim) - - all_axis = [None] * len(dims) - all_axis[ind] = Colon - - new_axis_slice = (Ellipsis, *all_axis, Colon) - all_freqs.append(freqs[new_axis_slice]) - - # concat all freqs - - all_freqs = torch.broadcast_tensors(*all_freqs) - return torch.cat(all_freqs, dim = -1) - - def forward( - self, - t, - seq_len: int | None = None, - offset = 0 - ): - should_cache = ( - self.cache_if_possible and - not self.learned_freq and - exists(seq_len) and - self.freqs_for != 'pixel' and - (offset + seq_len) <= self.cache_max_seq_len - ) - - if ( - should_cache and \ - exists(self.cached_freqs) and \ - (offset + seq_len) <= self.cached_freqs_seq_len - ): - return self.cached_freqs[offset:(offset + seq_len)].detach() - - freqs = self.freqs - - freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs) - freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2) - - if should_cache and offset == 0: - self.cached_freqs[:seq_len] = freqs.detach() - self.cached_freqs_seq_len = seq_len - - return freqs - -class RotaryEmbeddingBase(nn.Module): - def __init__(self, dim: int, rope_dim: int): - super().__init__() - self.rope = RotaryEmbedding( - dim=dim // rope_dim, - freqs_for="pixel", - max_freq=BYTEDANCE_ROPE_MAX_FREQ, - ) - freqs = self.rope.freqs - del self.rope.freqs - self.rope.register_buffer("freqs", freqs.data) - - def get_axial_freqs(self, *dims): - return self.rope.get_axial_freqs(*dims) - - -class RotaryEmbedding3d(RotaryEmbeddingBase): - def __init__(self, dim: int): - super().__init__(dim, rope_dim=3) - self.mm = False - - def forward( - self, - q: torch.FloatTensor, # b h l d - k: torch.FloatTensor, # b h l d - size: Tuple[int, int, int], - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - T, H, W = size - freqs = self.get_axial_freqs(T, H, W) - q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) - k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) - q = apply_rotary_emb(freqs, q.float()).to(q.dtype) - k = apply_rotary_emb(freqs, k.float()).to(k.dtype) - q = rearrange(q, "b h T H W d -> b h (T H W) d") - k = rearrange(k, "b h T H W d -> b h (T H W) d") - return q, k - - -class NaRotaryEmbedding3d(RotaryEmbedding3d): - def forward( - self, - q: torch.FloatTensor, - k: torch.FloatTensor, - shape: torch.LongTensor, - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape)) - freqs = freqs.to(device=q.device) - q = rearrange(q, "L h d -> h L d") - k = rearrange(k, "L h d -> h L d") - q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype) - k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype) - q = rearrange(q, "h L d -> L h d") - k = rearrange(k, "h L d -> L h d") - return q, k - - @torch._dynamo.disable - def get_freqs( - self, - shape: torch.LongTensor, - ) -> torch.Tensor: - # Primary provenance: ByteDance-Seed/SeedVR models/dit/rope.py builds - # 7B pixel RoPE with the interleaved-angle convention, not Comfy's - # Flux freqs_cis matrix. - plain_rope = RotaryEmbedding( - dim=self.rope.freqs.numel() * 2, - freqs_for="pixel", - max_freq=BYTEDANCE_ROPE_MAX_FREQ, - ) - plain_rope = plain_rope.to(self.rope.dummy.device) - freq_list = [] - for f, h, w in shape.tolist(): - freqs = plain_rope.get_axial_freqs(f, h, w) - freq_list.append(freqs.view(-1, freqs.size(-1))) - return torch.cat(freq_list, dim=0) - - -class MMRotaryEmbeddingBase(RotaryEmbeddingBase): - def __init__(self, dim: int, rope_dim: int): - super().__init__(dim, rope_dim) - self.rope = RotaryEmbedding( - dim=dim // rope_dim, - freqs_for="lang", - theta=ROPE_THETA, - cache_if_possible=False, - ) - freqs = self.rope.freqs - del self.rope.freqs - self.rope.register_buffer("freqs", freqs.data) - self.mm = True - -def slice_at_dim(t, dim_slice: slice, *, dim): - dim += (t.ndim if dim < 0 else 0) - colons = [slice(None)] * t.ndim - colons[dim] = dim_slice - return t[tuple(colons)] - -# rotary embedding helper functions - -def rotate_half(x): - x = rearrange(x, '... (d r) -> ... d r', r = 2) - x1, x2 = x.unbind(dim = -1) - x = torch.stack((-x2, x1), dim = -1) - return rearrange(x, '... d r -> ... (d r)') -def exists(val): - return val is not None - -def apply_rotary_emb( - freqs, - t, - start_index = 0, - scale = 1., - seq_dim = -2, - freqs_seq_dim = None -): - dtype = t.dtype - if not exists(freqs_seq_dim): - if freqs.ndim == 2 or t.ndim == 3: - freqs_seq_dim = 0 - - if t.ndim == 3 or exists(freqs_seq_dim): - seq_len = t.shape[seq_dim] - freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) - - rot_feats = freqs.shape[-1] - end_index = start_index + rot_feats - - t_left = t[..., :start_index] - t_middle = t[..., start_index:end_index] - t_right = t[..., end_index:] - - angles = freqs.to(t_middle.device)[..., ::2] - cos = torch.cos(angles) * scale - sin = torch.sin(angles) * scale - - col0 = torch.stack([cos, sin], dim=-1) - col1 = torch.stack([-sin, cos], dim=-1) - freqs_mat = torch.stack([col0, col1], dim=-1) - - t_middle_out = apply_rope1(t_middle, freqs_mat) - out = torch.cat((t_left, t_middle_out, t_right), dim=-1) - return out.type(dtype) - - -def _apply_seedvr2_rotary_emb( - freqs: torch.Tensor, - t: torch.Tensor, - start_index: int = 0, - scale: float = 1.0, - seq_dim: int = -2, - freqs_seq_dim: int | None = None, -) -> torch.Tensor: - dtype = t.dtype - if freqs_seq_dim is None and (freqs.ndim == 2 or t.ndim == 3): - freqs_seq_dim = 0 - - if t.ndim == 3 or freqs_seq_dim is not None: - seq_len = t.shape[seq_dim] - freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim) - - rot_feats = freqs.shape[-1] - end_index = start_index + rot_feats - - t_left = t[..., :start_index] - t_middle = t[..., start_index:end_index] - t_right = t[..., end_index:] - - freqs = freqs.to(device=t_middle.device, dtype=t_middle.dtype) - cos = freqs.cos() * scale - sin = freqs.sin() * scale - t_middle = (t_middle * cos) + (rotate_half(t_middle) * sin) - return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype) - -def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: - """Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`.""" - angles = freqs_interleaved[..., ::2].float() - cos = torch.cos(angles) - sin = torch.sin(angles) - out = torch.stack([cos, -sin, sin, cos], dim=-1) - return rearrange(out, "... d (i j) -> ... d i j", i=2, j=2) - - -def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - """Rotate the leading ``rot_d = 2 * freqs_cis.shape[-3]`` dims of ``t`` and pass the rest - through; in-place for inference, cloned for training (autograd). Mirrors the legacy - ``apply_rotary_emb`` ``t_left``/``t_middle``/``t_right`` split: 3B ``rope_dim=128`` gives - ``42*3 = 126`` rotated of head_dim 128 (trailing 2 unrotated). Fast path skips the cat when - ``rot_d == t.shape[-1]``. - """ - out = t.clone() if t.requires_grad or comfy.model_management.in_training else t - rot_d = 2 * freqs_cis.shape[-3] - seq_len = out.shape[-2] - for start in range(0, seq_len, SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS): - end = min(start + SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, seq_len) - freqs_chunk = freqs_cis[start:end] - if rot_d == out.shape[-1]: - out[..., start:end, :] = apply_rope1(out[..., start:end, :], freqs_chunk).to(out.dtype) - else: - out[..., start:end, :rot_d] = apply_rope1(out[..., start:end, :rot_d], freqs_chunk).to(out.dtype) - return out - - -class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): - def __init__(self, dim: int): - super().__init__(dim, rope_dim=3) - - def forward( - self, - vid_q: torch.FloatTensor, # L h d - vid_k: torch.FloatTensor, # L h d - vid_shape: torch.LongTensor, # B 3 - txt_q: torch.FloatTensor, # L h d - txt_k: torch.FloatTensor, # L h d - txt_shape: torch.LongTensor, # B 1 - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - torch.FloatTensor, - torch.FloatTensor, - ]: - vid_freqs, txt_freqs = cache( - "mmrope_freqs_3d", - lambda: self.get_freqs(vid_shape, txt_shape), - ) - target_device = vid_q.device - if vid_freqs.device != target_device: - vid_freqs = vid_freqs.to(target_device) - if txt_freqs.device != target_device: - txt_freqs = txt_freqs.to(target_device) - vid_q = rearrange(vid_q, "L h d -> h L d") - vid_k = rearrange(vid_k, "L h d -> h L d") - vid_q = _apply_rope1_partial(vid_q, vid_freqs) - vid_k = _apply_rope1_partial(vid_k, vid_freqs) - vid_q = rearrange(vid_q, "h L d -> L h d") - vid_k = rearrange(vid_k, "h L d -> L h d") - - txt_q = rearrange(txt_q, "L h d -> h L d") - txt_k = rearrange(txt_k, "L h d -> h L d") - txt_q = _apply_rope1_partial(txt_q, txt_freqs) - txt_k = _apply_rope1_partial(txt_k, txt_freqs) - txt_q = rearrange(txt_q, "h L d -> L h d") - txt_k = rearrange(txt_k, "h L d -> L h d") - return vid_q, vid_k, txt_q, txt_k - - @torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks - def get_freqs( - self, - vid_shape: torch.LongTensor, - txt_shape: torch.LongTensor, - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - ]: - - # Calculate actual max dimensions needed for this batch - max_temporal = 0 - max_height = 0 - max_width = 0 - max_txt_len = 0 - - for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal - max_height = max(max_height, h) - max_width = max(max_width, w) - max_txt_len = max(max_txt_len, l) - - autocast_device = "cuda" if torch.cuda.is_available() else "cpu" - with torch.amp.autocast(autocast_device, enabled=False): - vid_freqs = self.get_axial_freqs( - max_temporal + 16, - max_height + 4, - max_width + 4, - ).float() - txt_freqs = self.get_axial_freqs(max_txt_len + 16) - - # Now slice as before - vid_freq_list, txt_freq_list = [], [] - for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) - txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1)) - vid_freq_list.append(vid_freq) - txt_freq_list.append(txt_freq) - vid_freqs_interleaved = torch.cat(vid_freq_list, dim=0) - txt_freqs_interleaved = torch.cat(txt_freq_list, dim=0) - - # Convert from lucidrains-interleaved layout `[θ0, θ0, θ1, θ1, ...]` - # (produced by `repeat(freqs, '... n -> ... (n r)', r=2)` in the - # upstream `RotaryEmbedding.forward`) to flux-canonical `freqs_cis` - # in shape `[..., d/2, 2, 2]` with `cos/-sin/sin/cos` baked in. - # Mirrors `comfy/ldm/flux/math.py:rope` (line 27) so the trailing - # 2x2 is the per-frequency rotation matrix that - # `comfy.ldm.flux.math.apply_rope1` expects. - return _to_flux_freqs_cis(vid_freqs_interleaved), _to_flux_freqs_cis(txt_freqs_interleaved) - -class MMModule(nn.Module): - def __init__( - self, - module: Callable[..., nn.Module], - *args, - shared_weights: bool = False, - vid_only: bool = False, - **kwargs, - ): - super().__init__() - self.shared_weights = shared_weights - self.vid_only = vid_only - if self.shared_weights: - assert get_args("vid", args) == get_args("txt", args) - assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) - self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) - else: - self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) - self.txt = ( - module(*get_args("txt", args), **get_kwargs("txt", kwargs)) - if not vid_only - else None - ) - - def forward( - self, - vid: torch.FloatTensor, - txt: torch.FloatTensor, - *args, - **kwargs, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - vid_module = self.vid if not self.shared_weights else self.all - vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) - if not self.vid_only: - txt_module = self.txt if not self.shared_weights else self.all - txt = txt.to(device=vid.device, dtype=vid.dtype) - txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) - return vid, txt - -def get_na_rope(rope_type: Optional[str], dim: int): - if rope_type is None: - return None - if rope_type == "rope3d": - return NaRotaryEmbedding3d(dim=dim) - if rope_type == "mmrope3d": - return NaMMRotaryEmbedding3d(dim=dim) - -class NaMMAttention(nn.Module): - def __init__( - self, - vid_dim: int, - txt_dim: int, - heads: int, - head_dim: int, - qk_bias: bool, - qk_norm, - qk_norm_eps: float, - rope_type: Optional[str], - rope_dim: int, - shared_weights: bool, - device, dtype, operations, - **kwargs, - ): - super().__init__() - dim = MMArg(vid_dim, txt_dim) - self.heads = heads - inner_dim = heads * head_dim - qkv_dim = inner_dim * 3 - self.head_dim = head_dim - self.proj_qkv = MMModule( - operations.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights, device=device, dtype=dtype - ) - self.proj_out = MMModule(operations.Linear, inner_dim, dim, shared_weights=shared_weights, device=device, dtype=dtype) - self.norm_q = MMModule( - qk_norm, - normalized_shape=head_dim, - eps=qk_norm_eps, - elementwise_affine=True, - shared_weights=shared_weights, - device=device, dtype=dtype - ) - self.norm_k = MMModule( - qk_norm, - normalized_shape=head_dim, - eps=qk_norm_eps, - elementwise_affine=True, - shared_weights=shared_weights, - device=device, dtype=dtype - ) - - - self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) - - def forward(self): - pass - -def window( - hid: torch.FloatTensor, # (L c) - hid_shape: torch.LongTensor, # (b n) - window_fn: Callable[[torch.Tensor], List[torch.Tensor]], -): - hid = unflatten(hid, hid_shape) - hid = list(map(window_fn, hid)) - hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) - hid, hid_shape = flatten(list(chain(*hid))) - return hid, hid_shape, hid_windows - -def window_idx( - hid_shape: torch.LongTensor, # (b n) - window_fn: Callable[[torch.Tensor], List[torch.Tensor]], -): - hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) - tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) - tgt_idx = tgt_idx.squeeze(-1) - src_idx = torch.argsort(tgt_idx) - return ( - lambda hid: torch.index_select(hid, 0, tgt_idx), - lambda hid: torch.index_select(hid, 0, src_idx), - tgt_shape, - tgt_windows, - ) - -class NaSwinAttention(NaMMAttention): - def __init__( - self, - *args, - window: Union[int, Tuple[int, int, int]], - window_method: bool, # shifted or not - **kwargs, - ): - super().__init__(*args, **kwargs) - self.version_7b = kwargs.get("version", False) - self.window = _triple(window) - self.window_method = window_method - assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) - - self.window_op = get_window_op(window_method) - - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - - vid_qkv, txt_qkv = self.proj_qkv(vid, txt) - - # re-org the input seq for window attn - cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") - - def make_window(x: torch.Tensor): - t, h, w, _ = x.shape - window_slices = self.window_op((t, h, w), self.window) - return [x[st, sh, sw] for (st, sh, sw) in window_slices] - - window_partition, window_reverse, window_shape, window_count = cache_win( - "win_transform", - lambda: window_idx(vid_shape, make_window), - ) - vid_qkv_win = window_partition(vid_qkv) - - vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) - txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) - - vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) - txt_q, txt_k, txt_v = txt_qkv.unbind(1) - - vid_q, txt_q = self.norm_q(vid_q, txt_q) - vid_k, txt_k = self.norm_k(vid_k, txt_k) - - txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) - - vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) - txt_len = txt_len.to(window_count.device) - - # window rope - if self.rope: - if self.version_7b: - vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) - elif self.rope.mm: - # repeat text q and k for window mmrope - _, num_h, _ = txt_q.shape - txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") - txt_q_repeat = unflatten(txt_q_repeat, txt_shape) - txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] - txt_q_repeat = list(chain(*txt_q_repeat)) - txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) - txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) - - txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") - txt_k_repeat = unflatten(txt_k_repeat, txt_shape) - txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] - txt_k_repeat = list(chain(*txt_k_repeat)) - txt_k_repeat, _ = flatten(txt_k_repeat) - txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) - - vid_q, vid_k, txt_q, txt_k = self.rope( - vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win - ) - else: - vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) - - txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) - all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) - concat_win, unconcat_win = cache_win( - "mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count) - ) - out = optimized_var_attention( - q=concat_win(vid_q, txt_q), - k=concat_win(vid_k, txt_k), - v=concat_win(vid_v, txt_v), - heads=self.heads, skip_reshape=True, skip_output_reshape=True, - cu_seqlens_q=cache_win( - "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() - ), - cu_seqlens_k=cache_win( - "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() - ), - ) - vid_out, txt_out = unconcat_win(out) - - vid_out = rearrange(vid_out, "l h d -> l (h d)") - txt_out = rearrange(txt_out, "l h d -> l (h d)") - vid_out = window_reverse(vid_out) - - vid_out, txt_out = self.proj_out(vid_out, txt_out) - - return vid_out, txt_out - -class MLP(nn.Module): - def __init__( - self, - dim: int, - expand_ratio: int, - device, dtype, operations - ): - super().__init__() - self.proj_in = operations.Linear(dim, dim * expand_ratio, device=device, dtype=dtype) - self.act = nn.GELU("tanh") - self.proj_out = operations.Linear(dim * expand_ratio, dim, device=device, dtype=dtype) - - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - x = self.proj_in(x) - x = self.act(x) - x = self.proj_out(x) - return x - - -class SwiGLUMLP(nn.Module): - def __init__( - self, - dim: int, - expand_ratio: int, - multiple_of: int = 256, - device=None, dtype=None, operations=None - ): - super().__init__() - hidden_dim = int(2 * dim * expand_ratio / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.proj_in_gate = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) - self.proj_out = operations.Linear(hidden_dim, dim, bias=False, device=device, dtype=dtype) - self.proj_in = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) - - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) - -def get_mlp(mlp_type: Optional[str] = "normal"): - # 3b and 7b uses different mlp types - if mlp_type == "normal": - return MLP - elif mlp_type == "swiglu": - return SwiGLUMLP - -class NaMMSRTransformerBlock(nn.Module): - def __init__( - self, - *, - vid_dim: int, - txt_dim: int, - emb_dim: int, - heads: int, - head_dim: int, - expand_ratio: int, - norm, - norm_eps: float, - ada, - qk_bias: bool, - qk_norm, - mlp_type: str, - shared_weights: bool, - rope_type: str, - rope_dim: int, - is_last_layer: bool, - device, dtype, operations, - **kwargs, - ): - super().__init__() - version = kwargs.get("version", False) - dim = MMArg(vid_dim, txt_dim) - self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype) - - self.attn = NaSwinAttention( - vid_dim=vid_dim, - txt_dim=txt_dim, - heads=heads, - head_dim=head_dim, - qk_bias=qk_bias, - qk_norm=qk_norm, - qk_norm_eps=norm_eps, - rope_type=rope_type, - rope_dim=rope_dim, - shared_weights=shared_weights, - window=kwargs.pop("window", None), - window_method=kwargs.pop("window_method", None), - version=version, - device=device, dtype=dtype, operations=operations - ) - - self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) - self.mlp = MMModule( - get_mlp(mlp_type), - dim=dim, - expand_ratio=expand_ratio, - shared_weights=shared_weights, - vid_only=is_last_layer, - device=device, dtype=dtype, operations=operations - ) - self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) - self.is_last_layer = is_last_layer - self.version = version - - def _seedvr2_7b_mlp( - self, - vid: torch.FloatTensor, - txt: torch.FloatTensor, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - vid_module = self.mlp.vid if not self.mlp.shared_weights else self.mlp.all - if comfy.model_management.in_training or vid.requires_grad: - vid = torch.cat([vid_module(chunk) for chunk in vid.split(SEEDVR2_7B_MLP_CHUNK, dim=0)], dim=0) - else: - vid_out = None - offset = 0 - for chunk in vid.split(SEEDVR2_7B_MLP_CHUNK, dim=0): - chunk_out = vid_module(chunk) - if vid_out is None: - vid_out = chunk_out.new_empty((vid.shape[0], *chunk_out.shape[1:])) - vid_out[offset:offset + chunk_out.shape[0]] = chunk_out - offset += chunk_out.shape[0] - vid = vid_out - if not self.mlp.vid_only: - txt_module = self.mlp.txt if not self.mlp.shared_weights else self.mlp.all - txt = txt.to(device=vid.device, dtype=vid.dtype) - txt = txt_module(txt) - return vid, txt - - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - emb: torch.FloatTensor, - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - torch.LongTensor, - torch.LongTensor, - ]: - hid_len = MMArg( - cache("vid_len", lambda: vid_shape.prod(-1)), - cache("txt_len", lambda: txt_shape.prod(-1)), - ) - ada_kwargs = { - "emb": emb, - "hid_len": hid_len, - "cache": cache, - "branch_tag": MMArg("vid", "txt"), - } - - vid_attn, txt_attn = self.attn_norm(vid, txt) - vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) - vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) - vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) - vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) - - vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) - vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) - if self.version: - vid_mlp, txt_mlp = self._seedvr2_7b_mlp(vid_mlp, txt_mlp) - else: - vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) - vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) - vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) - - return vid_mlp, txt_mlp, vid_shape, txt_shape - -class PatchOut(nn.Module): - def __init__( - self, - out_channels: int, - patch_size: Union[int, Tuple[int, int, int]], - dim: int, - device, dtype, operations - ): - super().__init__() - t, h, w = _triple(patch_size) - self.patch_size = t, h, w - self.proj = operations.Linear(dim, out_channels * t * h * w, device=device, dtype=dtype) - - def forward( - self, - vid: torch.Tensor, - ) -> torch.Tensor: - t, h, w = self.patch_size - vid = self.proj(vid) - vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) - if t > 1: - vid = vid[:, :, (t - 1) :] - return vid - -class NaPatchOut(PatchOut): - def forward( - self, - vid: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), # for test - vid_shape_before_patchify = None - ) -> Tuple[ - torch.FloatTensor, - torch.LongTensor, - ]: - - t, h, w = self.patch_size - vid = self.proj(vid) - - if not (t == h == w == 1): - vid = unflatten(vid, vid_shape) - for i in range(len(vid)): - vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) - if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: - vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] - vid, vid_shape = flatten(vid) - - return vid, vid_shape - -class PatchIn(nn.Module): - def __init__( - self, - in_channels: int, - patch_size: Union[int, Tuple[int, int, int]], - dim: int, - device, dtype, operations - ): - super().__init__() - t, h, w = _triple(patch_size) - self.patch_size = t, h, w - self.proj = operations.Linear(in_channels * t * h * w, dim, device=device, dtype=dtype) - - def forward( - self, - vid: torch.Tensor, - ) -> torch.Tensor: - t, h, w = self.patch_size - if t > 1: - assert vid.size(2) % t == 1 - vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) - vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) - vid = self.proj(vid) - return vid - -class NaPatchIn(PatchIn): - def forward( - self, - vid: torch.Tensor, # l c - vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), # for test - ) -> torch.Tensor: - cache = cache.namespace("patch") - vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) - t, h, w = self.patch_size - if not (t == h == w == 1): - vid = unflatten(vid, vid_shape) - for i in range(len(vid)): - if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: - vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) - vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) - vid, vid_shape = flatten(vid) - - vid = self.proj(vid) - return vid, vid_shape - -def expand_dims(x: torch.Tensor, dim: int, ndim: int): - shape = x.shape - shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] - return x.reshape(shape) - - -class AdaSingle(nn.Module): - def __init__( - self, - dim: int, - emb_dim: int, - layers: List[str], - modes: List[str] = ["in", "out"], - device = None, dtype = None, - ): - assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" - super().__init__() - self.dim = dim - self.emb_dim = emb_dim - self.layers = layers - - randn_kwargs = {"device": device} - fp8_types = _torch_float8_types() - if dtype is not None and dtype not in fp8_types: - randn_kwargs["dtype"] = dtype - - for l in layers: - if "in" in modes: - # Passing fp8 ``dtype=`` here would break CPU weight - # loads: CPU has no ``normal_kernel_cpu`` for fp8. - self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5)) - self.register_parameter( - f"{l}_scale", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5 + 1) - ) - if "out" in modes: - self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5)) - - def forward( - self, - hid: torch.FloatTensor, # b ... c - emb: torch.FloatTensor, # b d - layer: str, - mode: str, - cache: Cache = Cache(disable=True), - branch_tag: str = "", - hid_len: Optional[torch.LongTensor] = None, # b - ) -> torch.FloatTensor: - idx = self.layers.index(layer) - emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] - emb = expand_dims(emb, 1, hid.ndim + 1) - - if hid_len is not None: - slice_inputs = lambda x, dim: x - emb = cache( - f"emb_repeat_{idx}_{branch_tag}", - lambda: slice_inputs( - torch.repeat_interleave(emb, hid_len, dim=0), - dim=0, - ), - ) - - shiftA, scaleA, gateA = emb.unbind(-1) - shiftB, scaleB, gateB = ( - getattr(self, f"{layer}_shift", None), - getattr(self, f"{layer}_scale", None), - getattr(self, f"{layer}_gate", None), - ) - - fp8_types = _torch_float8_types() - if fp8_types: - target_dtype = hid.dtype - - if shiftB is not None and shiftB.dtype in fp8_types: - shiftB = shiftB.to(target_dtype) - if scaleB is not None and scaleB.dtype in fp8_types: - scaleB = scaleB.to(target_dtype) - if gateB is not None and gateB.dtype in fp8_types: - gateB = gateB.to(target_dtype) - - if mode == "in": - return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) - if mode == "out": - if gateB is not None: - return hid.mul_(gateA + gateB) - else: - return hid.mul_(gateA) - - raise NotImplementedError - - -def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): - return emb1 if emb2 is None else emb1 + emb2 - - -class TimeEmbedding(nn.Module): - def __init__( - self, - sinusoidal_dim: int, - hidden_dim: int, - output_dim: int, - device, dtype, operations - ): - super().__init__() - self.sinusoidal_dim = sinusoidal_dim - self.proj_in = operations.Linear(sinusoidal_dim, hidden_dim, device=device, dtype=dtype) - self.proj_hid = operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype) - self.proj_out = operations.Linear(hidden_dim, output_dim, device=device, dtype=dtype) - self.act = nn.SiLU() - - def forward( - self, - timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], - device: torch.device, - dtype: torch.dtype, - ) -> torch.FloatTensor: - if not torch.is_tensor(timestep): - timestep = torch.tensor([timestep], device=device, dtype=dtype) - if timestep.ndim == 0: - timestep = timestep[None] - - emb = get_timestep_embedding( - timesteps=timestep, - embedding_dim=self.sinusoidal_dim, - flip_sin_to_cos=False, - downscale_freq_shift=0, - ).to(dtype) - emb = self.proj_in(emb) - emb = self.act(emb) - emb = self.proj_hid(emb) - emb = self.act(emb) - emb = self.proj_out(emb) - return emb - -def flatten( - hid: List[torch.FloatTensor], # List of (*** c) -) -> Tuple[ - torch.FloatTensor, # (L c) - torch.LongTensor, # (b n) -]: - assert len(hid) > 0 - shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) - hid = torch.cat([x.flatten(0, -2) for x in hid]) - return hid, shape - - -def unflatten( - hid: torch.FloatTensor, # (L c) or (L ... c) - hid_shape: torch.LongTensor, # (b n) -) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) - hid_len = hid_shape.prod(-1) - hid = hid.split(hid_len.tolist()) - hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] - return hid - -def repeat( - hid: torch.FloatTensor, # (L c) - hid_shape: torch.LongTensor, # (b n) - pattern: str, - **kwargs: Dict[str, torch.LongTensor], # (b) -) -> Tuple[ - torch.FloatTensor, - torch.LongTensor, -]: - hid = unflatten(hid, hid_shape) - kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] - return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) - -class NaDiT(nn.Module): - - def __init__( - self, - norm_eps, - qk_rope, - num_layers, - mlp_type, - vid_in_channels = 33, - vid_out_channels = 16, - vid_dim = 2560, - txt_in_dim = 5120, - heads = 20, - head_dim = 128, - mm_layers = 10, - expand_ratio = 4, - qk_bias = False, - patch_size = [ 1,2,2 ], - shared_qkv: bool = False, - shared_mlp: bool = False, - window_method: Optional[Tuple[str]] = None, - temporal_window_size: int = None, - temporal_shifted: bool = False, - rope_dim = 128, - rope_type = "mmrope3d", - vid_out_norm: Optional[str] = None, - device = None, - dtype = None, - operations = None, - **kwargs, - ): - self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM - if self._7b_version: - rope_type = "rope3d" - self.dtype = dtype - factory_kwargs = {"device": device, "dtype": dtype} - window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] - txt_dim = vid_dim - emb_dim = vid_dim * 6 - block_type = ["mmdit_sr"] * num_layers - window = num_layers * [(4,3,3)] - ada = AdaSingle - norm = CustomRMSNorm - qk_norm = CustomRMSNorm - if isinstance(block_type, str): - block_type = [block_type] * num_layers - elif len(block_type) != num_layers: - raise ValueError("The ``block_type`` list should equal to ``num_layers``.") - super().__init__() - # ``torch.empty`` returns uninitialized memory, not zeros. The - # SeedVR2Conditioning fail-loud guard at - # ``comfy_extras/nodes_seedvr.py`` distinguishes "buffer was loaded" - # from "buffer was never populated by the file" by checking - # ``positive_conditioning.abs().sum() == 0``. That sentinel is only - # reliable if the post-construction buffer state is deterministically - # zero, so explicitly zero-fill here rather than relying on the - # allocator's zero-on-alloc behavior (allocator-dependent and not - # contractual). When ``load_state_dict`` populates these buffers - # from a properly-baked SeedVR2 .safetensors, the in-place copy - # overwrites the zeros with the universal SeedVR2 conditioning - # tensors (shape (58, 5120) and (64, 5120) bf16). - self.register_buffer("positive_conditioning", torch.zeros((58, 5120), device=device, dtype=dtype)) - self.register_buffer("negative_conditioning", torch.zeros((64, 5120), device=device, dtype=dtype)) - self.vid_in = NaPatchIn( - in_channels=vid_in_channels, - patch_size=patch_size, - dim=vid_dim, - device=device, dtype=dtype, operations=operations - ) - self.txt_in = ( - operations.Linear(txt_in_dim, txt_dim, **factory_kwargs) - if txt_in_dim and txt_in_dim != txt_dim - else nn.Identity() - ) - self.emb_in = TimeEmbedding( - sinusoidal_dim=BYTEDANCE_SINUSOIDAL_DIM, - hidden_dim=max(vid_dim, txt_dim), - output_dim=emb_dim, - device=device, dtype=dtype, operations=operations - ) - - if window is None or isinstance(window[0], int): - window = [window] * num_layers - if window_method is None or isinstance(window_method, str): - window_method = [window_method] * num_layers - if temporal_window_size is None or isinstance(temporal_window_size, int): - temporal_window_size = [temporal_window_size] * num_layers - if temporal_shifted is None or isinstance(temporal_shifted, bool): - temporal_shifted = [temporal_shifted] * num_layers - - rope_dim = rope_dim if rope_dim is not None else head_dim // 2 - self.blocks = nn.ModuleList( - [ - NaMMSRTransformerBlock( - vid_dim=vid_dim, - txt_dim=txt_dim, - emb_dim=emb_dim, - heads=heads, - head_dim=head_dim, - expand_ratio=expand_ratio, - norm=norm, - norm_eps=norm_eps, - ada=ada, - qk_bias=qk_bias, - qk_rope=qk_rope, - qk_norm=qk_norm, - shared_qkv=shared_qkv, - shared_mlp=shared_mlp, - mlp_type=mlp_type, - rope_dim = rope_dim, - window=window[i], - window_method=window_method[i], - temporal_window_size=temporal_window_size[i], - temporal_shifted=temporal_shifted[i], - is_last_layer=(i == num_layers - 1) and not self._7b_version, - rope_type = rope_type, - shared_weights=not ( - (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] - ), - version = self._7b_version, - operations = operations, - **kwargs, - **factory_kwargs - ) - for i in range(num_layers) - ] - ) - self.vid_out = NaPatchOut( - out_channels=vid_out_channels, - patch_size=patch_size, - dim=vid_dim, - device=device, dtype=dtype, operations=operations - ) - - self.need_txt_repeat = block_type[0] in [ - "mmdit_stwin", - "mmdit_stwin_spatial", - "mmdit_stwin_3d_spatial", - ] - - self.vid_out_norm = None - if vid_out_norm is not None: - self.vid_out_norm = CustomRMSNorm( - normalized_shape=vid_dim, - eps=norm_eps, - elementwise_affine=True, - device=device, dtype=dtype - ) - self.vid_out_ada = ada( - dim=vid_dim, - emb_dim=emb_dim, - layers=["out"], - modes=["in"], - device=device, dtype=dtype - ) - - def _resolve_text_conditioning(self, context, cond_or_uncond=None): - if context is None or getattr(context, "numel", lambda: None)() == 0: - context = self.positive_conditioning - return flatten([context]) - if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): - if context.shape[0] == 1: - context = context.squeeze(0) - return flatten([context]) - return flatten(context.unbind(0)) - if context.shape[0] % 2 != 0: - raise ValueError(f"SeedVR2 expected an even text-conditioning batch, got shape {tuple(context.shape)}") - neg_cond, pos_cond = context.chunk(2, dim=0) - if pos_cond.shape[0] == 1: - pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) - return flatten([pos_cond, neg_cond]) - return flatten((*pos_cond.unbind(0), *neg_cond.unbind(0))) - - @staticmethod - def _seedvr2_is_single_conditioning_branch(cond_or_uncond): - if cond_or_uncond is None or len(cond_or_uncond) == 0: - return False - first = cond_or_uncond[0] - return all(entry == first for entry in cond_or_uncond) - - def _swap_pos_neg_halves(self, out, cond_or_uncond=None): - if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): - return out - # ``dim=0`` is explicit on both calls. The contract is "split - # the batch axis into two halves and swap them"; making the - # axis load-bearing in source guards against silent drift if a - # future refactor reorders tensor axes. - pos, neg = out.chunk(2, dim=0) - return torch.cat([neg, pos], dim=0) - - def forward( - self, - x, - timestep, - context, # l c - disable_cache: bool = False, # for test # TODO ? // gives an error when set to True - **kwargs - ): - transformer_options = kwargs.get("transformer_options", {}) - patches_replace = transformer_options.get("patches_replace", {}) - blocks_replace = patches_replace.get("dit", {}) - conditions = kwargs.get("condition") - b, tc, h, w = x.shape - x = x.view(b, 16, -1, h, w) - conditions = conditions.view(b, 17, -1, h, w) - x = x.movedim(1, -1) - conditions = conditions.movedim(1, -1) - cache = Cache(disable=disable_cache) - - txt, txt_shape = self._resolve_text_conditioning(context, transformer_options.get("cond_or_uncond")) - - vid, vid_shape = flatten(x) - cond_latent, _ = flatten(conditions) - - vid = torch.cat([vid, cond_latent], dim=-1) - if txt_shape.size(-1) == 1 and self.need_txt_repeat: - txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) - - txt = self.txt_in(txt) - - vid_shape_before_patchify = vid_shape - vid, vid_shape = self.vid_in(vid, vid_shape, cache=cache) - - emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) - - for i, block in enumerate(self.blocks): - if ("block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] = block( - vid=args["vid"], - txt=args["txt"], - vid_shape=args["vid_shape"], - txt_shape=args["txt_shape"], - emb=args["emb"], - cache=args["cache"], - ) - return out - out = blocks_replace[("block", i)]({ - "vid":vid, - "txt":txt, - "vid_shape":vid_shape, - "txt_shape":txt_shape, - "emb":emb, - "cache":cache, - }, {"original_block": block_wrap}) - vid, txt, vid_shape, txt_shape = out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] - else: - vid, txt, vid_shape, txt_shape = block( - vid=vid, - txt=txt, - vid_shape=vid_shape, - txt_shape=txt_shape, - emb=emb, - cache=cache, - ) - - if self.vid_out_norm: - vid = self.vid_out_norm(vid) - vid = self.vid_out_ada( - vid, - emb=emb, - layer="out", - mode="in", - hid_len=cache("vid_len", lambda: vid_shape.prod(-1)), - cache=cache, - branch_tag="vid", - ) - - vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) - vid = unflatten(vid, vid_shape) - out = torch.stack(vid) - out = out.movedim(-1, 1) - out = rearrange(out, "b c t h w -> b (c t) h w") - return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond")) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py deleted file mode 100644 index 68b11c0ff813..000000000000 --- a/comfy/ldm/seedvr/vae.py +++ /dev/null @@ -1,2110 +0,0 @@ -from contextlib import nullcontext -from typing import Literal, Optional, Tuple -import gc -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch import Tensor -from contextlib import contextmanager -from comfy.utils import ProgressBar - -from comfy.ldm.seedvr.model import safe_pad_operation -from comfy.ldm.seedvr.constants import ( - BYTEDANCE_BLOCK_OUT_CHANNELS, - BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD, - BYTEDANCE_GN_CHUNKS_FP16, - BYTEDANCE_GN_CHUNKS_FP32, - BYTEDANCE_LOGVAR_CLAMP_MAX, - BYTEDANCE_LOGVAR_CLAMP_MIN, - BYTEDANCE_SLICING_SAMPLE_MIN, - BYTEDANCE_VAE_CONV_MEM_GIB, - BYTEDANCE_VAE_NORM_MEM_GIB, - BYTEDANCE_VAE_SCALING_FACTOR, - BYTEDANCE_VAE_SHIFTING_FACTOR, - BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE, - BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE, - SEEDVR2_LATENT_CHANNELS, -) -from comfy.ldm.modules.attention import optimized_attention -from comfy.ldm.modules.diffusionmodules.model import vae_attention - -import math -from enum import Enum -from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND - -import logging -import comfy.model_management -import comfy.ops -ops = comfy.ops.disable_weight_init - - -def _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap, temporal_scale=1): - if temporal_size is None: - return None - - temporal_size = int(temporal_size) - if temporal_size <= 0: - return 0 - - temporal_overlap = max(0, int(temporal_overlap or 0)) - temporal_overlap = min(temporal_overlap, temporal_size - 1) - temporal_step = temporal_size - temporal_overlap - temporal_scale = max(1, int(temporal_scale)) - return max(1, math.ceil(temporal_step / temporal_scale)) - - -def _seedvr2_clamped_spatial_overlap(overlap, tile_size): - overlap = max(0, int(overlap)) - tile_size = max(1, int(tile_size)) - return min(overlap, tile_size - 1) - - -def _seedvr2_clear_temporal_memory(model): - for module in model.modules(): - if hasattr(module, "memory"): - module.memory = None - - -@torch.inference_mode() -def tiled_vae( - x, - vae_model, - tile_size=(512, 512), - tile_overlap=(64, 64), - temporal_size=16, - temporal_overlap=0, - encode=True, - **kwargs, -): - gc.collect() - comfy.model_management.soft_empty_cache() - - x = x.to(next(vae_model.parameters()).dtype) - if x.ndim != 5: - x = x.unsqueeze(2) - - _, _, d, h, w = x.shape - - sf_s = getattr(vae_model, "spatial_downsample_factor", BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE) - sf_t = getattr(vae_model, "temporal_downsample_factor", BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE) - if encode: - slicing_attr = "slicing_sample_min_size" - slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap) - else: - slicing_attr = "slicing_latent_min_size" - slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap, sf_t) - if encode: - ti_h, ti_w = tile_size - ov_h = _seedvr2_clamped_spatial_overlap(tile_overlap[0], ti_h) - ov_w = _seedvr2_clamped_spatial_overlap(tile_overlap[1], ti_w) - blend_ov_h = max(0, ov_h // sf_s) - blend_ov_w = max(0, ov_w // sf_s) - target_d = (d + sf_t - 1) // sf_t - target_h = (h + sf_s - 1) // sf_s - target_w = (w + sf_s - 1) // sf_s - else: - ti_h = max(1, tile_size[0] // sf_s) - ti_w = max(1, tile_size[1] // sf_s) - ov_h = _seedvr2_clamped_spatial_overlap(tile_overlap[0] // sf_s, ti_h) - ov_w = _seedvr2_clamped_spatial_overlap(tile_overlap[1] // sf_s, ti_w) - blend_ov_h = ov_h * sf_s - blend_ov_w = ov_w * sf_s - - target_d = max(1, d * sf_t - (sf_t - 1)) - target_h = h * sf_s - target_w = w * sf_s - - stride_h = max(1, ti_h - ov_h) - stride_w = max(1, ti_w - ov_w) - - storage_device = vae_model.device - result = None - count = None - def run_temporal_chunks(spatial_tile, model=vae_model, device=storage_device): - device = torch.device(device) - _seedvr2_clear_temporal_memory(model) - t_chunk = spatial_tile.to(device=device, dtype=next(model.parameters()).dtype, non_blocking=True).contiguous() - old_device = getattr(model, "device", None) - model.device = device - old_slicing_min_size = getattr(model, slicing_attr, None) - if old_slicing_min_size is not None and slicing_min_size is not None: - if slicing_min_size <= 0: - setattr(model, slicing_attr, t_chunk.shape[2]) - else: - setattr(model, slicing_attr, slicing_min_size) - try: - if encode: - out = model.encode(t_chunk)[0] - else: - out = model.decode_(t_chunk) - finally: - if old_slicing_min_size is not None and slicing_min_size is not None: - setattr(model, slicing_attr, old_slicing_min_size) - if old_device is not None: - model.device = old_device - if isinstance(out, (tuple, list)): - out = out[0] - if out.ndim == 4: - out = out.unsqueeze(2) - return out.to(storage_device) - - ramp_cache = {} - def get_ramp(steps): - if steps not in ramp_cache: - t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32) - ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi) - return ramp_cache[steps] - - tile_ranges = [] - for y_idx in range(0, h, stride_h): - y_end = min(y_idx + ti_h, h) - if y_idx > 0 and (y_end - y_idx) <= ov_h: - continue - for x_idx in range(0, w, stride_w): - x_end = min(x_idx + ti_w, w) - if x_idx > 0 and (x_end - x_idx) <= ov_w: - continue - tile_ranges.append((y_idx, y_end, x_idx, x_end)) - - total_tiles = len(tile_ranges) - bar = ProgressBar(total_tiles) - single_spatial_tile = h <= ti_h and w <= ti_w - - _seedvr2_clear_temporal_memory(vae_model) - - def run_tile(tile_index, tile_range): - y_idx, y_end, x_idx, x_end = tile_range - tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end] - tile_out = run_temporal_chunks(tile_x) - return tile_index, y_idx, y_end, x_idx, x_end, tile_out - - ordered_tile_outputs = ( - run_tile(tile_index, tile_range) - for tile_index, tile_range in enumerate(tile_ranges) - ) - - for _, y_idx, y_end, x_idx, x_end, tile_out in ordered_tile_outputs: - - if single_spatial_tile: - result = tile_out[:, :, :target_d, :target_h, :target_w] - if result.device != x.device: - result = result.to(x.device).to(x.dtype) - if x.shape[2] == 1 and sf_t == 1: - result = result.squeeze(2) - bar.update(1) - return result - - if result is None: - b_out, c_out = tile_out.shape[0], tile_out.shape[1] - result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) - count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32) - - if encode: - ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] - xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] - cur_ov_h = max(0, min(blend_ov_h, tile_out.shape[3] // 2)) - cur_ov_w = max(0, min(blend_ov_w, tile_out.shape[4] // 2)) - else: - ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3] - xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4] - cur_ov_h = max(0, min(blend_ov_h, tile_out.shape[3] // 2)) - cur_ov_w = max(0, min(blend_ov_w, tile_out.shape[4] // 2)) - - w_h = torch.ones((tile_out.shape[3],), device=storage_device) - w_w = torch.ones((tile_out.shape[4],), device=storage_device) - - if cur_ov_h > 0: - r = get_ramp(cur_ov_h) - if y_idx > 0: - w_h[:cur_ov_h] = r - if y_end < h: - w_h[-cur_ov_h:] = 1.0 - r - - if cur_ov_w > 0: - r = get_ramp(cur_ov_w) - if x_idx > 0: - w_w[:cur_ov_w] = r - if x_end < w: - w_w[-cur_ov_w:] = 1.0 - r - - final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) - - valid_d = min(tile_out.shape[2], result.shape[2]) - tile_out = tile_out[:, :, :valid_d, :, :] - - tile_out.mul_(final_weight) - - result[:, :, :valid_d, ys:ye, xs:xe] += tile_out - count[:, :, :, ys:ye, xs:xe] += final_weight - - del tile_out, final_weight, w_h, w_w - bar.update(1) - - result.div_(count.clamp(min=1e-6)) - _seedvr2_clear_temporal_memory(vae_model) - - if result.device != x.device: - result = result.to(x.device).to(x.dtype) - - if x.shape[2] == 1 and sf_t == 1: - result = result.squeeze(2) - - return result - -_NORM_LIMIT = float("inf") -def get_norm_limit(): - return _NORM_LIMIT - - -def set_norm_limit(value: Optional[float] = None): - global _NORM_LIMIT - if value is None: - value = float("inf") - _NORM_LIMIT = value - -@contextmanager -def ignore_padding(model): - orig_padding = model.padding - model.padding = (0, 0, 0) - try: - yield - finally: - model.padding = orig_padding - -class MemoryState(Enum): - DISABLED = 0 - INITIALIZING = 1 - ACTIVE = 2 - UNSET = 3 - -def get_cache_size(conv_module, input_len, pad_len, dim=0): - dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 - output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 - remain_len = ( - input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) - ) - overlap_len = dilated_kernerl_size - conv_module.stride[dim] - cache_len = overlap_len + remain_len # >= 0 - - assert output_len > 0 - return cache_len - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters: torch.Tensor, deterministic: bool = False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, BYTEDANCE_LOGVAR_CLAMP_MIN, BYTEDANCE_LOGVAR_CLAMP_MAX) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like( - self.mean, device=self.parameters.device, dtype=self.parameters.dtype - ) - - def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: - sample = torch.randn( - self.mean.shape, - generator=generator, - device=self.parameters.device, - dtype=self.parameters.dtype, - ) - x = self.mean + self.std * sample - return x - - def mode(self): - return self.mean - -class SpatialNorm(nn.Module): - def __init__( - self, - f_channels: int, - zq_channels: int, - ): - super().__init__() - self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) - self.conv_y = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - self.conv_b = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: - f_size = f.shape[-2:] - zq = F.interpolate(zq, size=f_size, mode="nearest") - norm_f = self.norm_layer(f) - new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) - return new_f - -# partial implementation of diffusers's Attention for comfyui -class Attention(nn.Module): - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - kv_heads: Optional[int] = None, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - added_kv_proj_dim: Optional[int] = None, - added_proj_bias: Optional[bool] = True, - norm_num_groups: Optional[int] = None, - spatial_norm_dim: Optional[int] = None, - out_bias: bool = True, - scale_qk: bool = True, - only_cross_attention: bool = False, - eps: float = 1e-5, - rescale_output_factor: float = 1.0, - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - out_dim: int = None, - out_context_dim: int = None, - context_pre_only=None, - pre_only=False, - is_causal: bool = False, - ): - super().__init__() - - self.inner_dim = out_dim if out_dim is not None else dim_head * heads - self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads - self.query_dim = query_dim - self.use_bias = bias - self.is_cross_attention = cross_attention_dim is not None - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - self.rescale_output_factor = rescale_output_factor - self.residual_connection = residual_connection - self.dropout = dropout - self.fused_projections = False - self.out_dim = out_dim if out_dim is not None else query_dim - self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim - self.context_pre_only = context_pre_only - self.pre_only = pre_only - self.is_causal = is_causal - - # we make use of this private variable to know whether this class is loaded - # with an deprecated state dict so that we can convert it on the fly - self._from_deprecated_attn_block = _from_deprecated_attn_block - - self.scale_qk = scale_qk - self.scale = dim_head**-0.5 if self.scale_qk else 1.0 - - self.heads = out_dim // dim_head if out_dim is not None else heads - self.sliceable_head_dim = heads - - self.added_kv_proj_dim = added_kv_proj_dim - self.only_cross_attention = only_cross_attention - - if norm_num_groups is not None: - self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) - else: - self.group_norm = None - - if spatial_norm_dim is not None: - self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) - else: - self.spatial_norm = None - - self.norm_q = None - self.norm_k = None - - self.norm_cross = None - self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias) - - if not self.only_cross_attention: - # only relevant for the `AddedKVProcessor` classes - self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - else: - self.to_k = None - self.to_v = None - - self.added_proj_bias = added_proj_bias - if self.added_kv_proj_dim is not None: - self.add_k_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) - self.add_v_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) - if self.context_pre_only is not None: - self.add_q_proj = ops.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - else: - self.add_q_proj = None - self.add_k_proj = None - self.add_v_proj = None - - if not self.pre_only: - self.to_out = nn.ModuleList([]) - self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(nn.Dropout(dropout)) - else: - self.to_out = None - - if self.context_pre_only is not None and not self.context_pre_only: - self.to_add_out = ops.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) - else: - self.to_add_out = None - - self.norm_added_q = None - self.norm_added_k = None - self.optimized_vae_attention = vae_attention() - - def __call__( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, - ) -> torch.Tensor: - - residual = hidden_states - if self.spatial_norm is not None: - hidden_states = self.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1]) - - if self.group_norm is not None: - hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = self.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif self.norm_cross: - encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) - - key = self.to_k(encoder_hidden_states) - value = self.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // self.heads - - query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - - if self.norm_q is not None: - query = self.norm_q(query) - if self.norm_k is not None: - key = self.norm_k(key) - - if input_ndim == 4 and encoder_hidden_states is hidden_states and attention_mask is None and self.heads == 1: - query = query.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) - key = key.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) - value = value.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) - hidden_states = self.optimized_vae_attention(query, key, value).reshape(batch_size, self.heads, head_dim, height * width).transpose(2, 3) - else: - hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if self.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / self.rescale_output_factor - - return hidden_states - - -def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor): - with torch.no_grad(): - depth = weight_3d.size(2) - weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) - return weight_3d - -def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor): - with torch.no_grad(): - bias_3d.copy_(bias_2d) - return bias_3d - - -def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): - weight_name = prefix + "weight" - bias_name = prefix + "bias" - if weight_name in state_dict: - weight_2d = state_dict[weight_name] - if weight_2d.dim() == 4: - weight_3d = inflate_weight_fn( - weight_2d=weight_2d, - weight_3d=layer.weight, - ) - state_dict[weight_name] = weight_3d - else: - return state_dict - if bias_name in state_dict: - bias_2d = state_dict[bias_name] - if bias_2d.dim() == 1: - bias_3d = inflate_bias_fn( - bias_2d=bias_2d, - bias_3d=layer.bias, - ) - state_dict[bias_name] = bias_3d - return state_dict - -def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: - input_dtype = x.dtype - if isinstance(norm_layer, (ops.LayerNorm, ops.RMSNorm)): - if x.ndim == 4: - x = rearrange(x, "b c h w -> b h w c") - x = norm_layer(x) - x = rearrange(x, "b h w c -> b c h w") - return x.to(input_dtype) - if x.ndim == 5: - x = rearrange(x, "b c t h w -> b t h w c") - x = norm_layer(x) - x = rearrange(x, "b t h w c -> b c t h w") - return x.to(input_dtype) - if isinstance(norm_layer, (ops.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): - if x.ndim <= 4: - return norm_layer(x).to(input_dtype) - if x.ndim == 5: - t = x.size(2) - x = rearrange(x, "b c t h w -> (b t) c h w") - memory_occupy = x.numel() * x.element_size() / 1024**3 - if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit(): - num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups) - assert norm_layer.num_groups % num_chunks == 0 - num_groups_per_chunk = norm_layer.num_groups // num_chunks - - x = list(x.chunk(num_chunks, dim=1)) - weights = norm_layer.weight.chunk(num_chunks, dim=0) - biases = norm_layer.bias.chunk(num_chunks, dim=0) - for i, (w, b) in enumerate(zip(weights, biases)): - x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps) - x[i] = x[i].to(input_dtype) - x = torch.cat(x, dim=1) - else: - x = norm_layer(x) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - return x.to(input_dtype) - raise NotImplementedError - -def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): - problematic_modes = ['bilinear', 'bicubic', 'trilinear'] - - if mode in problematic_modes: - try: - return F.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ) - except RuntimeError as e: - if ("not implemented for 'Half'" in str(e) or - "compute_indices_weights" in str(e)): - original_dtype = x.dtype - return F.interpolate( - x.float(), - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ).to(original_dtype) - else: - raise e - else: - # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire - return F.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ) - -_receptive_field_t = Literal["half", "full"] - -def extend_head(tensor, times: int = 2, memory = None): - if memory is not None: - return torch.cat((memory.to(tensor), tensor), dim=2) - assert times >= 0, "Invalid input for function 'extend_head'!" - if times == 0: - return tensor - else: - tile_repeat = [1] * tensor.ndim - tile_repeat[2] = times - return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2) - -def cache_send_recv(tensor, cache_size, times, memory=None): - recv_buffer = None - - if memory is not None: - recv_buffer = memory.to(tensor[0]) - elif times > 0: - tile_repeat = [1] * tensor[0].ndim - tile_repeat[2] = times - recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat) - - return recv_buffer - -class InflatedCausalConv3d(ops.Conv3d): - def __init__( - self, - *args, - inflation_mode, - memory_device = "same", - **kwargs, - ): - self.inflation_mode = inflation_mode - self.memory = None - super().__init__(*args, **kwargs) - self.temporal_padding = self.padding[0] - self.memory_device = memory_device - self.padding = (0, *self.padding[1:]) - self.memory_limit = float("inf") - self.logged_once = False - - def set_memory_limit(self, value: float): - self.memory_limit = value - - def set_memory_device(self, memory_device): - self.memory_device = memory_device - - def _conv_forward(self, input, weight, bias, *args, **kwargs): - if (NVIDIA_MEMORY_CONV_BUG_WORKAROUND and - weight.dtype in (torch.float16, torch.bfloat16) and - hasattr(torch.backends.cudnn, 'is_available') and - torch.backends.cudnn.is_available() and - getattr(torch.backends.cudnn, 'enabled', True)): - try: - out = torch.cudnn_convolution( - input, weight, self.padding, self.stride, self.dilation, self.groups, - benchmark=False, deterministic=False, allow_tf32=True - ) - if bias is not None: - out += bias.reshape((1, -1) + (1,) * (out.ndim - 2)) - return out - except RuntimeError: - pass - except NotImplementedError: - pass - try: - return super()._conv_forward(input, weight, bias, *args, **kwargs) - except NotImplementedError: - # for: Could not run 'aten::cudnn_convolution' with arguments from the 'CPU' backend - if not self.logged_once: - logging.warning("VAE is on CPU for decoding. This is most likely due to not enough memory") - self.logged_once = True - return F.conv3d(input, weight, bias, *args, **kwargs) - - def memory_limit_conv( - self, - x, - *, - split_dim=3, - padding=(0, 0, 0, 0, 0, 0), - prev_cache=None, - ): - # Compatible with no limit. - if math.isinf(self.memory_limit): - if prev_cache is not None: - x = torch.cat([prev_cache, x], dim=split_dim - 1) - return super().forward(x) - - # Compute tensor shape after concat & padding. - shape = torch.tensor(x.size()) - if prev_cache is not None: - shape[split_dim - 1] += prev_cache.size(split_dim - 1) - shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0) - memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB - if memory_occupy < self.memory_limit or split_dim == x.ndim: - x_concat = x - if prev_cache is not None: - x_concat = torch.cat([prev_cache, x], dim=split_dim - 1) - - def pad_and_forward(): - padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0) - if not padded.is_contiguous(): - padded = padded.contiguous() - with ignore_padding(self): - return torch.nn.Conv3d.forward(self, padded) - - return pad_and_forward() - - num_splits = math.ceil(memory_occupy / self.memory_limit) - size_per_split = x.size(split_dim) // num_splits - split_sizes = [size_per_split] * (num_splits - 1) - split_sizes += [x.size(split_dim) - sum(split_sizes)] - - x = list(x.split(split_sizes, dim=split_dim)) - if prev_cache is not None: - prev_cache = list(prev_cache.split(split_sizes, dim=split_dim)) - cache = None - for idx in range(len(x)): - if prev_cache is not None: - x[idx] = torch.cat([prev_cache[idx], x[idx]], dim=split_dim - 1) - - lpad_dim = (x[idx].ndim - split_dim - 1) * 2 - rpad_dim = lpad_dim + 1 - padding = list(padding) - padding[lpad_dim] = self.padding[split_dim - 2] if idx == 0 else 0 - padding[rpad_dim] = self.padding[split_dim - 2] if idx == len(x) - 1 else 0 - pad_len = padding[lpad_dim] + padding[rpad_dim] - padding = tuple(padding) - - next_cache = None - cache_len = cache.size(split_dim) if cache is not None else 0 - next_catch_size = get_cache_size( - conv_module=self, - input_len=x[idx].size(split_dim) + cache_len, - pad_len=pad_len, - dim=split_dim - 2, - ) - if next_catch_size != 0: - assert next_catch_size <= x[idx].size(split_dim) - next_cache = ( - x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim) - ) - - x[idx] = self.memory_limit_conv( - x[idx], - split_dim=split_dim + 1, - padding=padding, - prev_cache=cache - ) - - cache = next_cache - - output = torch.cat(x, dim=split_dim) - return output - - def forward( - self, - input, - memory_state: MemoryState = MemoryState.UNSET - ) -> Tensor: - assert memory_state != MemoryState.UNSET - if memory_state != MemoryState.ACTIVE: - self.memory = None - if ( - math.isinf(self.memory_limit) - and torch.is_tensor(input) - ): - return self.basic_forward(input, memory_state) - return self.slicing_forward(input, memory_state) - - def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET): - mem_size = self.stride[0] - self.kernel_size[0] - if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): - input = extend_head(input, memory=self.memory, times=-1) - else: - input = extend_head(input, times=self.temporal_padding * 2) - memory = ( - input[:, :, mem_size:].detach() - if (mem_size != 0 and memory_state != MemoryState.DISABLED) - else None - ) - if ( - memory_state != MemoryState.DISABLED - and not self.training - and (self.memory_device is not None) - ): - self.memory = memory - if self.memory_device == "cpu" and self.memory is not None: - self.memory = self.memory.to("cpu") - return super().forward(input) - - def slicing_forward( - self, - input, - memory_state: MemoryState = MemoryState.UNSET, - ) -> Tensor: - squeeze_out = False - if torch.is_tensor(input): - input = [input] - squeeze_out = True - - cache_size = self.kernel_size[0] - self.stride[0] - cache = cache_send_recv( - input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2 - ) - - # Single GPU inference - simplified memory management - if ( - memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing - and not self.training - and (self.memory_device is not None) - and cache_size != 0 - ): - if cache_size > input[-1].size(2) and cache is not None and len(input) == 1: - input[0] = torch.cat([cache, input[0]], dim=2) - cache = None - if cache_size <= input[-1].size(2): - self.memory = input[-1][:, :, -cache_size:].detach().contiguous() - if self.memory_device == "cpu" and self.memory is not None: - self.memory = self.memory.to("cpu") - - padding = tuple(x for x in reversed(self.padding) for _ in range(2)) - for i in range(len(input)): - # Prepare cache for next input slice. - next_cache = None - cache_size = 0 - if i < len(input) - 1: - cache_len = cache.size(2) if cache is not None else 0 - cache_size = get_cache_size(self, input[i].size(2) + cache_len, pad_len=0) - if cache_size != 0: - if cache_size > input[i].size(2) and cache is not None: - input[i] = torch.cat([cache, input[i]], dim=2) - cache = None - assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}" - next_cache = input[i][:, :, -cache_size:] - - # Conv forward for this input slice. - input[i] = self.memory_limit_conv( - input[i], - padding=padding, - prev_cache=cache - ) - - # Update cache. - cache = next_cache - - return input[0] if squeeze_out else input - -def remove_head(tensor: Tensor, times: int = 1) -> Tensor: - if times == 0: - return tensor - return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) - -class Upsample3D(nn.Module): - - def __init__( - self, - channels, - out_channels = None, - inflation_mode = "tail", - temporal_up: bool = False, - spatial_up: bool = True, - slicing: bool = False, - interpolate = True, - name: str = "conv", - use_conv_transpose = False, - use_conv: bool = False, - padding = 1, - bias = True, - kernel_size = None, - **kwargs, - ): - super().__init__() - self.interpolate = interpolate - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv_transpose = use_conv_transpose - self.use_conv = use_conv - self.name = name - - self.conv = None - if use_conv_transpose: - if kernel_size is None: - kernel_size = 4 - self.conv = ops.ConvTranspose2d( - channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias - ) - elif use_conv: - if kernel_size is None: - kernel_size = 3 - self.conv = ops.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) - - conv = self.conv if self.name == "conv" else self.Conv2d_0 - - # Note: lora_layer is not passed into constructor in the original implementation. - # So we make a simplification. - conv = InflatedCausalConv3d( - self.channels, - self.out_channels, - 3, - padding=1, - inflation_mode=inflation_mode, - ) - - self.temporal_up = temporal_up - self.spatial_up = spatial_up - self.temporal_ratio = 2 if temporal_up else 1 - self.spatial_ratio = 2 if spatial_up else 1 - self.slicing = slicing - - assert not self.interpolate - # [Override] MAGViT v2 implementation - if not self.interpolate: - upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio - self.upscale_conv = ops.Conv3d( - self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 - ) - identity = ( - torch.eye(self.channels) - .repeat(upscale_ratio, 1) - .reshape_as(self.upscale_conv.weight) - ) - self.upscale_conv.weight.data.copy_(identity) - - if self.name == "conv": - self.conv = conv - else: - self.Conv2d_0 = conv - - self.norm = None - - def forward( - self, - hidden_states: torch.FloatTensor, - memory_state=None, - **kwargs, - ) -> torch.FloatTensor: - assert hidden_states.shape[1] == self.channels - - if hasattr(self, "norm") and self.norm is not None: - # [Overridden] change to causal norm. - hidden_states = causal_norm_wrapper(self.norm, hidden_states) - - if self.use_conv_transpose: - return self.conv(hidden_states) - - if self.slicing: - split_size = hidden_states.size(2) // 2 - hidden_states = list( - hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) - ) - else: - hidden_states = [hidden_states] - - for i in range(len(hidden_states)): - hidden_states[i] = self.upscale_conv(hidden_states[i]) - hidden_states[i] = rearrange( - hidden_states[i], - "b (x y z c) f h w -> b c (f z) (h x) (w y)", - x=self.spatial_ratio, - y=self.spatial_ratio, - z=self.temporal_ratio, - ) - - if self.temporal_up and memory_state != MemoryState.ACTIVE: - hidden_states[0] = remove_head(hidden_states[0]) - - if not self.slicing: - hidden_states = hidden_states[0] - - if self.use_conv: - if self.name == "conv": - hidden_states = self.conv(hidden_states, memory_state=memory_state) - else: - hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state) - - if not self.slicing: - return hidden_states - else: - return torch.cat(hidden_states, dim=2) - - -class Downsample3D(nn.Module): - """A 3D downsampling layer with an optional convolution.""" - - def __init__( - self, - channels, - out_channels = None, - inflation_mode = "tail", - spatial_down: bool = False, - temporal_down: bool = False, - name: str = "conv", - kernel_size=3, - use_conv: bool = False, - padding = 1, - bias=True, - **kwargs, - ): - super().__init__() - self.padding = padding - self.name = name - self.channels = channels - self.out_channels = out_channels or channels - self.temporal_down = temporal_down - self.spatial_down = spatial_down - self.use_conv = use_conv - self.padding = padding - - self.temporal_ratio = 2 if temporal_down else 1 - self.spatial_ratio = 2 if spatial_down else 1 - - self.temporal_kernel = 3 if temporal_down else 1 - self.spatial_kernel = 3 if spatial_down else 1 - - if use_conv: - conv = InflatedCausalConv3d( - self.channels, - self.out_channels, - kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), - stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), - padding=( - 1 if self.temporal_down else 0, - self.padding if self.spatial_down else 0, - self.padding if self.spatial_down else 0, - ), - inflation_mode=inflation_mode, - ) - else: - assert self.channels == self.out_channels - conv = nn.AvgPool3d( - kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), - stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), - ) - - self.conv = conv - - - def forward( - self, - hidden_states: torch.FloatTensor, - memory_state = None, - **kwargs, - ) -> torch.FloatTensor: - - assert hidden_states.shape[1] == self.channels - - if hasattr(self, "norm") and self.norm is not None: - # [Overridden] change to causal norm. - hidden_states = causal_norm_wrapper(self.norm, hidden_states) - - if self.use_conv and self.padding == 0 and self.spatial_down: - pad = (0, 1, 0, 1) - hidden_states = safe_pad_operation(hidden_states, pad, mode="constant", value=0) - - assert hidden_states.shape[1] == self.channels - - hidden_states = self.conv(hidden_states, memory_state=memory_state) - - return hidden_states - - -class ResnetBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: Optional[int] = None, - conv_shortcut: bool = False, - dropout: float = 0.0, - temb_channels: int = 512, - groups: int = 32, - groups_out: Optional[int] = None, - eps: float = 1e-6, - non_linearity: str = "swish", - time_embedding_norm: str = "default", - output_scale_factor: float = 1.0, - skip_time_act: bool = False, - use_in_shortcut: Optional[bool] = None, - up: bool = False, - down: bool = False, - conv_shortcut_bias: bool = True, - conv_2d_out_channels: Optional[int] = None, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - slicing: bool = False, - **kwargs, - ): - super().__init__() - self.up = up - self.down = down - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels - conv_2d_out_channels = conv_2d_out_channels or out_channels - self.use_in_shortcut = use_in_shortcut - self.output_scale_factor = output_scale_factor - self.skip_time_act = skip_time_act - self.nonlinearity = nn.SiLU() - if temb_channels is not None: - self.time_emb_proj = ops.Linear(temb_channels, out_channels) - else: - self.time_emb_proj = None - self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - if groups_out is None: - groups_out = groups - self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - self.use_in_shortcut = self.in_channels != out_channels - self.dropout = torch.nn.Dropout(dropout) - self.conv1 = InflatedCausalConv3d( - self.in_channels, - self.out_channels, - kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), - stride=1, - padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), - inflation_mode=inflation_mode, - ) - - self.conv2 = InflatedCausalConv3d( - self.out_channels, - conv_2d_out_channels, - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - self.upsample = self.downsample = None - if self.up: - self.upsample = Upsample3D( - self.in_channels, - use_conv=False, - inflation_mode=inflation_mode, - slicing=slicing, - ) - elif self.down: - self.downsample = Downsample3D( - self.in_channels, - use_conv=False, - padding=1, - name="op", - inflation_mode=inflation_mode, - ) - - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = InflatedCausalConv3d( - self.in_channels, - conv_2d_out_channels, - kernel_size=1, - stride=1, - padding=0, - bias=True, - inflation_mode=inflation_mode, - ) - - def forward( - self, input_tensor, temb, memory_state = None, **kwargs - ): - hidden_states = input_tensor - - hidden_states = causal_norm_wrapper(self.norm1, hidden_states) - - hidden_states = self.nonlinearity(hidden_states) - - if self.upsample is not None: - if hidden_states.shape[0] >= BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD: - input_tensor = input_tensor.contiguous() - hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor, memory_state=memory_state) - hidden_states = self.upsample(hidden_states, memory_state=memory_state) - elif self.downsample is not None: - input_tensor = self.downsample(input_tensor, memory_state=memory_state) - hidden_states = self.downsample(hidden_states, memory_state=memory_state) - - hidden_states = self.conv1(hidden_states, memory_state=memory_state) - - if self.time_emb_proj is not None: - if not self.skip_time_act: - temb = self.nonlinearity(temb) - temb = self.time_emb_proj(temb)[:, :, None, None] - - if temb is not None: - hidden_states = hidden_states + temb - - hidden_states = causal_norm_wrapper(self.norm2, hidden_states) - - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, memory_state=memory_state) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) - - output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - - return output_tensor - - -class DownEncoderBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - downsample_padding: int = 1, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_down: bool = True, - spatial_down: bool = True, - ): - super().__init__() - resnets = [] - temporal_modules = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - # [Override] Replace module. - ResnetBlock3D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - ) - temporal_modules.append(nn.Identity()) - - self.resnets = nn.ModuleList(resnets) - self.temporal_modules = nn.ModuleList(temporal_modules) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample3D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - temporal_down=temporal_down, - spatial_down=spatial_down, - inflation_mode=inflation_mode, - ) - ] - ) - else: - self.downsamplers = None - - def forward( - self, - hidden_states: torch.FloatTensor, - memory_state = None, - **kwargs, - ) -> torch.FloatTensor: - for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) - hidden_states = temporal(hidden_states) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, memory_state=memory_state) - - return hidden_states - - -class UpDecoderBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - temb_channels: Optional[int] = None, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_up: bool = True, - spatial_up: bool = True, - slicing: bool = False, - ): - super().__init__() - resnets = [] - temporal_modules = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - # [Override] Replace module. - ResnetBlock3D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - slicing=slicing, - ) - ) - - temporal_modules.append(nn.Identity()) - - self.resnets = nn.ModuleList(resnets) - self.temporal_modules = nn.ModuleList(temporal_modules) - - if add_upsample: - # [Override] Replace module & use learnable upsample - self.upsamplers = nn.ModuleList( - [ - Upsample3D( - out_channels, - use_conv=True, - out_channels=out_channels, - temporal_up=temporal_up, - spatial_up=spatial_up, - interpolate=False, - inflation_mode=inflation_mode, - slicing=slicing, - ) - ] - ) - else: - self.upsamplers = None - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - memory_state=None - ) -> torch.FloatTensor: - for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) - hidden_states = temporal(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, memory_state=memory_state) - - return hidden_states - - -class UNetMidBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - add_attention: bool = True, - attention_head_dim: int = 1, - output_scale_factor: float = 1.0, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - ): - super().__init__() - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - self.add_attention = add_attention - - # there is always at least one resnet - resnets = [ - # [Override] Replace module. - ResnetBlock3D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - ] - attentions = [] - - if attention_head_dim is None: - attention_head_dim = in_channels - - for _ in range(num_layers): - if self.add_attention: - attentions.append( - Attention( - in_channels, - heads=in_channels // attention_head_dim, - dim_head=attention_head_dim, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=( - resnet_groups if resnet_time_scale_shift == "default" else None - ), - spatial_norm_dim=( - temb_channels if resnet_time_scale_shift == "spatial" else None - ), - residual_connection=True, - bias=True, - upcast_softmax=True, - _from_deprecated_attn_block=True, - ) - ) - else: - attentions.append(None) - - resnets.append( - ResnetBlock3D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None, memory_state=None): - video_length, frame_height, frame_width = hidden_states.size()[-3:] - hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") - hidden_states = attn(hidden_states, temb=temb) - hidden_states = rearrange( - hidden_states, "(b f) c h w -> b c f h w", f=video_length - ) - hidden_states = resnet(hidden_states, temb, memory_state=memory_state) - - return hidden_states - - -class Encoder3D(nn.Module): - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - double_z: bool = True, - mid_block_add_attention=True, - # [Override] add extra_cond_dim, temporal down num - temporal_down_num: int = 2, - extra_cond_dim: int = None, - gradient_checkpoint: bool = False, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - ): - super().__init__() - self.layers_per_block = layers_per_block - self.temporal_down_num = temporal_down_num - - self.conv_in = InflatedCausalConv3d( - in_channels, - block_out_channels[0], - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - self.mid_block = None - self.down_blocks = nn.ModuleList([]) - self.extra_cond_dim = extra_cond_dim - - self.conv_extra_cond = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - # [Override] to support temporal down block design - is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 - # Note: take the last ones - - assert down_block_type == "DownEncoderBlock3D" - - down_block = DownEncoderBlock3D( - num_layers=self.layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - add_downsample=not is_final_block, - resnet_eps=1e-6, - downsample_padding=0, - # Note: Don't know why set it as 0 - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - temporal_down=is_temporal_down_block, - spatial_down=True, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - self.down_blocks.append(down_block) - - def zero_module(module): - # Zero out the parameters of a module and return it. - for p in module.parameters(): - p.detach().zero_() - return module - - self.conv_extra_cond.append( - zero_module( - ops.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) - ) - if self.extra_cond_dim is not None and self.extra_cond_dim > 0 - else None - ) - - # mid - self.mid_block = UNetMidBlock3D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=None, - add_attention=mid_block_add_attention, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # out - self.conv_norm_out = ops.GroupNorm( - num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 - ) - self.conv_act = nn.SiLU() - - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = InflatedCausalConv3d( - block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode - ) - - self.gradient_checkpointing = gradient_checkpoint - - def forward( - self, - sample: torch.FloatTensor, - extra_cond=None, - memory_state = None - ) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - sample = sample.to(next(self.parameters()).device) - sample = self.conv_in(sample, memory_state = memory_state) - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - # down - # [Override] add extra block and extra cond - for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), sample, use_reentrant=False - ) - if extra_block is not None: - sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) - - # middle - sample = self.mid_block(sample) - - else: - # down - # [Override] add extra block and extra cond - for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): - sample = down_block(sample, memory_state=memory_state) - if extra_block is not None: - sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) - - # middle - sample = self.mid_block(sample, memory_state=memory_state) - - # post-process - sample = causal_norm_wrapper(self.conv_norm_out, sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample, memory_state = memory_state) - - return sample - - -class Decoder3D(nn.Module): - - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",), - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - norm_type: str = "group", # group, spatial - mid_block_add_attention=True, - # [Override] add temporal up block - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_up_num: int = 2, - slicing_up_num: int = 0, - gradient_checkpoint: bool = False, - ): - super().__init__() - self.layers_per_block = layers_per_block - self.temporal_up_num = temporal_up_num - - self.conv_in = InflatedCausalConv3d( - in_channels, - block_out_channels[-1], - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - temb_channels = in_channels if norm_type == "spatial" else None - - # mid - self.mid_block = UNetMidBlock3D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default" if norm_type == "group" else norm_type, - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=temb_channels, - add_attention=mid_block_add_attention, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - is_temporal_up_block = i < self.temporal_up_num - is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num - # Note: Keep symmetric - - assert up_block_type == "UpDecoderBlock3D" - up_block = UpDecoderBlock3D( - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - add_upsample=not is_final_block, - resnet_eps=1e-6, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - resnet_time_scale_shift=norm_type, - temb_channels=temb_channels, - temporal_up=is_temporal_up_block, - slicing=is_slicing_up_block, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_type == "spatial": - self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) - else: - self.conv_norm_out = ops.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 - ) - self.conv_act = nn.SiLU() - self.conv_out = InflatedCausalConv3d( - block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode - ) - - self.gradient_checkpointing = gradient_checkpoint - - # Note: Just copy from Decoder. - def forward( - self, - sample: torch.FloatTensor, - latent_embeds: Optional[torch.FloatTensor] = None, - memory_state = None, - ) -> torch.FloatTensor: - - sample = sample.to(next(self.parameters()).device) - sample = self.conv_in(sample, memory_state=memory_state) - - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - # middle - sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = up_block(sample, latent_embeds, memory_state=memory_state) - - # post-process - sample = causal_norm_wrapper(self.conv_norm_out, sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample, memory_state=memory_state) - - return sample - -class VideoAutoencoderKL(nn.Module): - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - layers_per_block: int = 2, - act_fn: str = "silu", - latent_channels: int = SEEDVR2_LATENT_CHANNELS, - norm_num_groups: int = 32, - attention: bool = True, - temporal_scale_num: int = 2, - slicing_up_num: int = 0, - gradient_checkpoint: bool = False, - inflation_mode = "pad", - time_receptive_field: _receptive_field_t = "full", - use_quant_conv: bool = False, - use_post_quant_conv: bool = False, - slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN, - *args, - **kwargs, - ): - self.slicing_sample_min_size = slicing_sample_min_size - self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) - extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None - block_out_channels = BYTEDANCE_BLOCK_OUT_CHANNELS - down_block_types = ("DownEncoderBlock3D",) * 4 - up_block_types = ("UpDecoderBlock3D",) * 4 - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder3D( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=True, - extra_cond_dim=extra_cond_dim, - # [Override] add temporal_down_num parameter - temporal_down_num=temporal_scale_num, - gradient_checkpoint=gradient_checkpoint, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # pass init params to Decoder - self.decoder = Decoder3D( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - norm_num_groups=norm_num_groups, - act_fn=act_fn, - # [Override] add temporal_up_num parameter - temporal_up_num=temporal_scale_num, - slicing_up_num=slicing_up_num, - gradient_checkpoint=gradient_checkpoint, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - self.quant_conv = ( - InflatedCausalConv3d( - in_channels=2 * latent_channels, - out_channels=2 * latent_channels, - kernel_size=1, - inflation_mode=inflation_mode, - ) - if use_quant_conv - else None - ) - self.post_quant_conv = ( - InflatedCausalConv3d( - in_channels=latent_channels, - out_channels=latent_channels, - kernel_size=1, - inflation_mode=inflation_mode, - ) - if use_post_quant_conv - else None - ) - - # A hacky way to remove attention. - if not attention: - self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) - self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) - - self.use_slicing = True - - def encode(self, x: torch.FloatTensor, return_dict: bool = True): - h = self.slicing_encode(x) - posterior = DiagonalGaussianDistribution(h).mode() - - if not return_dict: - return (posterior,) - - return posterior - - def decode_( - self, z: torch.Tensor, return_dict: bool = True - ): - decoded = self.slicing_decode(z) - - if not return_dict: - return (decoded,) - - return decoded - - def _encode( - self, x, memory_state = MemoryState.DISABLED - ) -> torch.Tensor: - _x = x.to(self.device) - h = self.encoder(_x, memory_state=memory_state) - if self.quant_conv is not None: - output = self.quant_conv(h, memory_state=memory_state) - else: - output = h - return output.to(x.device) - - def _decode( - self, z, memory_state = MemoryState.DISABLED - ) -> torch.Tensor: - _z = z.to(self.device) - - if self.post_quant_conv is not None: - _z = self.post_quant_conv(_z, memory_state=memory_state) - - output = self.decoder(_z, memory_state=memory_state) - return output.to(z.device) - - def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: - sp_size =1 - if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: - split_size = max( - self.slicing_sample_min_size * sp_size, - getattr(self, "temporal_downsample_factor", 1), - ) - x_slices = list(x[:, :, 1:].split(split_size=split_size, dim=2)) - min_active_len = getattr(self, "temporal_downsample_factor", 1) - if len(x_slices) > 1 and x_slices[-1].shape[2] < min_active_len: - x_slices[-2] = torch.cat((x_slices[-2], x_slices[-1]), dim=2) - x_slices.pop() - encoded_slices = [ - self._encode( - torch.cat((x[:, :, :1], x_slices[0]), dim=2), - memory_state=MemoryState.INITIALIZING, - ) - ] - for x_idx in range(1, len(x_slices)): - encoded_slices.append( - self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) - ) - out = torch.cat(encoded_slices, dim=2) - modules_with_memory = [m for m in self.modules() - if isinstance(m, InflatedCausalConv3d) and m.memory is not None] - for m in modules_with_memory: - m.memory = None - return out - else: - return self._encode(x) - - def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: - sp_size = 1 - if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: - z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) - decoded_slices = [ - self._decode( - torch.cat((z[:, :, :1], z_slices[0]), dim=2), - memory_state=MemoryState.INITIALIZING - ) - ] - for z_idx in range(1, len(z_slices)): - decoded_slices.append( - self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) - ) - out = torch.cat(decoded_slices, dim=2) - modules_with_memory = [m for m in self.modules() - if isinstance(m, InflatedCausalConv3d) and m.memory is not None] - for m in modules_with_memory: - m.memory = None - return out - else: - return self._decode(z) - - def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - raise NotImplementedError - - def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: - raise NotImplementedError - - def forward( - self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs - ): - # x: [b c t h w] - def _unwrap(value): - return value[0] if isinstance(value, tuple) else value - - if mode == "encode": - return _unwrap(self.encode(x)) - elif mode == "decode": - return _unwrap(self.decode_(x)) - else: - latent = _unwrap(self.encode(x)) - return _unwrap(self.decode_(latent)) - -class VideoAutoencoderKLWrapper(VideoAutoencoderKL): - def __init__( - self, - *args, - spatial_downsample_factor = 8, - temporal_downsample_factor = 4, - freeze_encoder = True, - **kwargs, - ): - self.spatial_downsample_factor = spatial_downsample_factor - self.temporal_downsample_factor = temporal_downsample_factor - self.freeze_encoder = freeze_encoder - self.enable_tiling = False - super().__init__(*args, **kwargs) - self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB) - - def forward(self, x: torch.FloatTensor): - with torch.no_grad() if self.freeze_encoder else nullcontext(): - z, p = self.encode(x) - x = self.decode(z) - return x, z, p - - def encode(self, x, orig_dims=None): - if x.ndim == 4: - x = x.unsqueeze(2) - x = x.to(dtype=next(self.parameters()).dtype) - self.device = x.device - p = super().encode(x) - z = p.squeeze(2) - return z, p - - def decode(self, z, seedvr2_tiling=None): - seedvr2_tiling = {} if seedvr2_tiling is None else seedvr2_tiling - if not isinstance(seedvr2_tiling, dict): - raise RuntimeError( - "SeedVR2 VideoAutoencoderKLWrapper.decode: `seedvr2_tiling` must be a dict; " - f"got {type(seedvr2_tiling).__name__} with value {seedvr2_tiling!r}." - ) - - if z.ndim == 5: - b, c, t_latent, h, w = z.shape - if c != 16: - raise RuntimeError( - "SeedVR2 VideoAutoencoderKLWrapper.decode: 5-D latent input must " - f"have 16 channels; got shape {tuple(z.shape)}." - ) - latent = z - elif z.ndim == 4: - b, tc, h, w = z.shape - if tc % 16 != 0: - raise RuntimeError( - "SeedVR2 VideoAutoencoderKLWrapper.decode: 4-D latent input must " - "use collapsed channel layout (B, 16*T, H, W); " - f"got shape {tuple(z.shape)}." - ) - latent = z.reshape(b, 16, -1, h, w) - else: - raise RuntimeError( - "SeedVR2 VideoAutoencoderKLWrapper.decode: latent input must be " - "4-D collapsed (B, 16*T, H, W) or 5-D (B, 16, T, H, W); " - f"got shape {tuple(z.shape)}." - ) - scale = BYTEDANCE_VAE_SCALING_FACTOR - shift = BYTEDANCE_VAE_SHIFTING_FACTOR - latent = latent / scale + shift - - self.device = latent.device - self.enable_tiling = seedvr2_tiling.get("enable_tiling", False) - - if self.enable_tiling: - decode_seedvr2_args = dict(seedvr2_tiling) - tile_h, tile_w = decode_seedvr2_args.get("tile_size", (512, 512)) - ov_h, ov_w = decode_seedvr2_args.get("tile_overlap", (64, 64)) - decode_seedvr2_args["tile_overlap"] = ( - min(ov_h, max(0, tile_h - 8)), - min(ov_w, max(0, tile_w - 8)), - ) - x = tiled_vae(latent, self, **decode_seedvr2_args, encode=False) - if x.ndim == 4: - # tiled_vae squeezes the temporal axis when - # temporal_downsample_factor == 1 AND latent T == 1 - # (see tiled_vae line 179-180); re-add it so the post-decode - # pipeline can keep batch and time distinct on the tiled path. - x = x.unsqueeze(2) - else: - x = super().decode_(latent) - - # ensure even dims for save video - h, w = x.shape[-2:] - w2 = w - (w % 2) - h2 = h - (h % 2) - x = x[..., :h2, :w2] - - return x - - def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float], memory_device = "same"): - set_norm_limit(norm_max_mem) - for m in self.modules(): - if isinstance(m, InflatedCausalConv3d): - m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) - - for module in self.modules(): - if isinstance(module, InflatedCausalConv3d): - module.set_memory_device(memory_device) diff --git a/comfy/model_base.py b/comfy/model_base.py index c084e23bb19f..042804771890 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -54,8 +54,6 @@ import comfy.ldm.pixeldit.pid import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 -import comfy.ldm.seedvr.model - import comfy.ldm.qwen_image.model import comfy.ldm.ideogram4.model import comfy.ldm.kandinsky5.model @@ -930,16 +928,6 @@ def extra_conds(self, **kwargs): out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) return out -class SeedVR2(BaseModel): - def __init__(self, model_config, model_type=ModelType.FLOW, device=None): - super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) - def extra_conds(self, **kwargs): - out = super().extra_conds(**kwargs) - condition = kwargs.get("condition", None) - if condition is not None: - out["condition"] = comfy.conds.CONDRegular(condition) - return out - class PixArt(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 9555810065c4..74c838d13338 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -598,56 +598,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config - if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b - dit_config = {} - dit_config["image_model"] = "seedvr2" - dit_config["vid_dim"] = 3072 - dit_config["heads"] = 24 - dit_config["num_layers"] = 36 - # 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.`` - # submodules) at EVERY block — verified by inspecting the 7B - # state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means - # ``MMModule.shared_weights=False``). Native NaDiT computes - # per-block ``shared_weights = not (i < mm_layers)``, so to keep - # every block non-shared we set ``mm_layers = num_layers``. - # Without this, blocks at index >= mm_layers (default 10) try to - # load ``blocks.N.*.all.*`` keys that don't exist in the file, - # silently miss-load → all-black output. - dit_config["mm_layers"] = 36 - dit_config["norm_eps"] = 1e-5 - dit_config["qk_rope"] = True - dit_config["rope_type"] = "rope3d" - dit_config["rope_dim"] = 64 - dit_config["mlp_type"] = "normal" - return dit_config - elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b - dit_config = {} - dit_config["image_model"] = "seedvr2" - dit_config["vid_dim"] = 3072 - dit_config["heads"] = 24 - dit_config["num_layers"] = 36 - # This checkpoint layout carries shared ``all.`` MMModule keys. - # Preserve the historical split: the initial blocks use separate - # vid/txt modules, later blocks use shared modules. - dit_config["mm_layers"] = 10 - dit_config["norm_eps"] = 1e-5 - dit_config["qk_rope"] = True - dit_config["rope_type"] = "rope3d" - dit_config["rope_dim"] = 64 - dit_config["mlp_type"] = "swiglu" - return dit_config - elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b - dit_config = {} - dit_config["image_model"] = "seedvr2" - dit_config["vid_dim"] = 2560 - dit_config["heads"] = 20 - dit_config["num_layers"] = 32 - dit_config["norm_eps"] = 1.0e-05 - dit_config["qk_rope"] = None - dit_config["mlp_type"] = "swiglu" - dit_config["vid_out_norm"] = True - return dit_config - if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} dit_config["image_model"] = "wan2.1" diff --git a/comfy/sample.py b/comfy/sample.py index de71596b3e1d..2be0cae5f872 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -44,13 +44,7 @@ def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None, is_empty = torch.count_nonzero(latent_image) == 0 if is_empty: if latent_format.latent_channels != latent_image.shape[1]: - preserves_collapsed_channels = ( - getattr(latent_format, "preserve_empty_channel_multiples", False) - and latent_image.ndim == 4 - and latent_image.shape[1] % latent_format.latent_channels == 0 - ) - if not preserves_collapsed_channels: - latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) + latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) if downscale_ratio_spacial is not None: if downscale_ratio_spacial != latent_format.spacial_downscale_ratio: ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio diff --git a/comfy/sd.py b/comfy/sd.py index 8ac08ac42d86..a66ba1bfb76e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,4 +1,3 @@ -import inspect import json import torch from enum import Enum @@ -17,7 +16,6 @@ import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae -import comfy.ldm.seedvr.vae import comfy.ldm.triposplat.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.cogvideo.vae @@ -86,36 +84,6 @@ import comfy.ldm.flux.redux -SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL = 160 - - -def _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w): - output_t = max(1, (latent_t - 1) * 4 + 1) - return output_t * latent_h * 8 * latent_w * 8 - - -def _seedvr2_vae_decode_memory_used(shape): - if len(shape) == 5: - candidates = [] - if shape[1] == 16: - candidates.append((shape[2], shape[3], shape[4])) - if shape[-1] == 16: - candidates.append((shape[1], shape[2], shape[3])) - if len(candidates) == 0: - candidates.append((shape[2], shape[3], shape[4])) - output_pixels = max(_seedvr2_vae_decode_output_pixels(*candidate) for candidate in candidates) - elif len(shape) == 4: - latent_t = max(1, (shape[1] + 15) // 16) - latent_h, latent_w = shape[2], shape[3] - output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w) - else: - latent_t, latent_h, latent_w = 1, shape[-2], shape[-1] - output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w) - # SeedVR2 decode performs full-frame LAB histogram matching: fp32 channels - # plus int64 sort indices dominate peak memory, not the VAE weight dtype. - return output_pixels * SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL - - def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None): key_map = {} if model is not None: @@ -499,10 +467,8 @@ def decode(self, token_ids, skip_special_tokens=True): class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): - is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd - if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format - if metadata is None or metadata.get("keep_diffusers_format") != "true": - sd = diffusers_convert.convert_vae_state_dict(sd) + if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + sd = diffusers_convert.convert_vae_state_dict(sd) if model_management.is_amd(): VAE_KL_MEM_RATIO = 2.73 @@ -574,20 +540,6 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None) self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 - elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 - self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() - self.latent_channels = 16 - self.latent_dim = 3 - self.disable_offload = True - self.memory_used_decode = lambda shape, dtype: _seedvr2_vae_decode_memory_used(shape) - self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype) - self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] - self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) - self.downscale_index_formula = (4, 8, 8) - self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) - self.upscale_index_formula = (4, 8, 8) - self.process_input = lambda image: image * 2.0 - 1.0 - self.crop_input = False elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} @@ -715,7 +667,6 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None) self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] - elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] @@ -1055,40 +1006,6 @@ def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) - def decode_tiled_seedvr2(self, samples, tile_x=32, tile_y=32, overlap=8, tile_t=16, overlap_t=4): - sf_s = getattr(self.first_stage_model, "spatial_downsample_factor", 8) - sf_t = getattr(self.first_stage_model, "temporal_downsample_factor", 4) - if tile_t is None: - tile_t = 16 - if overlap_t is None: - overlap_t = 4 - if tile_t > 0: - temporal_size = tile_t * sf_t - temporal_overlap = max(0, overlap_t) * sf_t - else: - temporal_size = 0 - temporal_overlap = 0 - args = { - "enable_tiling": True, - "tile_size": (tile_y * sf_s, tile_x * sf_s), - "tile_overlap": (overlap * sf_s, overlap * sf_s), - "temporal_size": temporal_size, - "temporal_overlap": temporal_overlap, - } - output = self.first_stage_model.decode( - samples.to(self.vae_dtype).to(self.device), - seedvr2_tiling=args, - ) - return self.process_output(output.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)) - - def _format_seedvr2_encoded_samples(self, samples): - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - if samples.ndim == 4: - samples = samples.unsqueeze(2) - samples = samples.contiguous() - samples = samples * 0.9152 - return samples - def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) @@ -1125,36 +1042,6 @@ def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap= encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) - def encode_tiled_seedvr2(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): - if tile_y is None: - tile_y = 512 - if tile_x is None: - tile_x = 512 - if overlap is None: - overlap_y = 64 - overlap_x = 64 - else: - overlap_y = overlap - overlap_x = overlap - if tile_t is None: - tile_t = 9999 - if overlap_t is None: - overlap_t = 0 - overlap_y = min(overlap_y, max(0, tile_y - 8)) - overlap_x = min(overlap_x, max(0, tile_x - 8)) - self.first_stage_model.device = self.device - x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device) - output = comfy.ldm.seedvr.vae.tiled_vae( - x, - self.first_stage_model, - tile_size=(tile_y, tile_x), - tile_overlap=(overlap_y, overlap_x), - temporal_size=tile_t, - temporal_overlap=overlap_t, - encode=True, - ) - return output.to(device=self.output_device, dtype=self.vae_output_dtype()) - def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() pixel_samples = None @@ -1202,40 +1089,16 @@ def decode(self, samples_in, vae_options={}): if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) elif dims == 2: - # SeedVR2 latents arrive in 4D collapsed form ``(B, 16*T, H, W)`` - # downstream of ``SeedVR2Conditioning`` (which performs the - # ``rearrange(b c t h w -> b (c t) h w)`` collapse). The - # generic ``decode_tiled_`` would treat the channel dim as - # spatial-only and crash on the collapsed (16, T) layout - # under ``tiled_scale``'s mask broadcast; route SeedVR2 4D - # latents to ``decode_tiled_seedvr2`` instead, whose wrapper - # dispatch handles both 4D and 5D inputs. - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - tile = 256 // self.spacial_compression_decode() - overlap = tile // 4 - pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) - else: - pixel_samples = self.decode_tiled_(samples_in) + pixel_samples = self.decode_tiled_(samples_in) elif dims == 3: tile = 256 // self.spacial_compression_decode() overlap = tile // 4 - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) - else: - pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples - def decode_tiled( - self, - samples, - tile_x=None, - tile_y=None, - overlap=None, - tile_t=None, - overlap_t=None, - ): + def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -1249,20 +1112,7 @@ def decode_tiled( args["overlap"] = overlap with model_management.cuda_device_context(self.device): - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper) and dims in (2, 3): - seedvr2_args = {} - if tile_x is not None: - seedvr2_args["tile_x"] = tile_x - if tile_y is not None: - seedvr2_args["tile_y"] = tile_y - if overlap is not None: - seedvr2_args["overlap"] = overlap - if tile_t is not None: - seedvr2_args["tile_t"] = tile_t - if overlap_t is not None: - seedvr2_args["overlap_t"] = overlap_t - output = self.decode_tiled_seedvr2(samples, **seedvr2_args) - elif dims == 1 or self.extra_1d_channel is not None: + if dims == 1 or self.extra_1d_channel is not None: args.pop("tile_y") output = self.decode_tiled_1d(samples, **args) elif dims == 2: @@ -1304,8 +1154,6 @@ def encode(self, pixel_samples): else: pixels_in = pixels_in.to(self.device) out = self.first_stage_model.encode(pixels_in) - if isinstance(out, tuple): - out = out[0] out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) @@ -1325,23 +1173,20 @@ def encode(self, pixel_samples): if self.latent_dim == 3: tile = 256 overlap = tile // 4 - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - samples = self.encode_tiled_seedvr2(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap) - else: - samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) elif self.latent_dim == 1 or self.extra_1d_channel is not None: samples = self.encode_tiled_1d(pixel_samples) else: samples = self.encode_tiled_(pixel_samples) - return self._format_seedvr2_encoded_samples(samples) + return samples def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) - if dims == 3 and pixel_samples.ndim < 5: + if dims == 3: if not self.not_video: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) else: @@ -1365,47 +1210,22 @@ def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, ti elif dims == 2: samples = self.encode_tiled_(pixel_samples, **args) elif dims == 3: - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - seedvr2_args = {} - if tile_x is not None: - seedvr2_args["tile_x"] = tile_x - else: - seedvr2_args["tile_x"] = 512 - if tile_y is not None: - seedvr2_args["tile_y"] = tile_y - else: - seedvr2_args["tile_y"] = 512 - if overlap is not None: - seedvr2_args["overlap"] = overlap - else: - seedvr2_args["overlap"] = 64 - if tile_t is not None: - seedvr2_args["tile_t"] = tile_t - else: - seedvr2_args["tile_t"] = 9999 - if overlap_t is not None: - seedvr2_args["overlap_t"] = overlap_t - else: - seedvr2_args["overlap_t"] = 0 - samples = self.encode_tiled_seedvr2(pixel_samples, **seedvr2_args) + if tile_t is not None: + tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) else: - if tile_t is not None: - tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) - else: - tile_t_latent = 9999 - args["tile_t"] = self.upscale_ratio[0](tile_t_latent) + tile_t_latent = 9999 + args["tile_t"] = self.upscale_ratio[0](tile_t_latent) - spatial_overlap = overlap if overlap is not None else 64 - if overlap_t is None: - args["overlap"] = (1, spatial_overlap, spatial_overlap) - else: - args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap) - maximum = pixel_samples.shape[2] - maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) + maximum = pixel_samples.shape[2] + maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) - samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) + samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) - return self._format_seedvr2_encoded_samples(samples) + return samples def get_sd(self): return self.first_stage_model.state_dict() @@ -1932,17 +1752,6 @@ class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, comfy.mo return (model, clip, vae) - -def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device): - set_dtype = model_config.set_inference_dtype - parameters = inspect.signature(set_dtype).parameters - supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values()) - if supports_device: - set_dtype(dtype, manual_cast_dtype, device=device) - else: - set_dtype(dtype, manual_cast_dtype) - - def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) @@ -2050,7 +1859,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if model_config.clip_vision_prefix is not None: if output_clipvision: @@ -2191,7 +2000,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if custom_operations is not None: model_config.custom_operations = custom_operations diff --git a/comfy/supported_models.py b/comfy/supported_models.py index fa95003cc237..7cf9c133b9cb 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1672,35 +1672,6 @@ def clip_target(self, state_dict={}): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) -class SeedVR2(supported_models_base.BASE): - unet_config = { - "image_model": "seedvr2" - } - latent_format = comfy.latent_formats.SeedVR2 - - vae_key_prefix = ["vae."] - text_encoder_key_prefix = ["text_encoders."] - supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] - sampling_settings = { - "shift": 1.0, - } - - def set_inference_dtype(self, dtype, manual_cast_dtype, device=None): - if ( - dtype == torch.float16 - and manual_cast_dtype is None - and comfy.model_management.should_use_bf16(device) - ): - manual_cast_dtype = torch.bfloat16 - super().set_inference_dtype(dtype, manual_cast_dtype, device=device) - - def get_model(self, state_dict, prefix="", device=None): - out = model_base.SeedVR2(self, device=device) - return out - - def clip_target(self, state_dict={}): - return None - class ChromaRadiance(Chroma): unet_config = { "image_model": "chroma_radiance", @@ -2058,6 +2029,7 @@ def clip_target(self, state_dict={}): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) + class RT_DETR_v4(supported_models_base.BASE): unet_config = { "image_model": "RT_DETR_v4", @@ -2295,7 +2267,6 @@ def get_model(self, state_dict, prefix="", device=None): HiDream, HiDreamO1, Chroma, - SeedVR2, ChromaRadiance, ACEStep, ACEStep15, diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 572f9984e9e6..0e7a829ba13b 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -115,7 +115,7 @@ def process_vae_state_dict_for_saving(self, state_dict): replace_prefix = {"": self.vae_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) - def set_inference_dtype(self, dtype, manual_cast_dtype, device=None): + def set_inference_dtype(self, dtype, manual_cast_dtype): self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py deleted file mode 100644 index d5cd029bacc8..000000000000 --- a/comfy_extras/nodes_seedvr.py +++ /dev/null @@ -1,1015 +0,0 @@ -from typing_extensions import override -from comfy_api.latest import ComfyExtension, io -import torch -import math -import logging -from einops import rearrange - -import gc -import comfy.model_management -import comfy.sample -import comfy.samplers -from comfy.ldm.seedvr.color_fix import ( - adain_color_transfer, - lab_color_transfer, - wavelet_color_transfer, -) -from comfy.ldm.seedvr.constants import ( - BYTEDANCE_IMG_SHIFT_FIT, - BYTEDANCE_SCHEDULE_T, - BYTEDANCE_VID_SHIFT_FIT, - SEEDVR2_ADAIN_SCALE_MULTIPLIER, - SEEDVR2_COLOR_MEM_HEADROOM, - SEEDVR2_COND_CHANNELS, - SEEDVR2_DTYPE_BYTES_FLOOR, - SEEDVR2_LAB_SCALE_MULTIPLIER, - SEEDVR2_LATENT_CHANNELS, - SEEDVR2_OOM_BACKOFF_DIVISOR, - SEEDVR2_WAVELET_SCALE_MULTIPLIER, -) - -from torchvision.transforms import functional as TVF -from torchvision.transforms import Lambda -from torchvision.transforms.functional import InterpolationMode - - -_SEEDVR2_INVALID_MODEL_MSG_PREFIX = ( - "SeedVR2Conditioning: model object does not match expected SeedVR2 structure" -) - -# Private sentinel for getattr default: distinguishes "attribute missing" -# from "attribute present but None" so the failure message is accurate. -_ATTR_MISSING = object() - - -def _seedvr2_auto_chunk_attempts(t_latent, t_pixel, frames_per_chunk): - """Return stricter 4n+1 frame chunk sizes for auto OOM retries.""" - attempts = [frames_per_chunk] - current_chunk_latent = ( - t_latent if t_pixel <= frames_per_chunk - else (frames_per_chunk - 1) // 4 + 1 - ) - current_chunk_count = max(1, math.ceil(t_latent / current_chunk_latent)) - seen = {frames_per_chunk} - - for target_chunks in range(max(2, current_chunk_count + 1), t_latent + 1): - chunk_latent = max(1, math.ceil(t_latent / target_chunks)) - candidate = 4 * (chunk_latent - 1) + 1 - if candidate in seen: - continue - if candidate >= attempts[-1]: - continue - attempts.append(candidate) - seen.add(candidate) - - return attempts - - -def _resolve_seedvr2_diffusion_model(model): - """Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message.""" - inner = getattr(model, "model", _ATTR_MISSING) - if inner is _ATTR_MISSING: - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input has no 'model' attribute " - f"(got type {type(model).__name__})." - ) - if inner is None: - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input.model is None " - f"(input type {type(model).__name__})." - ) - diffusion_model = getattr(inner, "diffusion_model", _ATTR_MISSING) - if diffusion_model is _ATTR_MISSING: - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model' has no " - f"'diffusion_model' attribute (got type {type(inner).__name__})." - ) - if diffusion_model is None: - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model.diffusion_model' " - f"is None (model.model type {type(inner).__name__})." - ) - return diffusion_model - - -def _apply_rope_freqs_float32_cast(diffusion_model): - """Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype.""" - for module in diffusion_model.modules(): - if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'): - if module.rope.freqs.data.dtype != torch.float32: - module.rope.freqs.data = module.rope.freqs.data.to(torch.float32) - - -def clear_vae_memory(vae_model): - for module in vae_model.modules(): - if hasattr(module, "memory"): - module.memory = None - gc.collect() - comfy.model_management.soft_empty_cache() - -def expand_dims(tensor, ndim): - shape = tensor.shape + (1,) * (ndim - tensor.ndim) - return tensor.reshape(shape) - -def get_conditions(latent, latent_blur): - t, h, w, c = latent.shape - cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) - cond[:, ..., :-1] = latent_blur[:] - cond[:, ..., -1:] = 1.0 - return cond - -def timestep_transform(timesteps, latents_shapes): - vt = 4 - vs = 8 - frames = (latents_shapes[:, 0] - 1) * vt + 1 - heights = latents_shapes[:, 1] * vs - widths = latents_shapes[:, 2] * vs - - # Compute shift factor. - def get_lin_function(x1, y1, x2, y2): - m = (y2 - y1) / (x2 - x1) - b = y1 - m * x1 - return lambda x: m * x + b - - img_shift_fn = get_lin_function(*BYTEDANCE_IMG_SHIFT_FIT) - vid_shift_fn = get_lin_function(*BYTEDANCE_VID_SHIFT_FIT) - shift = torch.where( - frames > 1, - vid_shift_fn(heights * widths * frames), - img_shift_fn(heights * widths), - ).to(timesteps.device) - - # Shift timesteps. - T = BYTEDANCE_SCHEDULE_T - timesteps = timesteps / T - timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) - timesteps = timesteps * T - return timesteps - -def inter(x_0, x_T, t): - t = expand_dims(t, x_0.ndim) - T = BYTEDANCE_SCHEDULE_T - B = lambda t: t / T - A = lambda t: 1 - (t / T) - return A(t) * x_0 + B(t) * x_T - -def div_pad(image, factor): - - height_factor, width_factor = factor - height, width = image.shape[-2:] - - pad_height = (height_factor - (height % height_factor)) % height_factor - pad_width = (width_factor - (width % width_factor)) % width_factor - - if pad_height == 0 and pad_width == 0: - return image - - if isinstance(image, torch.Tensor): - padding = (0, pad_width, 0, pad_height) - image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0) - - return image - -def cut_videos(videos): - t = videos.size(1) - if t == 1: - return videos - if t <= 4 : - padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - return videos - if (t - 1) % (4) == 0: - return videos - else: - padding = [videos[:, -1].unsqueeze(1)] * ( - 4 - ((t - 1) % (4)) - ) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - assert (videos.size(1) - 1) % (4) == 0 - return videos - -def _seedvr2_input_shorter_edge(images, node_name): - if images.dim() == 4: - return min(images.shape[1], images.shape[2]) - if images.dim() == 5: - return min(images.shape[2], images.shape[3]) - raise ValueError( - f"{node_name}: expected 4-D or 5-D IMAGE tensor, " - f"got shape {tuple(images.shape)}" - ) - - -def _seedvr2_pad(images, upscaled_shorter_edge, node_name): - if upscaled_shorter_edge < 2: - raise ValueError( - f"{node_name}: input shorter edge must be at least 2 pixels; " - f"got {upscaled_shorter_edge}." - ) - if images.shape[-1] > 3: - images = images[..., :3] - if images.dim() == 4: - # Comfy video components arrive as a 4-D IMAGE frame sequence: - # (frames, H, W, C). SeedVR2 consumes that as one video. - images = images.unsqueeze(0) - elif images.dim() != 5: - raise ValueError( - f"{node_name}: expected 4-D or 5-D IMAGE tensor, " - f"got shape {tuple(images.shape)}" - ) - images = images.permute(0, 1, 4, 2, 3) - - b, t, c, h, w = images.shape - images = images.reshape(b * t, c, h, w) - - clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) - images = clip(images) - images = div_pad(images, (16, 16)) - _, _, new_h, new_w = images.shape - - images = images.reshape(b, t, c, new_h, new_w) - images = cut_videos(images) - images_bthwc = rearrange(images, "b t c h w -> b t h w c") - - return io.NodeOutput(images_bthwc) - - -class SeedVR2Preprocess(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SeedVR2Preprocess", - display_name="Pre-Process SeedVR2 Input", - category="image/upscaling", - description="Pad a resized image for SeedVR2 model. Alpha channel is dropped. The node Post-Process SeedVR2 Output re-applies it from the original resized image.", - inputs=[ - io.Image.Input("resized_images", tooltip="The resized image to process."), - ], - outputs=[ - io.Image.Output("images"), - ] - ) - - @classmethod - def execute(cls, resized_images): - upscaled_shorter_edge = _seedvr2_input_shorter_edge(resized_images, "SeedVR2Preprocess") - return _seedvr2_pad( - resized_images, upscaled_shorter_edge, "SeedVR2Preprocess", - ) - - -class SeedVR2PostProcessing(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SeedVR2PostProcessing", - display_name="Post-Process SeedVR2 Output", - category="image/upscaling", - description="Align the generated image with the original resized image and apply color correction.", - inputs=[ - io.Image.Input("images", tooltip="The generated image to process."), - io.Image.Input("original_resized_images", tooltip="The original resized image before pre-processing, used as reference."), - io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="Method to match the generated image colors to the original image. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."), - ], - outputs=[io.Image.Output(display_name="images")], - ) - - @classmethod - def execute(cls, images, original_resized_images, color_correction_method): - alpha_input = None - if original_resized_images.shape[-1] == 4: - alpha_input = original_resized_images[..., 3:4] - original_resized_images = original_resized_images[..., :3] - decoded_5d, decoded_was_4d = cls._as_bthwc(images) - reference_full, _ = cls._as_bthwc(original_resized_images) - decoded_5d = cls._restore_reference_batch_time(decoded_5d, reference_full) - - b = min(decoded_5d.shape[0], reference_full.shape[0]) - t = min(decoded_5d.shape[1], reference_full.shape[1]) - reference_h = reference_full.shape[2] - reference_w = reference_full.shape[3] - - decoded_5d = decoded_5d[:b, :t, :, :, :] - target_h = min(decoded_5d.shape[2], reference_h) - target_w = min(decoded_5d.shape[3], reference_w) - decoded_5d = decoded_5d[:, :, :target_h, :target_w, :] - if color_correction_method in ("lab", "wavelet", "adain"): - reference_5d = reference_full[:b, :t, :, :, :] - reference_5d = cls._resize_reference(reference_5d, target_h, target_w) - output_device = decoded_5d.device - decoded_raw = cls._to_seedvr2_raw(decoded_5d) - reference_raw = cls._to_seedvr2_raw(reference_5d) - decoded_flat = rearrange(decoded_raw, "b t h w c -> (b t) c h w") - reference_flat = rearrange(reference_raw, "b t h w c -> (b t) c h w") - output = cls._color_transfer_chunked( - decoded_flat, reference_flat, output_device, color_correction_method, - ) - output = rearrange(output, "(b t) c h w -> b t h w c", b=b, t=t) - output = output.add(1.0).div(2.0).clamp(0.0, 1.0) - elif color_correction_method == "none": - output = decoded_5d - else: - raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") - - if alpha_input is not None: - alpha_5d, _ = cls._as_bthwc(alpha_input) - alpha_5d = alpha_5d[:output.shape[0], :output.shape[1], :output.shape[2], :output.shape[3], :] - output = torch.cat([output, alpha_5d.to(dtype=output.dtype, device=output.device)], dim=-1) - h2 = output.shape[-3] - (output.shape[-3] % 2) - w2 = output.shape[-2] - (output.shape[-2] % 2) - output = output[:, :, :h2, :w2, :] - if decoded_was_4d: - output = output.reshape(-1, output.shape[-3], output.shape[-2], output.shape[-1]) - return io.NodeOutput(output) - - @staticmethod - def _as_bthwc(images): - if images.ndim == 4: - return images.unsqueeze(0), True - if images.ndim == 5: - return images, False - raise ValueError( - f"SeedVR2PostProcessing: expected 4-D or 5-D IMAGE tensor, got shape {tuple(images.shape)}" - ) - - @staticmethod - def _restore_reference_batch_time(decoded, reference): - if decoded.shape[0] != 1: - return decoded - ref_b, ref_t = reference.shape[:2] - if ref_b < 1 or decoded.shape[1] % ref_b != 0: - return decoded - decoded_t = decoded.shape[1] // ref_b - if decoded_t < ref_t: - return decoded - return decoded.reshape(ref_b, decoded_t, decoded.shape[2], decoded.shape[3], decoded.shape[4]) - - @staticmethod - def _to_seedvr2_raw(images): - return images.mul(2.0).sub(1.0) - - @staticmethod - def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn): - color_device = comfy.model_management.vae_device() - decoded_flat = decoded_flat.to(device=color_device) - reference_flat = reference_flat.to(device=color_device) - output = transfer_fn(decoded_flat, reference_flat) - return output.to(device=output_device) - - @staticmethod - def _lab_color_transfer_on_vae_device(decoded_flat, reference_flat, output_device): - color_device = comfy.model_management.vae_device() - result = None - for start in range(decoded_flat.shape[0]): - decoded_frame = decoded_flat[start:start + 1].to(device=color_device).clone() - reference_frame = reference_flat[start:start + 1].to(device=color_device).clone() - output = lab_color_transfer(decoded_frame, reference_frame).to(device=output_device) - if result is None: - result = torch.empty( - (decoded_flat.shape[0],) + tuple(output.shape[1:]), - device=output_device, - dtype=output.dtype, - ) - result[start:start + 1].copy_(output) - if result is None: - raise ValueError("SeedVR2PostProcessing: LAB color correction requires at least one frame.") - return result - - @classmethod - def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method): - chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method) - while True: - next_chunk_size = None - try: - return cls._run_color_transfer_chunks( - decoded_flat, reference_flat, output_device, color_correction_method, chunk_size, - ) - except Exception as e: - comfy.model_management.raise_non_oom(e) - if chunk_size <= 1: - raise RuntimeError( - "SeedVR2PostProcessing: color correction OOM at one frame; " - f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}." - ) from e - next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR) - - comfy.model_management.soft_empty_cache() - chunk_size = next_chunk_size - - @classmethod - def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size): - result = None - for start in range(0, decoded_flat.shape[0], chunk_size): - end = min(start + chunk_size, decoded_flat.shape[0]) - decoded_chunk = decoded_flat[start:end] - reference_chunk = reference_flat[start:end] - if color_correction_method == "lab": - output = cls._lab_color_transfer_on_vae_device(decoded_chunk, reference_chunk, output_device) - elif color_correction_method == "wavelet": - output = cls._color_transfer_on_vae_device( - decoded_chunk, reference_chunk, output_device, wavelet_color_transfer, - ) - else: - output = cls._color_transfer_on_vae_device( - decoded_chunk, reference_chunk, output_device, adain_color_transfer, - ) - if result is None: - result = torch.empty( - (decoded_flat.shape[0],) + tuple(output.shape[1:]), - device=output_device, - dtype=output.dtype, - ) - result[start:end].copy_(output) - if result is None: - raise ValueError("SeedVR2PostProcessing: color correction requires at least one frame.") - return result - - @classmethod - def _estimate_color_correction_chunk_size(cls, decoded_flat, color_correction_method): - multiplier = cls._color_correction_memory_multiplier(color_correction_method) - frames = decoded_flat.shape[0] - _, channels, height, width = decoded_flat.shape - dtype_bytes = max(decoded_flat.element_size(), SEEDVR2_DTYPE_BYTES_FLOOR) - bytes_per_frame = height * width * channels * dtype_bytes * multiplier - if bytes_per_frame <= 0: - return frames - color_device = comfy.model_management.vae_device() - free_memory = comfy.model_management.get_free_memory(color_device) - chunk_size = int((free_memory * SEEDVR2_COLOR_MEM_HEADROOM) // bytes_per_frame) - return max(1, min(frames, chunk_size)) - - @staticmethod - def _color_correction_memory_multiplier(color_correction_method): - if color_correction_method == "lab": - return SEEDVR2_LAB_SCALE_MULTIPLIER - if color_correction_method == "wavelet": - return SEEDVR2_WAVELET_SCALE_MULTIPLIER - if color_correction_method == "adain": - return SEEDVR2_ADAIN_SCALE_MULTIPLIER - raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") - - @staticmethod - def _resize_reference(reference, height, width): - if reference.shape[2] == height and reference.shape[3] == width: - return reference - b, t = reference.shape[:2] - reference_flat = rearrange(reference, "b t h w c -> (b t) c h w") - resized = TVF.resize( - reference_flat, - size=(height, width), - interpolation=InterpolationMode.BICUBIC, - antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"), - ) - return rearrange(resized, "(b t) c h w -> b t h w c", b=b, t=t) - - -class SeedVR2Conditioning(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SeedVR2Conditioning", - display_name="Apply SeedVR2 Conditioning", - category="conditioning", - description="Build SeedVR2 positive/negative conditioning from a VAE latent.", - inputs=[ - io.Model.Input("model", tooltip="The SeedVR2 model."), - io.Latent.Input("vae_conditioning", display_name="latent"), - ], - outputs=[ - io.Model.Output(display_name = "model"), - io.Conditioning.Output(display_name = "positive"), - io.Conditioning.Output(display_name = "negative"), - io.Latent.Output(display_name = "latent"), - ], - ) - - @classmethod - def execute(cls, model, vae_conditioning) -> io.NodeOutput: - - vae_conditioning = vae_conditioning["samples"] - if vae_conditioning.ndim != 5: - raise ValueError( - "SeedVR2Conditioning expects a 5-D VAE latent in Comfy " - f"channel-first layout; got shape {tuple(vae_conditioning.shape)}." - ) - if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS: - raise ValueError( - "SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy " - f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); " - f"got channel-last shape {tuple(vae_conditioning.shape)}." - ) - vae_conditioning = vae_conditioning.movedim(1, -1).contiguous() - model_patcher = model - model = _resolve_seedvr2_diffusion_model(model_patcher) - pos_cond = model.positive_conditioning - neg_cond = model.negative_conditioning - - # Fail-loud guard against silently-wrong output when a - # DiT-only ``.safetensors`` (no ``positive_conditioning`` / - # ``negative_conditioning`` keys) is loaded via ``UNETLoader``. - # ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see - # ``comfy/ldm/seedvr/model.py``); ``load_state_dict(strict=False)`` - # leaves them at zero when the keys are absent. Detect that state - # here rather than at ``BaseModel.extra_conds`` (per sampling step, - # wasteful) or at the resolver helper (mixes structural shape with - # semantic content). Both buffers must be checked together — partial - # bake regressions could populate one but not the other. - if ( - pos_cond.float().abs().sum().item() == 0 - and neg_cond.float().abs().sum().item() == 0 - ): - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning " - f"and negative_conditioning buffers are zero-valued — model " - f"file appears to be a DiT-only export missing " - f"the SeedVR2 conditioning tensors. " - f"Re-bake the file with ``positive_conditioning`` (58, 5120) " - f"and ``negative_conditioning`` (64, 5120) keys at top level, " - f"or load via CheckpointLoaderSimple from a bundled " - f"checkpoint." - ) - - _apply_rope_freqs_float32_cast(model) - - condition = torch.stack([get_conditions(c, c) for c in vae_conditioning]) - condition = condition.movedim(-1, 1) - latent = vae_conditioning.movedim(-1, 1) - - latent = rearrange(latent, "b c t h w -> b (c t) h w") - condition = rearrange(condition, "b c t h w -> b (c t) h w") - - negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] - positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] - - return io.NodeOutput(model_patcher, positive, negative, {"samples": latent}) - -def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int, - t_end: int, channels: int) -> torch.Tensor: - """Slice collapsed ``(B, channels*T, H, W)`` along latent T: reshape (accepts non-contiguous inputs), slice, ``.contiguous()`` (T-slice of 5D is a non-contiguous view; re-collapse needs contiguous), re-collapse.""" - B, CT, H, W = tensor_4d.shape - if CT % channels != 0: - raise ValueError( - f"_slice_collapsed_4d_along_t: collapsed channel dim {CT} is not " - f"divisible by channels={channels}; tensor shape {tuple(tensor_4d.shape)}." - ) - T = CT // channels - if not (0 <= t_start < t_end <= T): - raise ValueError( - f"_slice_collapsed_4d_along_t: slice [{t_start}:{t_end}] out of " - f"range for T={T}." - ) - new_T = t_end - t_start - sliced = tensor_4d.reshape(B, channels, T, H, W)[:, :, t_start:t_end, :, :].contiguous() - return sliced.reshape(B, channels * new_T, H, W) - - -def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int): - """Return a new conditioning list with each entry's ``options["condition"]`` (collapsed ``(B, 17*T, H, W)``) sliced along latent T; text tensors, other option keys, and condition-less entries pass through unchanged and inputs are not mutated.""" - new_list = [] - for entry in cond_list: - text_cond, options = entry[0], entry[1] - if "condition" not in options: - new_list.append(entry) - continue - new_options = options.copy() - new_options["condition"] = _slice_collapsed_4d_along_t( - new_options["condition"], t_start, t_end, - SEEDVR2_COND_CHANNELS, - ) - new_list.append([text_cond, new_options]) - return new_list - - -def _slice_seedvr2_noise_mask_along_t(noise_mask: torch.Tensor, - samples_4d: torch.Tensor, - t_start: int, - t_end: int): - """Slice only masks already expanded to collapsed ``(B, 16*T, H, W)``; pass standard ``(B, 1, H, W)`` ``SetLatentNoiseMask`` outputs through for KSampler to expand.""" - if noise_mask.ndim == samples_4d.ndim and noise_mask.shape[1] == samples_4d.shape[1]: - return _slice_collapsed_4d_along_t( - noise_mask, t_start, t_end, SEEDVR2_LATENT_CHANNELS, - ) - return noise_mask - - -def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor: - """Concatenate collapsed ``(B, channels*T_i, H, W)`` chunks along latent T: un-collapse to 5D, cat on ``dim=2``, re-collapse to 4D.""" - if len(chunks_4d) == 0: - raise ValueError("_concat_chunks_along_t: empty chunk list.") - fives = [] - for ch in chunks_4d: - B, CT, H, W = ch.shape - if CT % channels != 0: - raise ValueError( - f"_concat_chunks_along_t: chunk shape {tuple(ch.shape)} " - f"channel dim {CT} not divisible by channels={channels}." - ) - T = CT // channels - fives.append(ch.reshape(B, channels, T, H, W)) - cat = torch.cat(fives, dim=2).contiguous() - B, C, T_total, H, W = cat.shape - return cat.reshape(B, C * T_total, H, W) - - -def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor: - """1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``): - Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3`` - (dead-band would collapse a tiny transition). Window shape matched to the reference - overlapping-frame blend for parity; caller broadcasts across ``(B, C, T_overlap, H, W)``. - """ - if overlap < 1: - raise ValueError( - f"_hann_blend_weights_1d: overlap must be >= 1; got {overlap}." - ) - if overlap >= 3: - t = torch.linspace(0.0, 1.0, steps=overlap, device=device, dtype=dtype) - blend_start = 1.0 / 3.0 - blend_end = 2.0 / 3.0 - u = ((t - blend_start) / (blend_end - blend_start)).clamp(0.0, 1.0) - return 0.5 + 0.5 * torch.cos(torch.pi * u) - return torch.linspace(1.0, 0.0, steps=overlap, device=device, dtype=dtype) - - -def _blend_overlap_region(prev_tail_5d: torch.Tensor, - cur_head_5d: torch.Tensor) -> torch.Tensor: - """Blend two equal-shape 5D ``(B, C, T_overlap, H, W)`` tensors with a 1D Hann/linear T-ramp: ``prev_tail_5d`` takes the descending weight, ``cur_head_5d`` takes ``1 - w_prev`` (caller ensures matching shape/dtype/device).""" - if prev_tail_5d.shape != cur_head_5d.shape: - raise ValueError( - f"_blend_overlap_region: shape mismatch " - f"prev {tuple(prev_tail_5d.shape)} vs " - f"cur {tuple(cur_head_5d.shape)}." - ) - overlap = int(prev_tail_5d.shape[2]) - w_prev_1d = _hann_blend_weights_1d( - overlap, prev_tail_5d.device, prev_tail_5d.dtype, - ) - # Reshape to (1, 1, overlap, 1, 1) for broadcast across B, C, H, W. - w_prev = w_prev_1d.view(1, 1, overlap, 1, 1) - w_cur = 1.0 - w_prev - return prev_tail_5d * w_prev + cur_head_5d * w_cur - - -def _concat_chunks_with_overlap_blend(chunk_specs, channels: int, - overlap_latent: int) -> torch.Tensor: - """Concatenate overlapping ``(t_start, t_end, chunk_4d)`` specs (source-latent T coords) into one collapsed 4D tensor, Hann/linear-blending overlaps; ``overlap_latent == 0`` fast-paths to plain concat (bit-identical to ``_concat_chunks_along_t``). Each blend uses the actual width ``min(prev_end - cur_start, chunk length)``, smaller than ``overlap_latent`` for a runt final chunk.""" - if len(chunk_specs) == 0: - raise ValueError("_concat_chunks_with_overlap_blend: empty chunk list.") - if overlap_latent < 0: - raise ValueError( - f"_concat_chunks_with_overlap_blend: overlap_latent must be " - f">= 0; got {overlap_latent}." - ) - - # Validate channel divisibility once and capture per-chunk T. - chunk_5d = [] - for t_start, t_end, ch in chunk_specs: - B, CT, H, W = ch.shape - if CT % channels != 0: - raise ValueError( - f"_concat_chunks_with_overlap_blend: chunk shape " - f"{tuple(ch.shape)} channel dim {CT} not divisible " - f"by channels={channels}." - ) - T = CT // channels - if t_end - t_start != T: - raise ValueError( - f"_concat_chunks_with_overlap_blend: chunk T={T} mismatches " - f"declared range [{t_start}:{t_end}]." - ) - chunk_5d.append((t_start, t_end, ch.reshape(B, channels, T, H, W))) - - if overlap_latent == 0: - # Fast path: pure concat in the caller-provided chunk order. - return _concat_chunks_along_t( - [c.reshape(c.shape[0], channels * c.shape[2], c.shape[3], c.shape[4]) - for _, _, c in chunk_5d], - channels, - ) - - T_total = max(t_end for _, t_end, _ in chunk_5d) - first_5d = chunk_5d[0][2] - B = first_5d.shape[0] - H = first_5d.shape[3] - W = first_5d.shape[4] - result = torch.empty( - (B, channels, T_total, H, W), - device=first_5d.device, dtype=first_5d.dtype, - ) - filled_until = 0 - for i, (cs, ce, ct_5d) in enumerate(chunk_5d): - chunk_T = int(ct_5d.shape[2]) - if i == 0: - result[:, :, cs:ce, :, :] = ct_5d - filled_until = ce - continue - # Overlap region width is bounded by both the previous fill - # frontier and the current chunk's actual length (for runt - # final chunks shorter than the configured overlap). - overlap_len = min(filled_until - cs, chunk_T) - if overlap_len > 0: - prev_tail = result[:, :, cs:cs + overlap_len, :, :].contiguous() - cur_head = ct_5d[:, :, :overlap_len, :, :].contiguous() - blended = _blend_overlap_region(prev_tail, cur_head) - result[:, :, cs:cs + overlap_len, :, :] = blended - tail_start = cs + overlap_len - tail_end = ce - if tail_end > tail_start: - result[:, :, tail_start:tail_end, :, :] = ( - ct_5d[:, :, overlap_len:, :, :] - ) - else: - # Disjoint chunks (overlap_latent set but this pair did not - # actually overlap, e.g. step_latent equal to chunk_latent - # in a degenerate config). Treat as concat. - result[:, :, cs:ce, :, :] = ct_5d - filled_until = ce - - return result.contiguous().reshape(B, channels * T_total, H, W) - - -def _run_standard_sample(model, seed: int, steps: int, cfg: float, - sampler_name: str, scheduler: str, - positive, negative, latent: dict, - denoise: float) -> dict: - """Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk.""" - samples_in = latent["samples"] - samples_in = comfy.sample.fix_empty_latent_channels( - model, samples_in, latent.get("downscale_ratio_spacial", None), - ) - batch_inds = latent.get("batch_index", None) - noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds) - noise_mask = latent.get("noise_mask", None) - samples = comfy.sample.sample( - model, noise, steps, cfg, sampler_name, scheduler, - positive, negative, samples_in, - denoise=denoise, noise_mask=noise_mask, seed=seed, - ) - out = latent.copy() - out.pop("downscale_ratio_spacial", None) - out["samples"] = samples - return out - - -class SeedVR2ProgressiveSampler(io.ComfyNode): - """Sequential temporal chunking sampler for SeedVR2 native. - - Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that - OOM on long sequences. The latent enters the sampler in SeedVR2's - collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning`` - at ``rearrange(b c t h w -> b (c t) h w)``); this node slices that - tensor along the temporal axis, runs the configured inner sampler - sequentially per chunk against the standard ``comfy.sample.sample`` - entry point, and concatenates per-chunk outputs back into a single - ``(B, 16*T_total, H, W)`` latent. - - ``frames_per_chunk`` is expressed in pixel-frame units to match the - SeedVR2 4n+1 constraint enforced upstream by ``cut_videos`` and the - VAE's ``temporal_downsample_factor=4``. A pixel chunk size ``F`` - maps to ``(F - 1) // 4 + 1`` latent-frame chunks. - - Determinism contract: a single noise tensor is generated once from - the user seed and sliced per chunk (rather than re-seeding each - chunk), so a workflow that fits in a single chunk produces output - identical to a workflow that fits in N chunks at the same seed, - modulo the inherent T-axis chunk-boundary independence of the model. - """ - - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SeedVR2ProgressiveSampler", - display_name="Sample SeedVR2 (Progressive)", - category="sampling", - description="Sample a SeedVR2 latent in sequential temporal chunks to allow longer videos to fit into VRAM via frame blending the resulting upscaled latents.", - inputs=[ - io.Model.Input("model", tooltip="The model used for denoising the input latent."), - io.Int.Input("seed", default=0, min=0, - max=0xffffffffffffffff, - control_after_generate=True, - tooltip="The random seed used for creating the noise."), - io.Int.Input("steps", default=20, min=1, max=10000, - tooltip="The number of steps used in the denoising process."), - io.Float.Input("cfg", default=1.0, min=0.0, max=100.0, - step=0.1, round=0.01, - tooltip="The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."), - io.Combo.Input("sampler_name", - options=comfy.samplers.SAMPLER_NAMES, - tooltip="The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."), - io.Combo.Input("scheduler", - options=comfy.samplers.SCHEDULER_NAMES, - tooltip="The scheduler controls how noise is gradually removed to form the image."), - io.Conditioning.Input("positive", - tooltip="The conditioning describing the attributes you want to include in the image."), - io.Conditioning.Input("negative", - tooltip="The conditioning describing the attributes you want to exclude from the image."), - io.Latent.Input("latent", - tooltip="The latent image to denoise."), - io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, - step=0.01, - tooltip="The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."), - io.Int.Input("frames_per_chunk", default=21, min=1, - max=16384, step=4, - tooltip="Pixel frames per temporal chunk (4n+1: 1, 5, 9, 13, ...)."), - io.Int.Input("temporal_overlap", default=0, min=0, - max=16384, - tooltip="Latent frames blended between adjacent chunks to hide the seam; 0 = no blend."), - io.Combo.Input("chunking_mode", - options=["manual", "auto"], - default="manual", - tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."), - ], - outputs=[io.Latent.Output(display_name="latent")], - ) - - @classmethod - def execute(cls, model, seed, steps, cfg, sampler_name, scheduler, - positive, negative, latent, denoise, - frames_per_chunk, temporal_overlap, - chunking_mode="manual") -> io.NodeOutput: - # 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline - # requires pixel-frame counts of the form 4n+1 (1, 5, 9, 13, ...), - # imposed at ``cut_videos`` upstream and propagated through the VAE's - # temporal_downsample_factor=4. Reject violations explicitly before - # any model invocation; a silent rounding would mis-align chunk - # boundaries with the 4n+1 lattice. - if frames_per_chunk < 1 or (frames_per_chunk - 1) % 4 != 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: frames_per_chunk must be a " - f"4n+1 pixel-frame count (1, 5, 9, 13, 17, 21, ...); " - f"got {frames_per_chunk}." - ) - - samples_4d = latent["samples"] - samples_4d = comfy.sample.fix_empty_latent_channels( - model, samples_4d, - latent.get("downscale_ratio_spacial", None), - ) - if samples_4d.ndim != 4: - raise ValueError( - f"SeedVR2ProgressiveSampler: expected 4D collapsed latent " - f"(B, 16*T, H, W); got shape {tuple(samples_4d.shape)}." - ) - B, CT, H, W = samples_4d.shape - if CT % SEEDVR2_LATENT_CHANNELS != 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: collapsed channel dim {CT} is " - f"not divisible by SeedVR2 latent channels " - f"{SEEDVR2_LATENT_CHANNELS}; latent does not appear to be " - f"SeedVR2-shaped." - ) - T_latent = CT // SEEDVR2_LATENT_CHANNELS - T_pixel = 4 * (T_latent - 1) + 1 - - if chunking_mode not in ("manual", "auto"): - raise ValueError( - f"SeedVR2ProgressiveSampler: chunking_mode must be " - f"'manual' or 'auto'; got {chunking_mode!r}." - ) - - if chunking_mode == "auto": - attempts = _seedvr2_auto_chunk_attempts( - T_latent, T_pixel, frames_per_chunk, - ) - for i, attempt_frames_per_chunk in enumerate(attempts): - retry = False - try: - return cls.execute( - model=model, seed=seed, steps=steps, cfg=cfg, - sampler_name=sampler_name, scheduler=scheduler, - positive=positive, negative=negative, - latent=latent, denoise=denoise, - frames_per_chunk=attempt_frames_per_chunk, - temporal_overlap=temporal_overlap, - chunking_mode="manual", - ) - except Exception as e: - comfy.model_management.raise_non_oom(e) - if i == len(attempts) - 1: - raise RuntimeError( - "SeedVR2ProgressiveSampler: exhausted auto " - "chunking attempts after OOM. Tried " - f"frames_per_chunk values {attempts}." - ) from e - retry = True - - if retry: - logging.warning( - "SeedVR2ProgressiveSampler auto chunking OOM at " - "frames_per_chunk=%s; retrying with " - "frames_per_chunk=%s.", - attempt_frames_per_chunk, attempts[i + 1], - ) - comfy.model_management.soft_empty_cache() - - # Short-circuit: total fits in one chunk -> standard path with no - # chunking overhead. Output of this branch is byte-identical to the - # built-in KSampler given the same (model, seed, steps, cfg, - # sampler_name, scheduler, positive, negative, latent, - # denoise) tuple. - if T_pixel <= frames_per_chunk: - return io.NodeOutput(_run_standard_sample( - model, seed, steps, cfg, sampler_name, scheduler, - positive, negative, latent, denoise, - )) - - # Map pixel chunk -> latent chunk. Each chunk's latent length is - # at most ``chunk_latent``; the final chunk may be a runt that - # is automatically 4n+1-aligned in the pixel domain by the - # T_pixel = 4*(T_latent-1) + 1 mapping (every positive integer - # T_latent corresponds to a valid 4n+1 pixel count). - chunk_latent = (frames_per_chunk - 1) // 4 + 1 - - # ``temporal_overlap`` is exposed in latent-frame units, but users - # do not know the derived latent chunk length. Treat oversized - # values as "maximum valid overlap" while preserving a strictly - # positive chunk-loop stride. - if temporal_overlap < 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: temporal_overlap must be >= 0; " - f"got {temporal_overlap}." - ) - temporal_overlap = min(temporal_overlap, chunk_latent - 1) - step_latent = chunk_latent - temporal_overlap - - # Generate full noise once from the user seed, then slice along T - # per chunk. Using one global noise tensor (rather than re-seeding - # per chunk) preserves seed-determinism across chunk-count - # variations: the same (seed, total T_latent) always produces the - # same noise samples regardless of how the work is partitioned. - batch_inds = latent.get("batch_index", None) - noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds) - - noise_mask = latent.get("noise_mask", None) - - # Build the flat list of chunk ranges first so the chunking - # geometry is fully known before any sample call. - chunk_ranges = [] - for chunk_start in range(0, T_latent, step_latent): - chunk_end = min(chunk_start + chunk_latent, T_latent) - if chunk_start >= chunk_end: - # The final iteration of a stride that lands exactly on - # T_latent produces a zero-length chunk; skip it. - break - chunk_ranges.append((chunk_start, chunk_end)) - if chunk_end >= T_latent: - break - - def _sample_one_chunk(chunk_start, chunk_end): - samples_chunk = _slice_collapsed_4d_along_t( - samples_4d, chunk_start, chunk_end, - SEEDVR2_LATENT_CHANNELS, - ) - noise_chunk = _slice_collapsed_4d_along_t( - noise_full, chunk_start, chunk_end, - SEEDVR2_LATENT_CHANNELS, - ) - positive_chunk = _slice_seedvr2_cond_along_t( - positive, chunk_start, chunk_end, - ) - negative_chunk = _slice_seedvr2_cond_along_t( - negative, chunk_start, chunk_end, - ) - - # Per-chunk noise_mask handling: standard masks are passed - # through for KSampler expansion; pre-expanded collapsed - # masks are sliced. - chunk_noise_mask = None - if noise_mask is not None: - chunk_noise_mask = _slice_seedvr2_noise_mask_along_t( - noise_mask, samples_4d, chunk_start, chunk_end, - ) - - return comfy.sample.sample( - model, noise_chunk, steps, cfg, sampler_name, scheduler, - positive_chunk, negative_chunk, samples_chunk, - denoise=denoise, noise_mask=chunk_noise_mask, seed=seed, - ) - - chunk_specs = [] - for chunk_start, chunk_end in chunk_ranges: - chunk_samples = _sample_one_chunk(chunk_start, chunk_end) - chunk_specs.append((chunk_start, chunk_end, chunk_samples)) - - final = _concat_chunks_with_overlap_blend( - chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap, - ) - - out = latent.copy() - out.pop("downscale_ratio_spacial", None) - out["samples"] = final - return io.NodeOutput(out) - - -class SeedVRExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[io.ComfyNode]]: - return [ - SeedVR2Conditioning, - SeedVR2Preprocess, - SeedVR2PostProcessing, - SeedVR2ProgressiveSampler, - ] - -async def comfy_entrypoint() -> SeedVRExtension: - return SeedVRExtension() diff --git a/nodes.py b/nodes.py index d9ac53eded42..2f5a478b59e3 100644 --- a/nodes.py +++ b/nodes.py @@ -47,18 +47,14 @@ if args.enable_manager: import comfyui_manager - def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() - def interrupt_processing(value=True): comfy.model_management.interrupt_current_processing(value) - MAX_RESOLUTION=16384 - class CLIPTextEncode(ComfyNodeABC): @classmethod def INPUT_TYPES(s) -> InputTypeDict: @@ -327,8 +323,8 @@ def INPUT_TYPES(s): return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32, "advanced": True}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}), - "temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}), - "temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time.", "advanced": True}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" @@ -338,32 +334,18 @@ def INPUT_TYPES(s): def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8): if tile_size < overlap * 4: overlap = tile_size // 4 + if temporal_size < temporal_overlap * 2: + temporal_overlap = temporal_overlap // 2 temporal_compression = vae.temporal_compression_decode() if temporal_compression is not None: - if temporal_size <= 0: - temporal_size = 0 - temporal_overlap = 0 - else: - requested_temporal_overlap = temporal_overlap - if temporal_size < temporal_overlap * 2: - temporal_overlap = temporal_overlap // 2 - temporal_size = max(2, temporal_size // temporal_compression) - temporal_overlap = min(temporal_size // 2, temporal_overlap // temporal_compression) - if requested_temporal_overlap > 0: - temporal_overlap = max(1, temporal_overlap) + temporal_size = max(2, temporal_size // temporal_compression) + temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression)) else: temporal_size = None temporal_overlap = None compression = vae.spacial_compression_decode() - images = vae.decode_tiled( - samples["samples"], - tile_x=tile_size // compression, - tile_y=tile_size // compression, - overlap=overlap // compression, - tile_t=temporal_size, - overlap_t=temporal_overlap, - ) + images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, ) @@ -380,7 +362,7 @@ def INPUT_TYPES(s): def encode(self, vae, pixels): t = vae.encode(pixels) - return ({"samples": t}, ) + return ({"samples":t}, ) class VAEEncodeTiled: @classmethod @@ -388,8 +370,8 @@ def INPUT_TYPES(s): return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64, "advanced": True}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}), - "temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}), - "temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time.", "advanced": True}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), }} RETURN_TYPES = ("LATENT",) FUNCTION = "encode" @@ -397,9 +379,6 @@ def INPUT_TYPES(s): CATEGORY = "experimental" def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): - if temporal_size <= 0: - temporal_size = 0 - temporal_overlap = 0 t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) return ({"samples": t}, ) @@ -2439,7 +2418,6 @@ async def init_builtin_extra_nodes(): "nodes_camera_trajectory.py", "nodes_edit_model.py", "nodes_tcfg.py", - "nodes_seedvr.py", "nodes_context_windows.py", "nodes_qwen.py", "nodes_chroma_radiance.py", diff --git a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py deleted file mode 100644 index 2a6e3d43075d..000000000000 --- a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py +++ /dev/null @@ -1,213 +0,0 @@ -"""Consolidated SeedVR2 conditioning and refactor regression tests. - -Merges the prior test_seedvr2_refactor_nodes.py and -test_seedvr_conditioning_hardening.py modules. Refactor tests use the -top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests -use _import_nodes_seedvr_isolated() for sys.modules isolation when -mocking comfy.model_management. -""" - -import importlib -import sys -from unittest.mock import MagicMock - -import pytest -import torch -import torch.nn as nn - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - - -_SENTINEL = object() -_TARGETS = ( - ("comfy.model_management", "comfy"), - ("comfy_extras.nodes_seedvr", "comfy_extras"), -) - - -def _import_nodes_seedvr_isolated(): - """Import comfy_extras.nodes_seedvr with comfy.model_management mocked.""" - priors = [] - for mod_name, parent_name in _TARGETS: - prior_mod = sys.modules.get(mod_name, _SENTINEL) - parent = sys.modules.get(parent_name) - attr = mod_name.split(".")[-1] - prior_attr = ( - getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL - ) - priors.append((mod_name, parent_name, attr, prior_mod, prior_attr)) - - mock_mm = MagicMock() - for fn in ( - "xformers_enabled", "xformers_enabled_vae", - "pytorch_attention_enabled", "pytorch_attention_enabled_vae", - "sage_attention_enabled", "flash_attention_enabled", - "is_intel_xpu", - ): - getattr(mock_mm, fn).return_value = False - tv = torch.version.__version__.split(".") - mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1])) - mock_mm.WINDOWS = False - sys.modules["comfy.model_management"] = mock_mm - if sys.modules.get("comfy") is None: - import comfy as _comfy_pkg # noqa: F401 - comfy_pkg = sys.modules.get("comfy") - if comfy_pkg is not None: - setattr(comfy_pkg, "model_management", mock_mm) - nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or ( - importlib.import_module("comfy_extras.nodes_seedvr") - ) - - def _restore(): - for mod_name, parent_name, attr, prior_mod, prior_attr in priors: - if prior_mod is _SENTINEL: - sys.modules.pop(mod_name, None) - else: - sys.modules[mod_name] = prior_mod - parent = sys.modules.get(parent_name) - if parent is None: - continue - if prior_attr is _SENTINEL: - if hasattr(parent, attr): - delattr(parent, attr) - else: - setattr(parent, attr, prior_attr) - - return nodes_seedvr, _restore - - -class _Rope(nn.Module): - """Minimal RoPE stub exposing a `freqs` parameter.""" - def __init__(self): - super().__init__() - self.freqs = nn.Parameter(torch.zeros(4)) - - -class _Block(nn.Module): - """Minimal transformer block stub holding a `_Rope`.""" - def __init__(self): - super().__init__() - self.rope = _Rope() - - -class _DiffusionModel(nn.Module): - """Stub diffusion model with N blocks and pos/neg conditioning buffers.""" - def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32): - super().__init__() - self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)]) - pos = torch.zeros if zero_conditioning else torch.ones - self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype)) - self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype)) - - -class _ModelInner: - """Inner model wrapper exposing `.diffusion_model`.""" - def __init__(self, diffusion_model): - self.diffusion_model = diffusion_model - - -class _ModelPatcher: - """ModelPatcher stub exposing `.model._ModelInner`.""" - def __init__(self, diffusion_model): - self.model = _ModelInner(diffusion_model) - - -def test_seedvr2_conditioning_schema_exposes_model_passthrough_output(): - nodes_seedvr, restore = _import_nodes_seedvr_isolated() - try: - schema = nodes_seedvr.SeedVR2Conditioning.define_schema() - assert [input_item.id for input_item in schema.inputs] == [ - "model", - "vae_conditioning", - ] - assert schema.inputs[1].display_name == "latent" - assert [output.display_name for output in schema.outputs] == [ - "model", - "positive", - "negative", - "latent", - ] - finally: - restore() - - -def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): - nodes_seedvr, restore = _import_nodes_seedvr_isolated() - try: - diffusion_model = _DiffusionModel() - patcher = _ModelPatcher(diffusion_model) - samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2) - vae_conditioning = {"samples": samples} - - _, first_positive, first_negative, first_latent = ( - nodes_seedvr.SeedVR2Conditioning.execute( - patcher, - vae_conditioning, - ) - ) - _, second_positive, second_negative, second_latent = ( - nodes_seedvr.SeedVR2Conditioning.execute( - patcher, - vae_conditioning, - ) - ) - - expected_latent = samples.reshape(1, 6, 2, 2) - channel_last = samples.movedim(1, -1).contiguous() - expected_condition = torch.cat( - [ - channel_last, - torch.ones((*channel_last.shape[:-1], 1)), - ], - dim=-1, - ).movedim(-1, 1).reshape(1, 9, 2, 2) - - assert torch.equal(first_latent["samples"], expected_latent) - assert torch.equal(second_latent["samples"], expected_latent) - assert torch.equal( - first_positive[0][1]["condition"], - expected_condition, - ) - assert torch.equal( - second_positive[0][1]["condition"], - expected_condition, - ) - assert torch.equal( - first_negative[0][1]["condition"], - expected_condition, - ) - assert torch.equal( - second_negative[0][1]["condition"], - expected_condition, - ) - finally: - restore() - - -def test_seedvr2_conditioning_fails_loud_on_zero_buffers(): - nodes_seedvr, restore = _import_nodes_seedvr_isolated() - try: - diffusion_model = _DiffusionModel(zero_conditioning=True) - patcher = _ModelPatcher(diffusion_model) - vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))} - - with pytest.raises(RuntimeError) as excinfo: - nodes_seedvr.SeedVR2Conditioning.execute( - patcher, vae_conditioning, - ) - - message = str(excinfo.value) - assert message.startswith( - nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX - ), ( - "Fail-loud message must use the standard " - "_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers " - f"can match it. Got: {message!r}" - ) - assert "positive_conditioning" in message - assert "negative_conditioning" in message - finally: - restore() diff --git a/tests-unit/comfy_extras_test/test_seedvr2_nodes.py b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py deleted file mode 100644 index f7d9a4f65ab3..000000000000 --- a/tests-unit/comfy_extras_test/test_seedvr2_nodes.py +++ /dev/null @@ -1,55 +0,0 @@ -import importlib -import inspect -import sys -from unittest.mock import MagicMock, patch - -import torch - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - - -def test_seedvr_node_signature_matches_schema(): - mock_mm = MagicMock() - mock_mm.xformers_enabled.return_value = False - mock_mm.xformers_enabled_vae.return_value = False - mock_mm.sage_attention_enabled.return_value = False - mock_mm.flash_attention_enabled.return_value = False - - sentinel = object() - prior_cpu = cli_args.cpu - cli_args.cpu = True - prior_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel) - comfy_pkg = sys.modules.get("comfy") - prior_mm_attr = getattr(comfy_pkg, "model_management", sentinel) if comfy_pkg else sentinel - - with patch.dict(sys.modules, {"comfy.model_management": mock_mm}): - if comfy_pkg is not None: - setattr(comfy_pkg, "model_management", mock_mm) - sys.modules.pop("comfy_extras.nodes_seedvr", None) - try: - nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr") - for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler): - schema_ids = [i.id for i in node_cls.define_schema().inputs] - exec_params = [ - p for p in inspect.signature(node_cls.execute).parameters.keys() - if p != "cls" - ] - assert schema_ids == exec_params, ( - f"{node_cls.__name__} schema/execute drift: " - f"schema_ids={schema_ids}, exec_params={exec_params}" - ) - finally: - cli_args.cpu = prior_cpu - if prior_module is sentinel: - sys.modules.pop("comfy_extras.nodes_seedvr", None) - else: - sys.modules["comfy_extras.nodes_seedvr"] = prior_module - if comfy_pkg is not None: - if prior_mm_attr is sentinel: - if hasattr(comfy_pkg, "model_management"): - delattr(comfy_pkg, "model_management") - else: - setattr(comfy_pkg, "model_management", prior_mm_attr) diff --git a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py deleted file mode 100644 index a27a8f8df24d..000000000000 --- a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py +++ /dev/null @@ -1,57 +0,0 @@ -from unittest.mock import patch - -import torch - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -from comfy_extras import nodes_seedvr # noqa: E402 - - -def _schema_ids(items): - return [item.id for item in items] - - -def test_seedvr2_post_processing_schema(): - schema = nodes_seedvr.SeedVR2PostProcessing.define_schema() - - assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"] - assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"] - assert schema.inputs[2].default == "lab" - assert schema.outputs[0].get_io_type() == "IMAGE" - - -def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch): - decoded = torch.full((1, 3, 4, 4), 0.25) - reference = torch.full((1, 3, 4, 4), 0.75) - - def _lab(content, style): - raise torch.cuda.OutOfMemoryError("CUDA out of memory") - - monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) - monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000) - monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None) - - with patch.object(nodes_seedvr, "lab_color_transfer", _lab): - try: - nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked( - decoded, reference, torch.device("cpu"), "lab", - ) - except RuntimeError as exc: - assert "color_correction_method=lab" in str(exc) - assert " method=lab" not in str(exc) - else: - raise AssertionError("expected RuntimeError for one-frame LAB OOM") - - -def test_seedvr2_post_processing_unknown_color_correction_method_raises(): - decoded = torch.zeros(1, 2, 4, 4, 3) - original = torch.zeros(1, 2, 4, 4, 3) - try: - nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus") - except ValueError as exc: - assert "color_correction_method" in str(exc) - else: - raise AssertionError("expected ValueError for unknown color_correction_method") diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py index c63f69a0df11..4e9350602d7a 100644 --- a/tests-unit/comfy_test/model_detection_test.py +++ b/tests-unit/comfy_test/model_detection_test.py @@ -73,24 +73,6 @@ def _make_flux_schnell_comfyui_sd(): return sd -def _make_seedvr2_7b_separate_mm_sd(): - return { - "blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072), - } - - -def _make_seedvr2_7b_shared_mm_sd(): - return { - "blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1), - } - - -def _make_seedvr2_3b_shared_mm_sd(): - return { - "blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1), - } - - class TestModelDetection: """Verify that first-match model detection selects the correct model based on list ordering and unet_config specificity.""" @@ -143,48 +125,6 @@ def test_flux_schnell_comfyui_detected_as_flux_schnell(self): assert model_config is not None assert type(model_config).__name__ == "FluxSchnell" - def test_seedvr2_7b_separate_mm_detection_config(self): - sd = _make_seedvr2_7b_separate_mm_sd() - unet_config = detect_unet_config(sd, "") - - assert unet_config is not None - assert unet_config["image_model"] == "seedvr2" - assert unet_config["vid_dim"] == 3072 - assert unet_config["heads"] == 24 - assert unet_config["num_layers"] == 36 - assert unet_config["mm_layers"] == 36 - assert unet_config["mlp_type"] == "normal" - assert unet_config["qk_rope"] is True - assert unet_config["rope_type"] == "rope3d" - assert unet_config["rope_dim"] == 64 - - def test_seedvr2_7b_shared_mm_detection_config(self): - sd = _make_seedvr2_7b_shared_mm_sd() - unet_config = detect_unet_config(sd, "") - - assert unet_config is not None - assert unet_config["image_model"] == "seedvr2" - assert unet_config["vid_dim"] == 3072 - assert unet_config["heads"] == 24 - assert unet_config["num_layers"] == 36 - assert unet_config["mm_layers"] == 10 - assert unet_config["mlp_type"] == "swiglu" - assert unet_config["qk_rope"] is True - assert unet_config["rope_type"] == "rope3d" - assert unet_config["rope_dim"] == 64 - - def test_seedvr2_3b_shared_mm_detection_config(self): - sd = _make_seedvr2_3b_shared_mm_sd() - unet_config = detect_unet_config(sd, "") - - assert unet_config is not None - assert unet_config["image_model"] == "seedvr2" - assert unet_config["vid_dim"] == 2560 - assert unet_config["heads"] == 20 - assert unet_config["num_layers"] == 32 - assert unet_config["mlp_type"] == "swiglu" - assert unet_config["qk_rope"] is None - def test_unet_config_and_required_keys_combination_is_unique(self): """Each model in the registry must have a unique combination of ``unet_config`` and ``required_keys``. If two models share the same diff --git a/tests-unit/comfy_test/seedvr_vae_forward_test.py b/tests-unit/comfy_test/seedvr_vae_forward_test.py deleted file mode 100644 index f9dbd68906d3..000000000000 --- a/tests-unit/comfy_test/seedvr_vae_forward_test.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must -honor the actual tensor/tuple return contract of ``encode()`` and -``decode_()`` and must NOT dereference diffusers-style ``.latent_dist`` -or ``.sample`` attributes on those returns. - -The pre-fix body raised ``AttributeError: 'Tensor' object has no -attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and -``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'`` -for ``mode == "decode"`` (the class only defines ``decode_`` with a -trailing underscore). The post-fix body unwraps the optional one-element -tuple shape that ``return_dict=False`` produces and returns the tensor -directly. - -Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses -the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and -overrides ``encode``/``decode_`` with known tensors so the contract can -be probed without loading any real VAE weights. -""" - -import torch -import torch.nn as nn - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402 - - -_LATENT_SHAPE = (1, 16, 2, 2, 2) -_DECODED_SHAPE = (1, 3, 5, 16, 16) -_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16) -_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2) - - -class _StubVAE(VideoAutoencoderKL): - def __init__(self): - nn.Module.__init__(self) - self._encode_out = torch.zeros(*_LATENT_SHAPE) - self._decode_out = torch.zeros(*_DECODED_SHAPE) - - def encode(self, x, return_dict=True): - return self._encode_out - - def decode_(self, z, return_dict=True): - return self._decode_out - - -def test_forward_encode_returns_tensor(): - vae = _StubVAE() - x = torch.zeros(*_INPUT_ENCODE_SHAPE) - result = vae.forward(x, mode="encode") - assert type(result) is torch.Tensor - assert result.shape == torch.Size(_LATENT_SHAPE) - - -def test_forward_decode_returns_tensor(): - vae = _StubVAE() - z = torch.zeros(*_INPUT_DECODE_SHAPE) - result = vae.forward(z, mode="decode") - assert type(result) is torch.Tensor - assert result.shape == torch.Size(_DECODED_SHAPE) - - -class _TupleReturningStubVAE(VideoAutoencoderKL): - """Stub variant whose ``encode``/``decode_`` return the - ``(tensor,)`` one-element tuple shape ``return_dict=False`` produces - in the parent class. Exercises the unwrap branch of - ``VideoAutoencoderKL.forward``. - """ - - def __init__(self): - nn.Module.__init__(self) - self._encode_tensor = torch.zeros(*_LATENT_SHAPE) - self._decode_tensor = torch.zeros(*_DECODED_SHAPE) - - def encode(self, x, return_dict=True): - return (self._encode_tensor,) - - def decode_(self, z, return_dict=True): - return (self._decode_tensor,) - - -def test_forward_all_unwraps_one_tuple_at_each_step(): - vae = _TupleReturningStubVAE() - x = torch.zeros(*_INPUT_ENCODE_SHAPE) - result = vae.forward(x, mode="all") - assert type(result) is torch.Tensor - assert result.shape == torch.Size(_DECODED_SHAPE) diff --git a/tests-unit/comfy_test/test_seedvr2_dtype.py b/tests-unit/comfy_test/test_seedvr2_dtype.py deleted file mode 100644 index e5d79a306b72..000000000000 --- a/tests-unit/comfy_test/test_seedvr2_dtype.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -import comfy.sd -import comfy.supported_models -import comfy.ldm.seedvr.model as seedvr_model - - -def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch): - bf16_device = object() - fp16_device = object() - - monkeypatch.setattr( - comfy.supported_models.comfy.model_management, - "should_use_bf16", - lambda device=None: device is bf16_device, - ) - - bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"}) - bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device) - assert bf16_config.manual_cast_dtype is torch.bfloat16 - - fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"}) - fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device) - assert fp16_config.manual_cast_dtype is None - - -def test_seedvr2_text_conditioning_accepts_cfg1_single_branch(): - context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2) - - txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0]) - - torch.testing.assert_close(txt, context.squeeze(0)) - torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device)) - - -def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer(): - estimate = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160)) - old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2 - - assert estimate == 101 * 960 * 1280 * 160 - assert estimate > 15 * 1024 ** 3 - assert estimate > old_estimate * 100 diff --git a/tests-unit/comfy_test/test_seedvr2_internals.py b/tests-unit/comfy_test/test_seedvr2_internals.py deleted file mode 100644 index 5b008ea6e97b..000000000000 --- a/tests-unit/comfy_test/test_seedvr2_internals.py +++ /dev/null @@ -1,341 +0,0 @@ -"""Consolidated SeedVR2 internals regression tests. - -Sources (all merged verbatim, helper names disambiguated where colliding): - - * RoPE rewrite — NaMMRotaryEmbedding3d.forward must match the legacy - apply_rotary_emb wrapper oracle at fp32. - * GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare - memory_occupy against get_norm_limit(), not float('inf'). - * SeedVR2 variable-length attention split-loop contract. - -Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and -comfy.ldm.modules.attention transitively pull in comfy.model_management, -which probes torch.cuda.current_device() at import time unless args.cpu is -set first. -""" - -from __future__ import annotations - -from unittest.mock import patch - -import pytest -import torch - -from comfy.cli_args import args - -if not torch.cuda.is_available(): - args.cpu = True - -import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 -import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 -import comfy.ldm.modules.attention as attention # noqa: E402 -import comfy.ops as comfy_ops # noqa: E402 -from comfy.ldm.seedvr.model import ( # noqa: E402 - Cache, - NaMMRotaryEmbedding3d, -) -from comfy.ldm.seedvr.vae import ( # noqa: E402 - causal_norm_wrapper, - set_norm_limit, -) -from comfy.ldm.modules.attention import var_attention_optimized_split # noqa: E402 - - -# --------------------------------------------------------------------------- -# RoPE rewrite tests (test_seedvr_rope_rewrite.py) -# --------------------------------------------------------------------------- - -# Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains -# requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8. -_DIM = 192 -_HEADS = 4 -_VID_T, _VID_H, _VID_W = 2, 4, 4 -_TXT_L = 8 -_L_VID = _VID_T * _VID_H * _VID_W -_SEED = 0 - - -def _make_inputs(dtype=torch.float32, device="cpu"): - """Construct the 6 forward inputs + cache. Deterministic via local - Generator so global RNG state is not mutated. - """ - g = torch.Generator(device=device).manual_seed(_SEED) - vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g) - vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g) - txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g) - txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g) - vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device) - txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device) - cache = Cache(disable=True) - return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache - - -def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape): - """Reproduce the pre-rewrite ``get_freqs`` body verbatim against - ``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method, - unchanged by the rewrite). - """ - max_temporal = 0 - max_height = 0 - max_width = 0 - max_txt_len = 0 - for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - max_temporal = max(max_temporal, l + f) - max_height = max(max_height, h) - max_width = max(max_width, w) - max_txt_len = max(max_txt_len, l) - with torch.amp.autocast(device_type="cuda", enabled=False): - vid_freqs_full = rope.get_axial_freqs( - min(max_temporal + 16, 1024), - min(max_height + 4, 128), - min(max_width + 4, 128), - ).float() - txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024)) - vid_freq_list, txt_freq_list = [], [] - for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1)) - txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1)) - vid_freq_list.append(vid_freq) - txt_freq_list.append(txt_freq) - return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) - - -def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape, - txt_q, txt_k, txt_shape): - """Compute expected forward output via the unchanged - ``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the - oracle. The wrapper itself is out of scope for the rewrite (Shape B). - """ - vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape) - vid_freqs = vid_freqs.to(vid_q.device) - txt_freqs = txt_freqs.to(txt_q.device) - - from einops import rearrange - - vid_q = rearrange(vid_q, "L h d -> h L d") - vid_k = rearrange(vid_k, "L h d -> h L d") - vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) - vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) - vid_q_out = rearrange(vid_q_out, "h L d -> L h d") - vid_k_out = rearrange(vid_k_out, "h L d -> L h d") - - txt_q = rearrange(txt_q, "L h d -> h L d") - txt_k = rearrange(txt_k, "L h d -> h L d") - txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) - txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) - txt_q_out = rearrange(txt_q_out, "h L d -> L h d") - txt_k_out = rearrange(txt_k_out, "h L d -> L h d") - return vid_q_out, vid_k_out, txt_q_out, txt_k_out - - -def test_namm_forward_output_tensor_equal_against_legacy_oracle(): - rope = NaMMRotaryEmbedding3d(dim=_DIM) - vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs() - - expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward( - rope, - vid_q.clone(), vid_k.clone(), vid_shape, - txt_q.clone(), txt_k.clone(), txt_shape, - ) - - actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward( - vid_q.clone(), vid_k.clone(), vid_shape, - txt_q.clone(), txt_k.clone(), txt_shape, cache, - ) - - torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0, - msg="vid_q output diverges from wrapper oracle") - torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0, - msg="vid_k output diverges from wrapper oracle") - torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0, - msg="txt_q output diverges from wrapper oracle") - torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0, - msg="txt_k output diverges from wrapper oracle") - - -# --------------------------------------------------------------------------- -# GroupNorm limit tests (test_seedvr_groupnorm_limit.py) -# --------------------------------------------------------------------------- - -_NUM_CHANNELS = 8 -_NUM_GROUPS = 4 -_TENSOR_SHAPE = (1, 8, 2, 4, 4) - -_GROUPNORM_SUBCLASSES = [ - pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"), - pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"), -] - - -@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES) -def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls): - real_group_norm = vae_mod.F.group_norm - set_norm_limit(1e-9) - try: - gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS) - gn.eval() - - forward_hook_calls = [] - - def _hook(module, inputs, output): - forward_hook_calls.append(tuple(inputs[0].shape)) - - spy_calls = [] - - def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs): - spy_calls.append({"num_groups": int(num_groups_arg)}) - return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs) - - handle = gn.register_forward_hook(_hook) - try: - with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy): - out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE)) - finally: - handle.remove() - - full_calls = len(forward_hook_calls) - chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS) - - assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE - assert full_calls == 0, ( - f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}" - ) - assert chunked_calls > 0, ( - f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}" - ) - finally: - set_norm_limit(None) - - -# --------------------------------------------------------------------------- -# SeedVR2 var_attention split-loop tests -# --------------------------------------------------------------------------- - -def test_var_attention_registry_contains_always_available_entries(): - assert ( - attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_optimized_split"] - is attention.var_attention_optimized_split - ) - - -def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch): - dim = 8 - heads = 2 - head_dim = 4 - attn = seedvr_model.NaSwinAttention( - vid_dim=dim, - txt_dim=dim, - heads=heads, - head_dim=head_dim, - qk_bias=False, - qk_norm=seedvr_model.CustomRMSNorm, - qk_norm_eps=1e-6, - rope_type=None, - rope_dim=head_dim, - shared_weights=False, - window=(2, 1, 1), - window_method="720pwin_by_size_bysize", - version=True, - device="cpu", - dtype=torch.float32, - operations=comfy_ops.disable_weight_init, - ) - generator = torch.Generator(device="cpu").manual_seed(11) - vid = torch.randn(8, dim, generator=generator) - txt = torch.randn(3, dim, generator=generator) - vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long) - txt_shape = torch.tensor([[3]], dtype=torch.long) - calls = [] - - def fake_optimized_var_attention(**kwargs): - calls.append(kwargs) - return kwargs["q"] - - monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention) - - vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True)) - - assert tuple(vid_out.shape) == (8, dim) - assert tuple(txt_out.shape) == (3, dim) - assert len(calls) == 1 - call = calls[0] - assert tuple(call["q"].shape) == (14, heads, head_dim) - assert tuple(call["k"].shape) == (14, heads, head_dim) - assert tuple(call["v"].shape) == (14, heads, head_dim) - assert call["heads"] == heads - assert call["skip_reshape"] is True - assert call["skip_output_reshape"] is True - torch.testing.assert_close( - call["cu_seqlens_q"], - torch.tensor([0, 7, 14], dtype=torch.int32), - rtol=0, - atol=0, - ) - torch.testing.assert_close( - call["cu_seqlens_k"], - torch.tensor([0, 7, 14], dtype=torch.int32), - rtol=0, - atol=0, - ) - - -def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch): - heads = 2 - head_dim = 3 - q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim) - k = q + 100 - v = q + 200 - cu = torch.tensor([0, 2, 5], dtype=torch.int32) - calls = [] - - def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs): - calls.append( - { - "q_shape": tuple(q_arg.shape), - "k_shape": tuple(k_arg.shape), - "v_shape": tuple(v_arg.shape), - "heads": heads_arg, - "kwargs": kwargs, - } - ) - return q_arg + v_arg - - monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention) - - out = var_attention_optimized_split( - q, - k, - v, - heads, - cu, - cu, - skip_reshape=True, - skip_output_reshape=True, - ) - - assert tuple(out.shape) == (5, heads, head_dim) - assert len(calls) == 2 - assert calls[0]["q_shape"] == (1, heads, 2, head_dim) - assert calls[1]["q_shape"] == (1, heads, 3, head_dim) - assert all(call["heads"] == heads for call in calls) - assert all(call["kwargs"]["skip_reshape"] is True for call in calls) - assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls) - torch.testing.assert_close(out, q + v, rtol=0, atol=0) - - -def test_var_attention_optimized_split_rejects_bad_offsets(): - q = torch.randn(5, 2, 3) - cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32) - cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32) - - with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"): - var_attention_optimized_split( - q, - q, - q, - 2, - cu_bad, - cu_ok, - skip_reshape=True, - skip_output_reshape=True, - ) diff --git a/tests-unit/comfy_test/test_seedvr2_model.py b/tests-unit/comfy_test/test_seedvr2_model.py deleted file mode 100644 index f2b9bcbbec8b..000000000000 --- a/tests-unit/comfy_test/test_seedvr2_model.py +++ /dev/null @@ -1,308 +0,0 @@ -"""Consolidated SeedVR2 model/graph/forward regression tests. - -Merged from: -- seedvr_model_test.py -- test_seedvr_7b_final_block_text_path.py -- test_seedvr_forward_no_device_cast.py -- test_seedvr_latent_format.py -- test_seedvr2_vae_graph_boundaries.py -""" - -from __future__ import annotations - -from unittest.mock import MagicMock - -import torch -from torch import nn - -from comfy.cli_args import args - -if not torch.cuda.is_available(): - args.cpu = True - -import comfy # noqa: E402 -import comfy.latent_formats # noqa: E402 -import comfy.ldm.seedvr.model # noqa: E402 -import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 -import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 -import comfy.model_management # noqa: E402 -import comfy.sample # noqa: E402 -import comfy.sd as sd_mod # noqa: E402 -import nodes as nodes_mod # noqa: E402 -from comfy.ldm.seedvr.model import NaDiT # noqa: E402 - - -# --------------------------------------------------------------------------- -# Helpers from seedvr_model_test.py -# --------------------------------------------------------------------------- - - -def _make_standin(positive_conditioning): - class _StandIn(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer( - "positive_conditioning", positive_conditioning - ) - - _resolve_text_conditioning = NaDiT._resolve_text_conditioning - - return _StandIn() - - -# --------------------------------------------------------------------------- -# Helpers from test_seedvr_7b_final_block_text_path.py -# --------------------------------------------------------------------------- - - -class _StubModule(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - - -def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]: - flags = [] - - class _Block(_StubModule): - def __init__(self, *args, **kwargs): - flags.append(kwargs["is_last_layer"]) - super().__init__() - - monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule) - monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule) - monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule) - monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block) - - seedvr_model.NaDiT( - norm_eps=1e-5, - qk_rope=None, - num_layers=4, - mlp_type="normal", - vid_dim=vid_dim, - txt_in_dim=txt_in_dim, - heads=24, - mm_layers=3, - ) - - return flags - - -# --------------------------------------------------------------------------- -# Helpers from test_seedvr_latent_format.py -# --------------------------------------------------------------------------- - - -class _Model: - def __init__(self, latent_format): - self._latent_format = latent_format - - def get_model_object(self, name): - assert name == "latent_format" - return self._latent_format - - -# --------------------------------------------------------------------------- -# Helpers from test_seedvr2_vae_graph_boundaries.py -# --------------------------------------------------------------------------- - - -class _Patcher: - def get_free_memory(self, device): - return 1024 * 1024 * 1024 - - -class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): - def __init__(self, encoded): - nn.Module.__init__(self) - self.encoded = encoded - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.seen = [] - - def encode(self, x): - self.seen.append(tuple(x.shape)) - return self.encoded.to(device=x.device, dtype=x.dtype) - - -class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): - def __init__(self): - nn.Module.__init__(self) - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.calls = [] - - def decode(self, z, seedvr2_tiling=None): - self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling}) - if z.ndim == 4: - b, tc, h, w = z.shape - t = tc // 16 - else: - b, _, t, h, w = z.shape - return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device) - - -def _make_vae(wrapper): - vae = sd_mod.VAE.__new__(sd_mod.VAE) - vae.first_stage_model = wrapper - vae.device = torch.device("cpu") - vae.output_device = torch.device("cpu") - vae.vae_dtype = torch.float32 - vae.latent_channels = 16 - vae.latent_dim = 3 - vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8) - vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) - vae.output_channels = 3 - vae.disable_offload = True - vae.extra_1d_channel = None - vae.crop_input = False - vae.not_video = False - vae.patcher = _Patcher() - vae.process_input = lambda image: image - vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0) - vae.vae_output_dtype = lambda: torch.float32 - vae.memory_used_encode = lambda shape, dtype: 1 - vae.memory_used_decode = lambda shape, dtype: 1 - vae.throw_exception_if_invalid = lambda: None - vae.vae_encode_crop_pixels = lambda pixels: pixels - vae.spacial_compression_decode = lambda: 8 - vae.temporal_compression_decode = lambda: 4 - return vae - - -# --------------------------------------------------------------------------- -# Tests from seedvr_model_test.py -# --------------------------------------------------------------------------- - - -def test_missing_context_falls_back_to_positive_buffer(): - """AC: ``context is None`` falls back to the registered - ``positive_conditioning`` buffer and runs to completion — no - silent zero substitution, no raised exception. - """ - pos_buffer = torch.full((58, 5120), 7.0) - standin = _make_standin(pos_buffer) - txt, txt_shape = standin._resolve_text_conditioning(None) - assert txt.shape == (58, 5120) - assert (txt == 7.0).all(), ( - "fallback path must use the positive_conditioning buffer " - "verbatim, not a zero tensor" - ) - assert txt_shape.shape == (1, 1) - assert txt_shape[0, 0].item() == 58 - - -# --------------------------------------------------------------------------- -# Tests from test_seedvr_7b_final_block_text_path.py -# --------------------------------------------------------------------------- - - -def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch): - assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [ - False, - False, - False, - False, - ] - - -def test_seedvr2_7b_rope3d_matches_wrapper_oracle(): - rope = seedvr_model.get_na_rope("rope3d", dim=64) - generator = torch.Generator(device="cpu").manual_seed(0) - q = torch.randn(4, 2, 128, generator=generator) - k = torch.randn(4, 2, 128, generator=generator) - shape = torch.tensor([[1, 2, 2]], dtype=torch.long) - freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1) - - expected_q = seedvr_model._apply_seedvr2_rotary_emb( - freqs, - q.permute(1, 0, 2).float(), - ).to(q.dtype).permute(1, 0, 2) - expected_k = seedvr_model._apply_seedvr2_rotary_emb( - freqs, - k.permute(1, 0, 2).float(), - ).to(k.dtype).permute(1, 0, 2) - - actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True)) - - torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0) - torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0) - - -# --------------------------------------------------------------------------- -# Tests from test_seedvr_latent_format.py -# --------------------------------------------------------------------------- - - -def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion(): - latent_format = comfy.latent_formats.SeedVR2() - latent_image = torch.zeros(1, 1, 4, 5) - - fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image) - - assert latent_format.latent_channels == 16 - assert latent_format.latent_dimensions == 2 - assert fixed.shape == (1, 16, 4, 5) - - -# --------------------------------------------------------------------------- -# Tests from test_seedvr2_vae_graph_boundaries.py -# --------------------------------------------------------------------------- - - -def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch): - monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) - - encoded = torch.full((1, 16, 2, 4, 5), 2.0) - vae = _make_vae(_EncodeWrapper(encoded)) - pixels = torch.zeros(1, 5, 32, 40, 3) - - node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0] - node_latent = node_output["samples"] - assert set(node_output) == {"samples"} - assert tuple(node_latent.shape) == (1, 16, 2, 4, 5) - assert node_latent.dtype == torch.float32 - assert node_latent.stride()[-1] == 1 - assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152)) - - tiled = torch.full((1, 16, 2, 4, 5), 3.0) - monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled)) - tiled_output = nodes_mod.VAEEncodeTiled().encode( - vae, - pixels, - tile_size=512, - overlap=64, - temporal_size=16, - temporal_overlap=4, - )[0] - tiled_latent = tiled_output["samples"] - assert set(tiled_output) == {"samples"} - assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5) - assert tiled_latent.dtype == torch.float32 - assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152)) - - -def test_vaedecode_tiled_visible_inputs_are_seedvr2_decode_tiling_authority(monkeypatch): - monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) - vae = _make_vae(_DecodeWrapper()) - - nodes_mod.VAEDecodeTiled().decode( - vae, - {"samples": torch.zeros(1, 16, 2, 4, 5)}, - tile_size=512, - overlap=64, - temporal_size=16, - temporal_overlap=4, - ) - - assert vae.first_stage_model.calls == [ - { - "shape": (1, 16, 2, 4, 5), - "seedvr2_tiling": { - "enable_tiling": True, - "tile_size": (512, 512), - "tile_overlap": (64, 64), - "temporal_size": 16, - "temporal_overlap": 4, - }, - } - ] diff --git a/tests-unit/comfy_test/test_seedvr2_vae_decode.py b/tests-unit/comfy_test/test_seedvr2_vae_decode.py deleted file mode 100644 index ea9f978f38b9..000000000000 --- a/tests-unit/comfy_test/test_seedvr2_vae_decode.py +++ /dev/null @@ -1,91 +0,0 @@ -from unittest.mock import patch - -import pytest -import torch -import torch.nn as nn - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 -from comfy_extras import nodes_seedvr # noqa: E402 - - -def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper: - wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( - vae_mod.VideoAutoencoderKLWrapper - ) - nn.Module.__init__(wrapper) - return wrapper - - -def _fingerprint_decode_(self, z, return_dict=True): - b = int(z.shape[0]) - t = int(z.shape[2]) - h = int(z.shape[3]) - w = int(z.shape[4]) - out = torch.empty(b, 3, t, h * 8, w * 8) - for batch_idx in range(b): - out[batch_idx].fill_(float(batch_idx + 1)) - return out - - -def _decode_with_patches(wrapper, z): - with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_): - return wrapper.decode(z) - - -def test_decode_b2_t3_multi_frame_batch_unchanged(): - wrapper = _make_wrapper() - - out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2)) - - assert tuple(out.shape) == (2, 3, 3, 16, 16) - - -class _Wrapper(vae_mod.VideoAutoencoderKLWrapper): - def __init__(self): - nn.Module.__init__(self) - self.calls = [] - - def parameters(self): - return iter([torch.nn.Parameter(torch.zeros(()))]) - -def _decode_stub(self, latent): - self.calls.append(tuple(latent.shape)) - return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8) - - -def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state(): - wrapper = _Wrapper() - - with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub): - out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5)) - - assert tuple(out.shape) == (1, 3, 2, 32, 40) - assert wrapper.calls == [(1, 16, 2, 4, 5)] - - -def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents(): - wrapper = _Wrapper() - - with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"): - wrapper.decode(torch.zeros(1, 16, 4)) - - -def _t_padded(t_in: int) -> int: - if t_in == 1: - return 1 - if t_in <= 4: - return 5 - if (t_in - 1) % 4 == 0: - return t_in - return t_in + (4 - ((t_in - 1) % 4)) - - -@pytest.mark.parametrize("t_in", [1, 5, 9]) -def test_t_padded_matches_cut_videos(t_in): - dummy = torch.zeros(1, t_in, 1, 1, 1) - assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in) diff --git a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py deleted file mode 100644 index 40079bbe2c47..000000000000 --- a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py +++ /dev/null @@ -1,347 +0,0 @@ -from contextlib import ExitStack -from unittest.mock import MagicMock, patch - -import torch -import torch.nn as nn - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 -import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 -import comfy.sd as sd_mod # noqa: E402 -from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402 - - -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_decode_latent_min_size_override.py -# --------------------------------------------------------------------------- - - -def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): - from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae - - class StubVAEModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.slicing_latent_min_size = 2 - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.device = torch.device("cpu") - self.use_slicing = True - self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) - self.decode_min_sizes = [] - self.memory_states = [] - - def decode_(self, t_chunk): - self.decode_min_sizes.append(self.slicing_latent_min_size) - return VideoAutoencoderKL.slicing_decode(self, t_chunk) - - def _decode(self, z, memory_state=MemoryState.DISABLED): - self.memory_states.append(memory_state) - b, c, d, h, w = z.shape - return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype) - - vae = StubVAEModel() - z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32) - - tiled_vae( - z, - vae, - tile_size=(64, 64), - tile_overlap=(0, 0), - temporal_size=0, - temporal_overlap=0, - encode=False, - ) - - assert vae.decode_min_sizes == [5] - assert vae.memory_states == [MemoryState.DISABLED] - assert vae.slicing_latent_min_size == 2 - - -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_encode_runt_slice_override.py -# --------------------------------------------------------------------------- - - -def test_zero_temporal_size_preserves_min_size_when_encode_raises(): - from comfy.ldm.seedvr.vae import tiled_vae - - class RaisingVAEModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.slicing_sample_min_size = 4 - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.device = torch.device("cpu") - self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) - - def encode(self, t_chunk): - raise RuntimeError("simulated encode failure") - - vae = RaisingVAEModel() - x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32) - - raised = False - try: - tiled_vae( - x, - vae, - tile_size=(64, 64), - tile_overlap=(0, 0), - temporal_size=0, - temporal_overlap=0, - encode=True, - ) - except RuntimeError as exc: - if "simulated encode failure" not in str(exc): - raise - raised = True - - assert raised - assert vae.slicing_sample_min_size == 4 - - -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_temporal_slicing.py -# --------------------------------------------------------------------------- - - -class _SlicingDecodeVAE(nn.Module): - def __init__(self, slicing_latent_min_size): - super().__init__() - self.slicing_latent_min_size = slicing_latent_min_size - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.device = torch.device("cpu") - self.use_slicing = True - self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32)) - self.decode_min_sizes = [] - self.memory_states = [] - - def decode_(self, z): - self.decode_min_sizes.append(self.slicing_latent_min_size) - return vae_mod.VideoAutoencoderKL.slicing_decode(self, z) - - def _decode(self, z, memory_state=MemoryState.DISABLED): - self.memory_states.append(memory_state) - x = z[:, :1].repeat( - 1, - 3, - 1, - self.spatial_downsample_factor, - self.spatial_downsample_factor, - ) - return x - - -def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size(): - vae = _SlicingDecodeVAE(slicing_latent_min_size=2) - z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8) - - tiled_vae( - z, - vae, - tile_size=(64, 64), - tile_overlap=(0, 0), - temporal_size=12, - temporal_overlap=4, - encode=False, - ) - - assert vae.decode_min_sizes == [2] - assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE] - assert vae.slicing_latent_min_size == 2 - - wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( - vae_mod.VideoAutoencoderKLWrapper - ) - nn.Module.__init__(wrapper) - seedvr2_tiling = { - "enable_tiling": True, - "tile_size": (64, 64), - "tile_overlap": (0, 0), - "temporal_size": 8, - "temporal_overlap": 7, - } - - captured = {} - - def _fake_tiled_vae(latent, model, **kwargs): - captured.update(kwargs) - return torch.zeros(1, 3, 1, 16, 16) - - with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae): - wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling) - - assert captured["temporal_overlap"] == 7 - - -# --------------------------------------------------------------------------- -# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py -# --------------------------------------------------------------------------- - - -def _force_oom(*a, **k): - raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") - - -def _make_vae(first_stage_model, latent_channels, latent_dim): - vae = sd_mod.VAE.__new__(sd_mod.VAE) - vae.first_stage_model = first_stage_model - vae.patcher = MagicMock() - vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) - vae.device = vae.output_device = torch.device("cpu") - vae.vae_dtype = torch.float32 - vae.disable_offload = True - vae.extra_1d_channel = None - vae.upscale_ratio = vae.downscale_ratio = 8 - vae.upscale_index_formula = vae.downscale_index_formula = None - vae.output_channels = 3 - vae.latent_channels = latent_channels - vae.latent_dim = latent_dim - vae.vae_output_dtype = lambda: torch.float32 - vae.spacial_compression_decode = lambda: 8 - vae.process_input = lambda x: x - vae.process_output = lambda x: x - vae.throw_exception_if_invalid = lambda: None - vae.memory_used_decode = lambda *a, **k: 1 - return vae - - -def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode): - mm = sd_mod.model_management - with ExitStack() as stack: - stack.enter_context(patch.object(mm, "raise_non_oom", lambda e: None)) - stack.enter_context(patch.object(mm, "load_models_gpu", lambda *a, **k: None)) - stack.enter_context(patch.object(mm, "soft_empty_cache", lambda: None)) - stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call)) - stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_", generic_call)) - if patch_wrapper_decode: - stack.enter_context(patch.object( - seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode", - side_effect=_force_oom)) - vae.decode(samples) - - -def test_4d_seedvr2_latent_routes_to_decode_tiled_seedvr2(): - wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( - seedvr_vae_mod.VideoAutoencoderKLWrapper) - vae = _make_vae(wrapper, latent_channels=16, latent_dim=3) - seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) - generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) - _dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True) - assert seedvr2_call.call_count == 1 - assert generic_call.call_count == 0 - - -def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled(): - first_stage = MagicMock() - first_stage.decode = MagicMock(side_effect=_force_oom) - vae = _make_vae(first_stage, latent_channels=4, latent_dim=2) - seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) - generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) - _dispatch(vae, torch.zeros(1, 4, 8, 8), seedvr2_call, generic_call, False) - assert generic_call.call_count == 1 - assert seedvr2_call.call_count == 0 - - -# --------------------------------------------------------------------------- -# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py -# --------------------------------------------------------------------------- - - -def _populate_common_vae_attrs_fallback(vae): - vae.patcher = MagicMock() - vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) - vae.device = torch.device("cpu") - vae.output_device = torch.device("cpu") - vae.vae_dtype = torch.float32 - vae.disable_offload = True - vae.extra_1d_channel = None - vae.upscale_ratio = 8 - vae.upscale_index_formula = None - vae.output_channels = 3 - vae.latent_channels = 16 - vae.latent_dim = 3 - vae.downscale_ratio = 8 - vae.downscale_index_formula = None - vae.not_video = False - vae.crop_input = False - vae.pad_channel_value = None - - vae.vae_output_dtype = lambda: torch.float32 - vae.spacial_compression_encode = lambda: 8 - vae.process_input = lambda x: x - vae.process_output = lambda x: x - vae.throw_exception_if_invalid = lambda: None - vae.memory_used_encode = lambda *a, **k: 1 - - -def _make_seedvr2_vae_fallback(): - vae = sd_mod.VAE.__new__(sd_mod.VAE) - wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( - seedvr_vae_mod.VideoAutoencoderKLWrapper - ) - vae.first_stage_model = wrapper - _populate_common_vae_attrs_fallback(vae) - return vae - - -def _make_non_seedvr2_vae_fallback(): - vae = sd_mod.VAE.__new__(sd_mod.VAE) - vae.first_stage_model = MagicMock() - _populate_common_vae_attrs_fallback(vae) - return vae - - -def _force_regular_encode_oom(*args, **kwargs): - raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") - - -def test_seedvr2_3d_routes_to_encode_tiled_seedvr2_on_oom(): - vae = _make_seedvr2_vae_fallback() - pixel_samples = torch.zeros((1, 8, 64, 64, 3)) - - seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) - generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) - - with patch.object(sd_mod.model_management, "raise_non_oom", - lambda e: None), \ - patch.object(sd_mod.model_management, "load_models_gpu", - lambda *a, **k: None), \ - patch.object(sd_mod.model_management, "soft_empty_cache", - lambda: None), \ - patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode", - side_effect=_force_regular_encode_oom), \ - patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call, - create=True), \ - patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): - vae.encode(pixel_samples) - - assert seedvr2_call.call_count == 1, ( - f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D " - f"input under OOM fallback; got {seedvr2_call.call_count} calls." - ) - assert generic_call.call_count == 0, ( - f"encode_tiled_3d must NOT be called for a SeedVR2 input; got " - f"{generic_call.call_count} calls." - ) - - -def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete(): - vae = _make_non_seedvr2_vae_fallback() - vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8) - vae.upscale_ratio = (lambda a: a * 4, 8, 8) - generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) - pixel_samples = torch.zeros((1, 8, 64, 64, 3)) - - with patch.object(sd_mod.model_management, "load_models_gpu", - lambda *a, **k: None), \ - patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): - vae.encode_tiled(pixel_samples) - - assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64) diff --git a/tests-unit/comfy_test/test_seedvr_progressive_sampler.py b/tests-unit/comfy_test/test_seedvr_progressive_sampler.py deleted file mode 100644 index 05291989edfa..000000000000 --- a/tests-unit/comfy_test/test_seedvr_progressive_sampler.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``.""" - -from unittest.mock import patch - -import pytest -import torch - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -import comfy.sample # noqa: E402 -import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402 -from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402 - -_LAT_C = 16 -_COND_C = 17 - - -def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8): - """Build minimal SeedVR2-shaped sampling inputs.""" - samples_5d = torch.arange( - B * _LAT_C * T * H * W, dtype=torch.float32 - ).reshape(B, _LAT_C, T, H, W) - samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous() - - cond_5d = torch.arange( - B * _COND_C * T * H * W, dtype=torch.float32 - ).reshape(B, _COND_C, T, H, W) + 10000.0 - cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous() - - text_pos = torch.zeros(1, 4, 32) - text_neg = torch.zeros(1, 4, 32) - positive = [[text_pos, {"condition": cond.clone()}]] - negative = [[text_neg, {"condition": cond.clone()}]] - latent_image = {"samples": samples} - return latent_image, positive, negative, samples_5d, cond_5d - - -def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None): - return latent_image - - -def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None): - """Return a tensor whose values encode ``(seed, position)``.""" - base = torch.arange( - latent_image.numel(), dtype=torch.float32 - ).reshape(latent_image.shape) - return base + float(seed) * 1e6 - - -def test_progressive_sampler_schema_exposes_manual_default_auto_chunking(): - schema = SeedVR2ProgressiveSampler.define_schema() - inputs = {item.id: item for item in schema.inputs} - - assert inputs["chunking_mode"].options == ["manual", "auto"] - assert inputs["chunking_mode"].default == "manual" - - -def test_auto_chunking_walks_two_three_four_chunk_ladder(): - """Auto mode must walk 2-, 3-, then 4-chunk geometries on OOM.""" - latent, pos, neg, _, _ = _make_inputs(T=17) - calls = [] - - def _oom_until_four_chunks(model, noise, steps, cfg, sampler_name, - scheduler, positive, negative, - latent_image, denoise=1.0, - noise_mask=None, seed=None): - calls.append(tuple(latent_image.shape)) - if latent_image.shape[1] > _LAT_C * 5: - raise torch.cuda.OutOfMemoryError("chunk too large") - return latent_image.clone() - - with patch.object(comfy.sample, "sample", - side_effect=_oom_until_four_chunks), \ - patch.object(comfy.sample, "fix_empty_latent_channels", - side_effect=_identity_fix_empty), \ - patch.object(comfy.sample, "prepare_noise", - side_effect=_fingerprinted_prepare_noise), \ - patch.object(nodes_seedvr_mod.comfy.model_management, - "soft_empty_cache") as soft_empty: - out = SeedVR2ProgressiveSampler.execute( - model=None, seed=0, steps=2, cfg=1.0, - sampler_name="euler", scheduler="simple", - positive=pos, negative=neg, latent=latent, - denoise=1.0, frames_per_chunk=65, temporal_overlap=0, - chunking_mode="auto", - ) - - assert calls[:4] == [ - (1, _LAT_C * 17, 8, 8), - (1, _LAT_C * 9, 8, 8), - (1, _LAT_C * 6, 8, 8), - (1, _LAT_C * 5, 8, 8), - ] - assert torch.equal(out.result[0]["samples"], latent["samples"]) - assert soft_empty.call_count == 3 - - -@pytest.mark.parametrize("bad_chunk", [0, -1, 2]) -def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk): - """``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation.""" - latent, pos, neg, _, _ = _make_inputs(T=5) - - sampler_called = {"n": 0} - - def _should_not_be_called(*args, **kwargs): - sampler_called["n"] += 1 - return torch.zeros(1) - - with patch.object(comfy.sample, "sample", - side_effect=_should_not_be_called), \ - patch.object(comfy.sample, "fix_empty_latent_channels", - side_effect=_identity_fix_empty), \ - patch.object(comfy.sample, "prepare_noise", - side_effect=_fingerprinted_prepare_noise): - with pytest.raises(ValueError) as excinfo: - SeedVR2ProgressiveSampler.execute( - model=None, seed=0, steps=2, cfg=1.0, - sampler_name="euler", scheduler="simple", - positive=pos, negative=neg, latent=latent, - denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0, - ) - assert str(bad_chunk) in str(excinfo.value) - assert sampler_called["n"] == 0 From cb9f6394160808f7d25163f6cc2ea300c6841ef9 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Tue, 9 Jun 2026 12:19:13 +0900 Subject: [PATCH 24/45] chore(openapi): sync shared API contract from cloud@5273c30 (#14266) --- openapi.yaml | 229 +++++++++++++++++---------------------------------- 1 file changed, 76 insertions(+), 153 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index b7e21245fa72..2510f97d08d5 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -3,11 +3,6 @@ components: Asset: description: Represents a user-owned asset (image, video, or other generated output). properties: - asset_hash: - deprecated: true - description: 'Deprecated: use hash instead. Blake3 hash of the asset content.' - pattern: ^blake3:[a-f0-9]{64}$ - type: string created_at: description: Timestamp when the asset was created format: date-time @@ -16,8 +11,12 @@ components: description: Display name of the asset. Mirrors name for backwards compatibility. nullable: true type: string + file_path: + description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors") + nullable: true + type: string hash: - description: Blake3 hash of the asset content. Preferred over asset_hash. + description: Blake3 hash of the asset content. pattern: ^blake3:[a-f0-9]{64}$ type: string id: @@ -139,17 +138,16 @@ components: AssetUpdated: description: Response returned when an existing asset is successfully updated. properties: - asset_hash: - deprecated: true - description: 'Deprecated: use hash instead. Blake3 hash of the asset content.' - pattern: ^blake3:[a-f0-9]{64}$ - type: string display_name: description: Display name of the asset. Mirrors name for backwards compatibility. nullable: true type: string + file_path: + description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors") + nullable: true + type: string hash: - description: Blake3 hash of the asset content. Preferred over asset_hash. + description: Blake3 hash of the asset content. pattern: ^blake3:[a-f0-9]{64}$ type: string id: @@ -828,7 +826,11 @@ components: type: string type: object PaginationInfo: - description: Offset/limit-based pagination metadata included in list responses. + description: | + Pagination metadata included in list responses. Supports both legacy + offset/limit pagination and cursor-based pagination. When cursor-based + pagination is used, `next_cursor` is the primary pagination token and + `offset`/`total` may be zero. properties: has_more: description: Whether more items are available beyond this page @@ -837,12 +839,19 @@ components: description: Items per page minimum: 1 type: integer + next_cursor: + description: | + Opaque cursor for the next page. Pass this value as the `after` + query parameter on the next request. Empty or absent when there + are no more results. + type: string offset: - description: Current offset (0-based) + deprecated: true + description: 'Current offset (0-based). Deprecated: use cursor-based pagination.' minimum: 0 type: integer total: - description: Total number of items matching filters + description: Total number of items matching filters (may be 0 when using cursor pagination) minimum: 0 type: integer required: @@ -1518,17 +1527,11 @@ paths: schema: default: true type: boolean - - description: Filter assets by exact content hash. Preferred over asset_hash. + - description: Filter assets by exact content hash. in: query name: hash schema: type: string - - deprecated: true - description: 'Deprecated: use hash instead. Filter assets by exact content hash.' - in: query - name: asset_hash - schema: - type: string - description: | Opaque cursor for keyset pagination. Pass the `next_cursor` value from the previous response to fetch the next page. When provided, @@ -1571,42 +1574,12 @@ paths: - file post: description: | - Uploads a new asset to the system with associated metadata. - Supports two upload methods: - 1. Direct file upload (multipart/form-data) - 2. URL-based upload (application/json with source: "url") + Creates a new asset from a direct file upload (multipart/form-data) with associated metadata. If an asset with the same hash already exists, returns the existing asset. - operationId: uploadAsset + operationId: createAsset requestBody: content: - application/json: - schema: - properties: - name: - description: Display name for the asset (used to determine file extension) - type: string - preview_id: - description: Optional preview asset ID - format: uuid - type: string - tags: - description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order. - items: - type: string - type: array - url: - description: HTTP/HTTPS URL to download the asset from - format: uri - type: string - user_metadata: - additionalProperties: true - description: Custom metadata to store with the asset - type: object - required: - - url - - name - type: object multipart/form-data: schema: properties: @@ -1614,6 +1587,10 @@ paths: description: The asset file to upload format: binary type: string + hash: + description: Content hash of the file. + pattern: ^(blake3|sha256):[a-f0-9]{64}$ + type: string id: description: Optional asset ID for idempotent creation. If provided and asset exists, returns existing asset. format: uuid @@ -1629,10 +1606,8 @@ paths: format: uuid type: string tags: - description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order. - items: - type: string - type: array + description: JSON-encoded array of freeform tag strings, e.g. '["models","checkpoint"]'. Common types include "models", "input", "output", and "temp", but any tag can be used in any order. + type: string user_metadata: description: Custom JSON metadata as a string type: string @@ -1641,36 +1616,32 @@ paths: type: object required: true responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/AssetCreated' + description: | + Asset already existed for this user (deduplicated by content hash); the + existing asset is returned with created_new=false. "201": content: application/json: schema: $ref: '#/components/schemas/AssetCreated' - description: Asset created successfully + description: Asset created successfully (created_new=true) "400": content: application/json: schema: $ref: '#/components/schemas/ErrorResponse' - description: Invalid request (bad file, invalid URL, invalid content type, etc.) + description: Invalid request (bad file, invalid content type, etc.) "401": content: application/json: schema: $ref: '#/components/schemas/ErrorResponse' description: Unauthorized - "403": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Source URL requires authentication or access denied - "404": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Source URL not found "413": content: application/json: @@ -1683,19 +1654,13 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' description: Unsupported media type - "422": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Download failed due to network error or timeout "500": content: application/json: schema: $ref: '#/components/schemas/ErrorResponse' description: Internal server error - summary: Upload a new asset + summary: Create a new asset tags: - file /api/assets/{id}: @@ -1730,7 +1695,7 @@ paths: application/json: schema: $ref: '#/components/schemas/ErrorResponse' - description: Asset cannot be deleted because it is referenced by another resource (e.g., workflow version) + description: 'Asset cannot be deleted because it is referenced by another resource, e.g. a workflow version (error code: ASSET_IN_USE)' "500": content: application/json: @@ -1783,7 +1748,7 @@ paths: description: | Updates an asset's metadata. At least one field must be provided. Only name, mime_type, preview_id, and user_metadata can be updated. - For tag management, use the dedicated PUT /api/assets/{id}/tags endpoint. + For tag management, use POST (add) and DELETE (remove) /api/assets/{id}/tags. operationId: updateAsset parameters: - description: Asset ID @@ -1982,76 +1947,6 @@ paths: summary: Add tags to asset tags: - file - put: - description: Adds and removes tags from an asset in a single operation - operationId: updateAssetTags - parameters: - - description: Asset ID - in: path - name: id - required: true - schema: - format: uuid - type: string - requestBody: - content: - application/json: - schema: - description: At least one of add or remove must contain items. Empty arrays are allowed when the other array has items. - minProperties: 1 - properties: - add: - description: Tags to add to the asset. Can be empty if remove has items. - items: - type: string - type: array - remove: - description: Tags to remove from the asset. Can be empty if add has items. - items: - type: string - type: array - type: object - required: true - responses: - "200": - content: - application/json: - schema: - $ref: '#/components/schemas/TagsModificationResponse' - description: Tags updated successfully - "400": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Invalid request - "401": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Unauthorized - "404": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Asset not found - "422": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Reserved tag validation error - "500": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Internal server error - summary: Update asset tags - tags: - - file /api/assets/from-hash: post: description: | @@ -2090,12 +1985,20 @@ paths: type: object required: true responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/AssetCreated' + description: | + Asset reference already existed for this user (deduplicated by content + hash); the existing asset is returned with created_new=false. "201": content: application/json: schema: $ref: '#/components/schemas/AssetCreated' - description: Asset reference created successfully + description: Asset reference created successfully (created_new=true) "400": content: application/json: @@ -2887,7 +2790,21 @@ paths: - asc - desc type: string - - description: Pagination offset (0-based) + - description: | + Opaque cursor for keyset pagination. Pass the `next_cursor` value + from a previous response to fetch the next page. + Cursor pagination is supported only when `sort_by=create_time` + (default). If `sort_by=execution_time`, `after` is ignored and + offset/limit pagination is used. + Cursors are opaque base64url payloads — clients should treat them + as strings and not parse the contents. + example: eyJzIjoiY3JlYXRlX3RpbWUiLCJ2IjoiMTcxNjIwMDAwMDAwMDAwMCIsImlkIjoiYTFiMmMzZDQtZTVmNi03YTg5LWIwYzEtZDJlM2Y0YTViNmM3In0 + in: query + name: after + schema: + type: string + - deprecated: true + description: 'Pagination offset (0-based). Deprecated: prefer cursor-based pagination via `after`.' in: query name: offset schema: @@ -2909,6 +2826,12 @@ paths: schema: $ref: '#/components/schemas/JobsListResponse' description: Success - Jobs retrieved + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Bad request (e.g. malformed pagination cursor). "401": content: application/json: From f89999289abe06c638e15d1895e3c7805bd486b1 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Mon, 8 Jun 2026 20:55:49 -0700 Subject: [PATCH 25/45] fix: Add back apply_rotary_emb for Qwen Image (#14364) --- comfy/ldm/qwen_image/model.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 3462d8108dd4..e49886dd9931 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -51,6 +51,18 @@ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: return hidden_states +# Addin this back because Nunchaku custom nodes rely on it, see comment here: +# https://github.com/Comfy-Org/ComfyUI/pull/14178#issuecomment-4640475161 +# TODO: Eventually remove this once we natively support SVDQuants +def apply_rotary_emb(x, freqs_cis): + if x.shape[1] == 0: + return x + + t_ = x.reshape(*x.shape[:-1], -1, 1, 2) + t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] + return t_out.reshape(*x.shape) + + class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None): super().__init__() From 8ed7f458d055b565d063343bf94dab99f10f649a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Tue, 9 Jun 2026 16:11:05 +0300 Subject: [PATCH 26/45] Allow custom templates with Ideogram4 TE (#14374) --- comfy/text_encoders/ideogram4.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/text_encoders/ideogram4.py b/comfy/text_encoders/ideogram4.py index 55e655d67a2f..84243772d0cd 100644 --- a/comfy/text_encoders/ideogram4.py +++ b/comfy/text_encoders/ideogram4.py @@ -32,7 +32,9 @@ def __init__(self, embedding_directory=None, tokenizer_data={}): self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): - if llama_template is None: + if text.startswith('<|im_start|>'): + llama_text = text + elif llama_template is None: llama_text = self.llama_template.format(text) else: llama_text = llama_template.format(text) From 1639dc7a7041eaaf7ad96f8c7ea2894be01a7d28 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 9 Jun 2026 23:55:00 +1000 Subject: [PATCH 27/45] main/server: Add --debug-hang (#14371) Add an option to debug a hang with ctrl-C, dumping the backtraces to see where its stuck or slow. --- comfy/cli_args.py | 2 ++ main.py | 15 ++++++++++++++- server.py | 9 +++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a4cabcc65a85..cba0dfa34036 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -166,6 +166,8 @@ class PerformanceFeature(enum.Enum): parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) +parser.add_argument("--debug-hang", action="store_true", help="Enable stack trace dumps on Ctrl-C for debugging hangs.") + parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.") parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") diff --git a/main.py b/main.py index 239a520138d5..7fcc8e97d520 100644 --- a/main.py +++ b/main.py @@ -26,6 +26,7 @@ from utils.mime_types import init_mime_types import faulthandler import logging +import signal import sys from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context @@ -37,7 +38,19 @@ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['DO_NOT_TRACK'] = '1' -faulthandler.enable(file=sys.stderr, all_threads=False) +faulthandler.enable(file=sys.stderr, all_threads=args.debug_hang) +if __name__ == "__main__" and args.debug_hang: + dumping_traceback = False + + def dump_traceback_on_sigint(signum, frame): + global dumping_traceback + if dumping_traceback: + raise KeyboardInterrupt + dumping_traceback = True + faulthandler.dump_traceback(file=sys.stderr, all_threads=True) + raise KeyboardInterrupt + + signal.signal(signal.SIGINT, dump_traceback_on_sigint) import comfy_aimdo.control diff --git a/server.py b/server.py index 268441bd1a87..a85c1e59147c 100644 --- a/server.py +++ b/server.py @@ -1253,6 +1253,15 @@ async def start_multi_address(self, addresses, call_on_start=None, verbose=True) if verbose: logging.info("Starting server\n") + if args.debug_hang: + logging.info( + f"{'-' * 80}\n" + "ComfyUI has been started in debug-hang mode. Run your workflow as normal up to\n" + "the point of the hang or freeze, then use ctrl-C in the cmd or controlling\n" + "terminal to dump the python backtraces for debugging. Please attach the extra\n" + "debug info to your bug report.\n" + f"{'-' * 80}" + ) for addr in addresses: address = addr[0] port = addr[1] From 07c53f8f0fa6b014a46756eaa5a07fa9e411ccad Mon Sep 17 00:00:00 2001 From: kelseyee <971704395@qq.com> Date: Tue, 9 Jun 2026 21:57:58 +0800 Subject: [PATCH 28/45] Add LoRA key mapping for LTXV/LTXAV models (#14349) --- comfy/lora.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index 4e0ea29e090d..2c8d0f0bfd1f 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -357,6 +357,12 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")] key_map["transformer.{}".format(key_lora)] = k + if isinstance(model, (comfy.model_base.LTXV, comfy.model_base.LTXAV)): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["{}".format(key_lora)] = k + return key_map From 184009c2f60db7b2e7dc4a80c28f9bc6029408d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Tue, 9 Jun 2026 18:24:09 +0300 Subject: [PATCH 29/45] feat: Add model support for SCAIL-2 (#14373) * initial SCAIL2 support --- comfy/ldm/wan/model.py | 57 ++++++- comfy/model_base.py | 74 +++++++++ comfy/model_detection.py | 2 + comfy/supported_models.py | 12 ++ comfy_extras/nodes_scail.py | 321 ++++++++++++++++++++++++++++++++++++ comfy_extras/nodes_wan.py | 58 ------- nodes.py | 1 + 7 files changed, 462 insertions(+), 63 deletions(-) create mode 100644 comfy_extras/nodes_scail.py diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 70dfe7b16f05..9178b334470e 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1631,13 +1631,15 @@ def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120 self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) - def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs): + def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs): if reference_latent is not None: x = torch.cat((reference_latent, x), dim=2) # embeddings x = self.patch_embedding(x.float()).to(x.dtype) + if ref_mask_latents is not None: # SCAIL-2 additive mask stream + x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype) grid_sizes = x.shape[2:] transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) @@ -1645,6 +1647,8 @@ def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_opt scail_pose_seq_len = 0 if pose_latents is not None: scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype) + if sam_latents is not None: # SCAIL-2 additive mask stream + scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype) scail_x = scail_x.flatten(2).transpose(1, 2) scail_pose_seq_len = scail_x.shape[1] x = torch.cat([x, scail_x], dim=1) @@ -1695,7 +1699,36 @@ def block_wrap(args): return x - def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}): + # ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode, + # which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset. + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}): + if ref_mask_flag is not None and not bool(ref_mask_flag): + REF_ROPE_H = 120.0 + POSE_ROPE_W = 120.0 + + ref_t_patches = 0 + if reference_latent is not None: + ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] + main_t_patches = t - ref_t_patches + + parts = [] + if ref_t_patches > 0: + ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}} + parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf)) + if main_t_patches > 0: + parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options)) + + if pose_latents is not None: + F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] + h_scale = h / H_pose + w_scale = w / W_pose + h_shift = (h_scale - 1) / 2 + w_shift = (w_scale - 1) / 2 + pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}} + parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf)) + + return torch.cat(parts, dim=1) + main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options) if pose_latents is None: @@ -1719,12 +1752,16 @@ def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=No return torch.cat([main_freqs, pose_freqs], dim=1) - def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs): + def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) if pose_latents is not None: pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size) + if ref_mask_latents is not None: # SCAIL-2 + ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size) + if sam_latents is not None: # SCAIL-2 + sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size) t_len = t if time_dim_concat is not None: @@ -1737,5 +1774,15 @@ def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, tr reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size) t_len += reference_latent.shape[2] - freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent) - return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w] + ref_mask_flag = kwargs.pop("ref_mask_flag", None) # SCAIL-2 + + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag) + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w] + + +class SCAIL2WanModel(SCAILWanModel): + """SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream.""" + + def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs): + super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs) + self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) diff --git a/comfy/model_base.py b/comfy/model_base.py index 042804771890..d212a7c2aa21 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1754,6 +1754,80 @@ def extra_conds_shapes(self, **kwargs): return out +class WAN21_SCAIL2(WAN21_SCAIL): + """SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream.""" + + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAIL2WanModel) + self.memory_usage_factor_conds = ("reference_latent", "pose_latents", "ref_mask_latents", "sam_latents") + self.memory_usage_shape_process = { + "pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]], + "sam_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]], + } + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + driving_mask_28ch = kwargs.get("driving_mask_28ch", None) + if driving_mask_28ch is not None: + out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous()) + + ref_mask_28ch = kwargs.get("ref_mask_28ch", None) + if ref_mask_28ch is not None: + out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous()) + + ref_mask_flag = kwargs.get("ref_mask_flag", None) + if ref_mask_flag is not None: + out['ref_mask_flag'] = comfy.conds.CONDConstant(ref_mask_flag) + + return out + + def extra_conds_shapes(self, **kwargs): + out = super().extra_conds_shapes(**kwargs) + driving_mask_28ch = kwargs.get("driving_mask_28ch", None) + if driving_mask_28ch is not None: + s = driving_mask_28ch.shape + out['sam_latents'] = [s[0], 28, s[1], s[3], s[4]] + ref_mask_28ch = kwargs.get("ref_mask_28ch", None) + if ref_mask_28ch is not None: + s = ref_mask_28ch.shape + out['ref_mask_latents'] = [s[0], 28, s[1], s[3], s[4]] + return out + + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key in ("sam_latents", "pose_latents"): + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + + def concat_cond(self, **kwargs): + # The 4 extra channels are the history_mask (1 at clean-anchor frames). + noise = kwargs.get("noise", None) + extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1] + if extra_channels != 4: + return super().concat_cond(**kwargs) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + return torch.zeros_like(noise)[:, :4] + + device = kwargs["device"] + if mask.shape[1] != 4: + mask = torch.mean(mask, dim=1, keepdim=True) + mask = 1.0 - mask + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + if mask.shape[1] == 1: + mask = mask.repeat(1, 4, 1, 1, 1) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + return mask + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + # Hold anchor constant across all sigmas instead of base sigma*noise + (1-sigma)*latent_image. + return latent_image + + class WAN22_WanDancer(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 74c838d13338..290938bd6d2c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -630,6 +630,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "humo" elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "animate" + elif '{}patch_embedding_mask.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "scail2" elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "scail" elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 7cf9c133b9cb..42325d71cc7d 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1450,6 +1450,17 @@ def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device) return out + +class WAN21_SCAIL2(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "scail2", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_SCAIL2(self, image_to_video=False, device=device) + return out + class WAN22_WanDancer(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -2259,6 +2270,7 @@ def get_model(self, state_dict, prefix="", device=None): WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, + WAN21_SCAIL2, WAN22_WanDancer, Hunyuan3Dv2mini, Hunyuan3Dv2, diff --git a/comfy_extras/nodes_scail.py b/comfy_extras/nodes_scail.py new file mode 100644 index 000000000000..a740442dee1e --- /dev/null +++ b/comfy_extras/nodes_scail.py @@ -0,0 +1,321 @@ +"""SCAIL / SCAIL-2 nodes: the WanSCAILToVideo conditioning node and the SAM3 +preprocessing that turns video tracks into the bundle the SCAIL-2 model consumes.""" + +from typing_extensions import override + +import torch +import torch.nn.functional as F + +import nodes +import node_helpers +import comfy.model_management +import comfy.utils +from comfy_api.latest import ComfyExtension, io +from comfy.ldm.sam3.tracker import unpack_masks + +SAM3TrackData = io.Custom("SAM3_TRACK_DATA") + + +# Model was trained on these exact colors; deviating degrades multi-identity quality. +DEFAULT_PALETTE = [ + (0.0, 0.0, 1.0), # Blue + (1.0, 0.0, 0.0), # Red + (0.0, 1.0, 0.0), # Green + (1.0, 0.0, 1.0), # Magenta + (0.0, 1.0, 1.0), # Cyan + (1.0, 1.0, 0.0), # Yellow +] + + +def _unpack(track_data): + packed = track_data["packed_masks"] + if packed is None or packed.shape[1] == 0: + return None + return unpack_masks(packed) + + +def _first_frame_cx_area(masks_bool): + first = masks_bool[0].float() + H, W = first.shape[-2], first.shape[-1] + n_pixels = H * W + grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W) + area = first.sum(dim=(-1, -2)).clamp_(min=1) + cx = (first * grid_x).sum(dim=(-1, -2)) / area + return (cx / W).tolist(), (area / n_pixels).tolist() + + +def _subset_track_data(track_data, obj_indices): + out = dict(track_data) + packed = track_data["packed_masks"] + if packed is None or not obj_indices: + out["packed_masks"] = None + if "scores" in out: + out["scores"] = [] + return out + out["packed_masks"] = packed[:, obj_indices].contiguous() + scores = track_data.get("scores") + if scores is not None: + out["scores"] = [scores[i] for i in obj_indices if i < len(scores)] + return out + + +def _render_colored_masks(track_data, background="black"): + packed = track_data["packed_masks"] + H, W = track_data["orig_size"] + device = comfy.model_management.intermediate_device() + dtype = comfy.model_management.intermediate_dtype() + bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0) + if packed is None or packed.shape[1] == 0: + T = track_data.get("n_frames", 1) if packed is None else packed.shape[0] + out = torch.empty(T, H, W, 3, device=device, dtype=dtype) + out[..., 0], out[..., 1], out[..., 2] = bg_rgb[0], bg_rgb[1], bg_rgb[2] + return out + T, N_obj = packed.shape[0], packed.shape[1] + colors = torch.tensor( + [DEFAULT_PALETTE[i % len(DEFAULT_PALETTE)] for i in range(N_obj)], + device=device, dtype=dtype, + ) + masks_full = unpack_masks(packed.to(device)).float() + Hm, Wm = masks_full.shape[-2], masks_full.shape[-1] + masks_full = F.interpolate( + masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest" + ).view(T, N_obj, H, W) > 0.5 + any_mask = masks_full.any(dim=1) + obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1) + color_overlay = colors[obj_idx_map] + bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3) + return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay)) + + +def _extract_mask_to_28ch(rgb_video): + """Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent + (1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c) + threshold-extracted at 225/255, 8x spatial downsample, 4-frame temporal stacking.""" + T, H, W, _ = rgb_video.shape + _ON_THRESH = 225.0 / 255.0 + mask = rgb_video.movedim(-1, 1).float() + R = (mask[:, 0:1] > _ON_THRESH).float() + G = (mask[:, 1:2] > _ON_THRESH).float() + B = (mask[:, 2:3] > _ON_THRESH).float() + nR, nG, nB = 1 - R, 1 - G, 1 - B + binary_7ch = torch.cat([ + R * G * B, # white + R * nG * nB, # red + nR * G * nB, # green + nR * nG * B, # blue + R * G * nB, # yellow + R * nG * B, # magenta + nR * G * B, # cyan + ], dim=1) + H_lat, W_lat = H, W + for _ in range(3): + H_lat = (H_lat + 1) // 2 + W_lat = (W_lat + 1) // 2 + binary_7ch = torch.nn.functional.interpolate(binary_7ch, size=(H_lat, W_lat), mode='area') + T_latent = (T - 1) // 4 + 1 + padded = torch.cat([binary_7ch[:1].repeat(4, 1, 1, 1), binary_7ch[1:]], dim=0) + out = padded.view(T_latent, 28, H_lat, W_lat) + return out.unsqueeze(0) + + +class WanSCAILToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSCAILToVideo", + category="model/conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."), + io.Image.Input("pose_video_mask", optional=True, tooltip="SCAIL-2 only. Colored per-identity SAM3 mask video at the same resolution as pose_video."), + io.Boolean.Input("replacement_mode", default=False, optional=True, tooltip="SCAIL-2 only. False = Animation Mode (pose_video_mask should have black background). True = Replacement Mode (pose_video_mask should have white background)."), + io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), + io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."), + io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."), + io.Image.Input("reference_image", optional=True, tooltip="Reference image, for multiple references composite all on single image."), + io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask at the same resolution as reference_image."), + io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."), + io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."), + io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."), + io.Image.Input("previous_frames", optional=True, tooltip="SCAIL-2 only. Full decoded output of the previous chunk. Only the last previous_frame_count are used as the extension anchor."), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."), + io.Int.Output(display_name="video_frame_offset", tooltip="Adjusted offset + length. Wire into the next chunk."), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, + video_frame_offset, previous_frame_count, replacement_mode=False, reference_image=None, clip_vision_output=None, pose_video=None, + pose_video_mask=None, reference_image_mask=None, previous_frames=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + noise_mask = None + + ref_mask_flag = not replacement_mode + positive = node_helpers.conditioning_set_values(positive, {"ref_mask_flag": ref_mask_flag}) + negative = node_helpers.conditioning_set_values(negative, {"ref_mask_flag": ref_mask_flag}) + + prev_trimmed = None + if previous_frames is not None and previous_frames.shape[0] > 0: + prev_trimmed = previous_frames[-previous_frame_count:] + video_frame_offset -= prev_trimmed.shape[0] + video_frame_offset = max(0, video_frame_offset) + + ref_latent = None + if reference_image is not None: + reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) + # Replacement Mode: composite ref on black bg using reference_image_mask as alpha matte + if replacement_mode and reference_image_mask is not None: + rm = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1) + is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype) + reference_image = reference_image * is_char + ref_latent = vae.encode(reference_image[:, :, :, :3]) + + if ref_latent is not None: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + if pose_video is not None: + if pose_video.shape[0] <= video_frame_offset: + pose_video = None + else: + pose_video = pose_video[video_frame_offset:] + if pose_video_mask is not None: + if pose_video_mask.shape[0] <= video_frame_offset: + pose_video_mask = None + else: + pose_video_mask = pose_video_mask[video_frame_offset:] + + # Truncate pose+mask jointly to the shorter of the two, capped at length. + ts = [v.shape[0] for v in (pose_video, pose_video_mask) if v is not None] + if ts: + T_kept = ((min(min(ts), length) - 1) // 4) * 4 + 1 + if pose_video is not None: + pose_video = pose_video[:T_kept] + if pose_video_mask is not None: + pose_video_mask = pose_video_mask[:T_kept] + + if pose_video is not None: + pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) + pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength + positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + + if pose_video_mask is not None: + mask_video_hw = comfy.utils.common_upscale(pose_video_mask[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) + driving_mask_28ch = _extract_mask_to_28ch(mask_video_hw) + positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch}) + negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch}) + + if reference_image_mask is not None: + ref_mask_hw = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) + ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw) + zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype) + ref_mask_28ch = torch.cat([ref_mask_1f, zeros], dim=1) + positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch}) + negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch}) + + if prev_trimmed is not None: + pf = comfy.utils.common_upscale(prev_trimmed.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) + prev_latent = vae.encode(pf[:, :, :, :3]) + prev_latent_frames = min(prev_latent.shape[2], latent.shape[2]) + latent[:, :, :prev_latent_frames] = prev_latent[:, :, :prev_latent_frames].to(latent.dtype) + noise_mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=latent.device, dtype=latent.dtype) + noise_mask[:, :, :prev_latent_frames] = 0.0 + + out_latent = {"samples": latent} + if noise_mask is not None: + out_latent["noise_mask"] = noise_mask + return io.NodeOutput(positive, negative, out_latent, video_frame_offset + length) + + +class SCAIL2ColoredMask(io.ComfyNode): + """Render SAM3 tracks for the driving pose video and (optionally) the reference + image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by` + across both outputs guarantees identity K maps to the same color on both + sides, for multi-person workflow consistency. + reference_image_mask is always rendered black-bg (model convention) + pose_video_mask bg follows replacement_mode: black = Animation Mode, white = Replacement Mode + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SCAIL2ColoredMask", + display_name="Create SCAIL-2 Colored Mask", + category="conditioning/video_models/scail", + inputs=[ + SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."), + SAM3TrackData.Input("ref_track_data", optional=True, + tooltip="SAM3 track of the reference image."), + io.String.Input("object_indices", default="", + tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."), + io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right", + tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). left_to_right = leftmost object (by first-frame centroid) gets the first color; area = biggest object (by first-frame mask area) gets the first color; none = keep SAM3's order."), + io.Boolean.Input("replacement_mode", default=False, + tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). Set the matching replacement_mode on WanSCAILToVideo. reference_image_mask is always black-bg regardless."), + ], + outputs=[ + io.Image.Output("pose_video_mask"), + io.Image.Output("reference_image_mask"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None): + def _prep(td): + masks_bool = _unpack(td) + if sort_by != "none" and masks_bool is not None: + cx, area = _first_frame_cx_area(masks_bool) + if sort_by == "left_to_right": + order = sorted(range(len(cx)), key=lambda i: cx[i]) + else: # "area" + order = sorted(range(len(area)), key=lambda i: -area[i]) + td = _subset_track_data(td, order) + if object_indices.strip(): + indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()] + packed = td.get("packed_masks") + n_obj = packed.shape[1] if packed is not None else 0 + indices = [i for i in indices if 0 <= i < n_obj] + td = _subset_track_data(td, indices) + return td + + drv = _prep(driving_track_data) + mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black") + + if ref_track_data is not None: + ref = _prep(ref_track_data) + reference_image_mask = _render_colored_masks(ref, "black") + else: + H, W = drv["orig_size"] + reference_image_mask = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) + + return io.NodeOutput(mask_video, reference_image_mask) + + +class SCAILExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanSCAILToVideo, + SCAIL2ColoredMask, + ] + + +async def comfy_entrypoint() -> SCAILExtension: + return SCAILExtension() diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 67d3a84434f1..d73be8e00700 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1456,63 +1456,6 @@ def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, wi return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) -class WanSCAILToVideo(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="WanSCAILToVideo", - category="model/conditioning/video_models", - inputs=[ - io.Conditioning.Input("positive"), - io.Conditioning.Input("negative"), - io.Vae.Input("vae"), - io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32), - io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32), - io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), - io.Int.Input("batch_size", default=1, min=1, max=4096), - io.ClipVisionOutput.Input("clip_vision_output", optional=True), - io.Image.Input("reference_image", optional=True), - io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."), - io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), - io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."), - io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."), - ], - outputs=[ - io.Conditioning.Output(display_name="positive"), - io.Conditioning.Output(display_name="negative"), - io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."), - ], - is_experimental=True, - ) - - @classmethod - def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput: - latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - - ref_latent = None - if reference_image is not None: - reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - ref_latent = vae.encode(reference_image[:, :, :, :3]) - - if ref_latent is not None: - positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) - negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True) - - if clip_vision_output is not None: - positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) - negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) - - if pose_video is not None: - pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) - pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength - positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) - negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) - - out_latent = {} - out_latent["samples"] = latent - return io.NodeOutput(positive, negative, out_latent) - - class WanExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1533,7 +1476,6 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: WanAnimateToVideo, Wan22ImageToVideoLatent, WanInfiniteTalkToVideo, - WanSCAILToVideo, ] async def comfy_entrypoint() -> WanExtension: diff --git a/nodes.py b/nodes.py index 2f5a478b59e3..4bf768045c98 100644 --- a/nodes.py +++ b/nodes.py @@ -2450,6 +2450,7 @@ async def init_builtin_extra_nodes(): "nodes_rtdetr.py", "nodes_frame_interpolation.py", "nodes_sam3.py", + "nodes_scail.py", "nodes_void.py", "nodes_wandancer.py", "nodes_hidream_o1.py", From 9fc6f5f6dd80ff7433edd25959fc7983e5d4d962 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Tue, 9 Jun 2026 23:36:56 +0800 Subject: [PATCH 30/45] Move bg_removal_model input socket to first position for nicer display (#14353) --- comfy_extras/nodes_bg_removal.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_bg_removal.py b/comfy_extras/nodes_bg_removal.py index 9dc9ad854369..c7b33a821720 100644 --- a/comfy_extras/nodes_bg_removal.py +++ b/comfy_extras/nodes_bg_removal.py @@ -36,15 +36,15 @@ def define_schema(cls): category="image/background removal", description="Generates a foreground mask to remove the background from an image using a background removal model.", inputs=[ - IO.Image.Input("image", tooltip="Input image to remove the background from"), - IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask") + IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask"), + IO.Image.Input("image", tooltip="Input image to remove the background from") ], outputs=[ IO.Mask.Output("mask", tooltip="Generated foreground mask") ] ) @classmethod - def execute(cls, image, bg_removal_model): + def execute(cls, bg_removal_model, image): mask = bg_removal_model.encode_image(image) return IO.NodeOutput(mask) From 6f01b244a29e7b15b4a33168ceb4d603a6244c86 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 10 Jun 2026 03:57:04 +1000 Subject: [PATCH 31/45] mm: dont reset cast buffers in cleanup_models_gc() (#14372) cleanup_models_gc can be called once per load_models_gpu via free_memory, which in turn can de-activate an active model via this reset_cast_buffers. cleanup_models_gc() could also come via obscure garbage collector paths so limit reset_cast_buffers to the post-node callsite instead. --- comfy/model_management.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 8e786c0a507b..9dc0a4e13cd8 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -958,8 +958,6 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): do_gc = False - reset_cast_buffers() - for i in range(len(current_loaded_models)): cur = current_loaded_models[i] if cur.is_dead(): From ad564899d37aec4cf151cf77b69bf1f4b6afe233 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 10 Jun 2026 03:55:29 +0800 Subject: [PATCH 32/45] Ensure conditions are not trainable to avoid bugs (#14368) --- comfy_extras/nodes_train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 046eeaaf5d4f..273f55e7c2f2 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -15,6 +15,7 @@ import comfy.sd import comfy.utils import comfy.model_management +from comfy.conds import CONDRegular, CONDList from comfy.cli_args import args, PerformanceFeature import comfy_extras.nodes_custom_sampler import folder_paths @@ -120,6 +121,11 @@ def process_cond_list(d, prefix=""): process_cond_list(v, f"{prefix}.{k}") elif isinstance(v, torch.Tensor): d[k] = v.clone() + elif isinstance(v, CONDList): + v.cond = [t.detach() if isinstance(t, torch.Tensor) else t for t in v.cond] + elif isinstance(v, CONDRegular): + if isinstance(v.cond, torch.Tensor): + v.cond = v.cond.detach() elif isinstance(v, (list, tuple)): for index, item in enumerate(v): process_cond_list(item, f"{prefix}.{k}.{index}") From f8e51b674c75f41b3960d65ce83f77301ab297c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 10 Jun 2026 02:47:34 +0300 Subject: [PATCH 33/45] feat: Add Bernini-R model support (Wan video) (CORE-279) (#14216) --- comfy/ldm/wan/model.py | 31 ++++++++- comfy/model_base.py | 18 ++++++ comfy_extras/nodes_bernini.py | 115 ++++++++++++++++++++++++++++++++++ nodes.py | 1 + 4 files changed, 163 insertions(+), 2 deletions(-) create mode 100644 comfy_extras/nodes_bernini.py diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 9178b334470e..282408891364 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -8,7 +8,7 @@ from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND -from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.flux.math import apply_rope1, rope import comfy.ldm.common_dit import comfy.model_management import comfy.patcher_extension @@ -570,6 +570,14 @@ def forward_orig( full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) x = torch.concat((full_ref, x), dim=1) + # In-context reference (Bernini) + context_latents = kwargs.get("context_latents", None) + main_len = x.shape[1] + if context_latents is not None: + for lat in context_latents: + cl = self.patch_embedding(lat.float().to(x.device)).to(x.dtype).flatten(2).transpose(1, 2) + x = torch.cat([x, cl], dim=1) + # context context = self.text_embedding(context) @@ -599,6 +607,9 @@ def block_wrap(args): # head x = self.head(x, e) + if context_latents is not None: + x = x[:, :main_len] + if full_ref is not None: x = x[:, full_ref.shape[1]:] @@ -606,7 +617,7 @@ def block_wrap(args): x = self.unpatchify(x, grid_sizes) return x - def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}): + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}, source_id=0): patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) @@ -638,6 +649,13 @@ def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=No img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) freqs = self.rope_embedder(img_ids).movedim(1, 2) + + # In-context reference: a non-zero source_id composes an extra rotation into the spatial rope + if source_id: + d = self.dim // self.num_heads + pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32) + id_rot = rope(pos, d, self.rope_embedder.theta).reshape(1, 1, 1, d // 2, 2, 2).to(freqs.dtype) + freqs = torch.einsum('...ij,...jk->...ik', freqs, id_rot) return freqs def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): @@ -661,6 +679,15 @@ def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, tr t_len += 1 freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) + + # In-context reference: one rope block per stream, each with it's own source_id (1, 2, ...) to distinguish from the target (id 0). + context_latents = kwargs.get("context_latents", None) + if context_latents is not None: + context_latents = [comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size) for lat in context_latents] + for i, lat in enumerate(context_latents): + freqs = torch.cat([freqs, self.rope_encode(lat.shape[-3], lat.shape[-2], lat.shape[-1], device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=i + 1)], dim=1) + kwargs = {**kwargs, "context_latents": context_latents} + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): diff --git a/comfy/model_base.py b/comfy/model_base.py index d212a7c2aa21..2a46d1fc1924 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1518,8 +1518,26 @@ def extra_conds(self, **kwargs): if reference_latents is not None: out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0]) + # In-context reference conditioning (Bernini) + context_latents = kwargs.get("context_latents", None) + if context_latents is not None: + out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents]) + return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + # In-context cond slicing (Bernini) + if cond_key == "context_latents" and isinstance(getattr(cond_value, "cond", None), list): + dim = window.dim + out = [] + for lat in cond_value.cond: + if lat.ndim > dim and lat.shape[dim] > 1 and lat.shape[dim] == x_in.shape[dim]: + out.append(window.get_tensor(lat, device, dim=dim, retain_index_list=retain_index_list)) + else: + out.append(lat.to(device)) + return cond_value._copy_with(out) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN21_CausalAR(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): diff --git a/comfy_extras/nodes_bernini.py b/comfy_extras/nodes_bernini.py new file mode 100644 index 000000000000..227fa5753bdd --- /dev/null +++ b/comfy_extras/nodes_bernini.py @@ -0,0 +1,115 @@ +import torch +from typing_extensions import override + +import comfy.model_management +import comfy.utils +import node_helpers +from comfy_api.latest import ComfyExtension, io + + +def _resize_long_edge(image, max_size, stride=16): + """Resize (preserve aspect) so the long edge <= max_size, then snap each side to `stride`""" + h, w = image.shape[1], image.shape[2] + scale = min(max_size / max(h, w), 1.0) + nh = max(stride, round(h * scale / stride) * stride) + nw = max(stride, round(w * scale / stride) * stride) + return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "area", "disabled").movedim(1, -1) + + +class BerniniConditioning(io.ComfyNode): + """Bernini in-context conditioning for a Wan2.2-A14B model. + + Attaches the VAE-encoded source video / reference images to the conditioning + source video first, then each reference image + + The task is inferred from which inputs are connected: + (nothing) -> t2v (text-to-video) + source_video -> v2v (video-to-video) + source_video + ref_images -> rv2v (reference-guided video editing) + ref_images only -> r2v (reference-to-video) + source_video + ref_video -> ads2v (insert image/video into video) + + source_video is the edit base / canvas (resized to width x height). + reference_video is moving content to composite in. + Streams are ordered source_video, reference_video, then reference_images -> source_id (1, 2, 3, ...). + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="BerniniConditioning", + display_name="Bernini Conditioning", + category="conditioning/video_models", + description="Conditioning node for Bernini in-context video/image conditioning. It can be used for the following tasks: t2v (text-to-video), v2v (video-to-video), rv2v (reference-guided video editing), r2v (reference-to-video), ads2v (insert image/video into video)." + "Reference images injected as in-context tokens (r2v, rv2v) are encoded independently at their own native aspect ratio (long edge capped at ref_max_size)", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=8192, step=16), + io.Int.Input("height", default=480, min=16, max=8192, step=16), + io.Int.Input("length", default=81, min=1, max=8192, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("source_video", optional=True, tooltip=( + "Source video to edit or restyle (v2v, rv2v). Resized to width/height and trimmed to length.")), + io.Image.Input("reference_video", optional=True, tooltip=( + "Video to insert into the source video (ads2v).")), + io.Autogrow.Input("reference_images", optional=True, + template=io.Autogrow.TemplatePrefix( + input=io.Image.Input("reference_image", tooltip=( + "Reference image injected as an in-context token (r2v, rv2v).")), + prefix="reference_image_", min=0, max=8)), + io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True, tooltip=( + "Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio and snapped to 16px.")), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, + source_video=None, reference_video=None, reference_images=None, ref_max_size=848) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], + device=comfy.model_management.intermediate_device()) + + # source_video (1), reference_video (2), reference_images (3, 4, ...). + context = [] + if source_video is not None: + vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + context.append(vae.encode(vid[:, :, :, :3])) + + if reference_video is not None: + ref_vid = _resize_long_edge(reference_video[:length], ref_max_size) # moving content, native aspect + context.append(vae.encode(ref_vid[:, :, :, :3])) + + # reference_images is an autogrow dict {reference_image_0: IMAGE, ...}; each slot is a + # separate stream at its own native aspect (a multi-image batch in one slot -> one stream per frame). + if reference_images: + for name in sorted(reference_images): + imgs = reference_images[name] + if imgs is None: + continue + for i in range(imgs.shape[0]): + img = _resize_long_edge(imgs[i:i + 1], ref_max_size) # native aspect per ref + context.append(vae.encode(img[:, :, :, :3])) + + if context: + positive = node_helpers.conditioning_set_values(positive, {"context_latents": context}) + negative = node_helpers.conditioning_set_values(negative, {"context_latents": context}) + + return io.NodeOutput(positive, negative, {"samples": latent}) + + +class BerniniExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + BerniniConditioning, + ] + + +async def comfy_entrypoint() -> BerniniExtension: + return BerniniExtension() diff --git a/nodes.py b/nodes.py index 4bf768045c98..fb6952badbaa 100644 --- a/nodes.py +++ b/nodes.py @@ -2404,6 +2404,7 @@ async def init_builtin_extra_nodes(): "nodes_video.py", "nodes_lumina2.py", "nodes_wan.py", + "nodes_bernini.py", "nodes_lotus.py", "nodes_hunyuan3d.py", "nodes_primitive.py", From 5ece24e73c0785e95a45a45a2773ecd1a6f5339b Mon Sep 17 00:00:00 2001 From: Talmaj Date: Wed, 10 Jun 2026 03:28:24 +0200 Subject: [PATCH 34/45] Depth anything 3 (Core-135) (#13853) Co-authored-by: Alexis Rolland --- comfy/image_encoders/dino2.py | 333 ++++++++- comfy/ldm/colormap.py | 25 + comfy/ldm/depth_anything_3/camera.py | 177 +++++ comfy/ldm/depth_anything_3/dpt.py | 489 +++++++++++++ comfy/ldm/depth_anything_3/model.py | 236 ++++++ comfy/ldm/depth_anything_3/preprocess.py | 128 ++++ comfy/ldm/depth_anything_3/ray_pose.py | 272 +++++++ .../reference_view_selector.py | 87 +++ comfy/ldm/depth_anything_3/transform.py | 160 ++++ comfy/model_base.py | 7 + comfy/model_detection.py | 89 +++ comfy/supported_models.py | 18 + comfy_extras/nodes_depth_anything_3.py | 681 ++++++++++++++++++ comfy_extras/nodes_moge.py | 14 +- nodes.py | 3 +- 15 files changed, 2687 insertions(+), 32 deletions(-) create mode 100644 comfy/ldm/colormap.py create mode 100644 comfy/ldm/depth_anything_3/camera.py create mode 100644 comfy/ldm/depth_anything_3/dpt.py create mode 100644 comfy/ldm/depth_anything_3/model.py create mode 100644 comfy/ldm/depth_anything_3/preprocess.py create mode 100644 comfy/ldm/depth_anything_3/ray_pose.py create mode 100644 comfy/ldm/depth_anything_3/reference_view_selector.py create mode 100644 comfy/ldm/depth_anything_3/transform.py create mode 100644 comfy_extras/nodes_depth_anything_3.py diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py index ee86f8309ec0..53e4fdb6c3a3 100644 --- a/comfy/image_encoders/dino2.py +++ b/comfy/image_encoders/dino2.py @@ -1,7 +1,13 @@ import torch +import torch.nn.functional as F + from comfy.text_encoders.bert import BertAttention import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.ldm.depth_anything_3.reference_view_selector import ( + select_reference_view, reorder_by_reference, restore_original_order, + THRESH_FOR_REF_SELECTION, +) class Dino2AttentionOutput(torch.nn.Module): @@ -14,13 +20,41 @@ def forward(self, x): class Dino2AttentionBlock(torch.nn.Module): - def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations): + def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations, + qk_norm=False): super().__init__() + self.heads = heads + self.head_dim = embed_dim // heads self.attention = BertAttention(embed_dim, heads, dtype, device, operations) self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations) - - def forward(self, x, mask, optimized_attention): - return self.output(self.attention(x, mask, optimized_attention)) + if qk_norm: + self.q_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device) + self.k_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device) + else: + self.q_norm = None + self.k_norm = None + + def forward(self, x, mask, optimized_attention, pos=None, rope=None): + # Fast path used by the existing CLIP-vision DINOv2 (no DA3 extensions). + if self.q_norm is None and rope is None: + return self.output(self.attention(x, mask, optimized_attention)) + + # DA3 path: do QKV manually so we can apply per-head QK-norm and 2D RoPE. + attn = self.attention + B, N, C = x.shape + h = self.heads + d = self.head_dim + q = attn.query(x).view(B, N, h, d).transpose(1, 2) + k = attn.key(x).view(B, N, h, d).transpose(1, 2) + v = attn.value(x).view(B, N, h, d).transpose(1, 2) + if self.q_norm is not None: + q = self.q_norm(q) + k = self.k_norm(k) + if rope is not None and pos is not None: + q = rope(q, pos) + k = rope(k, pos) + out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) + return self.output(out) class LayerScale(torch.nn.Module): @@ -64,9 +98,11 @@ def forward(self, x): class Dino2Block(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn): + def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn, + qk_norm=False): super().__init__() - self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations) + self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations, + qk_norm=qk_norm) self.layer_scale1 = LayerScale(dim, dtype, device, operations) self.layer_scale2 = LayerScale(dim, dtype, device, operations) if use_swiglu_ffn: @@ -76,19 +112,90 @@ def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, us self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) - def forward(self, x, optimized_attention): - x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention)) + def forward(self, x, optimized_attention, pos=None, rope=None, attn_mask=None): + x = x + self.layer_scale1(self.attention(self.norm1(x), attn_mask, optimized_attention, + pos=pos, rope=rope)) x = x + self.layer_scale2(self.mlp(self.norm2(x))) return x +# ----------------------------------------------------------------------------- +# 2D Rotary position embedding (DA3 extension) +# ----------------------------------------------------------------------------- + + +class _PositionGetter: + """Cache (h, w) -> flat (y, x) position grid used to feed ``rope``.""" + + def __init__(self): + self._cache: dict = {} + + def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor: + key = (height, width, device) + if key not in self._cache: + y = torch.arange(height, device=device) + x = torch.arange(width, device=device) + self._cache[key] = torch.cartesian_prod(y, x) + cached = self._cache[key] + return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone() + + +class RotaryPositionEmbedding2D(torch.nn.Module): + """2D RoPE used by DA3-Small/Base. No learnable parameters.""" + + def __init__(self, frequency: float = 100.0): + super().__init__() + self.base_frequency = frequency + self._freq_cache: dict = {} + + def _components(self, dim: int, seq_len: int, device, dtype): + key = (dim, seq_len, device, dtype) + if key not in self._freq_cache: + exp = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency ** exp) + pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + ang = torch.einsum("i,j->ij", pos, inv_freq) + ang = ang.to(dtype) + ang = torch.cat((ang, ang), dim=-1) + self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype)) + return self._freq_cache[key] + + @staticmethod + def _rotate(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] + x1, x2 = x[..., : d // 2], x[..., d // 2:] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d(self, tokens, positions, cos_c, sin_c): + cos = F.embedding(positions, cos_c)[:, None, :, :] + sin = F.embedding(positions, sin_c)[:, None, :, :] + return (tokens * cos) + (self._rotate(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + feature_dim = tokens.size(-1) // 2 + max_pos = int(positions.max()) + 1 + cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype) + v, h = tokens.chunk(2, dim=-1) + v = self._apply_1d(v, positions[..., 0], cos_c, sin_c) + h = self._apply_1d(h, positions[..., 1], cos_c, sin_c) + return torch.cat((v, h), dim=-1) + + class Dino2Encoder(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn): + def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn, + qknorm_start: int = -1): super().__init__() - self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) - for _ in range(num_layers)]) + self.layer = torch.nn.ModuleList([ + Dino2Block( + dim, num_heads, layer_norm_eps, dtype, device, operations, + use_swiglu_ffn=use_swiglu_ffn, + qk_norm=(qknorm_start != -1 and i >= qknorm_start), + ) + for i in range(num_layers) + ]) def forward(self, x, intermediate_output=None): + # Backward-compat path used by ``ClipVisionModel`` (no DA3 extensions). optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) if intermediate_output is not None: @@ -122,16 +229,27 @@ def forward(self, pixel_values): class Dino2Embeddings(torch.nn.Module): - def __init__(self, dim, dtype, device, operations): + def __init__(self, dim, dtype, device, operations, + patch_size: int = 14, image_size: int = 518, + use_mask_token: bool = True, + num_camera_tokens: int = 0): super().__init__() - patch_size = 14 - image_size = 518 self.patch_size = patch_size + self.image_size = image_size self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations) self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device)) self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) # mask_token is a pre-training param, kept only so strict loading accepts the key. - self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device)) + if use_mask_token: + self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device)) + else: + self.mask_token = None + if num_camera_tokens > 0: + # DA3 stores (ref_token, src_token) pairs that get injected at the + # alt-attn boundary; see ``Dinov2Model._inject_camera_token``. + self.camera_token = torch.nn.Parameter(torch.empty(1, num_camera_tokens, dim, dtype=dtype, device=device)) + else: + self.camera_token = None def interpolate_pos_encoding(self, x, h_pixels, w_pixels): pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32) @@ -140,12 +258,22 @@ def interpolate_pos_encoding(self, x, h_pixels, w_pixels): patch_pos = pos_embed[:, 1:] N = patch_pos.shape[1] M = int(N ** 0.5) + assert N == M * M, f"DINOv2 position grid must be square, got N={N} patches (sqrt={M})" h0 = h_pixels // self.patch_size w0 = w_pixels // self.patch_size - scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0). + # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0). + # scale_factor is (height_scale, width_scale) -- height MUST come first; + # swapping these only happens to work for square inputs and breaks + # non-square paths like DA3-Small / DA3-Base multi-view. + scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2) patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False) + assert (h0, w0) == patch_pos.shape[-2:], ( + f"Interpolated pos-embed grid {tuple(patch_pos.shape[-2:])} does not match " + f"target patch grid ({h0}, {w0}) for input {h_pixels}x{w_pixels} (patch_size={self.patch_size}); " + f"check scale_factor axis order and +0.1 rounding workaround" + ) patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2) return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype) @@ -168,12 +296,51 @@ def __init__(self, config_dict, dtype, device, operations): heads = config_dict["num_attention_heads"] layer_norm_eps = config_dict["layer_norm_eps"] use_swiglu_ffn = config_dict["use_swiglu_ffn"] + patch_size = config_dict.get("patch_size", 14) + image_size = config_dict.get("image_size", 518) + use_mask_token = config_dict.get("use_mask_token", True) + + # DA3 extensions (all default to disabled). + self.alt_start = config_dict.get("alt_start", -1) + self.qknorm_start = config_dict.get("qknorm_start", -1) + self.rope_start = config_dict.get("rope_start", -1) + self.cat_token = config_dict.get("cat_token", False) + rope_freq = config_dict.get("rope_freq", 100.0) + + self.embed_dim = dim + self.patch_size = patch_size + self.num_register_tokens = 0 + self.patch_start_idx = 1 - self.embeddings = Dino2Embeddings(dim, dtype, device, operations) - self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) + if self.rope_start != -1 and rope_freq > 0: + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) + self._position_getter = _PositionGetter() + else: + self.rope = None + self._position_getter = None + + # camera_token shape: (1, 2, dim) -> (ref_token, src_token). + num_cam_tokens = 2 if self.alt_start != -1 else 0 + + self.embeddings = Dino2Embeddings( + dim, dtype, device, operations, + patch_size=patch_size, image_size=image_size, + use_mask_token=use_mask_token, num_camera_tokens=num_cam_tokens, + ) + self.encoder = Dino2Encoder( + dim, heads, layer_norm_eps, num_layers, dtype, device, operations, + use_swiglu_ffn=use_swiglu_ffn, + qknorm_start=self.qknorm_start, + ) self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) def forward(self, pixel_values, attention_mask=None, intermediate_output=None): + if self.alt_start != -1: + raise RuntimeError( + "Dinov2Model.forward() is the backward-compatible CLIP-vision path and does not " + "apply DA3 extensions (RoPE, alternating attention, camera-token injection). " + "Use get_intermediate_layers_da3() for Depth Anything 3 models." + ) x = self.embeddings(pixel_values) x, i = self.encoder(x, intermediate_output=intermediate_output) x = self.layernorm(x) @@ -181,6 +348,7 @@ def forward(self, pixel_values, attention_mask=None, intermediate_output=None): return x, i, pooled_output, None def get_intermediate_layers(self, pixel_values, indices, apply_norm=True): + """Single-view multi-layer feature extraction.""" x = self.embeddings(pixel_values) optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) n_layers = len(self.encoder.layer) @@ -197,3 +365,132 @@ def get_intermediate_layers(self, pixel_values, indices, apply_norm=True): if i >= max_idx: break return [cache[i] for i in resolved] + + # ------------------------------------------------------------------ + # Depth Anything 3 forward + # ------------------------------------------------------------------ + def _prepare_rope_positions(self, B, S, H, W, device): + if self.rope is None: + return None, None + ph, pw = H // self.patch_size, W // self.patch_size + pos = self._position_getter(B * S, ph, pw, device=device) + # Shift so the cls/cam token at position 0 is reserved for "no diff". + pos = pos + 1 + cls_pos = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype) + # Per-view local: real grid positions for patches, 0 for cls token. + pos_local = torch.cat([cls_pos, pos], dim=1) + # Global (across views): same grid positions; cls token still at 0, + # but patches share the same positions in every view. + pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1) + return pos_local, pos_global + + def _inject_camera_token(self, x: torch.Tensor, B: int, S: int, cam_token: "torch.Tensor | None") -> torch.Tensor: + # x: (B, S, N, C). Replace token at index 0 with the camera token. + if cam_token is not None: + inj = cam_token + else: + ct = comfy.model_management.cast_to_device(self.embeddings.camera_token, x.device, x.dtype) + ref_token = ct[:, :1].expand(B, -1, -1) + src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1) + inj = torch.cat([ref_token, src_token], dim=1) + x = x.clone() + x[:, :, 0] = inj + return x + + def get_intermediate_layers_da3(self, pixel_values, out_layers, cam_token=None, ref_view_strategy="saddle_balanced", export_feat_layers=None): + """Multi-view multi-layer feature extraction used by Depth Anything 3.""" + if pixel_values.ndim == 4: + pixel_values = pixel_values.unsqueeze(1) + assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \ + f"expected (B,3,H,W) or (B,S,3,H,W); got {tuple(pixel_values.shape)}" + B, S, _, H, W = pixel_values.shape + + # Patch + cls + (interpolated) pos embed for each view. + x = pixel_values.reshape(B * S, 3, H, W) + x = self.embeddings(x) # (B*S, 1+N, C) + x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C) + + pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device) + # optimized_attention is only used by blocks without QK-norm/RoPE + # (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA. + optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) + + out_set = set(out_layers) + export_set = set(export_feat_layers) if export_feat_layers else set() + outputs: list[torch.Tensor] = [] + aux_outputs: list[torch.Tensor] = [] + local_x = x + b_idx = None + + + for i, blk in enumerate(self.encoder.layer): + apply_rope = self.rope is not None and i >= self.rope_start + block_rope = self.rope if apply_rope else None + l_pos = pos_local if apply_rope else None + g_pos = pos_global if apply_rope else None + + # Reference-view selection threshold: matches the upstream constant + # THRESH_FOR_REF_SELECTION = 3. Skipped when a user-supplied + # cam_token is provided (camera info already pins the geometry). + if (self.alt_start != -1 and i == self.alt_start - 1 and S >= THRESH_FOR_REF_SELECTION and cam_token is None): + b_idx = select_reference_view(x, strategy=ref_view_strategy) + x = reorder_by_reference(x, b_idx) + local_x = reorder_by_reference(local_x, b_idx) + + if self.alt_start != -1 and i == self.alt_start: + x = self._inject_camera_token(x, B, S, cam_token) + + if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1): + # Global attention across views: flatten S into the seq dim. + t = x.reshape(B, S * x.shape[-2], x.shape[-1]) + p = g_pos.reshape(B, S * g_pos.shape[-2], g_pos.shape[-1]) if g_pos is not None else None + t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope) + x = t.reshape(B, S, x.shape[-2], x.shape[-1]) + else: + # Per-view local attention. + t = x.reshape(B * S, x.shape[-2], x.shape[-1]) + p = l_pos.reshape(B * S, l_pos.shape[-2], l_pos.shape[-1]) if l_pos is not None else None + t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope) + x = t.reshape(B, S, x.shape[-2], x.shape[-1]) + local_x = x + + if i in out_set: + if self.cat_token: + out_x = torch.cat([local_x, x], dim=-1) + else: + out_x = x + # Restore original view order on the way out so heads see views + # in the user's expected order. + if b_idx is not None and self.alt_start != -1: + out_x = restore_original_order(out_x, b_idx) + outputs.append(out_x) + + if i in export_set: + aux = x + if b_idx is not None and self.alt_start != -1: + aux = restore_original_order(aux, b_idx) + aux_outputs.append(aux) + + # Apply final norm. When cat_token is set, only the right half + # ("global" features) is normalised; the left half is left as-is to + # match the upstream DA3 head signature. + normed: list[torch.Tensor] = [] + cls_tokens: list[torch.Tensor] = [] + for out_x in outputs: + cls_tokens.append(out_x[:, :, 0]) + if out_x.shape[-1] == self.embed_dim: + normed.append(self.layernorm(out_x)) + elif out_x.shape[-1] == self.embed_dim * 2: + left = out_x[..., :self.embed_dim] + right = self.layernorm(out_x[..., self.embed_dim:]) + normed.append(torch.cat([left, right], dim=-1)) + else: + raise ValueError(f"Unexpected token width: {out_x.shape[-1]}") + + # Drop cls/cam token from the patch sequence. + normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed] + + # Final layernorm + drop cls token from auxiliary features too. + aux_normed = [self.layernorm(o)[..., 1 + self.num_register_tokens:, :] + for o in aux_outputs] + return list(zip(normed, cls_tokens)), aux_normed diff --git a/comfy/ldm/colormap.py b/comfy/ldm/colormap.py new file mode 100644 index 000000000000..1f4d88bd91fd --- /dev/null +++ b/comfy/ldm/colormap.py @@ -0,0 +1,25 @@ +"""Colormap utilities for depth and geometry visualisation.""" + +from __future__ import annotations + +import torch + + +def turbo(x: torch.Tensor) -> torch.Tensor: + """Anton Mikhailov polynomial approximation of the Turbo colormap. + + Args: + x: Float tensor with values in [0, 1]. + + Returns: + RGB tensor of the same shape as ``x`` with a trailing size-3 dimension. + """ + x = x.clamp(0.0, 1.0) + x2 = x * x + x3 = x2 * x + x4 = x2 * x2 + x5 = x4 * x + r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5 + g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5 + b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5 + return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0) diff --git a/comfy/ldm/depth_anything_3/camera.py b/comfy/ldm/depth_anything_3/camera.py new file mode 100644 index 000000000000..65a57d66f36f --- /dev/null +++ b/comfy/ldm/depth_anything_3/camera.py @@ -0,0 +1,177 @@ +"""Camera-token encoder and decoder for Depth Anything 3.""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention_for_device +from .transform import affine_inverse, extri_intri_to_pose_encoding + + +# ----------------------------------------------------------------------- +# Building blocks (mirror depth_anything_3.model.utils.{attention,block}) +# ----------------------------------------------------------------------- + + +class _Mlp(nn.Module): + """Standard 2-layer MLP with GELU. Matches upstream ``utils.attention.Mlp``.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, *, device=None, dtype=None, operations=None): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = operations.Linear(in_features, hidden_features, bias=True, device=device, dtype=dtype) + self.fc2 = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype) + + def forward(self, x): + return self.fc2(F.gelu(self.fc1(x))) + + +class _LayerScale(nn.Module): + """Per-channel learnable scaling. Matches upstream LayerScale.""" + + def __init__(self, dim, *, device=None, dtype=None): + super().__init__() + self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + + def forward(self, x): + return x * self.gamma.to(dtype=x.dtype, device=x.device) + + +class _Attention(nn.Module): + """ Self-attention with fused QKV projection. Mirrors upstream utils.attention.Attention; + Layout matches the HF safetensors (attn.qkv.{weight,bias} and attn.proj.{weight,bias}).""" + + def __init__(self, dim, num_heads, *, device=None, dtype=None, operations=None): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = operations.Linear(dim, dim * 3, bias=True, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, C) + q, k, v = qkv.unbind(2) # each (B, N, C) + attn_fn = optimized_attention_for_device(x.device, small_input=True) + out = attn_fn(q, k, v, heads=self.num_heads) + return self.proj(out) + + +class _Block(nn.Module): + """Pre-norm transformer block with LayerScale. Used by :class:CameraEnc. Layout follows upstream utils.block.Block.""" + + def __init__(self, dim, num_heads, mlp_ratio=4, init_values=0.01, *, device=None, dtype=None, operations=None): + super().__init__() + self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype) + self.attn = _Attention(dim, num_heads, device=device, dtype=dtype, operations=operations) + self.ls1 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity() + self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype) + self.mlp = _Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), device=device, dtype=dtype, operations=operations) + self.ls2 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity() + + def forward(self, x): + x = x + self.ls1(self.attn(self.norm1(x))) + x = x + self.ls2(self.mlp(self.norm2(x))) + return x + + +class CameraEnc(nn.Module): + """Encode per-view (extrinsics, intrinsics) into a camera token. + + Maps a 9-D pose-encoding vector through a small MLP up to the backbone's + ``embed_dim``, then runs ``trunk_depth`` transformer blocks. The output + has shape ``(B, S, embed_dim)`` and is injected at block ``alt_start`` + of the DINOv2 backbone in place of the cls token. + + Parameters mirror the upstream ``cam_enc.py`` so HF weights load directly. + """ + + def __init__( + self, + dim_out: int = 1024, + dim_in: int = 9, + trunk_depth: int = 4, + target_dim: int = 9, + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + *, + device=None, dtype=None, operations=None, + **_kwargs, + ): + super().__init__() + self.target_dim = target_dim + self.trunk_depth = trunk_depth + self.trunk = nn.Sequential(*[ + _Block(dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio, + init_values=init_values, + device=device, dtype=dtype, operations=operations) + for _ in range(trunk_depth) + ]) + self.token_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype) + self.trunk_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype) + self.pose_branch = _Mlp( + in_features=dim_in, + hidden_features=dim_out // 2, + out_features=dim_out, + device=device, dtype=dtype, operations=operations, + ) + + def forward(self, extrinsics: torch.Tensor, intrinsics: torch.Tensor, + image_size_hw) -> torch.Tensor: + """Encode camera parameters into ``(B, S, dim_out)`` tokens.""" + c2ws = affine_inverse(extrinsics) + pose_encoding = extri_intri_to_pose_encoding(c2ws, intrinsics, image_size_hw) + tokens = self.pose_branch(pose_encoding.to(self.pose_branch.fc1.weight.dtype)) + tokens = self.token_norm(tokens) + tokens = self.trunk(tokens) + tokens = self.trunk_norm(tokens) + return tokens + + +class CameraDec(nn.Module): + """Decode the final cam token into a 9-D pose encoding. + + Output layout: ``[T(3), quat_xyzw(4), fov_h, fov_w]``. The translation is + always predicted by the network; the quaternion and FoV can either be + predicted or supplied via ``camera_encoding`` (used at training time + when GT cameras are available -- not exercised at inference here). + + Parameters mirror the upstream ``cam_dec.py`` so HF weights load directly. + """ + + def __init__(self, dim_in: int = 1536, + *, device=None, dtype=None, operations=None, **_kwargs): + super().__init__() + d = dim_in + self.backbone = nn.Sequential( + operations.Linear(d, d, device=device, dtype=dtype), + nn.ReLU(), + operations.Linear(d, d, device=device, dtype=dtype), + nn.ReLU(), + ) + self.fc_t = operations.Linear(d, 3, device=device, dtype=dtype) + self.fc_qvec = operations.Linear(d, 4, device=device, dtype=dtype) + self.fc_fov = nn.Sequential( + operations.Linear(d, 2, device=device, dtype=dtype), + nn.ReLU(), + ) + + def forward(self, feat: torch.Tensor, + camera_encoding: "torch.Tensor | None" = None) -> torch.Tensor: + """Decode ``(B, N, dim_in)`` cam tokens into ``(B, N, 9)`` pose enc.""" + B, N = feat.shape[:2] + feat = feat.reshape(B * N, -1) + feat = self.backbone(feat) + out_t = self.fc_t(feat.float()).reshape(B, N, 3) + if camera_encoding is None: + out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4) + out_fov = self.fc_fov(feat.float()).reshape(B, N, 2) + else: + out_qvec = camera_encoding[..., 3:7] + out_fov = camera_encoding[..., -2:] + return torch.cat([out_t, out_qvec, out_fov], dim=-1) diff --git a/comfy/ldm/depth_anything_3/dpt.py b/comfy/ldm/depth_anything_3/dpt.py new file mode 100644 index 000000000000..fb940873bcae --- /dev/null +++ b/comfy/ldm/depth_anything_3/dpt.py @@ -0,0 +1,489 @@ +"""DPT / DualDPT heads for Depth Anything 3.""" + +from __future__ import annotations + +from typing import List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Permute(nn.Module): + def __init__(self, dims: Tuple[int, ...]): + super().__init__() + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(*self.dims) + + +def _custom_interpolate( + x: torch.Tensor, + size: Optional[Tuple[int, int]] = None, + scale_factor: Optional[float] = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + if size is None: + assert scale_factor is not None + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + INT_MAX = 1610612736 + total = size[0] * size[1] * x.shape[0] * x.shape[1] + if total > INT_MAX: + chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0) + outs = [F.interpolate(c, size=size, mode=mode, align_corners=align_corners) for c in chunks] + return torch.cat(outs, dim=0).contiguous() + return F.interpolate(x, size=size, mode=mode, align_corners=align_corners) + + +def _create_uv_grid(width: int, height: int, aspect_ratio: float, dtype, device) -> torch.Tensor: + """Normalised UV grid spanning (-x_span, -y_span)..(x_span, y_span).""" + diag_factor = (aspect_ratio ** 2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + return torch.stack((uu, vv), dim=-1) # (H, W, 2) + + +def _make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100.0) -> torch.Tensor: + omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) + omega = 1.0 / omega_0 ** (omega / (embed_dim / 2.0)) + pos = pos.reshape(-1) + out = torch.einsum("m,d->md", pos, omega) + return torch.cat([out.sin(), out.cos()], dim=1).float() + + +def _position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100.0) -> torch.Tensor: + H, W, _ = pos_grid.shape + pos_flat = pos_grid.reshape(-1, 2) + emb_x = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) + emb_y = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) + emb = torch.cat([emb_x, emb_y], dim=-1) + return emb.view(H, W, embed_dim) + + +def _add_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """Stateless UV positional embedding added to a feature map (B, C, h, w).""" + pw, ph = x.shape[-1], x.shape[-2] + pe = _create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pe = _position_grid_to_embed(pe, x.shape[1]) * ratio + pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1).to(dtype=x.dtype) + return x + pe + + +def _apply_activation(x: torch.Tensor, activation: str) -> torch.Tensor: + act = (activation or "linear").lower() + if act == "exp": + return torch.exp(x) + if act == "expp1": + return torch.exp(x) + 1 + if act == "expm1": + return torch.expm1(x) + if act == "relu": + return torch.relu(x) + if act == "sigmoid": + return torch.sigmoid(x) + if act == "softplus": + return F.softplus(x) + if act == "tanh": + return torch.tanh(x) + return x + + +# ----------------------------------------------------------------------------- +# Fusion building blocks +# ----------------------------------------------------------------------------- + + +class ResidualConvUnit(nn.Module): + def __init__(self, features: int, device=None, dtype=None, operations=None): + super().__init__() + self.conv1 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype) + self.conv2 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype) + self.activation = nn.ReLU(inplace=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.activation(x) + out = self.conv1(out) + out = self.activation(out) + out = self.conv2(out) + return out + x + + +class FeatureFusionBlock(nn.Module): + def __init__(self, features: int, has_residual: bool = True, align_corners: bool = True, device=None, dtype=None, operations=None): + super().__init__() + self.align_corners = align_corners + self.has_residual = has_residual + if has_residual: + self.resConfUnit1 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations) + else: + self.resConfUnit1 = None + self.resConfUnit2 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations) + self.out_conv = operations.Conv2d(features, features, 1, 1, 0, bias=True, device=device, dtype=dtype) + + def forward(self, *xs: torch.Tensor, size: Optional[Tuple[int, int]] = None) -> torch.Tensor: + y = xs[0] + if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None: + y = y + self.resConfUnit1(xs[1]) + y = self.resConfUnit2(y) + if size is None: + up_kwargs = {"scale_factor": 2.0} + else: + up_kwargs = {"size": size} + y = _custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners) + y = self.out_conv(y) + return y + + +class _Scratch(nn.Module): + """Container that mirrors upstream ``scratch`` attribute layout.""" + + +def _make_scratch(in_shape: List[int], out_shape: int, device=None, dtype=None, operations=None) -> _Scratch: + scratch = _Scratch() + scratch.layer1_rn = operations.Conv2d(in_shape[0], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype) + scratch.layer2_rn = operations.Conv2d(in_shape[1], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype) + scratch.layer3_rn = operations.Conv2d(in_shape[2], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype) + scratch.layer4_rn = operations.Conv2d(in_shape[3], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype) + return scratch + + +def _make_fusion_block(features: int, has_residual: bool = True, device=None, dtype=None, operations=None) -> FeatureFusionBlock: + return FeatureFusionBlock(features, has_residual=has_residual, align_corners=True, device=device, dtype=dtype, operations=operations) + + +# ----------------------------------------------------------------------------- +# DPT (single head + optional sky head) -- used by DA3Mono/Metric +# ----------------------------------------------------------------------------- + + +class DPT(nn.Module): + """Single-head DPT used by DA3Mono-Large and DA3Metric-Large.""" + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 1, + activation: str = "exp", + conf_activation: str = "expp1", + features: int = 256, + out_channels: Sequence[int] = (256, 512, 1024, 1024), + pos_embed: bool = False, + down_ratio: int = 1, + head_name: str = "depth", + use_sky_head: bool = True, + sky_name: str = "sky", + sky_activation: str = "relu", + norm_type: str = "idt", + device=None, dtype=None, operations=None, + ): + super().__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.down_ratio = down_ratio + self.head_main = head_name + self.sky_name = sky_name + self.out_dim = output_dim + self.has_conf = output_dim > 1 + self.use_sky_head = use_sky_head + self.sky_activation = sky_activation + self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3) + + if norm_type == "layer": + self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype) + else: + self.norm = nn.Identity() + + out_channels = list(out_channels) + self.projects = nn.ModuleList([ + operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype) + for oc in out_channels + ]) + self.resize_layers = nn.ModuleList([ + operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype), + operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype), + nn.Identity(), + operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype), + ]) + + self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations) + + head_features_1 = features + head_features_2 = 32 + self.scratch.output_conv1 = operations.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype, + ) + self.scratch.output_conv2 = nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype), + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype), + ) + + if self.use_sky_head: + self.scratch.sky_output_conv2 = nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype), + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype), + ) + + def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict: + # feats[i][0] is the patch-token tensor with shape (B, S, N_patch, C) + B, S, N, C = feats[0][0].shape + feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats] + + ph, pw = H // self.patch_size, W // self.patch_size + resized = [] + for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): + x = feats_flat[take_idx][:, patch_start_idx:] + x = self.norm(x) + x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw) + x = self.projects[stage_idx](x) + if self.pos_embed: + x = _add_pos_embed(x, W, H) + x = self.resize_layers[stage_idx](x) + resized.append(x) + + l1_rn = self.scratch.layer1_rn(resized[0]) + l2_rn = self.scratch.layer2_rn(resized[1]) + l3_rn = self.scratch.layer3_rn(resized[2]) + l4_rn = self.scratch.layer4_rn(resized[3]) + + out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:]) + out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:]) + out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:]) + out = self.scratch.refinenet1(out, l1_rn) + + h_out = int(ph * self.patch_size / self.down_ratio) + w_out = int(pw * self.patch_size / self.down_ratio) + + fused = self.scratch.output_conv1(out) + fused = _custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True) + if self.pos_embed: + fused = _add_pos_embed(fused, W, H) + feat = fused + + main_logits = self.scratch.output_conv2(feat) + outs = {} + if self.has_conf: + fmap = main_logits.permute(0, 2, 3, 1) + pred = _apply_activation(fmap[..., :-1], self.activation) + conf = _apply_activation(fmap[..., -1], self.conf_activation) + outs[self.head_main] = pred.squeeze(-1).view(B, S, *pred.shape[1:-1]) + outs[f"{self.head_main}_conf"] = conf.view(B, S, *conf.shape[1:]) + else: + pred = _apply_activation(main_logits, self.activation) + outs[self.head_main] = pred.squeeze(1).view(B, S, *pred.shape[2:]) + + if self.use_sky_head: + sky_logits = self.scratch.sky_output_conv2(feat) + if self.sky_activation.lower() == "sigmoid": + sky = torch.sigmoid(sky_logits) + elif self.sky_activation.lower() == "relu": + sky = F.relu(sky_logits) + else: + sky = sky_logits + outs[self.sky_name] = sky.squeeze(1).view(B, S, *sky.shape[2:]) + + return outs + + +# ----------------------------------------------------------------------------- +# DualDPT (depth + auxiliary "ray" head) -- used by DA3-Small / DA3-Base +# ----------------------------------------------------------------------------- + + +class DualDPT(nn.Module): + """Two-head DPT used by DA3-Small / DA3-Base.""" + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 2, + activation: str = "exp", + conf_activation: str = "expp1", + features: int = 256, + out_channels: Sequence[int] = (256, 512, 1024, 1024), + pos_embed: bool = True, + down_ratio: int = 1, + aux_pyramid_levels: int = 4, + aux_out1_conv_num: int = 5, + head_names: Tuple[str, str] = ("depth", "ray"), + device=None, dtype=None, operations=None, + ): + super().__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.down_ratio = down_ratio + self.aux_levels = aux_pyramid_levels + self.aux_out1_conv_num = aux_out1_conv_num + self.head_main, self.head_aux = head_names + self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3) + # Toggle the auxiliary ray branch at runtime. Default off (mono path). + # DepthAnything3Net flips this on when running multi-view + ray-pose. + self.enable_aux: bool = False + + self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype) + out_channels = list(out_channels) + self.projects = nn.ModuleList([ + operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype) + for oc in out_channels + ]) + self.resize_layers = nn.ModuleList([ + operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype), + operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype), + nn.Identity(), + operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype), + ]) + + self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations) + # Main fusion chain + self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations) + # Auxiliary fusion chain (separate copies) + self.scratch.refinenet1_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet2_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet3_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations) + + head_features_1 = features + head_features_2 = 32 + + # Main head neck + final projection + self.scratch.output_conv1 = operations.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype, + ) + self.scratch.output_conv2 = nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype), + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype), + ) + + # Aux pre-head per level (multi-level pyramid) + self.scratch.output_conv1_aux = nn.ModuleList([ + self._make_aux_out1_block(head_features_1, device=device, dtype=dtype, operations=operations) + for _ in range(self.aux_levels) + ]) + + # Aux final projection per level (includes LayerNorm permute path). + ln_seq = [Permute((0, 2, 3, 1)), + operations.LayerNorm(head_features_2, device=device, dtype=dtype), + Permute((0, 3, 1, 2))] + self.scratch.output_conv2_aux = nn.ModuleList([ + nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype), + *ln_seq, + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype), + ) + for _ in range(self.aux_levels) + ]) + + @staticmethod + def _make_aux_out1_block(in_ch: int, *, device=None, dtype=None, operations=None) -> nn.Sequential: + # aux_out1_conv_num=5 in all Apache-2.0 variants. + return nn.Sequential( + operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype), + ) + + def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict: + B, S, N, C = feats[0][0].shape + feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats] + + ph, pw = H // self.patch_size, W // self.patch_size + resized = [] + for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): + x = feats_flat[take_idx][:, patch_start_idx:] + x = self.norm(x) + x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw) + x = self.projects[stage_idx](x) + if self.pos_embed: + x = _add_pos_embed(x, W, H) + x = self.resize_layers[stage_idx](x) + resized.append(x) + + l1_rn = self.scratch.layer1_rn(resized[0]) + l2_rn = self.scratch.layer2_rn(resized[1]) + l3_rn = self.scratch.layer3_rn(resized[2]) + l4_rn = self.scratch.layer4_rn(resized[3]) + + # Main pyramid (output_conv1 is applied inside the upstream `_fuse`, + # before interpolation -- replicate that order here). + m = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:]) + if self.enable_aux: + a4 = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:]) + aux_pyr = [a4] + m = self.scratch.refinenet3(m, l3_rn, size=l2_rn.shape[2:]) + if self.enable_aux: + aux_pyr.append(self.scratch.refinenet3_aux(aux_pyr[-1], l3_rn, size=l2_rn.shape[2:])) + m = self.scratch.refinenet2(m, l2_rn, size=l1_rn.shape[2:]) + if self.enable_aux: + aux_pyr.append(self.scratch.refinenet2_aux(aux_pyr[-1], l2_rn, size=l1_rn.shape[2:])) + m = self.scratch.refinenet1(m, l1_rn) + if self.enable_aux: + aux_pyr.append(self.scratch.refinenet1_aux(aux_pyr[-1], l1_rn)) + m = self.scratch.output_conv1(m) + + h_out = int(ph * self.patch_size / self.down_ratio) + w_out = int(pw * self.patch_size / self.down_ratio) + + m = _custom_interpolate(m, (h_out, w_out), mode="bilinear", align_corners=True) + if self.pos_embed: + m = _add_pos_embed(m, W, H) + main_logits = self.scratch.output_conv2(m) + fmap = main_logits.permute(0, 2, 3, 1) + depth_pred = _apply_activation(fmap[..., :-1], self.activation) + depth_conf = _apply_activation(fmap[..., -1], self.conf_activation) + + outs = { + self.head_main: depth_pred.squeeze(-1).view(B, S, *depth_pred.shape[1:-1]), + f"{self.head_main}_conf": depth_conf.view(B, S, *depth_conf.shape[1:]), + } + + if self.enable_aux: + # Auxiliary "ray" head (multi-level inside) -- only the last level + # is returned. Mirrors upstream ``DualDPT._fuse`` + ``_forward_impl``: + # each aux pyramid level goes through ``output_conv1_aux[i]`` + # (5-layer conv stack that ends at ``features // 2`` channels), + # then the last level optionally gets a pos-embed and finally + # ``output_conv2_aux[-1]``. + aux_processed = [ + self.scratch.output_conv1_aux[i](a) for i, a in enumerate(aux_pyr) + ] + last_aux = aux_processed[-1] + if self.pos_embed: + last_aux = _add_pos_embed(last_aux, W, H) + last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux) + fmap_last = last_aux_logits.permute(0, 2, 3, 1) + # Channels: [ray(6), ray_conf(1)]; ray uses 'linear' activation. + aux_pred = fmap_last[..., :-1] + aux_conf = _apply_activation(fmap_last[..., -1], self.conf_activation) + outs[self.head_aux] = aux_pred.view(B, S, *aux_pred.shape[1:]) + outs[f"{self.head_aux}_conf"] = aux_conf.view(B, S, *aux_conf.shape[1:]) + + return outs diff --git a/comfy/ldm/depth_anything_3/model.py b/comfy/ldm/depth_anything_3/model.py new file mode 100644 index 000000000000..f3c8a5ee339c --- /dev/null +++ b/comfy/ldm/depth_anything_3/model.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from typing import Dict, Optional, Sequence + +import torch +import torch.nn as nn + +from comfy.image_encoders.dino2 import Dinov2Model + +from .camera import CameraDec, CameraEnc +from .dpt import DPT, DualDPT +from .ray_pose import get_extrinsic_from_camray +from .transform import affine_inverse, pose_encoding_to_extri_intri + + +_HEAD_REGISTRY = { + "dpt": DPT, + "dualdpt": DualDPT, +} + + +# Backbone presets (mirror the upstream DINOv2 ViT variants). +_BACKBONE_PRESETS = { + "vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False), + "vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False), + "vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False), + "vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True), +} + + +def _build_backbone_config( + backbone_name: str, + *, + alt_start: int, + qknorm_start: int, + rope_start: int, + cat_token: bool, +) -> dict: + if backbone_name not in _BACKBONE_PRESETS: + raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}") + cfg = dict(_BACKBONE_PRESETS[backbone_name]) + cfg.update(dict( + layer_norm_eps=1e-6, + patch_size=14, + image_size=518, + # No mask_token in DA3 weights; omit param to avoid load warnings. + use_mask_token=False, + alt_start=alt_start, + qknorm_start=qknorm_start, + rope_start=rope_start, + cat_token=cat_token, + rope_freq=100.0, + )) + return cfg + + +class DepthAnything3Net(nn.Module): + + PATCH_SIZE = 14 + + def __init__( + self, + # --- Backbone --- + backbone_name: str = "vitl", + out_layers: Sequence[int] = (4, 11, 17, 23), + alt_start: int = -1, + qknorm_start: int = -1, + rope_start: int = -1, + cat_token: bool = False, + # --- Head --- + head_type: str = "dpt", # dpt or dualdpt + head_dim_in: int = 1024, + head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf + head_features: int = 256, + head_out_channels: Sequence[int] = (256, 512, 1024, 1024), + head_use_sky_head: bool = True, # ignored by DualDPT + head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT + # --- Camera (multi-view) --- + has_cam_enc: bool = False, + has_cam_dec: bool = False, + cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim) + cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token) + # ComfyUI plumbing + device=None, dtype=None, operations=None, + **_ignored, + ): + super().__init__() + head_cls = _HEAD_REGISTRY[head_type.lower()] + self.head_type = head_type.lower() + self.has_sky = (self.head_type == "dpt") and head_use_sky_head + self.has_conf = head_output_dim > 1 + self.out_layers = list(out_layers) + + backbone_cfg = _build_backbone_config( + backbone_name, + alt_start=alt_start, + qknorm_start=qknorm_start, + rope_start=rope_start, + cat_token=cat_token, + ) + self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations) + + head_kwargs = dict( + dim_in=head_dim_in, + patch_size=self.PATCH_SIZE, + output_dim=head_output_dim, + features=head_features, + out_channels=tuple(head_out_channels), + device=device, dtype=dtype, operations=operations, + ) + if self.head_type == "dpt": + head_kwargs.update( + use_sky_head=head_use_sky_head, + pos_embed=(False if head_pos_embed is None else head_pos_embed), + ) + else: # dualdpt + head_kwargs.update( + pos_embed=(True if head_pos_embed is None else head_pos_embed), + ) + self.head = head_cls(**head_kwargs) + + # Built only if checkpoint has weights; cam_enc output dim == embed_dim. + embed_dim = backbone_cfg["hidden_size"] + if has_cam_enc: + self.cam_enc = CameraEnc( + dim_out=cam_dim_out if cam_dim_out is not None else embed_dim, + num_heads=max(1, embed_dim // 64), + device=device, dtype=dtype, operations=operations, + ) + else: + self.cam_enc = None + if has_cam_dec: + default_dim = embed_dim * (2 if cat_token else 1) + self.cam_dec = CameraDec( + dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim, + device=device, dtype=dtype, operations=operations, + ) + else: + self.cam_dec = None + + self.dtype = dtype + + def forward( + self, + image: torch.Tensor, + extrinsics: Optional[torch.Tensor] = None, + intrinsics: Optional[torch.Tensor] = None, + *, + use_ray_pose: bool = False, + ref_view_strategy: str = "saddle_balanced", + export_feat_layers: Optional[Sequence[int]] = None, + **_unused, + ) -> Dict[str, torch.Tensor]: + """Run depth and optionally pose prediction.""" + if image.ndim == 4: + image = image.unsqueeze(1) # (B, 1, 3, H, W) + assert image.ndim == 5 and image.shape[2] == 3, \ + f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}" + + B, S, _, H, W = image.shape + assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \ + f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}" + + # Camera-token preparation (multi-view path). + cam_token = None + if extrinsics is not None and intrinsics is not None and self.cam_enc is not None: + cam_token = self.cam_enc(extrinsics, intrinsics, (H, W)) + + # Toggle aux ray output on/off depending on what the caller asked for. + if isinstance(self.head, DualDPT): + self.head.enable_aux = bool(use_ray_pose) + + feats, aux_feats = self.backbone.get_intermediate_layers_da3( + image, self.out_layers, cam_token=cam_token, + ref_view_strategy=ref_view_strategy, + export_feat_layers=export_feat_layers, + ) + head_out = self.head(feats, H=H, W=W, patch_start_idx=0) + + # Pose prediction. + out: Dict[str, torch.Tensor] = {} + if use_ray_pose and "ray" in head_out and "ray_conf" in head_out: + ray = head_out["ray"] + ray_conf = head_out["ray_conf"] + extr_c2w, focal, pp = get_extrinsic_from_camray( + ray, ray_conf, ray.shape[-3], ray.shape[-2], + ) + # Match the upstream output: w2c, drop the homogeneous row. + extr_w2c = affine_inverse(extr_c2w)[:, :, :3, :] + # Build pixel-space intrinsics from the normalised focal/pp output. + intr = torch.eye(3, device=ray.device, dtype=ray.dtype) + intr = intr[None, None].expand(extr_c2w.shape[0], extr_c2w.shape[1], 3, 3).clone() + intr[:, :, 0, 0] = focal[:, :, 0] / 2 * W + intr[:, :, 1, 1] = focal[:, :, 1] / 2 * H + intr[:, :, 0, 2] = pp[:, :, 0] * W * 0.5 + intr[:, :, 1, 2] = pp[:, :, 1] * H * 0.5 + out["extrinsics"] = extr_w2c + out["intrinsics"] = intr + elif self.cam_dec is not None and S > 1: + # Decode the cam-token of the final out_layer into a pose encoding. + cam_feat = feats[-1][1] # (B, S, dim_in_to_cam_dec) + pose_enc = self.cam_dec(cam_feat) + c2w_3x4, intr = pose_encoding_to_extri_intri(pose_enc, (H, W)) + # Match the upstream output convention: w2c (world->camera), 3x4. + c2w_4x4 = torch.cat([ + c2w_3x4, + torch.tensor([0, 0, 0, 1], device=c2w_3x4.device, dtype=c2w_3x4.dtype) + .view(1, 1, 1, 4).expand(B, S, 1, 4), + ], dim=-2) + out["extrinsics"] = affine_inverse(c2w_4x4)[:, :, :3, :] + out["intrinsics"] = intr + + # Flatten the views axis for per-pixel outputs (depth/conf/sky) so the + # per-image consumer keeps its (B*S, H, W) interface. + for k, v in head_out.items(): + if k in ("ray", "ray_conf"): + # Keep multi-view shape for downstream pose work. + out[k] = v + elif v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S: + out[k] = v.reshape(B * S, *v.shape[2:]) + else: + out[k] = v + + if export_feat_layers: + out["aux_features"] = self._reshape_aux_features(aux_feats, H, W) + return out + + def _reshape_aux_features(self, aux_feats, H: int, W: int): + """Reshape (B, S, N, C) aux features into (B, S, h_p, w_p, C).""" + ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE + out = [] + for f in aux_feats: + B, S, N, C = f.shape + assert N == ph * pw, f"aux feature seq mismatch: {N} != {ph}*{pw}" + out.append(f.reshape(B, S, ph, pw, C)) + return out diff --git a/comfy/ldm/depth_anything_3/preprocess.py b/comfy/ldm/depth_anything_3/preprocess.py new file mode 100644 index 000000000000..2238bd0d6a4d --- /dev/null +++ b/comfy/ldm/depth_anything_3/preprocess.py @@ -0,0 +1,128 @@ +"""Input/output preprocessing helpers for Depth Anything 3.""" + +from __future__ import annotations + +from typing import Tuple + +import torch + +import comfy.utils + +PATCH_SIZE = 14 + +# ImageNet normalization constants used during DA3 training. +_IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]) +_IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]) + + +def _round_to_patch(x: int, patch: int = PATCH_SIZE) -> int: + down = (x // patch) * patch + up = down + patch + return up if abs(up - x) <= abs(x - down) else down + + +def compute_target_size(orig_h: int, orig_w: int, process_res: int, method: str = "upper_bound_resize") -> Tuple[int, int]: + """Compute (target_h, target_w) for a single image. + upper_bound_resize: scale longest side to process_res, then round each dim to nearest multiple of 14 (default upstream method). + lower_bound_resize: scale shortest side to process_res, then round.""" + + if method == "upper_bound_resize": + longest = max(orig_h, orig_w) + scale = process_res / float(longest) + elif method == "lower_bound_resize": + shortest = min(orig_h, orig_w) + scale = process_res / float(shortest) + else: + raise ValueError(f"Unsupported process_res_method: {method}") + + new_w = max(1, _round_to_patch(int(round(orig_w * scale)))) + new_h = max(1, _round_to_patch(int(round(orig_h * scale)))) + return new_h, new_w + + +def preprocess_image(image: torch.Tensor, process_res: int = 504, method: str = "upper_bound_resize") -> torch.Tensor: + assert image.ndim == 4 and image.shape[-1] == 3, f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}" + B, H, W, _ = image.shape + target_h, target_w = compute_target_size(H, W, process_res, method) + + # (B, H, W, 3) -> (B, 3, H, W) + x = image.movedim(-1, 1).contiguous() + if (target_h, target_w) != (H, W): + # Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale). + # Lanczos in ``common_upscale`` is anti-aliased and produces the + # closest pixel-wise match in a sweep across {bilinear, bicubic, + # area, lanczos, bislerp}. Used in both directions for simplicity. + x = comfy.utils.common_upscale(x.float(), target_w, target_h, "lanczos", "disabled",) + x = x.clamp(0.0, 1.0) + + mean = _IMAGENET_MEAN.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + std = _IMAGENET_STD.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + x = (x - mean) / std + return x + + +# ----------------------------------------------------------------------------- +# Output post-processing (sky-aware clipping for Mono/Metric variants) +# ----------------------------------------------------------------------------- + + +def compute_non_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor: + """Boolean mask: True for non-sky pixels (sky probability < threshold).""" + return sky_prediction < threshold + + +def apply_sky_aware_clip(depth: torch.Tensor, sky: torch.Tensor, threshold: float = 0.3, quantile: float = 0.99) -> torch.Tensor: + """Clips sky regions to the 99th percentile of non-sky depth. Returns a new depth tensor.""" + non_sky = compute_non_sky_mask(sky, threshold=threshold) + if non_sky.sum() <= 10 or (~non_sky).sum() <= 10: + return depth.clone() + + non_sky_depth = depth[non_sky] + if non_sky_depth.numel() > 100_000: + idx = torch.randint(0, non_sky_depth.numel(), (100_000,), device=non_sky_depth.device) + sampled = non_sky_depth[idx] + else: + sampled = non_sky_depth + + max_depth = torch.quantile(sampled, quantile) + out = depth.clone() + out[~non_sky] = max_depth + return out + + +def normalize_depth_v2_style(depth: torch.Tensor, sky: torch.Tensor | None = None, low_quantile: float = 0.01, high_quantile: float = 0.99) -> torch.Tensor: + """V2-style normalization computes percentile bounds over non-sky pixels (when available), then maps depth into [0, 1] with near = white (1.0).""" + if sky is not None: + mask = compute_non_sky_mask(sky) + if mask.any(): + valid = depth[mask] + else: + valid = depth.flatten() + else: + valid = depth.flatten() + + if valid.numel() > 100_000: + idx = torch.randint(0, valid.numel(), (100_000,), device=valid.device) + sample = valid[idx] + else: + sample = valid + + lo = torch.quantile(sample, low_quantile) + hi = torch.quantile(sample, high_quantile) + rng = (hi - lo).clamp(min=1e-6) + norm = ((depth - lo) / rng).clamp(0.0, 1.0) + # Nearer pixels are brighter (1.0) + norm = 1.0 - norm + if sky is not None: + # Sky pixels become black (far / unknown) + sky_mask = ~compute_non_sky_mask(sky) + norm = torch.where(sky_mask, torch.zeros_like(norm), norm) + return norm + + +def normalize_depth_min_max(depth: torch.Tensor) -> torch.Tensor: + """Simple per-frame min/max normalization with near=1.0 convention.""" + lo = depth.amin(dim=(-2, -1), keepdim=True) + hi = depth.amax(dim=(-2, -1), keepdim=True) + rng = (hi - lo).clamp(min=1e-6) + return 1.0 - ((depth - lo) / rng).clamp(0.0, 1.0) diff --git a/comfy/ldm/depth_anything_3/ray_pose.py b/comfy/ldm/depth_anything_3/ray_pose.py new file mode 100644 index 000000000000..90890f1da6ef --- /dev/null +++ b/comfy/ldm/depth_anything_3/ray_pose.py @@ -0,0 +1,272 @@ +"""Ray-to-pose conversion for the multi-view path of Depth Anything 3.""" + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch + + +# qr/svd use fp32: CUDA often has no fp16/bf16 kernels for these ops. + + +def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Decompose A = Q @ L with Q orthogonal and L lower-triangular. + Implemented in terms of QR by reversing the columns/rows; the standard + trick from the upstream reference. Inputs A are (3, 3).""" + P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device, dtype=A.dtype) + A_tilde = A @ P + # CUDA QR is not implemented for fp16/bf16; upcast just for this call. + Q_tilde, R_tilde = torch.linalg.qr(A_tilde.float()) + Q_tilde = Q_tilde.to(A.dtype) + R_tilde = R_tilde.to(A.dtype) + Q = Q_tilde @ P + L = P @ R_tilde @ P + d = torch.diag(L) + sign = torch.sign(d) + Q = Q * sign[None, :] # scale columns of Q + L = L * sign[:, None] # scale rows of L + return Q, L + + +def _homogenize_points(points: torch.Tensor) -> torch.Tensor: + return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + + +# ----------------------------------------------------------------------------- +# Weighted-LSQ + RANSAC homography (batched) +# ----------------------------------------------------------------------------- + + +def _find_homography_weighted_lsq(src_pts: torch.Tensor, dst_pts: torch.Tensor, confident_weight: torch.Tensor,) -> torch.Tensor: + """Solve a single H with weighted least-squares (DLT).""" + N = src_pts.shape[0] + if N < 4: + raise ValueError("At least 4 points are required to compute a homography.") + w = confident_weight.sqrt().unsqueeze(1) # (N, 1) + x = src_pts[:, 0:1] + y = src_pts[:, 1:2] + u = dst_pts[:, 0:1] + v = dst_pts[:, 1:2] + zeros = torch.zeros_like(x) + A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1) + A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1) + A = torch.cat([A1, A2], dim=0) # (2N, 9) + # CUDA SVD is not implemented for fp16/bf16; upcast just for this call. + _, _, Vh = torch.linalg.svd(A.float()) + Vh = Vh.to(A.dtype) + H = Vh[-1].reshape(3, 3) + return H / H[-1, -1] + + +def _find_homography_weighted_lsq_batched(src_pts_batch: torch.Tensor, dst_pts_batch: torch.Tensor, confident_weight_batch: torch.Tensor) -> torch.Tensor: + """Batched DLT solver. Inputs (B, K, 2) / (B, K); output (B, 3, 3).""" + B, K, _ = src_pts_batch.shape + w = confident_weight_batch.sqrt().unsqueeze(2) + x = src_pts_batch[:, :, 0:1] + y = src_pts_batch[:, :, 1:2] + u = dst_pts_batch[:, :, 0:1] + v = dst_pts_batch[:, :, 1:2] + zeros = torch.zeros_like(x) + A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2) + A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2) + A = torch.cat([A1, A2], dim=1) # (B, 2K, 9) + # CUDA SVD is not implemented for fp16/bf16; upcast just for this call. + _, _, Vh = torch.linalg.svd(A.float()) + Vh = Vh.to(A.dtype) + H = Vh[:, -1].reshape(B, 3, 3) + return H / H[:, 2:3, 2:3] + + +def _ransac_find_homography_weighted_batched( + src_pts: torch.Tensor, # (B, N, 2) + dst_pts: torch.Tensor, # (B, N, 2) + confident_weight: torch.Tensor, # (B, N) + n_sample: int, + n_iter: int = 100, + reproj_threshold: float = 3.0, + num_sample_for_ransac: int = 8, + max_inlier_num: int = 10000, + rand_sample_iters_idx: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Batched weighted-RANSAC homography estimator. Returns (B, 3, 3) homography matrices.""" + B, N, _ = src_pts.shape + assert N >= 4 + device = src_pts.device + + sorted_idx = torch.argsort(confident_weight, descending=True, dim=1) + candidate_idx = sorted_idx[:, :n_sample] # (B, n_sample) + + if rand_sample_iters_idx is None: + rand_sample_iters_idx = torch.stack( + [torch.randperm(n_sample, device=device)[:num_sample_for_ransac] + for _ in range(n_iter)], + dim=0, + ) + + rand_idx = candidate_idx[:, rand_sample_iters_idx] # (B, n_iter, k) + b_idx = ( + torch.arange(B, device=device) + .view(B, 1, 1) + .expand(B, n_iter, num_sample_for_ransac) + ) + src_b = src_pts[b_idx, rand_idx] + dst_b = dst_pts[b_idx, rand_idx] + w_b = confident_weight[b_idx, rand_idx] + + cB, cN = src_b.shape[:2] + H_batch = _find_homography_weighted_lsq_batched( + src_b.flatten(0, 1), dst_b.flatten(0, 1), w_b.flatten(0, 1), + ).unflatten(0, (cB, cN)) # (B, n_iter, 3, 3) + + src_homo = torch.cat([src_pts, torch.ones(B, N, 1, device=device, dtype=src_pts.dtype)], dim=2) + proj = torch.bmm( + src_homo.unsqueeze(1).expand(B, n_iter, N, 3).reshape(-1, N, 3), + H_batch.reshape(-1, 3, 3).transpose(1, 2), + ) # (B*n_iter, N, 3) + proj_xy = (proj[:, :, :2] / proj[:, :, 2:3]).reshape(B, n_iter, N, 2) + err = ((proj_xy - dst_pts.unsqueeze(1)) ** 2).sum(-1).sqrt() # (B, n_iter, N) + inlier_mask = err < reproj_threshold + score = (inlier_mask * confident_weight.unsqueeze(1)).sum(dim=2) + best_idx = torch.argmax(score, dim=1) + best_inlier_mask = inlier_mask[torch.arange(B, device=device), best_idx] + + # Refit with the inlier set (per-batch, since the inlier counts vary). + H_inlier_list = [] + for b in range(B): + mask = best_inlier_mask[b] + in_src = src_pts[b][mask] + in_dst = dst_pts[b][mask] + in_w = confident_weight[b][mask] + if in_src.shape[0] < 4: + # Fall back to identity when RANSAC fails to find enough inliers. + H_inlier_list.append(torch.eye(3, device=device, dtype=src_pts.dtype)) + continue + sorted_w = torch.argsort(in_w, descending=True) + if len(sorted_w) > max_inlier_num: + keep = max(int(len(sorted_w) * 0.95), max_inlier_num) + sorted_w = sorted_w[:keep][torch.randperm(keep, device=device)[:max_inlier_num]] + H_inlier_list.append( + _find_homography_weighted_lsq(in_src[sorted_w], in_dst[sorted_w], in_w[sorted_w]) + ) + return torch.stack(H_inlier_list, dim=0) + + +# ----------------------------------------------------------------------------- +# Camera-ray utilities +# ----------------------------------------------------------------------------- + + +def _unproject_identity(num_y: int, num_x: int, B: int, S: int, device, dtype) -> torch.Tensor: + """Camera-space unit rays for an identity intrinsic on a 2x2 image plane.""" + dx = 1.0 / num_x + dy = 1.0 / num_y + # Centered camera-space coords directly (skip the K^-1 step since it's + # just a translation by -1 on x and y when K is identity-with-center=1). + y = torch.linspace(-(1 - dy), (1 - dy), num_y, device=device, dtype=dtype) + x = torch.linspace(-(1 - dx), (1 - dx), num_x, device=device, dtype=dtype) + yy, xx = torch.meshgrid(y, x, indexing="ij") + grid = torch.stack((xx, yy), dim=-1) # (h, w, 2) + grid = grid.unsqueeze(0).unsqueeze(0).expand(B, S, num_y, num_x, 2) + return torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1) + + +def _camray_to_caminfo( + camray: torch.Tensor, # (B, S, h, w, 6) + confidence: Optional[torch.Tensor] = None, # (B, S, h, w) + reproj_threshold: float = 0.2, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert per-pixel camera rays to per-view (R, T, focal, principal).""" + if confidence is None: + confidence = torch.ones_like(camray[..., 0]) + B, S, h, w, _ = camray.shape + device = camray.device + dtype = camray.dtype + + rays_target = camray[..., :3] # (B, S, h, w, 3) + rays_origin = _unproject_identity(h, w, B, S, device, dtype) + + # Flatten (B*S, h*w, *) for the RANSAC routine. + rays_target = rays_target.flatten(0, 1).flatten(1, 2) + rays_origin = rays_origin.flatten(0, 1).flatten(1, 2) + weights = confidence.flatten(0, 1).flatten(1, 2).clone() + + # Project to 2D in homogeneous form (the upstream calls this "perspective division"). + z_thresh = 1e-4 + mask = (rays_target[:, :, 2].abs() > z_thresh) & (rays_origin[:, :, 2].abs() > z_thresh) + weights = torch.where(mask, weights, torch.zeros_like(weights)) + src = rays_origin.clone() + dst = rays_target.clone() + src[..., 0] = torch.where(mask, src[..., 0] / src[..., 2], src[..., 0]) + src[..., 1] = torch.where(mask, src[..., 1] / src[..., 2], src[..., 1]) + dst[..., 0] = torch.where(mask, dst[..., 0] / dst[..., 2], dst[..., 0]) + dst[..., 1] = torch.where(mask, dst[..., 1] / dst[..., 2], dst[..., 1]) + src = src[..., :2] + dst = dst[..., :2] + + N = src.shape[1] + n_iter = 100 + sample_ratio = 0.3 + num_sample_for_ransac = 8 + n_sample = max(num_sample_for_ransac, int(N * sample_ratio)) + rand_idx = torch.stack( + [torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)], + dim=0, + ) + + # Chunk along the view axis to keep peak memory predictable. + chunk = 2 + A_list = [] + for i in range(0, src.shape[0], chunk): + A = _ransac_find_homography_weighted_batched( + src[i:i + chunk], dst[i:i + chunk], weights[i:i + chunk], + n_sample=n_sample, n_iter=n_iter, + num_sample_for_ransac=num_sample_for_ransac, + reproj_threshold=reproj_threshold, + rand_sample_iters_idx=rand_idx, + max_inlier_num=8000, + ) + # Flip sign on dets that come out < 0 (so that the QL produces a + # right-handed rotation). ``det`` lacks fp16/bf16 CUDA kernels, so + # do the comparison in fp32. + flip = torch.linalg.det(A.float()) < 0 + A = torch.where(flip[:, None, None], -A, A) + A_list.append(A) + A = torch.cat(A_list, dim=0) # (B*S, 3, 3) + + R_list, f_list, pp_list = [], [], [] + for i in range(A.shape[0]): + R, L = _ql_decomposition(A[i]) + L = L / L[2][2] + f_list.append(torch.stack((L[0][0], L[1][1]))) + pp_list.append(torch.stack((L[2][0], L[2][1]))) + R_list.append(R) + R = torch.stack(R_list).reshape(B, S, 3, 3) + focal = torch.stack(f_list).reshape(B, S, 2) + pp = torch.stack(pp_list).reshape(B, S, 2) + + # Translation: confidence-weighted average of camray direction(s). + cf = confidence.flatten(0, 1).flatten(1, 2) + T = (camray.flatten(0, 1).flatten(1, 2)[..., 3:] * cf.unsqueeze(-1)).sum(dim=1) + T = T / cf.sum(dim=-1, keepdim=True) + T = T.reshape(B, S, 3) + + # Match upstream output convention: focal -> 1/focal, pp + 1. + return R, T, 1.0 / focal, pp + 1.0 + + +def get_extrinsic_from_camray( + camray: torch.Tensor, # (B, S, h, w, 6) + conf: torch.Tensor, # (B, S, h, w, 1) or (B, S, h, w) + patch_size_y: int, + patch_size_x: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Wrap a 4x4 extrinsic + per-view focal + principal-point output.""" + if conf.ndim == 5 and conf.shape[-1] == 1: + conf = conf.squeeze(-1) + R, T, focal, pp = _camray_to_caminfo(camray, confidence=conf) + extr = torch.cat([R, T.unsqueeze(-1)], dim=-1) # (B, S, 3, 4) + homo_row = torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device) + homo_row = homo_row.view(1, 1, 1, 4).expand(R.shape[0], R.shape[1], 1, 4) + extr = torch.cat([extr, homo_row], dim=-2) # (B, S, 4, 4) + return extr, focal, pp diff --git a/comfy/ldm/depth_anything_3/reference_view_selector.py b/comfy/ldm/depth_anything_3/reference_view_selector.py new file mode 100644 index 000000000000..90f00be92875 --- /dev/null +++ b/comfy/ldm/depth_anything_3/reference_view_selector.py @@ -0,0 +1,87 @@ +"""Reference-view selection for the multi-view path of Depth Anything 3.""" + +from __future__ import annotations + +from typing import Literal + +import torch + + +RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"] + + +# Per the upstream constants module: ``THRESH_FOR_REF_SELECTION = 3``. +# Reference selection only runs when there are at least this many views. +THRESH_FOR_REF_SELECTION: int = 3 + + +def select_reference_view(x: torch.Tensor, strategy: RefViewStrategy = "saddle_balanced") -> torch.Tensor: + """Pick a reference view index per batch element.""" + B, S, _, _ = x.shape + if S <= 1: + return torch.zeros(B, dtype=torch.long, device=x.device) + if strategy == "first": + return torch.zeros(B, dtype=torch.long, device=x.device) + if strategy == "middle": + return torch.full((B,), S // 2, dtype=torch.long, device=x.device) + + # Feature-based strategies: normalised cls/cam token per view. + img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # (B,S,C) + + if strategy == "saddle_balanced": + sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # (B,S,S) + sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0) + sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # (B,S) + feat_norm = x[:, :, 0].norm(dim=-1) # (B,S) + feat_var = img_class_feat.var(dim=-1) # (B,S) + + def _normalize(metric): + mn = metric.min(dim=1, keepdim=True).values + mx = metric.max(dim=1, keepdim=True).values + return (metric - mn) / (mx - mn + 1e-8) + + sim_n, norm_n, var_n = _normalize(sim_score), _normalize(feat_norm), _normalize(feat_var) + balance = (sim_n - 0.5).abs() + (norm_n - 0.5).abs() + (var_n - 0.5).abs() + return balance.argmin(dim=1) + + if strategy == "saddle_sim_range": + sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) + sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0) + sim_max = sim_no_diag.max(dim=-1).values + sim_min = sim_no_diag.min(dim=-1).values + return (sim_max - sim_min).argmax(dim=1) + + raise ValueError( + f"Unknown reference view selection strategy: {strategy!r}. " + f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'" + ) + + +def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor: + """Reorder x so the reference view is at position 0 in axis S.""" + B, S = x.shape[0], x.shape[1] + if S <= 1: + return x + positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) + b_idx_exp = b_idx.unsqueeze(1) + reorder = torch.where( + (positions > 0) & (positions <= b_idx_exp), + positions - 1, + positions, + ) + reorder[:, 0] = b_idx + batch = torch.arange(B, device=x.device).unsqueeze(1) + return x[batch, reorder] + + +def restore_original_order(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor: + """Inverse of reorder_by_reference.""" + B, S = x.shape[0], x.shape[1] + if S <= 1: + return x + target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) + b_idx_exp = b_idx.unsqueeze(1) + restore = torch.where(target_positions < b_idx_exp, target_positions + 1, target_positions) + restore = torch.scatter(restore, dim=1, index=b_idx_exp, src=torch.zeros_like(b_idx_exp)) + batch = torch.arange(B, device=x.device).unsqueeze(1) + return x[batch, restore] diff --git a/comfy/ldm/depth_anything_3/transform.py b/comfy/ldm/depth_anything_3/transform.py new file mode 100644 index 000000000000..b735d7bec81c --- /dev/null +++ b/comfy/ldm/depth_anything_3/transform.py @@ -0,0 +1,160 @@ +"""Geometry / camera transform helpers for Depth Anything 3.""" + +from __future__ import annotations + +from typing import Tuple + +import torch +import torch.nn.functional as F + + +# ----------------------------------------------------------------------------- +# Affine 4x4 helpers +# ----------------------------------------------------------------------------- + + +def as_homogeneous(ext: torch.Tensor) -> torch.Tensor: + """Promote (...,3,4) extrinsics to (...,4,4) homogeneous form. No-op when the input is already ``(...,4,4)``.""" + if ext.shape[-2:] == (4, 4): + return ext + if ext.shape[-2:] == (3, 4): + ones = torch.zeros_like(ext[..., :1, :4]) + ones[..., 0, 3] = 1.0 + return torch.cat([ext, ones], dim=-2) + raise ValueError(f"Invalid affine shape: {ext.shape}") + + +def affine_inverse(A: torch.Tensor) -> torch.Tensor: + """Inverse of an affine matrix ``[R|T; 0 0 0 1]``.""" + R = A[..., :3, :3] + T = A[..., :3, 3:] + P = A[..., 3:, :] + return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2) + + +# ----------------------------------------------------------------------------- +# Quaternion <-> rotation matrix (xyzw / scalar-last) +# ----------------------------------------------------------------------------- + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """sqrt(max(0, x)) with a zero subgradient where x == 0.""" + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """Force the real part of a unit quaternion (xyzw) to be non-negative.""" + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """Convert quaternions (xyzw) to (...,3,3) rotation matrices.""" + i, j, k, r = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """Convert (...,3,3) rotation matrices to quaternions (xyzw).""" + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape( + batch_dim + (4,) + ) + # Reorder rijk -> xyzw (i.e. ijkr). + out = out[..., [1, 2, 3, 0]] + return standardize_quaternion(out) + + +# ----------------------------------------------------------------------------- +# Pose-encoding <-> extrinsics + intrinsics +# ----------------------------------------------------------------------------- + + +def extri_intri_to_pose_encoding(extrinsics: torch.Tensor, intrinsics: torch.Tensor, image_size_hw: Tuple[int, int]) -> torch.Tensor: + """Pack (extr, intr, image_size) into the 9-D pose-encoding vector. + extrinsics: camera-to-world (c2w) (B,S,4,4) matrices, + intrinsics: pixel-space (B,S,3,3) matrices, + image_size_hw: is a (H, W) pair. + """ + R = extrinsics[..., :3, :3] + T = extrinsics[..., :3, 3] + quat = mat_to_quat(R) + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + return torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() + + +def pose_encoding_to_extri_intri(pose_encoding: torch.Tensor, image_size_hw: Tuple[int, int]) -> Tuple[torch.Tensor, torch.Tensor]: + """Inverse of extri_intri_to_pose_encoding.""" + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + # Normalize to unit quaternion. CameraDec outputs raw values; a near-zero + # quaternion causes two_s = 2/norm² → inf in quat_to_mat → NaN extrinsics. + quat = quat / quat.norm(dim=-1, keepdim=True).clamp(min=1e-6) + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + H, W = image_size_hw + fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6) + fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6) + intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device, dtype=pose_encoding.dtype) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 + return extrinsics, intrinsics diff --git a/comfy/model_base.py b/comfy/model_base.py index 2a46d1fc1924..2289e081202c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -65,6 +65,7 @@ import comfy.ldm.sam3.detector import comfy.ldm.hidream_o1.model from comfy.ldm.hidream_o1.conditioning import build_extra_conds +import comfy.ldm.depth_anything_3.model import comfy.model_management import comfy.patcher_extension @@ -2319,6 +2320,12 @@ class RT_DETR_v4(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4) + +class DepthAnything3(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.depth_anything_3.model.DepthAnything3Net) + class ErnieImage(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 290938bd6d2c..7d0cab30819d 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -862,6 +862,95 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0] return dit_config + # Depth Anything 3 (repackaged to ComfyUI's native Dinov2Model layout via scripts/convert_da3.py) + if '{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix) in state_dict_keys: + dit_config = {} + dit_config["image_model"] = "DepthAnything3" + + patch_w = state_dict['{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix)] + embed_dim = patch_w.shape[0] + depth = count_blocks(state_dict_keys, '{}backbone.encoder.layer.'.format(key_prefix) + '{}.') + + # Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg). + backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim) + if backbone_name is None: + return None + dit_config["backbone_name"] = backbone_name + + # Detect DA3 extensions on top of vanilla DINOv2. + has_camera_token = '{}backbone.embeddings.camera_token'.format(key_prefix) in state_dict_keys + # qk-norm shows up as `attention.q_norm.weight` on enabled blocks. + qknorm_indices = [ + i for i in range(depth) + if '{}backbone.encoder.layer.{}.attention.q_norm.weight'.format(key_prefix, i) in state_dict_keys + ] + qknorm_start = qknorm_indices[0] if qknorm_indices else -1 + + # The DA3 main-series configs always set alt_start == qknorm_start == rope_start. + # cat_token=True is implied by the presence of camera_token. + if has_camera_token: + dit_config["alt_start"] = qknorm_start + dit_config["rope_start"] = qknorm_start + dit_config["qknorm_start"] = qknorm_start + dit_config["cat_token"] = True + else: + dit_config["alt_start"] = -1 + dit_config["rope_start"] = -1 + dit_config["qknorm_start"] = -1 + dit_config["cat_token"] = False + + # Detect head type and config. + has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys + dit_config["head_dim_in"] = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1] + dit_config["head_features"] = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0] + dit_config["head_out_channels"] = [ + state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0] + for i in range(4) + ] + if has_aux: + # DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width). + dit_config["head_type"] = "dualdpt" + dit_config["head_output_dim"] = 2 + dit_config["head_use_sky_head"] = False + else: + dit_config["head_type"] = "dpt" + dit_config["head_output_dim"] = state_dict[ + '{}head.scratch.output_conv2.2.weight'.format(key_prefix) + ].shape[0] + dit_config["head_use_sky_head"] = ( + '{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys + ) + + # out_layers: hard-coded per upstream YAML config (depth-aware default). + if depth >= 24: + # vitl: depths used vary between DA3-Large (DualDPT) and Mono/Metric (DPT). + if has_aux: + dit_config["out_layers"] = [11, 15, 19, 23] + else: + dit_config["out_layers"] = [4, 11, 17, 23] + else: + # vits/vitb: 12 blocks + dit_config["out_layers"] = [5, 7, 9, 11] + + # Camera encoder/decoder presence (multi-view + pose path). + has_cam_enc = '{}cam_enc.token_norm.weight'.format(key_prefix) in state_dict_keys + has_cam_dec = '{}cam_dec.fc_t.weight'.format(key_prefix) in state_dict_keys + dit_config["has_cam_enc"] = has_cam_enc + dit_config["has_cam_dec"] = has_cam_dec + if has_cam_enc: + cam_enc_w = state_dict.get( + '{}cam_enc.pose_branch.fc2.weight'.format(key_prefix) + ) + if cam_enc_w is not None: + dit_config["cam_dim_out"] = cam_enc_w.shape[0] + if has_cam_dec: + cam_dec_w = state_dict.get( + '{}cam_dec.fc_t.weight'.format(key_prefix) + ) + if cam_dec_w is not None: + dit_config["cam_dec_dim_in"] = cam_dec_w.shape[1] + return dit_config + if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image dit_config = {} dit_config["image_model"] = "ernie" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 42325d71cc7d..3be935577c88 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -2056,6 +2056,23 @@ def clip_target(self, state_dict={}): return None +class DepthAnything3(supported_models_base.BASE): + unet_config = { + "image_model": "DepthAnything3", + } + + # Mono path: no num_heads / num_head_channels needed. + unet_extra_config = {} + + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.DepthAnything3(self, device=device) + + def clip_target(self, state_dict={}): + return None + + class ErnieImage(supported_models_base.BASE): unet_config = { "image_model": "ernie", @@ -2298,4 +2315,5 @@ def get_model(self, state_dict, prefix="", device=None): CogVideoX_I2V, CogVideoX_T2V, SVD_img2vid, + DepthAnything3, ] diff --git a/comfy_extras/nodes_depth_anything_3.py b/comfy_extras/nodes_depth_anything_3.py new file mode 100644 index 000000000000..02011251591d --- /dev/null +++ b/comfy_extras/nodes_depth_anything_3.py @@ -0,0 +1,681 @@ +"""ComfyUI nodes for Depth Anything 3. +Model capability matrix: + +Variant head_type has_sky has_conf cam_dec +DA3-Small dualdpt False True yes +DA3-Base dualdpt False True yes +DA3-Mono-Large dpt True False no +DA3-Metric-Large dpt True False no (raw output is metres) +""" + +from __future__ import annotations + +import logging +from typing_extensions import override + +import torch + +import comfy.model_management as mm +import comfy.sd +import folder_paths +from comfy.ldm.colormap import turbo as _turbo +from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess +from comfy_api.latest import ComfyExtension, Types, io +from comfy.ldm.moge.geometry import triangulate_grid_mesh + +DA3ModelType = io.Custom("DA3_MODEL") +DA3Geometry = io.Custom("DA3_GEOMETRY") +DA3PointCloud = io.Custom("DA3_POINT_CLOUD") + +# DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them): +# +# Per-frame tensors - B = batch size in mono mode; B = S (number of views) in multi-view mode. +# "depth": torch.Tensor (B, H, W) -- raw model depth (always present; matches MoGe convention) +# "image": torch.Tensor (B, H, W, 3) -- source image in [0, 1], CPU (always present) +# "mode": str -- "mono" or "multiview" (always present) +# "sky": torch.Tensor (B, H, W) -- sky probability in [0, 1] (Mono/Metric variants only) +# "confidence": torch.Tensor (B, H, W) -- raw model confidence output (Small/Base variants only) +# +# Multi-view only - S = number of views; the leading 1 is the scene dimension from the model. +# "extrinsics": torch.Tensor (1, S, 3, 4) -- world-to-camera [R|t] matrices +# "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics +# +# DA3_POINT_CLOUD is a dict: +# "points": torch.Tensor (N, 3) -- 3-D coords in glTF convention (Y-up, Z-back) +# "colors": torch.Tensor (N, 3) -- RGB in [0, 1], or None +# "confidence": torch.Tensor (N,) -- raw confidence per point, or None + + +def _da3_unproject(depth: torch.Tensor, K: torch.Tensor) -> torch.Tensor: + """Pixel-space K⁻¹ unprojection: (H,W) depth → (H,W,3) point map in OpenCV space.""" + H, W = depth.shape + u = torch.arange(W, dtype=torch.float32, device=depth.device) + v = torch.arange(H, dtype=torch.float32, device=depth.device) + u, v = torch.meshgrid(u, v, indexing='xy') # both (H, W) + pix = torch.stack([u, v, torch.ones_like(u)], dim=-1) # (H, W, 3) + rays = torch.einsum('ij,hwj->hwi', torch.linalg.inv(K.to(depth.device)), pix) + return rays * depth.unsqueeze(-1) # (H, W, 3) + + +def _da3_default_K(H: int, W: int) -> torch.Tensor: + """Fallback ~60° FOV pinhole K for mono-mode DA3 (no intrinsics in geometry).""" + fx = fy = float(W) * 0.7 + return torch.tensor([[fx, 0.0, (W - 1) / 2.0], + [0.0, fy, (H - 1) / 2.0], + [0.0, 0.0, 1.0]], dtype=torch.float32) + + +def _da3_get_K(geometry: dict, b: int, H: int, W: int) -> torch.Tensor: + """Return pixel-space K for batch element b, falling back to a default estimate.""" + if "intrinsics" in geometry: + # shape (1, S, 3, 3) - leading scene dimension from the multiview head + return geometry["intrinsics"][0, b].float() + logging.getLogger("comfy").warning( + "DA3_GEOMETRY has no intrinsics (mono-mode model). " + "Using a ~60° FOV estimate; 3-D reconstruction may be inaccurate." + ) + return _da3_default_K(H, W) + + +def _da3_get_extrinsic(geometry: dict, b: int) -> torch.Tensor | None: + """Return the world-to-camera extrinsic for batch element b, or None in mono mode. + + The model outputs (1, S, 3, 4) [R|t] matrices; the fallback identity is (4, 4). + _da3_apply_extrinsic handles both shapes via [:3, :3] / [:3, 3] slicing. + """ + if "extrinsics" not in geometry: + return None + return geometry["extrinsics"][0, b].float() + + +def _da3_apply_extrinsic(points_cam: torch.Tensor, E: torch.Tensor) -> torch.Tensor: + """Transform (H,W,3) OpenCV camera-space points to world space.""" + E = E.to(points_cam.device).float() + if not torch.isfinite(E).all(): + logging.getLogger("comfy").warning( + "DA3 extrinsic matrix contains non-finite values (pose estimation may have failed). " + "Falling back to camera-space coordinates." + ) + return points_cam + H, W, _ = points_cam.shape + R = E[:3, :3] # (3, 3) rotation + t = E[:3, 3] # (3,) translation + R_inv = R.T # rotation inverse = transpose for orthogonal R + t_inv = -(R_inv @ t) # (3,) + pts = points_cam.reshape(-1, 3) # (N, 3) + pts_world = pts @ R_inv.T + t_inv # (N, 3) + return pts_world.reshape(H, W, 3) + + +def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor: + """Map raw confidence to [0, 1] per image.""" + B = conf.shape[0] + out = [] + for i in range(B): + c = conf[i] + c_min, c_max = c.min(), c.max() + out.append((c - c_min) / (c_max - c_min) if c_max > c_min else torch.ones_like(c)) + return torch.stack(out, dim=0) + + +def _da3_build_mask(geometry: dict, b: int, H: int, W: int, confidence_threshold: float, use_sky_mask: bool) -> torch.Tensor: + """Build (H,W) bool keep-mask from sky probability and confidence.""" + mask = torch.ones(H, W, dtype=torch.bool) + if use_sky_mask and "sky" in geometry: + mask = mask & (geometry["sky"][b] < 0.5) + if "confidence" in geometry and confidence_threshold > 0.0: + conf_norm = _normalize_confidence(geometry["confidence"][b:b + 1])[0] + mask = mask & (conf_norm >= confidence_threshold) + return mask + + +class LoadDA3Model(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadDA3Model", + display_name="Load Depth Anything 3", + category="model/loaders", + inputs=[ + io.Combo.Input( + "model_name", + options=folder_paths.get_filename_list("geometry_estimation"), + ), + io.Combo.Input( + "weight_dtype", + options=["default", "fp16", "bf16", "fp32"], + default="default", + ), + ], + outputs=[DA3ModelType.Output()], + ) + + @classmethod + def execute(cls, model_name, weight_dtype) -> io.NodeOutput: + model_options = {} + if weight_dtype == "fp16": + model_options["dtype"] = torch.float16 + elif weight_dtype == "bf16": + model_options["dtype"] = torch.bfloat16 + elif weight_dtype == "fp32": + model_options["dtype"] = torch.float32 + + path = folder_paths.get_full_path_or_raise("geometry_estimation", model_name) + model = comfy.sd.load_diffusion_model(path, model_options=model_options) + return io.NodeOutput(model) + + +def _run_da3(model_patcher, image: torch.Tensor, process_res: int, method: str = "upper_bound_resize"): + """Run DA3 on (B,H,W,3), returns depth/conf/sky at original resolution (or None).""" + assert image.ndim == 4 and image.shape[-1] == 3, f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}" + + B, H, W, _ = image.shape + mm.load_model_gpu(model_patcher) + diffusion = model_patcher.model.diffusion_model + device = mm.get_torch_device() + dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32 + + depths, confs, skies = [], [], [] + for i in range(B): + single = image[i:i + 1].to(device) + x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method) + x = x.to(dtype=dtype) + with torch.no_grad(): + out = diffusion(x) + + depth_lr = out["depth"] + depth_full = torch.nn.functional.interpolate( + depth_lr.unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + depths.append(depth_full) + + if "depth_conf" in out: + conf_full = torch.nn.functional.interpolate( + out["depth_conf"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + confs.append(conf_full) + if "sky" in out: + sky_full = torch.nn.functional.interpolate( + out["sky"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + skies.append(sky_full) + + depth = torch.cat(depths, dim=0) + confidence = torch.cat(confs, dim=0) if confs else None + sky = torch.cat(skies, dim=0) if skies else None + return depth, confidence, sky + + +class DA3Inference(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DA3Inference", + search_aliases=["depth", "geometry", "da3", "depth anything", "monocular", "pointmap", "sky", "3d", "metric depth", "disparity"], + display_name="Run Depth Anything 3", + category="image/geometry estimation", + description="Run Depth Anything 3 on an image. In multi-view mode each image is treated as a separate view of the same scene.", + inputs=[ + DA3ModelType.Input("da3_model"), + io.Image.Input("image"), + io.Int.Input("resolution", default=504, min=140, max=2520, step=14, + tooltip="Resolution the model runs at (longest side, multiple of 14).\n" + "Lower = faster / less VRAM.\n" + "Higher = more detail.\n" + "Output is upsampled back to the original size."), + io.Combo.Input("resize_method", options=["upper_bound_resize", "lower_bound_resize"], default="upper_bound_resize", + tooltip="upper_bound_resize: scale so the longest side = resolution (caps memory, default).\n" + "lower_bound_resize: scale so the shortest side = resolution (preserves more detail on tall/wide images, uses more memory)."), + io.DynamicCombo.Input("mode", tooltip="mono: single view image (works with any model variant).\n" + "multiview: all images processed together for geometric consistency + camera pose (for Small/Base models only).", + options=[ + io.DynamicCombo.Option("mono", []), + io.DynamicCombo.Option("multiview", [ + io.Combo.Input("ref_view_strategy", options=["saddle_balanced", "saddle_sim_range", "first", "middle"], default="saddle_balanced", + tooltip="Which view acts as the geometric anchor.\n" + "- saddle_balanced: the view most 'average' across all others (best general choice).\n" + "- saddle_sim_range: the view most visually distinct from the others.\n" + "- first / middle: fixed positional picks."), + io.Combo.Input("pose_method", options=["cam_dec", "ray_pose"], default="cam_dec", + tooltip="How the camera field-of-view is estimated (for Small/Base models only).\n" + "- cam_dec: learned from image features.\n" + "- ray_pose: derived geometrically from the model's 3D ray output.\n" + "Affects perspective correctness of the 3D output. Try both if results look distorted."), + ]), + ]), + ], + outputs=[ + DA3Geometry.Output("da3_geometry", tooltip="Dictionary of non-normalized tensors.\n" + "Always has the keys: depth, image, mode.\n" + "Optional keys: sky (for Mono/Metric), confidence (for Small/Base), extrinsics + intrinsics (for multi-view)."), + ], + ) + + @classmethod + def execute(cls, da3_model, image, resolution, resize_method, mode) -> io.NodeOutput: + mode_val = mode["mode"] # "mono" or "multiview" + + if mode_val == "mono": + return cls._execute_mono(da3_model, image, resolution, resize_method) + + # Capability checks for multi-view mode. + diffusion = da3_model.model.diffusion_model + pose_method = mode["pose_method"] + ref_view_strategy = mode["ref_view_strategy"] + + has_cam_dec = diffusion.cam_dec is not None + has_dualdpt = diffusion.head_type == "dualdpt" + + if not has_cam_dec and not has_dualdpt: + raise ValueError( + "multi-view mode requires Small or Base model. The loaded model " + f"(head_type='{diffusion.head_type}') does not support cross-view " + "attention or camera pose estimation. Switch mode to 'mono', or " + "load Small or Base model for mult-view." + ) + + if pose_method == "cam_dec" and not has_cam_dec: + raise ValueError( + "pose_method='cam_dec' requires a camera decoder, but the loaded " + f"model (head_type='{diffusion.head_type}') does not have one. " + "Use pose_method='ray_pose' instead." + ) + if pose_method == "ray_pose" and not has_dualdpt: + raise ValueError( + "pose_method='ray_pose' requires a DualDPT head, but the loaded " + f"model has a '{diffusion.head_type}' head. " + "Use pose_method='cam_dec' instead." + ) + + return cls._execute_multiview( + da3_model, image, resolution, resize_method, + ref_view_strategy, pose_method, + ) + + @classmethod + def _execute_mono(cls, model, image, resolution, resize_method) -> io.NodeOutput: + depth, confidence, sky = _run_da3(model, image, resolution, method=resize_method) + + geometry: dict = { + "depth": depth.contiguous(), + "image": image[..., :3].cpu(), + "mode": "mono", + } + if sky is not None: + geometry["sky"] = sky.contiguous() + if confidence is not None: + geometry["confidence"] = confidence.contiguous() + return io.NodeOutput(geometry) + + @classmethod + def _execute_multiview(cls, model, image, resolution, resize_method, ref_view_strategy, pose_method) -> io.NodeOutput: + assert image.ndim == 4 and image.shape[-1] == 3, \ + f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}" + S, H, W, _ = image.shape + + mm.load_model_gpu(model) + diffusion = model.model.diffusion_model + device = mm.get_torch_device() + dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32 + + # All views in a single forward pass: (1, S, 3, H', W'). + x = image.to(device) + x = da3_preprocess.preprocess_image(x, process_res=resolution, method=resize_method) + x = x.to(dtype=dtype).unsqueeze(0) + + use_ray_pose = (pose_method == "ray_pose") + with torch.no_grad(): + out = diffusion(x, use_ray_pose=use_ray_pose, ref_view_strategy=ref_view_strategy) + + depth = torch.nn.functional.interpolate( + out["depth"].float().unsqueeze(1), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + + sky = None + if "sky" in out: + sky = torch.nn.functional.interpolate( + out["sky"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + + if "extrinsics" in out and "intrinsics" in out: + extrinsics = out["extrinsics"].float().cpu() + intrinsics = out["intrinsics"].float().cpu() + else: + extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone() + intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone() + + geometry: dict = { + "depth": depth.contiguous(), + "image": image[..., :3].cpu(), + "mode": "multiview", + "extrinsics": extrinsics.contiguous(), + "intrinsics": intrinsics.contiguous(), + } + if sky is not None: + geometry["sky"] = sky.contiguous() + if "depth_conf" in out: + conf = torch.nn.functional.interpolate( + out["depth_conf"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + geometry["confidence"] = conf.contiguous() + return io.NodeOutput(geometry) + + +class DA3Render(io.ComfyNode): + """Render a visualization from a DA3_GEOMETRY packet.""" + + _DEPTH_RENDER_INPUTS = [ + io.Combo.Input("normalization", + options=["v2_style", "min_max", "raw"], + default="v2_style", + tooltip="- v2_style: mean/std normalisation for perceptually balanced results (default).\n" + "- min_max: stretches the full depth range to [0, 1] for maximum contrast.\n" + "- raw: no scaling,preserves metric units for Metric model."), + io.Boolean.Input("apply_sky_clip", default=False, + tooltip="Clip sky-region depth to the 99th percentile of foreground depth before normalisation. " + "Requires a sky key in the da3_geometry input (for Mono/Metric models only)."), + ] + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DA3Render", + display_name="Render Depth Anything 3", + category="image/geometry estimation", + description="Render a depth map, confidence map, or sky mask from Depth Anything 3 geometry data.", + inputs=[ + DA3Geometry.Input("da3_geometry"), + io.DynamicCombo.Input("output", + tooltip="- depth: normalised greyscale depth image.\n" + "- depth_colored: depth mapped through the Turbo colormap.\n" + "- sky_mask: sky probability in [0, 1] (for Mono/Metric models only).\n" + "- confidence: normalised depth confidence (for Small/Base models only).", + options=[ + io.DynamicCombo.Option("depth", cls._DEPTH_RENDER_INPUTS), + io.DynamicCombo.Option("depth_colored", cls._DEPTH_RENDER_INPUTS), + io.DynamicCombo.Option("sky_mask", [ + io.Boolean.Input("colored", default=False, tooltip="Apply the Turbo colormap to the sky mask."), + ]), + io.DynamicCombo.Option("confidence", [ + io.Boolean.Input("colored", default=False, tooltip="Apply the Turbo colormap to the confidence map."), + ]), + ]), + ], + outputs=[io.Image.Output()], + ) + + @classmethod + def execute(cls, da3_geometry, output) -> io.NodeOutput: + output_val = output["output"] + + if output_val in ("depth", "depth_colored"): + normalization = output["normalization"] + apply_sky_clip = output["apply_sky_clip"] + if apply_sky_clip and "sky" not in da3_geometry: + raise ValueError( + "apply_sky_clip=True requires a sky tensor in the da3_geometry input, but none is present. " + "Run with Mono/Metric models or set apply_sky_clip=False." + ) + depth = da3_geometry["depth"] + sky = da3_geometry.get("sky") + if apply_sky_clip and sky is not None: + depth = torch.stack([ + da3_preprocess.apply_sky_aware_clip(depth[i], sky[i]) + for i in range(depth.shape[0]) + ], dim=0) + grey = cls._depth_to_image(depth, sky, normalization) # (B,H,W,3) greyscale + result = _turbo(grey[..., 0]) if output_val == "depth_colored" else grey + + elif output_val == "sky_mask": + if "sky" not in da3_geometry: + raise ValueError("geometry has no sky output; run with Mono/Metric models.") + sky = da3_geometry["sky"] + if output["colored"]: + result = _turbo(sky) + else: + result = sky.unsqueeze(-1).expand(*sky.shape, 3).contiguous() + + elif output_val == "confidence": + if "confidence" not in da3_geometry: + raise ValueError("da3_geometry has no confidence output; run with Small/Base models.") + conf = _normalize_confidence(da3_geometry["confidence"]) + if output["colored"]: + result = _turbo(conf) + else: + result = conf.unsqueeze(-1).expand(*conf.shape, 3).contiguous() + + else: + raise ValueError(f"Unknown output mode: {output_val}") + + return io.NodeOutput(result.float()) + + @staticmethod + def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None, normalization: str) -> torch.Tensor: + """Normalise depth and pack as an (B,H,W,3) image tensor.""" + + N = depth.shape[0] + if normalization == "v2_style": + norm = torch.stack([ + da3_preprocess.normalize_depth_v2_style( + depth[i], sky_for_norm[i] if sky_for_norm is not None else None) + for i in range(N) + ], dim=0) + elif normalization == "min_max": + norm = da3_preprocess.normalize_depth_min_max(depth) + else: + norm = depth + + out = norm.unsqueeze(-1).repeat(1, 1, 1, 3) + if normalization != "raw": + out = out.clamp(0.0, 1.0) + return out.contiguous() + + +class DA3GeometryToMesh(io.ComfyNode): + """Convert a DA3_GEOMETRY packet into a Types.MESH by unprojecting depth and triangulating.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DA3GeometryToMesh", + search_aliases=["da3", "depth anything", "mesh", "geometry", "3d", "triangulate"], + display_name="Convert DA3 Geometry to Mesh", + category="image/geometry estimation", + description="Convert a depth map into a triangulated 3D mesh.", + inputs=[ + DA3Geometry.Input("da3_geometry"), + io.Int.Input("batch_index", default=0, min=0, max=4096, tooltip="Which image of a batch to convert. Per-image vertex counts differ so batches cannot be stacked."), + io.Int.Input("decimation", default=1, min=1, max=8, tooltip="Vertex stride. 1 = full resolution, 2 = half, etc."), + io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01, tooltip="Drop triangles whose 3x3 depth span exceeds this fraction. 0 = off."), + io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01, + tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all, 1 = keep only the single most confident pixel). " + "Used when the geometry has a confidence map (Small/Base models)."), + io.Boolean.Input("use_sky_mask", default=True, tooltip="Exclude sky-probability pixels (sky >= 0.5) from the mesh. Used when the geometry has a sky map (Mono/Metric models)."), + io.Boolean.Input("texture", default=True, tooltip="Use the source image as a base color texture."), + ], + outputs=[io.Mesh.Output()], + ) + + @classmethod + def execute(cls, da3_geometry, batch_index, decimation, discontinuity_threshold, confidence_threshold, use_sky_mask, texture) -> io.NodeOutput: + depth_all = da3_geometry["depth"] # (B, H, W) + B = depth_all.shape[0] + if batch_index >= B: + raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.") + + depth = depth_all[batch_index] # (H, W) + H, W = depth.shape + + # NaN/inf depth would propagate silently through unproject and produce an + # empty mesh; replace them with 0 here so those pixels are later excluded + # by the isfinite check inside triangulate_grid_mesh. + depth = depth.clone() + n_bad = (~torch.isfinite(depth)).sum().item() + if n_bad: + logging.getLogger("comfy").warning( + f"DA3GeometryToMesh: depth[{batch_index}] has {n_bad} non-finite pixels " + f"({100*n_bad/(H*W):.1f}%) - zeroed before unproject." + ) + depth[~torch.isfinite(depth)] = 0.0 + logging.getLogger("comfy").debug( + f"DA3GeometryToMesh: depth[{batch_index}] range " + f"[{depth.min():.4g}, {depth.max():.4g}], mean={depth.mean():.4g}" + ) + + K = _da3_get_K(da3_geometry, batch_index, H, W) + points = _da3_unproject(depth, K) # (H, W, 3) in OpenCV camera space + + # Apply world-to-camera inverse so multi-view frames share a common world frame. + E = _da3_get_extrinsic(da3_geometry, batch_index) + if E is not None: + points = _da3_apply_extrinsic(points, E) + + # Mask invalid pixels by setting them to inf so triangulate_grid_mesh skips them. + mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask) + # Also exclude pixels where depth was invalid. + mask = mask & (depth_all[batch_index] > 0) & torch.isfinite(depth_all[batch_index]) + points = points.clone() + points[~mask] = float('inf') + + verts, faces, uvs = triangulate_grid_mesh( + points, + decimation=decimation, + discontinuity_threshold=discontinuity_threshold, + depth=depth, + ) + if verts.shape[0] == 0 or faces.shape[0] == 0: + raise ValueError( + "DA3GeometryToMesh produced an empty mesh. " + "Try raising discontinuity_threshold, lowering confidence_threshold, " + "or disabling use_sky_mask." + ) + + # OpenCV (X right, Y down, Z forward) → glTF (X right, Y up, Z back). + # Same transform as MoGePointMapToMesh perspective branch. + verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype) + faces = faces[:, [0, 2, 1]].contiguous() + + tex = da3_geometry["image"][batch_index:batch_index + 1] if texture else None + mesh = Types.MESH( + vertices=verts.unsqueeze(0), + faces=faces.unsqueeze(0), + uvs=uvs.unsqueeze(0), + texture=tex, + ) + return io.NodeOutput(mesh) + + +class DA3GeometryToPointCloud(io.ComfyNode): + """Unproject a DA3_GEOMETRY depth map into a filtered DA3_POINT_CLOUD.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DA3GeometryToPointCloud", + search_aliases=["da3", "depth anything", "point cloud", "pointcloud", "3d", "geometry"], + display_name="Convert DA3 Geometry to Point Cloud", + category="image/geometry estimation", + description="Convert a depth map into a 3D point cloud.", + inputs=[ + DA3Geometry.Input("da3_geometry"), + io.Int.Input("batch_index", default=0, min=0, max=4096, tooltip="Which image of a batch to convert."), + io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01, + tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all). Used when the geometry has a confidence map (Small/Base models)."), + io.Boolean.Input("use_sky_mask", default=True, + tooltip="Exclude sky-probability pixels (sky >= 0.5). Used when the geometry has a sky map (Mono/Metric models)."), + io.Int.Input("downsample", default=1, min=1, max=16, + tooltip="Take every Nth pixel (1 = full resolution). Higher values give fewer points and faster processing."), + ], + # TODO: add a proper PointCloud output type + outputs=[DA3PointCloud.Output(display_name="point_cloud")], + ) + + @classmethod + def execute(cls, da3_geometry, batch_index, confidence_threshold, use_sky_mask, downsample) -> io.NodeOutput: + depth_all = da3_geometry["depth"] # (B, H, W) + B = depth_all.shape[0] + if batch_index >= B: + raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.") + + depth = depth_all[batch_index].clone() # (H, W) + depth[~torch.isfinite(depth)] = 0.0 + H, W = depth.shape + + K = _da3_get_K(da3_geometry, batch_index, H, W) + + if downsample > 1: + depth = depth[::downsample, ::downsample].contiguous() + # Scale intrinsics to the downsampled grid. + K = K.clone() + K[0, :] /= downsample + K[1, :] /= downsample + + H_ds, W_ds = depth.shape + points = _da3_unproject(depth, K) # (H_ds, W_ds, 3) in OpenCV camera space + + # Apply world-to-camera inverse so multi-view frames share a common world frame. + E = _da3_get_extrinsic(da3_geometry, batch_index) + if E is not None: + points = _da3_apply_extrinsic(points, E) + + # Rebuild mask at downsampled resolution. + mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask) + if downsample > 1: + mask = mask[::downsample, ::downsample] + + mask = mask & torch.isfinite(depth) + + # OpenCV → glTF: flip Y and Z. + points_gltf = points.clone() + points_gltf[..., 1] *= -1.0 + points_gltf[..., 2] *= -1.0 + + pts_flat = points_gltf.reshape(-1, 3)[mask.reshape(-1)] + + colors_flat = None + if "image" in da3_geometry: + img = da3_geometry["image"][batch_index] # (H, W, 3) + if downsample > 1: + img = img[::downsample, ::downsample] + colors_flat = img.reshape(-1, 3)[mask.reshape(-1)] + + conf_flat = None + if "confidence" in da3_geometry: + conf = da3_geometry["confidence"][batch_index] # (H, W) + if downsample > 1: + conf = conf[::downsample, ::downsample] + conf_flat = conf.reshape(-1)[mask.reshape(-1)] + + if pts_flat.shape[0] == 0: + raise ValueError( + "DA3GeometryToPointCloud produced zero points after filtering. " + "Try lowering confidence_threshold or disabling use_sky_mask." + ) + + return io.NodeOutput({ + "points": pts_flat, + "colors": colors_flat, + "confidence": conf_flat, + }) + + +class DA3Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LoadDA3Model, + DA3Inference, + DA3Render, + DA3GeometryToMesh, + # DA3GeometryToPointCloud, # Keep this commented out for now until we have a proper PointCloud output type + ] + + +async def comfy_entrypoint() -> DA3Extension: + return DA3Extension() diff --git a/comfy_extras/nodes_moge.py b/comfy_extras/nodes_moge.py index 422949531434..a63f0414b408 100644 --- a/comfy_extras/nodes_moge.py +++ b/comfy_extras/nodes_moge.py @@ -8,6 +8,7 @@ from comfy_api.latest import ComfyExtension, Types, io from typing_extensions import override +from comfy.ldm.colormap import turbo as _turbo from comfy.ldm.moge.model import MoGeModel from comfy.ldm.moge.geometry import triangulate_grid_mesh from comfy.ldm.moge.panorama import get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid @@ -27,19 +28,6 @@ # "image": torch.Tensor (B, H, W, 3) in [0, 1], CPU (always present) -def _turbo(x: torch.Tensor) -> torch.Tensor: - """Anton Mikhailov polynomial approximation of the turbo colormap.""" - x = x.clamp(0.0, 1.0) - x2 = x * x - x3 = x2 * x - x4 = x2 * x2 - x5 = x4 * x - r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5 - g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5 - b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5 - return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0) - - def _normals_from_points(points: torch.Tensor) -> torch.Tensor: """Camera-space surface normals from a (B, H, W, 3) point map (v1 fallback).""" finite = torch.isfinite(points).all(dim=-1) diff --git a/nodes.py b/nodes.py index fb6952badbaa..0d422d418af3 100644 --- a/nodes.py +++ b/nodes.py @@ -2459,7 +2459,8 @@ async def init_builtin_extra_nodes(): "nodes_moge.py", "nodes_mediapipe.py", "nodes_gaussian_splat.py", - "nodes_triposplat.py" + "nodes_triposplat.py", + "nodes_depth_anything_3.py", ] import_failed = [] From 5fcf7a4a0f7786912733f2ec8a6808ad614ea74e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 9 Jun 2026 18:39:24 -0700 Subject: [PATCH 35/45] Always enable cuda malloc on cu130 and higher. (#14381) --- cuda_malloc.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/cuda_malloc.py b/cuda_malloc.py index f7651981c126..8c4422db82ed 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -2,6 +2,7 @@ import importlib.util from comfy.cli_args import args, PerformanceFeature import subprocess +import re #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. def get_gpu_names(): @@ -77,11 +78,24 @@ def cuda_malloc_supported(): except: pass +def get_raw_cuda_version(version_str): + match = re.search(r'\+cu(\d+)', version_str) + if match: + try: + return int(match.group(1)) + except: + pass + return None + if not args.cuda_malloc: try: if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc - args.cuda_malloc = cuda_malloc_supported() + cuda_version = get_raw_cuda_version(version) + if cuda_version is not None and cuda_version >= 130: + args.cuda_malloc = True + else: + args.cuda_malloc = cuda_malloc_supported() except: pass From 46d45aade1dfab6d5a3658f2650a4626f175be3a Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Wed, 10 Jun 2026 10:58:42 +0900 Subject: [PATCH 36/45] chore(openapi): sync shared API contract from cloud@ca12913 (#14367) --- openapi.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index 2510f97d08d5..c27ed7adf1b8 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -1960,8 +1960,8 @@ paths: schema: properties: hash: - description: Hash of the existing asset. Supports Blake3 (blake3:) or SHA256 (sha256:) formats - pattern: ^(blake3|sha256):[a-f0-9]{64}$ + description: 'Blake3 content hash of the existing asset (blake3: prefix)' + pattern: ^blake3:[a-f0-9]{64}$ type: string mime_type: description: MIME type of the asset (e.g., "image/png", "video/mp4") From f350acdf213a1b3cbeab2059888265b21590ce9f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 10 Jun 2026 11:07:47 +0800 Subject: [PATCH 37/45] [Trainer/bug] Ensure model is not inference mode (CORE-72) (#13400) * Ensure model is not inference mode * force clone inside training mode to avoid inference tensor * Allow force deepcopy for model patcher --- comfy/model_patcher.py | 7 ++-- comfy_extras/nodes_train.py | 68 ++++++++++++++++++------------------- 2 files changed, 38 insertions(+), 37 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b716a69e223a..d70b42bf8c06 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -379,10 +379,11 @@ def get_free_memory(self, device): def get_clone_model_override(self): return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) - def clone(self, disable_dynamic=False, model_override=None): + def clone(self, disable_dynamic=False, model_override=None, force_deepcopy=False): class_ = self.__class__ - if self.is_dynamic() and disable_dynamic: - class_ = ModelPatcher + if self.is_dynamic() and disable_dynamic or force_deepcopy: + if self.is_dynamic() and disable_dynamic: + class_ = ModelPatcher if model_override is None: if self.cached_patcher_init is None: raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.") diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 273f55e7c2f2..bb68da6fa044 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1149,45 +1149,45 @@ def execute( # Process conditioning positive = _process_conditioning(positive) - # Setup model and dtype - mp = model.clone() - use_grad_scaler = False - lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) - if training_dtype != "none": - dtype = node_helpers.string_to_torch_dtype(training_dtype) - mp.set_model_compute_dtype(dtype) - else: - # Detect model's native dtype for autocast - model_dtype = mp.model.get_dtype() - if model_dtype == torch.float16: - dtype = torch.float16 - # GradScaler only supports float16 gradients, not bfloat16. - # Only enable it when lora params will also be in float16. - if lora_dtype != torch.bfloat16: - use_grad_scaler = True - # Warn about fp16 accumulation instability during training - if PerformanceFeature.Fp16Accumulation in args.fast: - logging.warning( - "WARNING: FP16 model detected with fp16_accumulation enabled. " - "This combination can be numerically unstable during training and may cause NaN values. " - "Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)." - ) + with torch.inference_mode(False): + # Setup model and dtype + mp = model.clone(force_deepcopy=True) + use_grad_scaler = False + lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) + if training_dtype != "none": + dtype = node_helpers.string_to_torch_dtype(training_dtype) + mp.set_model_compute_dtype(dtype) else: - # For fp8, bf16, or other dtypes, use bf16 autocast - dtype = torch.bfloat16 + # Detect model's native dtype for autocast + model_dtype = mp.model.get_dtype() + if model_dtype == torch.float16: + dtype = torch.float16 + # GradScaler only supports float16 gradients, not bfloat16. + # Only enable it when lora params will also be in float16. + if lora_dtype != torch.bfloat16: + use_grad_scaler = True + # Warn about fp16 accumulation instability during training + if PerformanceFeature.Fp16Accumulation in args.fast: + logging.warning( + "WARNING: FP16 model detected with fp16_accumulation enabled. " + "This combination can be numerically unstable during training and may cause NaN values. " + "Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)." + ) + else: + # For fp8, bf16, or other dtypes, use bf16 autocast + dtype = torch.bfloat16 - # Prepare latents and compute counts - latents_dtype = dtype if dtype not in (None,) else torch.bfloat16 - latents, num_images, multi_res = _prepare_latents_and_count( - latents, latents_dtype, bucket_mode - ) + # Prepare latents and compute counts + latents_dtype = dtype if dtype not in (None,) else torch.bfloat16 + latents, num_images, multi_res = _prepare_latents_and_count( + latents, latents_dtype, bucket_mode + ) - # Validate and expand conditioning - positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) + # Validate and expand conditioning + positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) - with torch.inference_mode(False): # Setup models for training - mp.model.requires_grad_(False) + mp.model.requires_grad_(False).train() # Load existing LoRA weights if provided existing_weights, existing_steps = _load_existing_lora(existing_lora) From a76bb4380ee9fcc0fc96e5bc9fe25d66ad7ca412 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Tue, 9 Jun 2026 21:07:10 -0700 Subject: [PATCH 38/45] chore(assets): drop vestigial tags.tag_type column (#14248) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit tag_type was always "user" in practice — no code path ever set it to anything else (no system/seeded classification was wired up) and nothing queried it. The column, its ix_tags_tag_type index, and the TagUsage.type API field were dead weight, so they're removed. Adds alembic migration 0004 to drop the column and index. Verified: asset-seeder tests pass; migration applies cleanly on a fresh SQLite (tags retains only name; tag_type column + index dropped). Co-authored-by: guill --- alembic_db/versions/0004_drop_tag_type.py | 39 +++++++++++++++++++ app/assets/api/routes.py | 4 +- app/assets/api/schemas_out.py | 1 - app/assets/database/models.py | 3 -- app/assets/database/queries/tags.py | 13 +++---- app/assets/scanner.py | 2 +- app/assets/services/schemas.py | 1 - app/assets/services/tagging.py | 2 +- tests-unit/assets_test/queries/test_tags.py | 31 ++++++--------- .../assets_test/services/test_tagging.py | 10 ++--- .../assets_test/test_sync_references.py | 2 +- 11 files changed, 66 insertions(+), 42 deletions(-) create mode 100644 alembic_db/versions/0004_drop_tag_type.py diff --git a/alembic_db/versions/0004_drop_tag_type.py b/alembic_db/versions/0004_drop_tag_type.py new file mode 100644 index 000000000000..582bec4e8469 --- /dev/null +++ b/alembic_db/versions/0004_drop_tag_type.py @@ -0,0 +1,39 @@ +""" +Drop the vestigial tags.tag_type column. + +tag_type was always "user" in practice — no code path ever set it to anything +else (no system/seeded classification was ever wired up) and nothing queried it. +The column, its index (ix_tags_tag_type), and the corresponding API field were +dead weight, so they are removed. + +Revision ID: 0004_drop_tag_type +Revises: 0003_add_metadata_job_id +Create Date: 2026-06-03 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0004_drop_tag_type" +down_revision = "0003_add_metadata_job_id" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("tags") as batch_op: + batch_op.drop_index("ix_tags_tag_type") + batch_op.drop_column("tag_type") + + +def downgrade() -> None: + with op.batch_alter_table("tags") as batch_op: + batch_op.add_column( + sa.Column( + "tag_type", + sa.String(length=32), + nullable=False, + server_default="user", + ) + ) + batch_op.create_index("ix_tags_tag_type", ["tag_type"]) diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 6555974e9253..252ddfe8f672 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -575,8 +575,8 @@ async def get_tags(request: web.Request) -> web.Response: ) tags = [ - schemas_out.TagUsage(name=name, count=count, type=tag_type) - for (name, tag_type, count) in rows + schemas_out.TagUsage(name=name, count=count) + for (name, count) in rows ] payload = schemas_out.TagsList( tags=tags, total=total, has_more=(query.offset + len(tags)) < total diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index 0e748b90777c..143848329ae5 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -46,7 +46,6 @@ class AssetsList(BaseModel): class TagUsage(BaseModel): name: str count: int - type: str class TagsList(BaseModel): diff --git a/app/assets/database/models.py b/app/assets/database/models.py index a3af8a192373..9b61d309a8dc 100644 --- a/app/assets/database/models.py +++ b/app/assets/database/models.py @@ -227,7 +227,6 @@ class Tag(Base): __tablename__ = "tags" name: Mapped[str] = mapped_column(String(512), primary_key=True) - tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship( back_populates="tag", @@ -240,7 +239,5 @@ class Tag(Base): overlaps="asset_reference_links,tag_links,tags,asset_reference", ) - __table_args__ = (Index("ix_tags_tag_type", "tag_type"),) - def __repr__(self) -> str: return f"" diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py index f4126dba8212..d41d73a10084 100644 --- a/app/assets/database/queries/tags.py +++ b/app/assets/database/queries/tags.py @@ -55,13 +55,11 @@ def validate_tags_exist(session: Session, tags: list[str]) -> None: raise ValueError(f"Unknown tags: {missing}") -def ensure_tags_exist( - session: Session, names: Iterable[str], tag_type: str = "user" -) -> None: +def ensure_tags_exist(session: Session, names: Iterable[str]) -> None: wanted = normalize_tags(list(names)) if not wanted: return - rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + rows = [{"name": n} for n in list(dict.fromkeys(wanted))] ins = ( sqlite.insert(Tag) .values(rows) @@ -97,7 +95,7 @@ def set_reference_tags( to_remove = [t for t in current if t not in desired] if to_add: - ensure_tags_exist(session, to_add, tag_type="user") + ensure_tags_exist(session, to_add) session.add_all( [ AssetReferenceTag( @@ -142,7 +140,7 @@ def add_tags_to_reference( return AddTagsResult(added=[], already_present=[], total_tags=total) if create_if_missing: - ensure_tags_exist(session, norm, tag_type="user") + ensure_tags_exist(session, norm) current = set(get_reference_tags(session, reference_id)) @@ -289,7 +287,6 @@ def list_tags_with_usage( q = ( select( Tag.name, - Tag.tag_type, func.coalesce(counts_sq.c.cnt, 0).label("count"), ) .select_from(Tag) @@ -331,7 +328,7 @@ def list_tags_with_usage( rows = (session.execute(q.limit(limit).offset(offset))).all() total = (session.execute(total_q)).scalar_one() - rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + rows_norm = [(name, int(count or 0)) for (name, count) in rows] return rows_norm, int(total or 0) diff --git a/app/assets/scanner.py b/app/assets/scanner.py index 495c3044329d..2c1e978402c9 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -355,7 +355,7 @@ def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int: return 0 with create_session() as sess: if tag_pool: - ensure_tags_exist(sess, tag_pool, tag_type="user") + ensure_tags_exist(sess, tag_pool) result = batch_insert_seed_assets(sess, specs=specs, owner_id="") sess.commit() return result.inserted_refs diff --git a/app/assets/services/schemas.py b/app/assets/services/schemas.py index 0eb128f587d9..2a52d76cac2e 100644 --- a/app/assets/services/schemas.py +++ b/app/assets/services/schemas.py @@ -56,7 +56,6 @@ class IngestResult: class TagUsage(NamedTuple): name: str - tag_type: str count: int diff --git a/app/assets/services/tagging.py b/app/assets/services/tagging.py index 37b61275324b..5fa39d26af08 100644 --- a/app/assets/services/tagging.py +++ b/app/assets/services/tagging.py @@ -75,7 +75,7 @@ def list_tags( owner_id=owner_id, ) - return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total + return [TagUsage(name, count) for name, count in rows], total def list_tag_histogram( diff --git a/tests-unit/assets_test/queries/test_tags.py b/tests-unit/assets_test/queries/test_tags.py index 4ed99aa37c52..6222714d12b9 100644 --- a/tests-unit/assets_test/queries/test_tags.py +++ b/tests-unit/assets_test/queries/test_tags.py @@ -40,15 +40,15 @@ def _make_reference(session: Session, asset: Asset, name: str = "test", owner_id class TestEnsureTagsExist: def test_creates_new_tags(self, session: Session): - ensure_tags_exist(session, ["alpha", "beta"], tag_type="user") + ensure_tags_exist(session, ["alpha", "beta"]) session.commit() tags = session.query(Tag).all() assert {t.name for t in tags} == {"alpha", "beta"} def test_is_idempotent(self, session: Session): - ensure_tags_exist(session, ["alpha"], tag_type="user") - ensure_tags_exist(session, ["alpha"], tag_type="user") + ensure_tags_exist(session, ["alpha"]) + ensure_tags_exist(session, ["alpha"]) session.commit() assert session.query(Tag).count() == 1 @@ -65,13 +65,6 @@ def test_empty_list_is_noop(self, session: Session): session.commit() assert session.query(Tag).count() == 0 - def test_tag_type_is_set(self, session: Session): - ensure_tags_exist(session, ["system-tag"], tag_type="system") - session.commit() - - tag = session.query(Tag).filter_by(name="system-tag").one() - assert tag.tag_type == "system" - class TestGetReferenceTags: def test_returns_empty_for_no_tags(self, session: Session): @@ -193,7 +186,7 @@ class TestMissingTagFunctions: def test_add_missing_tag_for_asset_id(self, session: Session): asset = _make_asset(session, "hash1") ref = _make_reference(session, asset) - ensure_tags_exist(session, ["missing"], tag_type="system") + ensure_tags_exist(session, ["missing"]) add_missing_tag_for_asset_id(session, asset_id=asset.id) session.commit() @@ -204,7 +197,7 @@ def test_add_missing_tag_for_asset_id(self, session: Session): def test_add_missing_tag_is_idempotent(self, session: Session): asset = _make_asset(session, "hash1") ref = _make_reference(session, asset) - ensure_tags_exist(session, ["missing"], tag_type="system") + ensure_tags_exist(session, ["missing"]) add_missing_tag_for_asset_id(session, asset_id=asset.id) add_missing_tag_for_asset_id(session, asset_id=asset.id) @@ -216,7 +209,7 @@ def test_add_missing_tag_is_idempotent(self, session: Session): def test_remove_missing_tag_for_asset_id(self, session: Session): asset = _make_asset(session, "hash1") ref = _make_reference(session, asset) - ensure_tags_exist(session, ["missing"], tag_type="system") + ensure_tags_exist(session, ["missing"]) add_missing_tag_for_asset_id(session, asset_id=asset.id) remove_missing_tag_for_asset_id(session, asset_id=asset.id) @@ -237,7 +230,7 @@ def test_returns_tags_with_counts(self, session: Session): rows, total = list_tags_with_usage(session) - tag_dict = {name: count for name, _, count in rows} + tag_dict = {name: count for name, count in rows} assert tag_dict["used"] == 1 assert tag_dict["unused"] == 0 assert total == 2 @@ -252,7 +245,7 @@ def test_exclude_zero_counts(self, session: Session): rows, total = list_tags_with_usage(session, include_zero=False) - tag_names = {name for name, _, _ in rows} + tag_names = {name for name, _ in rows} assert "used" in tag_names assert "unused" not in tag_names @@ -262,7 +255,7 @@ def test_prefix_filter(self, session: Session): rows, total = list_tags_with_usage(session, prefix="alph") - tag_names = {name for name, _, _ in rows} + tag_names = {name for name, _ in rows} assert tag_names == {"alpha", "alphabet"} def test_order_by_name(self, session: Session): @@ -271,7 +264,7 @@ def test_order_by_name(self, session: Session): rows, _ = list_tags_with_usage(session, order="name_asc") - names = [name for name, _, _ in rows] + names = [name for name, _ in rows] assert names == ["alpha", "middle", "zebra"] def test_owner_visibility(self, session: Session): @@ -287,13 +280,13 @@ def test_owner_visibility(self, session: Session): # Empty owner sees only shared rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False) - tag_dict = {name: count for name, _, count in rows} + tag_dict = {name: count for name, count in rows} assert tag_dict.get("shared-tag", 0) == 1 assert tag_dict.get("owner-tag", 0) == 0 # User1 sees both rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False) - tag_dict = {name: count for name, _, count in rows} + tag_dict = {name: count for name, count in rows} assert tag_dict.get("shared-tag", 0) == 1 assert tag_dict.get("owner-tag", 0) == 1 diff --git a/tests-unit/assets_test/services/test_tagging.py b/tests-unit/assets_test/services/test_tagging.py index ab69e5dc1a70..fa121db3ecf9 100644 --- a/tests-unit/assets_test/services/test_tagging.py +++ b/tests-unit/assets_test/services/test_tagging.py @@ -141,7 +141,7 @@ def test_returns_tags_with_counts(self, mock_create_session, session: Session): rows, total = list_tags() - tag_dict = {name: count for name, _, count in rows} + tag_dict = {name: count for name, count in rows} assert tag_dict["used"] == 1 assert tag_dict["unused"] == 0 assert total == 2 @@ -155,7 +155,7 @@ def test_excludes_zero_counts(self, mock_create_session, session: Session): rows, total = list_tags(include_zero=False) - tag_names = {name for name, _, _ in rows} + tag_names = {name for name, _ in rows} assert "used" in tag_names assert "unused" not in tag_names @@ -165,7 +165,7 @@ def test_prefix_filter(self, mock_create_session, session: Session): rows, _ = list_tags(prefix="alph") - tag_names = {name for name, _, _ in rows} + tag_names = {name for name, _ in rows} assert tag_names == {"alpha", "alphabet"} def test_order_by_name(self, mock_create_session, session: Session): @@ -174,7 +174,7 @@ def test_order_by_name(self, mock_create_session, session: Session): rows, _ = list_tags(order="name_asc") - names = [name for name, _, _ in rows] + names = [name for name, _ in rows] assert names == ["alpha", "middle", "zebra"] def test_pagination(self, mock_create_session, session: Session): @@ -185,7 +185,7 @@ def test_pagination(self, mock_create_session, session: Session): assert total == 5 assert len(rows) == 2 - names = [name for name, _, _ in rows] + names = [name for name, _ in rows] assert names == ["b", "c"] def test_clamps_limit(self, mock_create_session, session: Session): diff --git a/tests-unit/assets_test/test_sync_references.py b/tests-unit/assets_test/test_sync_references.py index 94cc255bcabb..2e85076e021d 100644 --- a/tests-unit/assets_test/test_sync_references.py +++ b/tests-unit/assets_test/test_sync_references.py @@ -95,7 +95,7 @@ def _make_asset( def _ensure_missing_tag(session: Session): """Ensure the 'missing' tag exists.""" if not session.get(Tag, "missing"): - session.add(Tag(name="missing", tag_type="system")) + session.add(Tag(name="missing")) session.flush() From 84e0692a3dd4748e197579da0e7fb0aa8510e363 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Tue, 9 Jun 2026 21:14:03 -0700 Subject: [PATCH 39/45] feat(assets): cursor-based pagination on GET /api/assets (#14014) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * spec(assets): add cursor pagination params to GET /api/assets Add 'after' query param and 'next_cursor' response field for keyset pagination. Matches the cloud Go implementation (BE-893) so frontend sees a unified contract across runtimes. Offset/limit remain as a deprecated fallback. * feat(assets): add cursor encode/decode helpers for keyset pagination Port of cloud common/pagination/cursor.go. Wire format is base64url of {"s", "v", "id"} JSON; times are Unix microseconds UTC to match PostgreSQL timestamp precision. Includes a byte-identity fixture pinned against the cloud Go wire format so cross-runtime FE pagination can't silently drift. * feat(assets): thread cursor through schemas, service, and query layer list_assets_page accepts an opaque 'after' cursor and returns next_cursor when more pages are available. The query applies a keyset WHERE clause and a secondary ORDER BY id for deterministic tiebreak. Cursor sort field is validated against the request sort, and a last_access_time sort (OSS-only) falls back to offset/limit. Offset is ignored whenever a cursor is supplied. * feat(assets): wire cursor pagination through GET /api/assets handler Adds integration tests for: full cursor walk, invalid-cursor 400, sort/cursor mismatch 400, cursor-wins-over-offset, absent next_cursor when no more results, and pagination stability across deletes. * fix(assets): address cursor-review verified findings - Mint next_cursor on every cursor-supported sort, not only when 'after' was supplied. A first request (no 'after') previously returned next_cursor=None, leaving cursor mode unreachable from a clean start. - Over-fetch limit+1 so an exactly-full terminal page doesn't mint a spurious cursor pointing at a phantom next page. - Map crafted out-of-range microsecond cursors (OverflowError / OSError in datetime construction) to 400 INVALID_CURSOR instead of leaking 500. - Bump MAX_CURSOR_VALUE_LENGTH 256 -> 512 to match the AssetReference name column max; without this, a long-named asset minted a cursor the same server then refused on the next request. Cross-runtime byte identity with cloud is unaffected because no cloud cursor ever carries a value > 256 (cloud schema doesn't permit it). - Return None from _encode_next_cursor when the boundary row carries a NULL sort value (e.g. an Asset without size_bytes backfilled), instead of silently encoding 0 and mis-positioning the keyset. - Fix schemas_in.py comment so it matches actual handler behavior (last_access_time + 'after' raises 400, does not fall back). - Add AssetsApiError schema + 400 response to GET /api/assets in openapi.yaml so generated clients know the INVALID_CURSOR envelope. - Extend integration coverage: first-page mint, exact-multiple terminal page, cursor walks for created_at/updated_at/size sorts, datetime overflow surfaces as 400 not 500. - Add unit coverage for datetime overflow and 512-char round-trip. * feat(assets): bind cursor to sort order + Go-compat JSON escaping Address three needs-judgment items from the cursor-review judge synthesis: 1. Cursor wire format now includes an "o" key carrying the sort direction ("asc" / "desc") it was minted under. A request that replays the cursor with a flipped `order` parameter is rejected with 400 INVALID_CURSOR instead of silently walking the wrong direction. Legacy cursors without "o" still decode (the binding is best-effort until cloud mirrors the field — follow-up filed separately). 2. JSON serialization now escapes `<`, `>`, `&`, U+2028, U+2029 to mirror Go's default `json.Marshal` behavior. Without this, an asset name containing those characters produced different bytes on Python vs cloud Go. The escaped form is what both runtimes emit. 3. Add direct query-layer tests for the keyset tiebreaker — the secondary ORDER BY id branch was previously unexercised. Two scenarios: all rows share a primary sort value, and mixed ties straddle page boundaries. Both assert no row is dropped or duplicated across the walk. Wire-format note: Python cursors now differ from current cloud cursors by exactly the "o" key. Cloud follow-up will bring the two back into byte alignment. * fix(assets): address bot review comments - Soften offset param prose: it's not deprecated, just not preferred for sequential walks. Random-access UIs (jump-to-page, item count displays) legitimately still want offset, so dropping the 'deprecated' framing rather than promoting it to a machine-readable deprecated:true flag. - Add explicit HTTP status assertions before every json() / next_cursor read in test_list_cursor.py so a failing request surfaces as an HTTP error instead of a confusing KeyError on a 4xx/5xx body. * feat(assets): require cursor o field, drop legacy permissive path Cursor pagination hasn't shipped on either runtime yet — this PR is still draft and cloud's mirror is just behind it — so there are no legacy no-o cursors in the wild. Make o mandatory from day one rather than landing permissive and tightening later. decode_cursor now rejects any payload without o (or with a non-string o) as malformed. CursorPayload.order becomes a required str. Tests that constructed CursorPayload directly now pass order="desc"; test_legacy_cursor_without_order_accepted flips to test_cursor_without_order_rejected. * chore(assets): drop cross-repo prose from cursor comments Strip prose references to sibling Go implementations and external ticket IDs from cursor.py, the cursor tests, the keyset integration tests, asset_management's sort-field comment, and the legacy prompt_id alias comment. Pure docstring/comment scrub — no behavior or wire-format changes. x-runtime: [cloud] field annotations in openapi.yaml are unchanged; those are the spec's structural cross-runtime convention, not internal references. * test(assets): include 'o' in microsecond-boundary cursor payload The boundary test was building a cursor without the required `o` key, so decode failed on the missing-order branch before reaching the µs-overflow path the test is asserting. Both paths return 400 INVALID_CURSOR so the assertion passed for the wrong reason. Add `o` to the payload and matching `order=` to the request so the decode reaches the intended branch. * fix(assets): address ultrareview findings on cursor pagination Six fact-checked findings from the multi-model review pass: - Encoder/decoder length asymmetry: encode_cursor now rejects empty id, oversized id (>128), oversized value (>512), and invalid order tokens symmetrically with decode_cursor. Prevents the same server from minting a cursor it then 400s on the next request (e.g. a filesystem-scanned asset name >512 chars). The bad-order path now raises InvalidCursorError (still subclasses ValueError) so route-layer handling stays uniform. - Raw U+2028/U+2029 in cursor.py source: ripgrep treated those lines as line-terminators, confirming the bytes were the actual separators. Any editor save / autoformat / git tooling that normalizes invisibles would silently break the encoder. Replaced with explicit 
 / 
 Python escape sequences. - set(seen) == set(names) hid ordering regressions: a cursor walk that dropped a row at a page boundary or returned duplicates could pass. Reworked the assertion to (1) reject duplicates, (2) require full coverage, and (3) assert strict positional order for size sort, the only field with a clock-independent ordering. - Flaky time.sleep(0.05) between inserts: Windows CI clock resolution is ~15ms, so back-to-back inserts under load could collide and exercise the tiebreaker instead of the documented path. Removed the sleep and let the strengthened assertion above carry coverage / no-duplicates, with size sort carrying strict order. - Cursor error envelope diverged from the rest of routes.py: cursor 400s emitted {error: {code, message}} while every other 400 in the file emits {error: {code, message, details}} via _build_error_response. Switched to _build_error_response and added the details field to the AssetsApiError schema in openapi.yaml. - "Byte-identity fixtures" only checked substring containment, defeating the test class's stated purpose of pinning the wire format. Switched to exact-bytes equality against an inline expected payload string per fixture, so any whitespace / key-order / escape drift fails loudly. Also dropped Go / json.Marshal references from docstrings — the byte format is the contract, not the runtime that mints it. * fix(assets): cap cursors by encoded wire size, not just char count Char-count guards on value/id can still let multibyte or escape-heavy inputs blow past MAX_ENCODED_CURSOR_LENGTH once UTF-8 + escape expansion + base64url runs. A 512-character name of 'é' (2 bytes UTF-8) or '<' (serializes to the 6-byte '<' escape) passes the char check, mints a ~1500-byte cursor, then 400s when handed back on the next request. Compute the final encoded form and reject it before returning if it exceeds the wire cap. Adds regression tests for both inflation paths. * refactor(assets): extract cursor JSON escaping helper; size wire cap above per-field caps Addresses review feedback on cursor.py: - Extract the inline escape chain into _apply_wire_compatible_json_escapes() with a comment pinning it to the wire format's escape set, so the parity intent is explicit rather than reading as an ad-hoc transform. - Raise MAX_ENCODED_CURSOR_LENGTH to 8192 (comfortably above the ~5.2KB worst-case the per-field caps can produce) and drop the mint-time length guard. Encoder/decoder symmetry now holds by construction: the encoder can't produce a cursor the decode path rejects, so there is no confusing user-visible 'cursor too long' failure at mint time. - Rewrite the two over-wire-cap tests to assert worst-case multibyte and escape-heavy values mint and round-trip, instead of being rejected. * refactor(assets): drop cross-runtime cursor escaping; cursors are opaque The custom JSON escaping of <, >, &, U+2028, and U+2029 existed only to keep the encoded cursor byte-identical with the Cloud implementation of the same payload format. Cursors are opaque tokens, so byte-level compatibility across implementations is not needed — plain json.dumps output is sufficient. Remove the escaping helper and the byte-identity test fixtures that pinned the wire format; keep round-trip coverage for the affected characters. --------- Co-authored-by: guill --- app/assets/api/routes.py | 40 +- app/assets/api/schemas_in.py | 5 + app/assets/api/schemas_out.py | 2 + .../database/queries/asset_reference.py | 35 +- app/assets/services/asset_management.py | 94 ++++- app/assets/services/cursor.py | 213 +++++++++++ app/assets/services/schemas.py | 1 + .../queries/test_asset_reference_keyset.py | 112 ++++++ .../assets_test/services/test_cursor.py | 278 ++++++++++++++ tests-unit/assets_test/test_list_cursor.py | 349 ++++++++++++++++++ 10 files changed, 1112 insertions(+), 17 deletions(-) create mode 100644 app/assets/services/cursor.py create mode 100644 tests-unit/assets_test/queries/test_asset_reference_keyset.py create mode 100644 tests-unit/assets_test/services/test_cursor.py create mode 100644 tests-unit/assets_test/test_list_cursor.py diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 252ddfe8f672..544a614f2bd3 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -39,6 +39,7 @@ update_asset_metadata, upload_from_temp_path, ) +from app.assets.services.cursor import InvalidCursorError from app.assets.services.tagging import list_tag_histogram ROUTES = web.RouteTableDef() @@ -174,7 +175,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu user_metadata=result.ref.user_metadata or {}, metadata=result.ref.system_metadata, job_id=result.ref.job_id, - prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat + prompt_id=result.ref.job_id, # deprecated alias of job_id, kept for compatibility created_at=result.ref.created_at, updated_at=result.ref.updated_at, last_access_time=result.ref.last_access_time, @@ -211,24 +212,37 @@ async def list_assets_route(request: web.Request) -> web.Response: order_candidate = (q.order or "desc").lower() order = order_candidate if order_candidate in {"asc", "desc"} else "desc" - result = list_assets_page( - owner_id=USER_MANAGER.get_request_user_id(request), - include_tags=q.include_tags, - exclude_tags=q.exclude_tags, - name_contains=q.name_contains, - metadata_filter=q.metadata_filter, - limit=q.limit, - offset=q.offset, - sort=sort, - order=order, - ) + try: + result = list_assets_page( + owner_id=USER_MANAGER.get_request_user_id(request), + include_tags=q.include_tags, + exclude_tags=q.exclude_tags, + name_contains=q.name_contains, + metadata_filter=q.metadata_filter, + limit=q.limit, + offset=q.offset, + sort=sort, + order=order, + after=q.after, + ) + except InvalidCursorError as e: + return _build_error_response(400, "INVALID_CURSOR", str(e)) summaries = [_build_asset_response(item) for item in result.items] + # has_more semantics differ by mode: + # - cursor mode: a non-empty next_cursor means there are more results. + # - offset mode: derived from total - (offset + page size). + if q.after is not None: + has_more = result.next_cursor is not None + else: + has_more = (q.offset + len(summaries)) < result.total + payload = schemas_out.AssetsList( assets=summaries, total=result.total, - has_more=(q.offset + len(summaries)) < result.total, + has_more=has_more, + next_cursor=result.next_cursor, ) return web.json_response(payload.model_dump(mode="json", exclude_none=True)) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index 186a6ae1e5b1..af666746dce2 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -59,6 +59,11 @@ class ListAssetsQuery(BaseModel): limit: conint(ge=1, le=500) = 20 offset: conint(ge=0) = 0 + # Opaque keyset cursor. When supplied, `offset` is ignored. Cursor pagination + # is supported for sort values `created_at`, `updated_at`, `name`, `size`. + # Supplying `after` together with `sort=last_access_time` returns + # 400 INVALID_CURSOR; that sort only supports offset/limit. + after: str | None = None sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = ( "created_at" diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index 143848329ae5..4e38e19d1824 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -41,6 +41,8 @@ class AssetsList(BaseModel): assets: list[Asset] total: int has_more: bool + # Opaque cursor for the next page. Omitted when there are no more results. + next_cursor: str | None = None class TagUsage(BaseModel): diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index 8b90ae5110f5..792411800e1b 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -266,9 +266,18 @@ def list_references_page( metadata_filter: dict | None = None, sort: str | None = None, order: str | None = None, + after_cursor_value: object | None = None, + after_cursor_id: str | None = None, ) -> tuple[list[AssetReference], dict[str, list[str]], int]: """List references with pagination, filtering, and sorting. + When ``after_cursor_value``/``after_cursor_id`` are supplied the query uses + keyset pagination — ``offset`` is ignored and a WHERE clause selects rows + strictly after the given ``(sort_col, id)`` position in the active sort + direction. The cursor value must already be typed for the column + (datetime for time sorts, int for size, str for name); the caller decodes + the opaque cursor string and resolves to the typed value. + Returns (references, tag_map, total_count). """ base = ( @@ -297,9 +306,31 @@ def list_references_page( "size": Asset.size_bytes, } sort_col = sort_map.get(sort, AssetReference.created_at) - sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + descending = order == "desc" + + # Keyset WHERE: (sort_col, id) strictly less-than / greater-than the cursor. + # Equivalent to: sort_col v OR (sort_col = v AND id cursor_id). + if after_cursor_value is not None and after_cursor_id is not None: + if descending: + keyset = sa.or_( + sort_col < after_cursor_value, + sa.and_(sort_col == after_cursor_value, AssetReference.id < after_cursor_id), + ) + else: + keyset = sa.or_( + sort_col > after_cursor_value, + sa.and_(sort_col == after_cursor_value, AssetReference.id > after_cursor_id), + ) + base = base.where(keyset) + + # Secondary ORDER BY id (matching the primary direction) gives the keyset + # comparison a deterministic tiebreaker on duplicate sort_col values. + id_exp = AssetReference.id.desc() if descending else AssetReference.id.asc() + sort_exp = sort_col.desc() if descending else sort_col.asc() - base = base.order_by(sort_exp).limit(limit).offset(offset) + base = base.order_by(sort_exp, id_exp).limit(limit) + if after_cursor_id is None: + base = base.offset(offset) count_stmt = ( select(sa.func.count()) diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py index 5aefd99567dd..1072c95faaea 100644 --- a/app/assets/services/asset_management.py +++ b/app/assets/services/asset_management.py @@ -1,8 +1,19 @@ import contextlib import mimetypes import os +from datetime import timezone from typing import Sequence +from app.assets.services.cursor import ( + CursorPayload, + InvalidCursorError, + decode_cursor, + decode_cursor_int, + decode_cursor_time, + encode_cursor, + encode_cursor_from_time, +) + from app.assets.database.models import Asset from app.assets.database.queries import ( @@ -242,6 +253,11 @@ def get_asset_by_hash(asset_hash: str) -> AssetData | None: return extract_asset_data(asset) +# Sort fields that support cursor pagination. `last_access_time` is not +# in this list — it falls back to offset/limit. +_CURSOR_SORT_FIELDS = ("created_at", "updated_at", "name", "size") + + def list_assets_page( owner_id: str = "", include_tags: Sequence[str] | None = None, @@ -252,7 +268,39 @@ def list_assets_page( offset: int = 0, sort: str = "created_at", order: str = "desc", + after: str | None = None, ) -> ListAssetsResult: + """List assets with optional cursor pagination. + + When ``after`` is supplied it overrides ``offset``. The cursor's sort field + must match ``sort`` and be in the cursor-supported allowlist; mismatches + raise InvalidCursorError so the handler can map to 400 INVALID_CURSOR. + """ + cursor_value: object | None = None + cursor_id: str | None = None + # Mint next_cursor on every page where the sort is cursor-supported, not + # only when the request itself arrived with a cursor. Otherwise a first + # request (no `after`) returns next_cursor=None and the client can never + # enter cursor mode. + mint_cursor = sort in _CURSOR_SORT_FIELDS + + if after is not None: + if sort not in _CURSOR_SORT_FIELDS: + raise InvalidCursorError( + f"cursor pagination is not supported for sort={sort!r}" + ) + payload = decode_cursor(after, _CURSOR_SORT_FIELDS, expected_order=order) + if payload.sort_field != sort: + raise InvalidCursorError( + f"cursor sort field {payload.sort_field!r} does not match request sort {sort!r}" + ) + cursor_value, cursor_id = _resolve_cursor_value(payload), payload.id + + # Over-fetch by one row so we can distinguish "exactly `limit` rows total + # remaining" from "more rows past this page" without a second query. Drop + # the sentinel before returning. + fetch_limit = limit + 1 if mint_cursor else limit + with create_session() as session: refs, tag_map, total = list_references_page( session, @@ -261,12 +309,22 @@ def list_assets_page( exclude_tags=exclude_tags, name_contains=name_contains, metadata_filter=metadata_filter, - limit=limit, + limit=fetch_limit, offset=offset, sort=sort, order=order, + after_cursor_value=cursor_value, + after_cursor_id=cursor_id, ) + next_cursor: str | None = None + if mint_cursor and len(refs) > limit: + # There's at least one more row past this page — mint a cursor from + # the last row of the page (i.e. index `limit - 1`, since we + # over-fetched), and drop the sentinel. + next_cursor = _encode_next_cursor(refs[limit - 1], sort, order) + refs = refs[:limit] + items: list[AssetSummaryData] = [] for ref in refs: items.append( @@ -277,7 +335,39 @@ def list_assets_page( ) ) - return ListAssetsResult(items=items, total=total) + return ListAssetsResult(items=items, total=total, next_cursor=next_cursor) + + +def _resolve_cursor_value(payload: CursorPayload) -> object: + """Map a decoded cursor payload to a column-typed Python value.""" + if payload.sort_field in ("created_at", "updated_at"): + # DB stores naive UTC; strip tzinfo so the comparison binds against a + # `TIMESTAMP WITHOUT TIME ZONE` column without an offset shift. + return decode_cursor_time(payload).replace(tzinfo=None) + if payload.sort_field == "size": + return decode_cursor_int(payload) + return payload.value # name, str-typed + + +def _encode_next_cursor(ref, sort: str, order: str) -> str | None: + """Mint a cursor pointing at *ref* for the given sort dimension. + + Returns None when the boundary row carries a NULL sort value (e.g. an asset + record whose size_bytes hasn't been backfilled). Continuing pagination + across a NULL boundary is undefined under keyset ordering — better to + truncate cleanly here than to mint a cursor that mis-positions. + """ + if sort == "name": + return encode_cursor("name", ref.name, ref.id, order=order) + if sort == "size": + if ref.asset is None or ref.asset.size_bytes is None: + return None + return encode_cursor("size", str(ref.asset.size_bytes), ref.id, order=order) + # created_at / updated_at — DB datetimes are naive UTC; attach tz before encoding. + value = ref.created_at if sort == "created_at" else ref.updated_at + if value is None: + return None + return encode_cursor_from_time(sort, value.replace(tzinfo=timezone.utc), ref.id, order=order) def resolve_hash_to_path( diff --git a/app/assets/services/cursor.py b/app/assets/services/cursor.py new file mode 100644 index 000000000000..6c779152844d --- /dev/null +++ b/app/assets/services/cursor.py @@ -0,0 +1,213 @@ +"""Opaque keyset-pagination cursor for /api/assets. + +Payload JSON uses short keys to keep the encoded length small: + + {"s": , "v": , "id": , "o": } + +The `o` key binds the cursor to the sort direction it was minted under, +so replaying a `desc` cursor against an `asc` request fails with +``INVALID_CURSOR`` rather than silently walking the wrong direction. +`o` is mandatory on every payload — a cursor without it is rejected as +malformed. + +Encoding is base64url with no padding. Cursors are opaque tokens: the +payload format is internal to this server, and clients must treat a +cursor as a black box handed back via `next_cursor`. No byte-level +compatibility with any other implementation is required. + +Time values are serialized as Unix microseconds (UTC) — microsecond +precision is sufficient to round-trip the timestamps stored by the +database without rounding rows in the same millisecond bucket. +""" +from __future__ import annotations + +import base64 +import json +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Iterable, Optional + + +class InvalidCursorError(ValueError): + """Raised on a malformed, oversized, or unsupported-sort-field cursor. + + Map to a 400 response with code ``INVALID_CURSOR`` at the handler. + """ + + +# Wire-format length caps. Cursors are user-controlled, so caps protect the +# decode path from oversized allocations and downstream SQL predicates from +# unbounded strings. +# +# MAX_CURSOR_VALUE_LENGTH is 512 to fit the `AssetReference.name` column max +# (`String(512)`) — otherwise a long-named asset would mint a cursor the same +# server then refuses on the next request. +# +# MAX_ENCODED_CURSOR_LENGTH is the decode-path guard, sized comfortably above +# the largest cursor the per-field caps can produce. Worst case is value + id +# at their caps with every character JSON-escaping to the six-byte `\uXXXX` +# form (control characters), which is ~5.2 KB once base64url-encoded. At 8192 +# the encoder can never mint a cursor that exceeds it, so a freshly minted +# cursor always decodes on the next request and there is no user-visible +# "cursor too long" failure. +MAX_ENCODED_CURSOR_LENGTH = 8192 +MAX_CURSOR_VALUE_LENGTH = 512 +MAX_CURSOR_ID_LENGTH = 128 + + +@dataclass(frozen=True) +class CursorPayload: + sort_field: str + value: str + id: str + order: str + + +_VALID_ORDERS = ("asc", "desc") + + +def encode_cursor(sort_field: str, value: str, id: str, order: str = "desc") -> str: + """Encode a cursor payload as a base64url (no-padding) string. + + `order` binds the cursor to the sort direction it was minted under so a + later request with a flipped `order` query parameter is rejected with + ``INVALID_CURSOR`` rather than silently walking the wrong direction. + """ + if order not in _VALID_ORDERS: + raise InvalidCursorError(f"order must be one of {_VALID_ORDERS}, got {order!r}") + # Symmetric input validation: the encoder must reject anything the + # decoder rejects, or the same server will mint cursors it then 400s on + # the next request. + if not id: + raise InvalidCursorError("id must be non-empty") + if len(id) > MAX_CURSOR_ID_LENGTH: + raise InvalidCursorError("id exceeds maximum length") + if len(value) > MAX_CURSOR_VALUE_LENGTH: + raise InvalidCursorError("value exceeds maximum length") + payload = {"s": sort_field, "v": value, "id": id, "o": order} + raw = json.dumps(payload, separators=(",", ":"), ensure_ascii=False) + # No mint-time length guard is needed: the per-field caps above bound the + # encoded length well below MAX_ENCODED_CURSOR_LENGTH (see its definition), + # so the encoder can never produce a cursor the decode path would reject. + return base64.urlsafe_b64encode(raw.encode("utf-8")).rstrip(b"=").decode("ascii") + + +def encode_cursor_from_time(sort_field: str, t: datetime, id: str, order: str = "desc") -> str: + """Encode a time-typed cursor at Unix microsecond precision. + + Accepts an aware datetime (any timezone) and normalizes to UTC. Naive + datetimes are rejected so callers can't accidentally encode the local + wall-clock value of a UTC-stored timestamp. + """ + if t.tzinfo is None: + raise ValueError("encode_cursor_from_time requires an aware datetime") + micros = _datetime_to_unix_micros(t.astimezone(timezone.utc)) + return encode_cursor(sort_field, str(micros), id, order=order) + + +def decode_cursor( + cursor: str, + allowed_sort_fields: Iterable[str], + expected_order: str | None = None, +) -> CursorPayload: + """Parse an opaque cursor. + + ``allowed_sort_fields`` is the endpoint's accepted sort-field list — a + cursor carrying a field outside this set is rejected so a cursor minted + for one column can't be replayed against another (e.g. a ``created_at`` + timestamp string compared against a ``name`` column). + + ``expected_order`` (``"asc"``/``"desc"``), when supplied, must match the + payload's ``o`` field. ``o`` is required on every payload; a cursor + missing it is rejected as malformed. + + Passing no allowed fields rejects every cursor. + """ + if len(cursor) > MAX_ENCODED_CURSOR_LENGTH: + raise InvalidCursorError("cursor exceeds maximum length") + + try: + # urlsafe_b64decode requires correct padding; we strip on encode, so + # restore the trailing '=' pad here. + padding = "=" * (-len(cursor) % 4) + raw = base64.urlsafe_b64decode(cursor + padding) + except (ValueError, base64.binascii.Error) as e: + raise InvalidCursorError(f"encoding: {e}") from e + + try: + decoded = json.loads(raw) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + raise InvalidCursorError(f"payload: {e}") from e + + if not isinstance(decoded, dict): + raise InvalidCursorError("payload: expected object") + + sort_field = decoded.get("s") + value = decoded.get("v") + id = decoded.get("id") + order = decoded.get("o") + + if not isinstance(sort_field, str) or not isinstance(value, str) or not isinstance(id, str): + raise InvalidCursorError("payload: missing or non-string s/v/id") + + if id == "": + raise InvalidCursorError("missing id") + if len(id) > MAX_CURSOR_ID_LENGTH: + raise InvalidCursorError("id exceeds maximum length") + if len(value) > MAX_CURSOR_VALUE_LENGTH: + raise InvalidCursorError("value exceeds maximum length") + + if sort_field not in allowed_sort_fields: + raise InvalidCursorError(f"unsupported sort field {sort_field!r}") + + if not isinstance(order, str): + raise InvalidCursorError("missing or non-string o") + if order not in _VALID_ORDERS: + raise InvalidCursorError(f"unsupported order {order!r}") + if expected_order is not None and order != expected_order: + raise InvalidCursorError( + f"cursor order {order!r} does not match request order {expected_order!r}" + ) + + return CursorPayload(sort_field=sort_field, value=value, id=id, order=order) + + +def decode_cursor_time(payload: Optional[CursorPayload]) -> datetime: + """Parse a time-typed cursor value as Unix microseconds, returning UTC.""" + if payload is None: + raise InvalidCursorError("nil cursor payload") + try: + micros = int(payload.value) + except ValueError as e: + raise InvalidCursorError(f"value is not a valid timestamp: {e}") from e + try: + return _unix_micros_to_datetime(micros) + except (OverflowError, OSError, ValueError) as e: + # Crafted out-of-range microseconds (e.g. > datetime.MAX_YEAR) blow up + # in fromtimestamp / datetime construction. Map to 400, not 500. + raise InvalidCursorError(f"value is out of representable range: {e}") from e + + +def decode_cursor_int(payload: Optional[CursorPayload]) -> int: + """Parse a cursor value as a base-10 integer.""" + if payload is None: + raise InvalidCursorError("nil cursor payload") + try: + return int(payload.value) + except ValueError as e: + raise InvalidCursorError(f"value is not a valid integer: {e}") from e + + +_EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc) + + +def _datetime_to_unix_micros(t: datetime) -> int: + """Convert an aware UTC datetime to Unix microseconds (integer math).""" + delta = t - _EPOCH + return (delta.days * 86_400 + delta.seconds) * 1_000_000 + delta.microseconds + + +def _unix_micros_to_datetime(micros: int) -> datetime: + """Convert Unix microseconds to a UTC datetime, preserving precision.""" + seconds, micro_remainder = divmod(micros, 1_000_000) + return datetime.fromtimestamp(seconds, tz=timezone.utc).replace(microsecond=micro_remainder) diff --git a/app/assets/services/schemas.py b/app/assets/services/schemas.py index 2a52d76cac2e..4d2af8a02fe1 100644 --- a/app/assets/services/schemas.py +++ b/app/assets/services/schemas.py @@ -70,6 +70,7 @@ class AssetSummaryData: class ListAssetsResult: items: list[AssetSummaryData] total: int + next_cursor: str | None = None @dataclass(frozen=True) diff --git a/tests-unit/assets_test/queries/test_asset_reference_keyset.py b/tests-unit/assets_test/queries/test_asset_reference_keyset.py new file mode 100644 index 000000000000..d143d60f9436 --- /dev/null +++ b/tests-unit/assets_test/queries/test_asset_reference_keyset.py @@ -0,0 +1,112 @@ +"""Keyset-pagination tiebreaker tests for list_references_page. + +When multiple rows share the same primary sort value (e.g. four assets +created in the same microsecond), the secondary `ORDER BY id` is what keeps +keyset pagination from losing or repeating rows. This file exercises that +branch directly against an in-memory SQLite session — engineering identical +timestamps via HTTP is unreliable enough that we work at the query layer. +""" +import uuid +from datetime import datetime + +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries.asset_reference import list_references_page + + +def _make_ref(session: Session, created_at: datetime, name: str, owner: str = "") -> AssetReference: + asset = Asset(hash=f"blake3:{uuid.uuid4().hex}", size_bytes=1024) + session.add(asset) + session.flush() + ref = AssetReference( + id=str(uuid.uuid4()), + asset_id=asset.id, + owner_id=owner, + name=name, + file_path=f"/tmp/{name}", + created_at=created_at, + updated_at=created_at, + last_access_time=created_at, + is_missing=False, + ) + session.add(ref) + return ref + + +@pytest.mark.parametrize("order", ["desc", "asc"]) +def test_tiebreaker_walks_duplicate_sort_values(session: Session, order: str): + """Four rows with the SAME created_at must paginate cleanly under cursor + mode — no row dropped, no row repeated, despite the primary sort column + being non-discriminating. + """ + shared_ts = datetime(2024, 5, 20, 12, 0, 0) # naive UTC, like the DB stores + refs = [_make_ref(session, shared_ts, f"tie_{i}.png") for i in range(4)] + session.commit() + + expected_ids = sorted([r.id for r in refs], reverse=(order == "desc")) + + # Walk the cursor by hand: page size 2, take 3 pages (2 + 2 + 0). + seen: list[str] = [] + after_value = None + after_id = None + for _ in range(4): # generous loop bound; ought to be 2 iterations + page, _tag_map, _total = list_references_page( + session, + limit=2, + sort="created_at", + order=order, + after_cursor_value=after_value, + after_cursor_id=after_id, + ) + if not page: + break + seen.extend(p.id for p in page) + # Use the last row's (created_at, id) as the next cursor input. + last = page[-1] + after_value, after_id = last.created_at, last.id + if len(page) < 2: + break + + assert seen == expected_ids, ( + f"keyset tiebreaker failed for order={order}: expected {expected_ids}, got {seen}" + ) + + +def test_tiebreaker_no_duplicates_under_mixed_collisions(session: Session): + """Some rows share a timestamp, some don't. The cursor must still walk + every row exactly once regardless of where ties sit relative to a + page boundary.""" + t1 = datetime(2024, 5, 20, 12, 0, 0) + t2 = datetime(2024, 5, 20, 12, 0, 1) + layout = [t1, t1, t1, t2, t2] # three rows at t1, two at t2 + refs = [_make_ref(session, ts, f"mix_{i}.png") for i, ts in enumerate(layout)] + session.commit() + + all_ids = {r.id for r in refs} + seen_set: set[str] = set() + seen_list: list[str] = [] + after_value = None + after_id = None + for _ in range(6): + page, _, _ = list_references_page( + session, + limit=2, + sort="created_at", + order="desc", + after_cursor_value=after_value, + after_cursor_id=after_id, + ) + if not page: + break + for p in page: + assert p.id not in seen_set, f"duplicate row {p.id} appeared in cursor walk" + seen_set.add(p.id) + seen_list.append(p.id) + last = page[-1] + after_value, after_id = last.created_at, last.id + if len(page) < 2: + break + + assert seen_set == all_ids, f"missing rows: expected {all_ids}, got {seen_set}" diff --git a/tests-unit/assets_test/services/test_cursor.py b/tests-unit/assets_test/services/test_cursor.py new file mode 100644 index 000000000000..47970e1684cc --- /dev/null +++ b/tests-unit/assets_test/services/test_cursor.py @@ -0,0 +1,278 @@ +"""Tests for app.assets.services.cursor. + +Cursors are opaque tokens internal to this server — these tests cover +round-tripping, validation, and length caps, not any particular wire +byte layout. +""" +from __future__ import annotations + +import base64 +from datetime import datetime, timedelta, timezone + +import pytest + +from app.assets.services.cursor import ( + MAX_CURSOR_ID_LENGTH, + MAX_CURSOR_VALUE_LENGTH, + MAX_ENCODED_CURSOR_LENGTH, + CursorPayload, + InvalidCursorError, + decode_cursor, + decode_cursor_int, + decode_cursor_time, + encode_cursor, + encode_cursor_from_time, +) + + +ALLOWED = ("created_at", "updated_at", "name", "size") + + +class TestRoundTrip: + @pytest.mark.parametrize( + "sort_field, value, id", + [ + ("created_at", "1716200000000000", "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7"), + ("size", "1024", "asset-123"), + ("name", "my-asset.png", "asset-abc"), + ("name", "résumé.txt", "asset-uni"), + ("name", "foo<&>bar.png", "asset-html"), + ("name", 'quo"te\\back\nnewline.png', "asset-esc"), + ], + ) + def test_encode_decode(self, sort_field, value, id): + encoded = encode_cursor(sort_field, value, id) + assert encoded != "" + payload = decode_cursor(encoded, ALLOWED) + assert payload.sort_field == sort_field + assert payload.value == value + assert payload.id == id + + +class TestTimeCursor: + def test_microsecond_precision_preserved(self): + # Pick a time with non-zero microseconds — encoding at ms would lose the µs. + ts = datetime(2024, 5, 20, 12, 53, 20, 123456, tzinfo=timezone.utc) + encoded = encode_cursor_from_time("created_at", ts, "id-1") + payload = decode_cursor(encoded, ALLOWED) + # Value must be a microsecond integer string, not a millisecond one. + assert payload.value == "1716209600123456" + decoded = decode_cursor_time(payload) + assert decoded == ts + + def test_decode_returns_utc(self): + payload = CursorPayload(sort_field="created_at", value="1716200000123456", id="id-1", order="desc") + decoded = decode_cursor_time(payload) + assert decoded.tzinfo == timezone.utc + + def test_naive_datetime_rejected_on_encode(self): + naive = datetime(2024, 5, 20, 12, 0, 0) + with pytest.raises(ValueError): + encode_cursor_from_time("created_at", naive, "id-1") + + def test_non_integer_value_rejected_on_decode(self): + with pytest.raises(InvalidCursorError): + decode_cursor_time(CursorPayload("created_at", "not-a-number", "id-1", "desc")) + + def test_none_payload_rejected(self): + with pytest.raises(InvalidCursorError): + decode_cursor_time(None) + + def test_non_utc_aware_normalized(self): + # Same instant, different timezone — must encode to the same micros. + utc_ts = datetime(2024, 5, 20, 12, 0, 0, tzinfo=timezone.utc) + offset_ts = utc_ts.astimezone(timezone(timedelta(hours=-5))) + assert encode_cursor_from_time("created_at", utc_ts, "x") == encode_cursor_from_time( + "created_at", offset_ts, "x" + ) + + +class TestIntCursor: + def test_decode_int(self): + assert decode_cursor_int(CursorPayload("size", "1024", "id-1", "desc")) == 1024 + + def test_decode_int_rejects_non_int(self): + with pytest.raises(InvalidCursorError): + decode_cursor_int(CursorPayload("size", "abc", "id-1", "desc")) + + def test_decode_int_rejects_none(self): + with pytest.raises(InvalidCursorError): + decode_cursor_int(None) + + +class TestInvalidInputs: + def test_oversized_cursor(self): + oversized = "a" * (MAX_ENCODED_CURSOR_LENGTH + 1) + with pytest.raises(InvalidCursorError, match="maximum length"): + decode_cursor(oversized, ALLOWED) + + def test_not_base64(self): + with pytest.raises(InvalidCursorError): + decode_cursor("not base64!!!", ALLOWED) + + def test_not_json(self): + encoded = base64.urlsafe_b64encode(b"definitely not json").rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError): + decode_cursor(encoded, ALLOWED) + + def test_empty_id(self): + # Encoder rejects empty id symmetrically with the decoder, so build the + # payload manually to exercise the decoder's missing-id branch. + raw = b'{"s":"created_at","v":"1","id":"","o":"desc"}' + encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError, match="missing id"): + decode_cursor(encoded, ALLOWED) + + def test_oversized_id(self): + # Encoder enforces the cap symmetrically; hand-build to exercise decode. + big_id = "a" * (MAX_CURSOR_ID_LENGTH + 1) + raw = ('{"s":"created_at","v":"1","id":"' + big_id + '","o":"desc"}').encode("ascii") + encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError, match="id exceeds maximum length"): + decode_cursor(encoded, ALLOWED) + + def test_oversized_value(self): + # Encoder enforces the cap symmetrically; hand-build to exercise decode. + big_v = "v" * (MAX_CURSOR_VALUE_LENGTH + 1) + raw = ('{"s":"created_at","v":"' + big_v + '","id":"id-1","o":"desc"}').encode("ascii") + encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError, match="value exceeds maximum length"): + decode_cursor(encoded, ALLOWED) + + def test_unsupported_sort_field(self): + encoded = encode_cursor("execution_time", "1", "id-1") + with pytest.raises(InvalidCursorError, match="unsupported sort field"): + decode_cursor(encoded, ALLOWED) + + def test_no_allowed_fields_rejects_everything(self): + encoded = encode_cursor("created_at", "1", "id-1") + with pytest.raises(InvalidCursorError): + decode_cursor(encoded, ()) + + def test_non_dict_payload_rejected(self): + encoded = base64.urlsafe_b64encode(b'["array","not","dict"]').rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError, match="expected object"): + decode_cursor(encoded, ALLOWED) + + +class TestEncodeAtCapsFits: + def test_max_field_lengths_fit_wire_cap(self): + # Worst-case payload: value and id at their per-field caps, with a long + # sort field name. The encoded cursor must fit within MAX_ENCODED_CURSOR_LENGTH + # so the wire cap cannot reject a cursor the encoder mints at the per-field caps. + value = "v" * MAX_CURSOR_VALUE_LENGTH + id = "i" * MAX_CURSOR_ID_LENGTH + sort_field = "very_long_sort_field_name" + + encoded = encode_cursor(sort_field, value, id) + assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH + payload = decode_cursor(encoded, (sort_field,)) + assert payload.value == value + assert payload.id == id + + +class TestDatetimeOverflow: + """Crafted cursors with extreme micros must map to InvalidCursorError, + not OverflowError/OSError leaking as 500. + """ + + @pytest.mark.parametrize( + "micros_str", + [ + "999999999999999999999", # 10^21 µs — past datetime.MAX_YEAR by ~14 orders + "-999999999999999999999", # symmetric negative — pre-epoch overflow + ], + ) + def test_out_of_range_micros_rejected(self, micros_str): + encoded = encode_cursor("created_at", micros_str, "asset-x") + payload = decode_cursor(encoded, ALLOWED) + with pytest.raises(InvalidCursorError): + decode_cursor_time(payload) + + +class TestEncoderDecoderSymmetry: + """The encoder must never mint a cursor the decoder would reject, or the + same server would 400 on a cursor it just handed out. Per-field caps keep + the encoded length below the wire cap, so a freshly minted cursor always + round-trips. + """ + + def test_long_name_within_cap_round_trips(self): + """Assets allow names up to 512 chars (`String(512)`); the cursor + encoder must round-trip a value at that cap so a freshly minted + cursor never fails decode on the next request.""" + long_name = "n" * MAX_CURSOR_VALUE_LENGTH + encoded = encode_cursor("name", long_name, "asset-x") + payload = decode_cursor(encoded, ALLOWED) + assert payload.value == long_name + + def test_encoder_rejects_empty_id(self): + with pytest.raises(InvalidCursorError, match="id must be non-empty"): + encode_cursor("created_at", "1", "") + + def test_encoder_rejects_oversized_id(self): + with pytest.raises(InvalidCursorError, match="id exceeds maximum length"): + encode_cursor("created_at", "1", "a" * (MAX_CURSOR_ID_LENGTH + 1)) + + def test_encoder_rejects_oversized_value(self): + with pytest.raises(InvalidCursorError, match="value exceeds maximum length"): + encode_cursor("name", "v" * (MAX_CURSOR_VALUE_LENGTH + 1), "id-1") + + def test_multibyte_value_at_cap_round_trips(self): + """A value at the char-count cap made of multibyte characters + (e.g. 'é' = 2 UTF-8 bytes) stays under the wire cap, so it mints and + round-trips — the per-field caps, not a mint-time length check, are + what bound cursor size.""" + value = "é" * MAX_CURSOR_VALUE_LENGTH + encoded = encode_cursor("name", value, "asset-multibyte") + assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH + payload = decode_cursor(encoded, ALLOWED) + assert payload.value == value + + def test_escape_heavy_value_at_cap_round_trips(self): + """JSON escape expansion is the worst case: each control character + serializes to the six-byte `\\uXXXX` form. A value of 512 of them is + the largest a cursor can get, and it still fits the wire cap, mints, + and round-trips.""" + value = "\x01" * MAX_CURSOR_VALUE_LENGTH + encoded = encode_cursor("name", value, "asset-escape") + assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH + payload = decode_cursor(encoded, ALLOWED) + assert payload.value == value + + +class TestOrderBinding: + def test_order_baked_into_payload(self): + encoded = encode_cursor("created_at", "1", "id-1", order="asc") + payload = decode_cursor(encoded, ALLOWED) + assert payload.order == "asc" + + def test_mismatched_order_rejected(self): + encoded = encode_cursor("created_at", "1", "id-1", order="desc") + with pytest.raises(InvalidCursorError, match="does not match request order"): + decode_cursor(encoded, ALLOWED, expected_order="asc") + + def test_matching_order_accepted(self): + encoded = encode_cursor("created_at", "1", "id-1", order="desc") + payload = decode_cursor(encoded, ALLOWED, expected_order="desc") + assert payload.order == "desc" + + def test_invalid_order_token_rejected_on_encode(self): + with pytest.raises(ValueError): + encode_cursor("created_at", "1", "id-1", order="sideways") + + def test_invalid_order_token_rejected_on_decode(self): + # Hand-craft a payload with an illegal `o` value. + raw = b'{"s":"name","v":"x","id":"id-1","o":"sideways"}' + encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError, match="unsupported order"): + decode_cursor(encoded, ALLOWED) + + def test_cursor_without_order_rejected(self): + """`o` is mandatory. A cursor minted without it is rejected as + malformed rather than silently walking the keyset in whatever + direction the request happens to ask for.""" + raw = b'{"s":"name","v":"x","id":"id-1"}' + encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError, match="missing or non-string o"): + decode_cursor(encoded, ALLOWED, expected_order="desc") diff --git a/tests-unit/assets_test/test_list_cursor.py b/tests-unit/assets_test/test_list_cursor.py new file mode 100644 index 000000000000..a37019fd6adc --- /dev/null +++ b/tests-unit/assets_test/test_list_cursor.py @@ -0,0 +1,349 @@ +"""Integration tests for cursor-based pagination on GET /api/assets. + +These tests exercise the handler/service/query path end-to-end; +cursor-encoding-level tests live in +tests-unit/assets_test/services/test_cursor.py. +""" +import pytest +import requests + + +def _seed(asset_factory, make_asset_bytes, count: int, tag: str) -> list[str]: + names = [f"cursor_{i:02d}.safetensors" for i in range(count)] + for n in names: + asset_factory( + n, + ["models", "checkpoints", "unit-tests", tag], + {}, + make_asset_bytes(n, size=2048), + ) + return sorted(names) + + +def test_cursor_pages_all_items_in_order(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + names = _seed(asset_factory, make_asset_bytes, count=5, tag="cursor-walk") + + params = { + "include_tags": "unit-tests,cursor-walk", + "sort": "name", + "order": "asc", + "limit": "2", + } + + seen: list[str] = [] + after: str | None = None + pages = 0 + while True: + page_params = dict(params) + if after is not None: + page_params["after"] = after + r = http.get(api_base + "/api/assets", params=page_params, timeout=120) + assert r.status_code == 200, r.text + body = r.json() + seen.extend(a["name"] for a in body["assets"]) + pages += 1 + after = body.get("next_cursor") + if after is None: + break + assert body["has_more"] is True + assert pages < 10, "guard against runaway cursor loop" + + assert seen == names, f"expected {names}, got {seen}" + # Last page should have has_more False + assert body["has_more"] is False + assert "next_cursor" not in body + + +def test_cursor_invalid_returns_400(http: requests.Session, api_base: str): + r = http.get( + api_base + "/api/assets", + params={"after": "not-a-real-cursor", "sort": "created_at"}, + timeout=120, + ) + assert r.status_code == 400, r.text + body = r.json() + assert body["error"]["code"] == "INVALID_CURSOR" + + +def test_cursor_sort_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + _seed(asset_factory, make_asset_bytes, count=2, tag="cursor-mismatch") + + # Take a real cursor minted for sort=name. + r = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-mismatch", + "sort": "name", + "order": "asc", + "limit": "1", + }, + timeout=120, + ) + assert r.status_code == 200 + cursor = r.json()["next_cursor"] + assert cursor is not None + + # Replay against sort=created_at — should fail with INVALID_CURSOR. + r2 = http.get( + api_base + "/api/assets", + params={"after": cursor, "sort": "created_at"}, + timeout=120, + ) + assert r2.status_code == 400, r2.text + assert r2.json()["error"]["code"] == "INVALID_CURSOR" + + +def test_cursor_wins_over_offset(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-vs-offset") + + # Take a cursor that points past the first item. + r = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-vs-offset", + "sort": "name", + "order": "asc", + "limit": "1", + }, + timeout=120, + ) + assert r.status_code == 200, r.text + cursor = r.json()["next_cursor"] + assert cursor is not None + + # Pass both 'after' and a large offset. Cursor must win; offset is ignored. + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-vs-offset", + "sort": "name", + "order": "asc", + "limit": "1", + "after": cursor, + "offset": "999", + }, + timeout=120, + ) + assert r2.status_code == 200 + body = r2.json() + # Should land on the second name in sorted order — not skip ahead by 999. + assert [a["name"] for a in body["assets"]] == [names[1]] + + +def test_next_cursor_absent_when_no_more_results(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + _seed(asset_factory, make_asset_bytes, count=2, tag="cursor-exhaust") + + r = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-exhaust", + "sort": "name", + "order": "asc", + "limit": "50", + }, + timeout=120, + ) + assert r.status_code == 200, r.text + body = r.json() + assert body["has_more"] is False + assert "next_cursor" not in body + + +def test_cursor_pagination_first_page_mints_cursor(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + """First-page request (no `after`) must still return `next_cursor` when + more rows exist, or pagination is unreachable from a cold start. + """ + _seed(asset_factory, make_asset_bytes, count=3, tag="cursor-first-page") + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,cursor-first-page", "sort": "name", "order": "asc", "limit": "2"}, + timeout=120, + ) + assert r.status_code == 200, r.text + body = r.json() + assert body["has_more"] is True + assert body.get("next_cursor"), "first page must mint a cursor when more rows exist" + + +def test_cursor_no_spurious_cursor_when_page_size_equals_remainder(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + """When `total` is an exact multiple of `limit`, the final page must + NOT carry a next_cursor — there is nothing past it. + """ + _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-exact-multiple") + # Page 1 + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2"}, + timeout=120, + ) + assert r.status_code == 200, r.text + cursor = r.json()["next_cursor"] + assert cursor is not None + # Page 2 — should exhaust the set with no cursor for a phantom page 3 + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2", "after": cursor}, + timeout=120, + ) + assert r2.status_code == 200, r2.text + body = r2.json() + assert len(body["assets"]) == 2 + assert body["has_more"] is False + assert "next_cursor" not in body + + +@pytest.mark.parametrize("sort_field", ["created_at", "updated_at", "size"]) +def test_cursor_walks_for_non_name_sorts(sort_field, http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + """Cursor pagination must work for every sort field the contract claims. + + Without this, the `created_at` / `updated_at` (time-encoded micros) and + `size` (int-encoded) cursor paths go entirely unexercised end-to-end. + """ + # Sizes increase strictly by index, so `size desc` has a deterministic + # expected order. Time-based sorts (created_at / updated_at) can tie when + # rows are inserted faster than the DB's timestamp resolution; for those + # we check coverage and no-duplicates and let the keyset tiebreaker do + # the rest, instead of sleeping between inserts and asserting an order + # that depends on clock granularity. + names = [] + for i in range(4): + n = f"cursor_{sort_field}_{i:02d}.safetensors" + asset_factory(n, ["models", "checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i)) + names.append(n) + + params = { + "include_tags": f"unit-tests,cursor-{sort_field}", + "sort": sort_field, + "order": "desc", + "limit": "2", + } + seen: list[str] = [] + after: str | None = None + pages = 0 + while True: + page_params = dict(params) + if after is not None: + page_params["after"] = after + r = http.get(api_base + "/api/assets", params=page_params, timeout=120) + assert r.status_code == 200, r.text + body = r.json() + seen.extend(a["name"] for a in body["assets"]) + after = body.get("next_cursor") + pages += 1 + if after is None: + break + assert pages < 10, "guard against runaway cursor loop" + + # No duplicates: a faulty keyset boundary that returns the same row across + # two pages must fail this check. + assert len(seen) == len(set(seen)), ( + f"cursor walk repeated rows for sort={sort_field}: {seen}" + ) + # Full coverage: every seeded asset reached exactly once. + assert set(seen) == set(names), ( + f"missing items for sort={sort_field}: expected {set(names)}, got {set(seen)}" + ) + # Strict order check for the only field with a clock-independent ordering. + if sort_field == "size": + assert seen == list(reversed(names)), ( + f"size cursor walked out of order: got {seen}, expected {list(reversed(names))}" + ) + + +def test_cursor_order_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + """A cursor minted under desc order replayed against asc must 400, not + silently walk the wrong direction.""" + _seed(asset_factory, make_asset_bytes, count=3, tag="cursor-order-flip") + + r = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-order-flip", + "sort": "name", + "order": "desc", + "limit": "1", + }, + timeout=120, + ) + assert r.status_code == 200, r.text + cursor = r.json()["next_cursor"] + assert cursor is not None + + # Replay with order flipped to asc — server must reject the cursor. + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-order-flip", + "sort": "name", + "order": "asc", + "limit": "1", + "after": cursor, + }, + timeout=120, + ) + assert r2.status_code == 400, r2.text + assert r2.json()["error"]["code"] == "INVALID_CURSOR" + + +def test_cursor_invalid_cursor_at_microsecond_boundary(http: requests.Session, api_base: str): + """A cursor carrying an out-of-range microsecond timestamp must map to + 400 INVALID_CURSOR, not 500.""" + import base64 + import json + # 10^18 microseconds ≈ year 33658, well past datetime.MAX_YEAR. + # `o` and `order=` must be set; otherwise decode fails earlier on the + # missing-order branch and the µs-overflow path is never exercised. + payload = {"s": "created_at", "o": "desc", "v": "999999999999999999999", "id": "asset-x"} + raw = json.dumps(payload, separators=(",", ":")).encode("utf-8") + cursor = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + r = http.get( + api_base + "/api/assets", + params={"after": cursor, "sort": "created_at", "order": "desc"}, + timeout=120, + ) + assert r.status_code == 400, r.text + assert r.json()["error"]["code"] == "INVALID_CURSOR" + + +def test_cursor_pagination_stable_after_delete(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-delete") + + # Page 1. + r = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-delete", + "sort": "name", + "order": "asc", + "limit": "2", + }, + timeout=120, + ) + assert r.status_code == 200 + body = r.json() + page1_names = [a["name"] for a in body["assets"]] + cursor = body["next_cursor"] + assert cursor is not None + assert page1_names == names[:2] + + # Delete an item from page 1 (already returned) — cursor should still + # locate the next page from where it was minted, not re-index. + target_id = body["assets"][0]["id"] + d = http.delete(api_base + f"/api/assets/{target_id}", timeout=120) + assert d.status_code in (200, 204), d.text + + # Page 2 via cursor. + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-delete", + "sort": "name", + "order": "asc", + "limit": "2", + "after": cursor, + }, + timeout=120, + ) + assert r2.status_code == 200, r2.text + body2 = r2.json() + assert [a["name"] for a in body2["assets"]] == names[2:] From 039ed38ed10ad0072a13f6471e06913ed33d5e56 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Tue, 9 Jun 2026 21:52:14 -0700 Subject: [PATCH 40/45] fix(assets): remove unused delete_content param from deleteAsset (#14241) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(assets): remove unused delete_content param from deleteAsset The delete_content query param on DELETE /api/assets/{id} was introduced in #12125 and had its default flipped to false in #12621. In practice no client sends it: the frontend issues a bare DELETE /assets/{id}, so every real caller already gets the default soft-delete (the reference is hidden, content preserved). The only thing that set delete_content=true was this repo's own test teardown. Remove the param from the route and the OpenAPI spec so the contract matches what clients actually use (and lines up with the cloud surface). The route now always soft-deletes. The underlying delete_asset_reference helper keeps its delete_content_if_orphan option, so orphan reclamation remains available internally for a future GC path — it's just no longer exposed on the public endpoint. Tests that used delete_content=true for hard cleanup now soft-delete; test_delete_upon_reference_count asserts content preservation instead of orphan removal. * test/docs: address review on deleteAsset delete_content removal - Rename test_delete_upon_reference_count -> test_soft_delete_preserves_asset_identity_across_references; the old name implied last-ref cleanup, but it now verifies the opposite (soft delete preserves identity across references). - Strengthen the re-association assertion: also check asset_hash == src_hash so it proves content reuse rather than relying on the now-tautological created_new is False. - Document delete_asset_reference: the orphan-reclamation branch is intentionally internal-only; the public endpoint always soft-deletes. - Normalize the soft-delete comment phrasing. * test(assets): make seed content unique per test for isolation Removing the delete_content param means delete is always a soft delete, so content created by one test now survives into the next. The suite had been relying on hard-delete teardown for isolation, so shared fixed-content fixtures started colliding: seeded_asset (b"A"*4096) and make_asset_bytes (deterministic on name) produced the same hash every test, so the second seed deduped to the surviving asset and returned 200 instead of 201, cascading into ~14 failures/errors. Salt both fixtures with a per-test uuid so each test creates fresh content (created_new True, 201), while keeping content deterministic within a test (same name/size -> same bytes) and preserving exact byte length so size-based list/sort assertions are unaffected. --- app/assets/api/routes.py | 10 +++------- app/assets/services/asset_management.py | 10 ++++++++++ tests-unit/assets_test/conftest.py | 21 +++++++++++++++++---- tests-unit/assets_test/test_crud.py | 23 +++++++++++++---------- tests-unit/assets_test/test_downloads.py | 2 +- tests-unit/assets_test/test_tags_api.py | 4 ++-- 6 files changed, 46 insertions(+), 24 deletions(-) diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 544a614f2bd3..7ef462f5cd04 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -533,18 +533,14 @@ async def update_asset_route(request: web.Request) -> web.Response: @_require_assets_feature_enabled async def delete_asset_route(request: web.Request) -> web.Response: reference_id = str(uuid.UUID(request.match_info["id"])) - delete_content_param = request.query.get("delete_content") - delete_content = ( - False - if delete_content_param is None - else delete_content_param.lower() not in {"0", "false", "no"} - ) try: + # Deleting an asset is a soft delete of the reference; the underlying + # content is preserved (it may be shared with other references). deleted = delete_asset_reference( reference_id=reference_id, owner_id=USER_MANAGER.get_request_user_id(request), - delete_content_if_orphan=delete_content, + delete_content_if_orphan=False, ) except Exception: logging.exception( diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py index 1072c95faaea..d4e4fc61c4ad 100644 --- a/app/assets/services/asset_management.py +++ b/app/assets/services/asset_management.py @@ -160,6 +160,16 @@ def delete_asset_reference( owner_id: str, delete_content_if_orphan: bool = True, ) -> bool: + """Delete an asset reference. + + With ``delete_content_if_orphan=False`` (a soft delete), the reference is + hidden and the underlying content is preserved. With ``True``, the content + is also removed once it becomes orphaned. + + Note: the public DELETE /api/assets/{id} endpoint always soft-deletes + (passes ``False``); the orphan-reclamation path is intentionally + internal-only, retained for a future GC/admin caller. + """ with create_session() as session: if not delete_content_if_orphan: # Soft delete: mark the reference as deleted but keep everything diff --git a/tests-unit/assets_test/conftest.py b/tests-unit/assets_test/conftest.py index 9867b4e140a8..4aa20372f25a 100644 --- a/tests-unit/assets_test/conftest.py +++ b/tests-unit/assets_test/conftest.py @@ -6,6 +6,7 @@ import sys import tempfile import time +import uuid from pathlib import Path from typing import Callable, Iterator, Optional @@ -188,9 +189,17 @@ def _post_multipart_asset( @pytest.fixture def make_asset_bytes() -> Callable[[str, int], bytes]: + # Salt content per test so it never collides with assets left over from + # earlier tests. Delete is now always a soft delete (content is preserved), + # so the suite can no longer rely on hard-deleting content for isolation. + # Deterministic within a test: the same (name, size) yields the same bytes. + salt = uuid.uuid4().bytes + def _make(name: str, size: int = 8192) -> bytes: seed = sum(ord(c) for c in name) % 251 - return bytes((i * 31 + seed) % 256 for i in range(size)) + body = bytearray((i * 31 + seed) % 256 for i in range(size)) + body[: len(salt)] = salt[:size] + return bytes(body) return _make @@ -212,7 +221,7 @@ def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict: for aid in created: with contextlib.suppress(Exception): - http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30) + http.delete(f"{api_base}/api/assets/{aid}", timeout=30) @pytest.fixture @@ -227,7 +236,11 @@ def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_bas if tags is None: tags = ["models", "checkpoints", "unit-tests", "alpha"] meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None} - files = {"file": (name, b"A" * 4096, "application/octet-stream")} + # Unique content per test so the seed always creates a fresh asset (201). + # Delete is now always a soft delete, so content from a prior test survives + # and would otherwise dedup this upload into an existing asset (200). + content = uuid.uuid4().bytes + b"A" * (4096 - 16) + files = {"file": (name, content, "application/octet-stream")} form_data = { "tags": json.dumps(tags), "name": name, @@ -260,4 +273,4 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str): break for aid in ids: with contextlib.suppress(Exception): - http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30) + http.delete(f"{api_base}/api/assets/{aid}", timeout=30) diff --git a/tests-unit/assets_test/test_crud.py b/tests-unit/assets_test/test_crud.py index fd2e9a0984c6..36abb60eecb1 100644 --- a/tests-unit/assets_test/test_crud.py +++ b/tests-unit/assets_test/test_crud.py @@ -45,8 +45,8 @@ def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asse assert "user_metadata" in detail assert "filename" in detail["user_metadata"] - # DELETE (hard delete to also remove underlying asset and file) - rd = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) + # Soft delete — the reference is hidden, content is preserved + rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) assert rd.status_code == 204 # GET again -> 404 @@ -60,7 +60,7 @@ def test_soft_delete_hides_from_get(http: requests.Session, api_base: str, seede aid = seeded_asset["id"] asset_hash = seeded_asset["asset_hash"] - # Soft-delete (default, no delete_content param) + # Soft delete — the reference is hidden, content is preserved rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) assert rd.status_code == 204 @@ -81,11 +81,10 @@ def test_soft_delete_hides_from_get(http: requests.Session, api_base: str, seede ids = [a["id"] for a in rl.json().get("assets", [])] assert aid not in ids - # Clean up: hard-delete the soft-deleted reference and orphaned asset - http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) + # The reference is already soft-deleted; content is preserved. -def test_delete_upon_reference_count( +def test_soft_delete_preserves_asset_identity_across_references( http: requests.Session, api_base: str, seeded_asset: dict ): # Create a second reference to the same asset via from-hash @@ -119,16 +118,20 @@ def test_delete_upon_reference_count( rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) assert rh2.status_code == 200 # asset identity preserved (soft delete) - # Re-associate via from-hash, then hard-delete -> orphan content removed + # Re-associate via from-hash: it must reuse the same preserved content + # (created_new False AND the same hash), proving the soft deletes did not + # destroy the underlying asset. Then soft-delete again -> still preserved. r3 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) assert r3.status_code == 201, r3.json() + assert r3.json()["created_new"] is False + assert r3.json()["asset_hash"] == src_hash # reused the surviving content aid3 = r3.json()["id"] - rd3 = http.delete(f"{api_base}/api/assets/{aid3}?delete_content=true", timeout=120) + rd3 = http.delete(f"{api_base}/api/assets/{aid3}", timeout=120) assert rd3.status_code == 204 rh3 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) - assert rh3.status_code == 404 # orphan content removed + assert rh3.status_code == 200 # content preserved (soft delete) def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict): @@ -249,7 +252,7 @@ def test_concurrent_delete_same_asset_info_single_204( # Hit the same endpoint N times in parallel. n_tests = 4 - url = f"{api_base}/api/assets/{aid}?delete_content=false" + url = f"{api_base}/api/assets/{aid}" def _do_delete(delete_url): with requests.Session() as s: diff --git a/tests-unit/assets_test/test_downloads.py b/tests-unit/assets_test/test_downloads.py index 672ba97284dd..42c64a5fde71 100644 --- a/tests-unit/assets_test/test_downloads.py +++ b/tests-unit/assets_test/test_downloads.py @@ -117,7 +117,7 @@ def test_download_missing_file_returns_404( assert body["error"]["code"] == "FILE_NOT_FOUND" finally: # We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually. - dr = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) + dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) dr.content diff --git a/tests-unit/assets_test/test_tags_api.py b/tests-unit/assets_test/test_tags_api.py index 595bf29c66a1..9729b7d03a6a 100644 --- a/tests-unit/assets_test/test_tags_api.py +++ b/tests-unit/assets_test/test_tags_api.py @@ -69,8 +69,8 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, used_names = [t["name"] for t in body2["tags"]] assert custom_tag in used_names - # Hard-delete the asset so the tag usage drops to zero - rd = http.delete(f"{api_base}/api/assets/{_asset['id']}?delete_content=true", timeout=120) + # Delete the asset reference so the tag usage drops to zero + rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120) assert rd.status_code == 204 # Now the custom tag must not be returned when include_zero=false From 6d18f4adacea2304f1f6f4ff3c0279d13654ec5c Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 11 Jun 2026 03:54:32 +1000 Subject: [PATCH 41/45] main: force cudnn.benchmark to false (#14390) Some custom nodes try to set this true globally. It messes with dynamic VRAM with one-off spikes that can OOM but this is also very high risk for windows where such allocations might get serviced by shared memory fallback. Trump it. --- comfy/model_management.py | 6 ++++-- main.py | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9dc0a4e13cd8..55ddaab8e338 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -534,8 +534,10 @@ def aotriton_supported(gpu_arch): except: pass -if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast: - torch.backends.cudnn.benchmark = True + +def set_cudnn_benchmark(): + if torch.cuda.is_available() and torch.backends.cudnn.is_available(): + torch.backends.cudnn.benchmark = PerformanceFeature.AutoTune in args.fast try: if torch_version_numeric >= (2, 5): diff --git a/main.py b/main.py index 7fcc8e97d520..0ad6603767c6 100644 --- a/main.py +++ b/main.py @@ -490,6 +490,11 @@ def start_comfyui(asyncio_loop=None): init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, init_api_nodes=not args.disable_api_nodes )) + + # Re-apply Comfy's cuDNN benchmark policy after custom-node imports. Benchmark + # mode can request near-card-sized autotune workspaces, and some custom nodes set it at import time. + comfy.model_management.set_cudnn_benchmark() + hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() From e5b7140dcc5a88a6ad673a249eed223238e45a2b Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Wed, 10 Jun 2026 16:55:25 -0700 Subject: [PATCH 42/45] feat(assets): add job_ids filter to GET /api/assets (#13998) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(assets): add job_ids filter to GET /api/assets Mirrors the existing cloud `job_ids` query param on the local Python server: clients can pass a comma-separated list (or repeated query params) of UUIDs to filter assets by their associated job. The `AssetReference.job_id` column already exists, so no migration is needed — this just plumbs the filter through schema → service → query. Marks the parameter as available in both runtimes by dropping the `[cloud-only]` description prefix and the `x-runtime: [cloud]` tag from the OpenAPI spec, per the OSS field-drift convention (absent runtime tag = populated by both local and cloud). * fix(assets): tighten job_ids — array schema, max_length, narrow except From cursor-reviews on the parent commit: - OpenAPI: declare job_ids as `type: array, items: string format: uuid` with `style: form, explode: true` so it matches the documented contract (and matches sibling include_tags/exclude_tags shape). Description now states both accepted shapes explicitly. - Schema: cap `job_ids` at 500 entries (max_length on the Pydantic field) so a client can't splice an unbounded list into the IN clauses. - Schema: drop `AttributeError` from the except — `raw` only contains `str` items by construction, so `uuid.UUID()` raises `ValueError` exclusively; the second clause was dead code. * fix(assets): tighten job_ids validator + add schema-level tests Aligns with the parallel hardening from draft PR #13848 (now closed as a duplicate). The validator now: - Raises ValueError on non-string list items (was: silently dropped). - Raises ValueError on non-string / non-list top-level values like dict or int (was: silently passed through to Pydantic's downstream coercion). Adds tests-unit/assets_test/queries/test_list_assets_query.py covering the validator end-to-end: CSV canonicalization, dedup order, default empty, invalid UUID, non-string list item, non-string non-list value, and the max_length=500 boundary. * feat(prompt): enforce canonical UUID prompt_id at job creation POST /prompt previously accepted any client-supplied prompt_id verbatim, str()-coercing even non-strings, and minting the literal job id "None" for an explicit JSON null. The new GET /api/assets job_ids filter matches stored job ids as canonical UUIDs exactly, so a non-UUID id minted a job whose assets could never be filtered. - validate_job_id (comfy_execution/jobs.py): requires a string in the canonical lowercase hyphenated UUID form; raises ValueError otherwise, including parseable-but-non-canonical spellings (uppercase, braced, URN, bare hex), which would otherwise be silently rewritten and then miss every exact-match lookup downstream (history keys, websocket correlation, /interrupt, the assets job_ids filter). - POST /prompt: absent or null prompt_id means the server mints uuid4; invalid means 400 invalid_prompt_id on the standard error envelope. - openapi.yaml: document the request-side prompt_id (format uuid, nullable) on PromptRequest. - tests: unit matrix for validate_job_id; integration tests against the booted server covering rejection, acceptance, and null handling. --------- Co-authored-by: guill --- app/assets/api/routes.py | 1 + app/assets/api/schemas_in.py | 36 ++++++++++ .../database/queries/asset_reference.py | 6 ++ app/assets/services/asset_management.py | 2 + comfy_execution/jobs.py | 21 ++++++ openapi.yaml | 5 ++ server.py | 18 ++++- .../assets_test/queries/test_asset_info.py | 50 ++++++++++++++ .../queries/test_list_assets_query.py | 60 ++++++++++++++++ .../assets_test/test_prompt_id_enforcement.py | 69 +++++++++++++++++++ tests/execution/test_jobs.py | 43 ++++++++++++ 11 files changed, 309 insertions(+), 2 deletions(-) create mode 100644 tests-unit/assets_test/queries/test_list_assets_query.py create mode 100644 tests-unit/assets_test/test_prompt_id_enforcement.py diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 7ef462f5cd04..6c9a3200d2ea 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -219,6 +219,7 @@ async def list_assets_route(request: web.Request) -> web.Response: exclude_tags=q.exclude_tags, name_contains=q.name_contains, metadata_filter=q.metadata_filter, + job_ids=q.job_ids, limit=q.limit, offset=q.offset, sort=sort, diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index af666746dce2..4ae18c65a523 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -1,4 +1,5 @@ import json +import uuid from dataclasses import dataclass from typing import Any, Literal @@ -53,6 +54,7 @@ class ListAssetsQuery(BaseModel): include_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list) name_contains: str | None = None + job_ids: list[str] = Field(default_factory=list, max_length=500) # Accept either a JSON string (query param) or a dict metadata_filter: dict[str, Any] | None = None @@ -86,6 +88,40 @@ def _split_csv_tags(cls, v): return out return v + @field_validator("job_ids", mode="before") + @classmethod + def _split_and_validate_job_ids(cls, v): + # Accept "uuid1,uuid2" or ["uuid1","uuid2"] or repeated query params. + # Each entry must parse as a UUID; canonicalized to lowercase hyphenated form. + if v is None: + return [] + if isinstance(v, str): + raw = [t.strip() for t in v.split(",") if t.strip()] + elif isinstance(v, list): + raw = [] + for item in v: + if not isinstance(item, str): + raise ValueError( + f"job_ids entries must be strings, got {type(item).__name__}" + ) + raw.extend([t.strip() for t in item.split(",") if t.strip()]) + else: + raise ValueError( + f"job_ids must be a string or list of strings, got {type(v).__name__}" + ) + + out: list[str] = [] + seen: set[str] = set() + for s in raw: + try: + canonical = str(uuid.UUID(s)) + except ValueError as e: + raise ValueError(f"job_ids must be UUIDs: {s!r}") from e + if canonical not in seen: + seen.add(canonical) + out.append(canonical) + return out + @field_validator("metadata_filter", mode="before") @classmethod def _parse_metadata_json(cls, v): diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index 792411800e1b..33ded8a1cc8c 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -264,6 +264,7 @@ def list_references_page( include_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None, metadata_filter: dict | None = None, + job_ids: Sequence[str] | None = None, sort: str | None = None, order: str | None = None, after_cursor_value: object | None = None, @@ -293,6 +294,9 @@ def list_references_page( escaped, esc = escape_sql_like_string(name_contains) base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) + if job_ids: + base = base.where(AssetReference.job_id.in_(list(job_ids))) + base = apply_tag_filters(base, include_tags, exclude_tags) base = apply_metadata_filter(base, metadata_filter) @@ -345,6 +349,8 @@ def list_references_page( count_stmt = count_stmt.where( AssetReference.name.ilike(f"%{escaped}%", escape=esc) ) + if job_ids: + count_stmt = count_stmt.where(AssetReference.job_id.in_(list(job_ids))) count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = apply_metadata_filter(count_stmt, metadata_filter) diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py index d4e4fc61c4ad..53aec7a15e1f 100644 --- a/app/assets/services/asset_management.py +++ b/app/assets/services/asset_management.py @@ -274,6 +274,7 @@ def list_assets_page( exclude_tags: Sequence[str] | None = None, name_contains: str | None = None, metadata_filter: dict | None = None, + job_ids: Sequence[str] | None = None, limit: int = 20, offset: int = 0, sort: str = "created_at", @@ -319,6 +320,7 @@ def list_assets_page( exclude_tags=exclude_tags, name_contains=name_contains, metadata_filter=metadata_filter, + job_ids=job_ids, limit=fetch_limit, offset=offset, sort=sort, diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index fcd7ef735988..3fbcc3eb01be 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -3,6 +3,7 @@ Provides normalization and helper functions for job status tracking. """ +import uuid from typing import Optional from comfy_api.internal import prune_dict @@ -19,6 +20,26 @@ class JobStatus: ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED] +def validate_job_id(value) -> str: + """Validate a client-supplied job (prompt) id. + + Job ids must be UUIDs in the canonical lowercase hyphenated form. The id + is stored and compared verbatim everywhere downstream — history keys, + websocket events, /interrupt matching, and the assets ``job_ids`` filter + (a String(36) column matched exactly) — so accepting another spelling + would either rewrite the client's id behind its back or mint a job whose + outputs the filter can never find. Rejecting loudly beats both. + + Returns the id unchanged. Raises ValueError when the value is not a + string in canonical UUID form. + """ + if not isinstance(value, str): + raise ValueError(f"job id must be a string, got {type(value).__name__}") + if str(uuid.UUID(value)) != value: + raise ValueError("job id must be a UUID in canonical lowercase hyphenated form") + return value + + # Media types that can be previewed in the frontend PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'}) diff --git a/openapi.yaml b/openapi.yaml index c27ed7adf1b8..58614103a756 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -896,6 +896,11 @@ components: additionalProperties: true description: The workflow graph to execute type: object + prompt_id: + description: Optional client-supplied job id. Must be a UUID in canonical lowercase hyphenated form; it is echoed back in the response. Omitted or null means the server generates one. + format: uuid + nullable: true + type: string workflow_id: description: UUID identifying the cloud workflow entity to associate with this job type: string diff --git a/server.py b/server.py index a85c1e59147c..cc3b33a5c962 100644 --- a/server.py +++ b/server.py @@ -8,7 +8,7 @@ import nodes import folder_paths import execution -from comfy_execution.jobs import JobStatus, get_job, get_all_jobs +from comfy_execution.jobs import JobStatus, get_job, get_all_jobs, validate_job_id import uuid import urllib import json @@ -942,7 +942,21 @@ async def post_prompt(request): if "prompt" in json_data: prompt = json_data["prompt"] - prompt_id = str(json_data.get("prompt_id", uuid.uuid4())) + client_prompt_id = json_data.get("prompt_id") + if client_prompt_id is None: + # Absent or explicit null: the server mints the id. + prompt_id = str(uuid.uuid4()) + else: + try: + prompt_id = validate_job_id(client_prompt_id) + except ValueError: + error = { + "type": "invalid_prompt_id", + "message": "prompt_id must be a valid UUID", + "details": "prompt_id must be a UUID string in canonical lowercase hyphenated form; omit it to let the server generate one", + "extra_info": {} + } + return web.json_response({"error": error, "node_errors": {}}, status=400) partial_execution_targets = None if "partial_execution_targets" in json_data: diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py index fe510e3422e6..ba729a2701d5 100644 --- a/tests-unit/assets_test/queries/test_asset_info.py +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -158,6 +158,56 @@ def test_sorting(self, session: Session): refs, _, _ = list_references_page(session, sort="name", order="asc") assert refs[0].name == "large" + def test_job_ids_filter(self, session: Session): + asset = _make_asset(session, "hash1") + job_a = str(uuid.uuid4()) + job_b = str(uuid.uuid4()) + ref_a = _make_reference(session, asset, name="from_job_a") + ref_a.job_id = job_a + ref_b = _make_reference(session, asset, name="from_job_b") + ref_b.job_id = job_b + _make_reference(session, asset, name="no_job") + session.commit() + + # Single job filter + refs, _, total = list_references_page(session, job_ids=[job_a]) + assert total == 1 + assert refs[0].name == "from_job_a" + + # Multi-job filter (IN) + refs, _, total = list_references_page(session, job_ids=[job_a, job_b]) + names = sorted(r.name for r in refs) + assert total == 2 + assert names == ["from_job_a", "from_job_b"] + + # Unknown job id matches nothing + refs, _, total = list_references_page(session, job_ids=[str(uuid.uuid4())]) + assert total == 0 + assert refs == [] + + # Empty/None means no filter -> all three references + refs, _, total = list_references_page(session, job_ids=[]) + assert total == 3 + refs, _, total = list_references_page(session, job_ids=None) + assert total == 3 + + def test_job_ids_combined_with_other_filters(self, session: Session): + asset = _make_asset(session, "hash1") + job_a = str(uuid.uuid4()) + ref_match = _make_reference(session, asset, name="match.bin") + ref_match.job_id = job_a + ref_wrong_name = _make_reference(session, asset, name="other.bin") + ref_wrong_name.job_id = job_a + ref_wrong_job = _make_reference(session, asset, name="match.bin") + ref_wrong_job.job_id = str(uuid.uuid4()) + session.commit() + + refs, _, total = list_references_page( + session, job_ids=[job_a], name_contains="match" + ) + assert total == 1 + assert refs[0].id == ref_match.id + class TestFetchReferenceAssetAndTags: def test_returns_none_for_nonexistent(self, session: Session): diff --git a/tests-unit/assets_test/queries/test_list_assets_query.py b/tests-unit/assets_test/queries/test_list_assets_query.py new file mode 100644 index 000000000000..e8d3430e2eb7 --- /dev/null +++ b/tests-unit/assets_test/queries/test_list_assets_query.py @@ -0,0 +1,60 @@ +"""Schema-level unit tests for ListAssetsQuery (no DB required).""" +import uuid + +import pytest +from pydantic import ValidationError + +from app.assets.api.schemas_in import ListAssetsQuery + + +class TestJobIdsValidator: + def test_csv_string_parses_and_canonicalizes(self): + a = "AAAAAAAA-BBBB-CCCC-DDDD-EEEEEEEEEEEE" + b = "11111111-2222-3333-4444-555555555555" + q = ListAssetsQuery.model_validate({"job_ids": f"{a},{b}"}) + # Canonicalized to lowercase + assert q.job_ids == [a.lower(), b] + + def test_repeated_query_params_as_list(self): + a = "11111111-1111-1111-1111-111111111111" + b = "22222222-2222-2222-2222-222222222222" + q = ListAssetsQuery.model_validate({"job_ids": [a, b]}) + assert q.job_ids == [a, b] + + def test_dedup_preserves_first_seen_order(self): + a = "11111111-1111-1111-1111-111111111111" + b = "22222222-2222-2222-2222-222222222222" + q = ListAssetsQuery.model_validate({"job_ids": [a, b, a]}) + assert q.job_ids == [a, b] + + def test_default_empty(self): + q = ListAssetsQuery.model_validate({}) + assert q.job_ids == [] + + def test_invalid_uuid_rejected(self): + with pytest.raises(ValidationError) as exc: + ListAssetsQuery.model_validate({"job_ids": "not-a-uuid"}) + assert "must be UUIDs" in str(exc.value) + + def test_non_string_list_item_rejected(self): + with pytest.raises(ValidationError) as exc: + ListAssetsQuery.model_validate( + {"job_ids": ["11111111-1111-1111-1111-111111111111", 42]} + ) + assert "must be strings" in str(exc.value) + + def test_non_string_non_list_value_rejected(self): + with pytest.raises(ValidationError) as exc: + ListAssetsQuery.model_validate({"job_ids": {"bad": "shape"}}) + assert "must be a string or list of strings" in str(exc.value) + + def test_max_length_enforced(self): + too_many = [str(uuid.uuid4()) for _ in range(501)] + with pytest.raises(ValidationError) as exc: + ListAssetsQuery.model_validate({"job_ids": too_many}) + assert exc.value.errors()[0]["type"] == "too_long" + + def test_max_length_boundary_accepted(self): + at_cap = [str(uuid.uuid4()) for _ in range(500)] + q = ListAssetsQuery.model_validate({"job_ids": at_cap}) + assert len(q.job_ids) == 500 diff --git a/tests-unit/assets_test/test_prompt_id_enforcement.py b/tests-unit/assets_test/test_prompt_id_enforcement.py new file mode 100644 index 000000000000..fb961beae957 --- /dev/null +++ b/tests-unit/assets_test/test_prompt_id_enforcement.py @@ -0,0 +1,69 @@ +"""POST /prompt enforces canonical-UUID job ids at creation time. + +Lives in assets_test because it uses this suite's booted-server fixture and +because the invariant exists for the assets pipeline: the GET /api/assets +``job_ids`` filter matches stored job ids exactly, so a job minted with a +non-canonical id would produce assets the filter can never find. + +The prompt bodies here are intentionally invalid workflows — prompt_id +validation happens before workflow validation, so a rejected id returns +``invalid_prompt_id`` while an accepted id falls through to the ordinary +workflow-validation error (proving it cleared the id check). +""" +import requests + + +def _post_prompt(http: requests.Session, api_base: str, body: dict) -> requests.Response: + return http.post(api_base + "/prompt", json=body, timeout=30) + + +def _error_type(r: requests.Response) -> str: + return r.json()["error"]["type"] + + +def test_non_uuid_prompt_id_rejected(http: requests.Session, api_base: str): + r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": "not-a-uuid"}) + assert r.status_code == 400, r.text + assert _error_type(r) == "invalid_prompt_id" + + +def test_non_string_prompt_id_rejected(http: requests.Session, api_base: str): + # Previously str()-coerced (123 became the job id "123"); must now be a 400, + # not a 500 from uuid.UUID choking on a non-string. + r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": 123}) + assert r.status_code == 400, r.text + assert _error_type(r) == "invalid_prompt_id" + + +def test_non_canonical_uuid_rejected(http: requests.Session, api_base: str): + # Parseable as a UUID, but not the canonical lowercase form: rejected + # loudly rather than silently rewritten (downstream lookups match the + # stored id exactly). + r = _post_prompt( + http, + api_base, + {"prompt": {}, "prompt_id": "AAAAAAAA-BBBB-4CCC-8DDD-EEEEEEEEEEEE"}, + ) + assert r.status_code == 400, r.text + assert _error_type(r) == "invalid_prompt_id" + + +def test_canonical_uuid_accepted(http: requests.Session, api_base: str): + # The id clears validation; the empty workflow then fails ordinary prompt + # validation, proving the request got past the id check. + r = _post_prompt( + http, + api_base, + {"prompt": {}, "prompt_id": "aaaaaaaa-bbbb-4ccc-8ddd-eeeeeeeeeeee"}, + ) + assert r.status_code == 400, r.text + assert _error_type(r) != "invalid_prompt_id" + + +def test_null_prompt_id_not_rejected(http: requests.Session, api_base: str): + # Explicit null means "server generates" and must not be rejected as an + # invalid id. (The minted id itself is not observable here because the + # workflow is invalid; unit tests cover validate_job_id directly.) + r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": None}) + assert r.status_code == 400, r.text + assert _error_type(r) != "invalid_prompt_id" diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py index 814af5c1326e..30e47071d8c8 100644 --- a/tests/execution/test_jobs.py +++ b/tests/execution/test_jobs.py @@ -1,5 +1,7 @@ """Unit tests for comfy_execution/jobs.py""" +import pytest + from comfy_execution.jobs import ( JobStatus, is_previewable, @@ -10,9 +12,50 @@ get_outputs_summary, apply_sorting, has_3d_extension, + validate_job_id, ) +class TestValidateJobId: + """validate_job_id guards job creation: POST /prompt rejects ids it raises on.""" + + def test_canonical_form_passes_through(self): + cid = "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7" + assert validate_job_id(cid) == cid + + @pytest.mark.parametrize( + "variant", + [ + "A1B2C3D4-E5F6-7A89-B0C1-D2E3F4A5B6C7", # uppercase + "{a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7}", # braced + "urn:uuid:a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7", # URN + "a1b2c3d4e5f67a89b0c1d2e3f4a5b6c7", # bare hex + " a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7 ", # padded + ], + ) + def test_non_canonical_spellings_rejected(self, variant): + # uuid.UUID parses all of these, but accepting them would silently + # rewrite the client's id (history keys, websocket events, and the + # assets job_ids filter all match the stored form exactly). + with pytest.raises(ValueError): + validate_job_id(variant) + + @pytest.mark.parametrize( + "bad", + ["", "not-a-uuid", "prompt-123", "a1b2c3d4-e5f6-7a89-b0c1", "None"], + ) + def test_non_uuid_strings_rejected(self, bad): + with pytest.raises(ValueError): + validate_job_id(bad) + + @pytest.mark.parametrize("bad", [123, 1.5, True, None, ["a"], {"id": "x"}]) + def test_non_strings_rejected(self, bad): + # uuid.UUID raises AttributeError/TypeError on non-strings; the helper + # must normalize those to ValueError so callers need one except clause. + with pytest.raises(ValueError): + validate_job_id(bad) + + class TestJobStatus: """Test JobStatus constants.""" From ce200c0850182722cfd6e0f9f9bd3f619e48281e Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Wed, 10 Jun 2026 17:04:52 -0700 Subject: [PATCH 43/45] feat(assets): include asset id in executed WebSocket message (#13862) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(assets): enrich executed WS message with asset metadata When --enable-assets is set, each file-type output entry in the `executed` WebSocket message now includes id, name, asset_hash, size, and mime_type — matching the shape already returned by /upload/image. The enrichment lives in comfy_execution/asset_enrichment.py (no torch dependency) and is called from both send sites in execution.py: freshly executed nodes register the file inline via register_file_in_place; cached node re-sends look up the existing AssetReference by file path to avoid re-hashing. Errors are caught per-entry so a failure never blocks the WS message from sending. * fix(assets): inject only id in executed WS message per Asset Identity RFC Per the Asset Identity RFC, the executed WebSocket payload should carry id alone — hash is already encoded in the filename, and name/preview_url/ size belong behind GET /api/assets/{id} rather than being pushed eagerly. Simplifies the DB lookup path: we only need ref.id, so the asset.hash null-check is no longer required as a fallback trigger. * fix(assets): reject path traversal when resolving output abs_path Subfolder/filename were joined and absolutized without containment check, so '..' segments or an absolute filename could escape the type's base directory and register an unrelated on-disk file as an asset. Add commonpath-based containment check; skip enrichment (warn, leave entry unchanged) when the resolved path escapes base. Catches ValueError from cross-drive paths on Windows. * docs(assets): drop Asset Identity RFC reference from docstring * docs(assets): trim docstring to what enrichment does, not what it doesn't * test(assets): use real platform paths so containment check works on Windows The previous test setup patched os.path.abspath to identity and used a POSIX-style '/output' base, which collided with Windows path separators in os.path.commonpath. Drop the abspath/join patches and use a real tempdir-rooted base so the containment check runs against actual platform paths. * refactor(assets): enrich at output-processing time, not in the WS send path Per review: enrichment lived inside the client_id-guarded send sites, so a headless run (no websocket client) never registered assets at all, and ui_outputs/history stored the un-enriched entries. Now output_ui is enriched once, right after the node produces it and before it is stored in ui_outputs — so registration happens regardless of connected clients, and the asset id flows into history and the execution cache for free. _send_cached_ui re-sends the stored (already-enriched) dict verbatim, which lets the DB-lookup-by-path fallback be deleted: every enrichment is now a fresh output, and register_file_in_place re-hashes on upsert so an overwritten path can never carry a stale id. --- comfy_execution/asset_enrichment.py | 66 ++++++ execution.py | 6 + .../execution_test/test_enrich_output.py | 205 ++++++++++++++++++ 3 files changed, 277 insertions(+) create mode 100644 comfy_execution/asset_enrichment.py create mode 100644 tests-unit/execution_test/test_enrich_output.py diff --git a/comfy_execution/asset_enrichment.py b/comfy_execution/asset_enrichment.py new file mode 100644 index 000000000000..38e9496a884b --- /dev/null +++ b/comfy_execution/asset_enrichment.py @@ -0,0 +1,66 @@ +"""Enrich executed-node output entries with asset id.""" +import logging +import os + + +def enrich_output_with_assets(output_ui: dict) -> dict: + """Register file-type output entries as assets and inject their ``id``. + + Runs at output-processing time, once per produced output, when + --enable-assets is set. Returns a new dict; entries without a resolvable + on-disk file path are left unchanged. Errors are caught per-entry so a + failure never blocks execution or the other entries. + """ + from comfy.cli_args import args + if not args.enable_assets: + return output_ui + + import folder_paths + from app.assets.services.ingest import register_file_in_place, DependencyMissingError + + enriched = {} + for key, entries in output_ui.items(): + if not isinstance(entries, list): + enriched[key] = entries + continue + new_entries = [] + for entry in entries: + if not isinstance(entry, dict) or "filename" not in entry or "type" not in entry: + new_entries.append(entry) + continue + try: + base = folder_paths.get_directory_by_type(entry["type"]) + if base is None: + new_entries.append(entry) + continue + base_abs = os.path.abspath(base) + abs_path = os.path.abspath(os.path.join(base_abs, entry.get("subfolder") or "", entry["filename"])) + try: + if os.path.commonpath([base_abs, abs_path]) != base_abs: + raise ValueError("escapes base") + except ValueError: + logging.warning("Asset enrichment skipped (path escapes base): %s", entry.get("filename")) + new_entries.append(entry) + continue + if not os.path.isfile(abs_path): + new_entries.append(entry) + continue + + # Register unconditionally: the file was just produced, and + # register_file_in_place re-hashes so an overwritten path can + # never carry a stale id. + result = register_file_in_place( + abs_path=abs_path, + name=entry["filename"], + tags=[entry["type"]], + ) + + entry = dict(entry) + entry["id"] = result.ref.id + except DependencyMissingError: + logging.warning("Asset enrichment skipped (blake3 not available): %s", entry.get("filename")) + except Exception: + logging.warning("Failed to enrich output entry with asset id: %s", entry.get("filename"), exc_info=True) + new_entries.append(entry) + enriched[key] = new_entries + return enriched diff --git a/execution.py b/execution.py index 5246d651cfa5..e6c6f39d6ff7 100644 --- a/execution.py +++ b/execution.py @@ -40,6 +40,7 @@ from comfy_execution.validation import validate_node_input from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler from comfy_execution.utils import CurrentNodeContext +from comfy_execution.asset_enrichment import enrich_output_with_assets from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger @@ -418,6 +419,7 @@ def _is_intermediate_output(dynprompt, node_id): class_def = nodes.NODE_CLASS_MAPPINGS[class_type] return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False) + def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs): if server.client_id is None: return @@ -552,6 +554,10 @@ async def await_completion(): asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: + # Enrich at output-processing time (not in the send path) so assets + # are registered even when no client is connected, and the asset id + # flows into ui_outputs and the cache alongside the raw entries. + output_ui = enrich_output_with_assets(output_ui) ui_outputs[unique_id] = { "meta": { "node_id": unique_id, diff --git a/tests-unit/execution_test/test_enrich_output.py b/tests-unit/execution_test/test_enrich_output.py new file mode 100644 index 000000000000..61490c49ec57 --- /dev/null +++ b/tests-unit/execution_test/test_enrich_output.py @@ -0,0 +1,205 @@ +"""Tests for enrich_output_with_assets in comfy_execution/asset_enrichment.py.""" +import os +import types +import unittest +from unittest.mock import MagicMock, patch + + +def _make_args(enable_assets: bool): + a = types.SimpleNamespace() + a.enable_assets = enable_assets + return a + + +def _make_register_result(ref_id="ref-id-2"): + result = MagicMock() + result.ref.id = ref_id + return result + + +# Platform-appropriate absolute base. tempfile.gettempdir() returns C:\... on +# Windows and /tmp on POSIX, so containment via commonpath behaves naturally. +_DEFAULT_BASE = os.path.join(__import__("tempfile").gettempdir(), "asset-enrichment-test-base") + + +def _mocked_modules(*, enable_assets=True, register_file_in_place=None, directory=_DEFAULT_BASE): + return { + "comfy.cli_args": MagicMock(args=_make_args(enable_assets)), + "folder_paths": MagicMock(get_directory_by_type=MagicMock(return_value=directory)), + "app.assets.services.ingest": MagicMock( + register_file_in_place=register_file_in_place or MagicMock(return_value=_make_register_result()), + DependencyMissingError=type("DependencyMissingError", (Exception,), {}), + ), + } + + +def _call(output_ui, *, enable_assets=True, file_exists=True, register_result=None, directory=_DEFAULT_BASE): + register_mock = MagicMock(return_value=register_result or _make_register_result()) + mocked = _mocked_modules( + enable_assets=enable_assets, + register_file_in_place=register_mock, + directory=directory, + ) + + # Only os.path.isfile is patched — abspath/join must run natively so the + # containment check sees real platform paths. + with patch.dict("sys.modules", mocked), \ + patch("os.path.isfile", return_value=file_exists): + import importlib + import comfy_execution.asset_enrichment as mod + importlib.reload(mod) + return mod.enrich_output_with_assets(output_ui) + + +class TestEnrichOutputWithAssets(unittest.TestCase): + + def test_disabled_returns_unchanged(self): + output = {"images": [{"filename": "a.png", "subfolder": "", "type": "output"}]} + result = _call(output, enable_assets=False) + self.assertNotIn("id", result["images"][0]) + + def test_non_list_value_passed_through(self): + output = {"text": "hello"} + result = _call(output) + self.assertEqual(result["text"], "hello") + + def test_entry_without_filename_unchanged(self): + output = {"latent": [{"subfolder": "", "type": "output"}]} + result = _call(output) + self.assertNotIn("id", result["latent"][0]) + + def test_entry_without_type_unchanged(self): + output = {"data": [{"filename": "a.png", "subfolder": ""}]} + result = _call(output) + self.assertNotIn("id", result["data"][0]) + + def test_file_not_on_disk_unchanged(self): + output = {"images": [{"filename": "missing.png", "subfolder": "", "type": "output"}]} + result = _call(output, file_exists=False) + self.assertNotIn("id", result["images"][0]) + + def test_unknown_type_returns_none_directory_unchanged(self): + output = {"images": [{"filename": "a.png", "subfolder": "", "type": "unknown"}]} + result = _call(output, directory=None) + self.assertNotIn("id", result["images"][0]) + + def test_register_injects_only_id(self): + reg = _make_register_result(ref_id="inline-ref") + output = {"images": [{"filename": "new.png", "subfolder": "", "type": "output"}]} + result = _call(output, register_result=reg) + img = result["images"][0] + self.assertEqual(img["id"], "inline-ref") + # Only id is injected — no asset_hash, name, preview_url, size + self.assertNotIn("asset_hash", img) + self.assertNotIn("name", img) + self.assertNotIn("preview_url", img) + self.assertNotIn("size", img) + + def test_register_called_per_entry(self): + register_mock = MagicMock(return_value=_make_register_result()) + mocked = _mocked_modules(register_file_in_place=register_mock) + output = { + "images": [ + {"filename": "a.png", "subfolder": "", "type": "output"}, + {"filename": "b.png", "subfolder": "", "type": "output"}, + ] + } + + with patch.dict("sys.modules", mocked), \ + patch("os.path.isfile", return_value=True): + import importlib + import comfy_execution.asset_enrichment as mod + importlib.reload(mod) + mod.enrich_output_with_assets(output) + + self.assertEqual(register_mock.call_count, 2) + + def test_original_entry_not_mutated(self): + orig = {"filename": "a.png", "subfolder": "", "type": "output"} + output = {"images": [orig]} + _call(output) + self.assertNotIn("id", orig) + + def test_enrichment_error_does_not_block_sibling_entries(self): + call_count = [0] + good_reg = _make_register_result(ref_id="good-ref") + + def register_side_effect(abs_path, name, tags): + call_count[0] += 1 + if call_count[0] == 1: + raise RuntimeError("boom") + return good_reg + + mocked = _mocked_modules(register_file_in_place=register_side_effect) + + output = { + "images": [ + {"filename": "bad.png", "subfolder": "", "type": "output"}, + {"filename": "good.png", "subfolder": "", "type": "output"}, + ] + } + + with patch.dict("sys.modules", mocked), \ + patch("os.path.isfile", return_value=True): + import importlib + import comfy_execution.asset_enrichment as mod + importlib.reload(mod) + result = mod.enrich_output_with_assets(output) + + imgs = result["images"] + self.assertNotIn("id", imgs[0]) + self.assertEqual(imgs[1]["id"], "good-ref") + + def test_multiple_output_keys_all_enriched(self): + output = { + "images": [{"filename": "a.png", "subfolder": "", "type": "output"}], + "videos": [{"filename": "b.mp4", "subfolder": "", "type": "output"}], + } + result = _call(output) + self.assertIn("id", result["images"][0]) + self.assertIn("id", result["videos"][0]) + + def test_none_entry_in_list_unchanged(self): + output = {"images": [None, {"filename": "a.png", "subfolder": "", "type": "output"}]} + result = _call(output) + self.assertIsNone(result["images"][0]) + self.assertIn("id", result["images"][1]) + + def test_path_traversal_subfolder_skipped(self): + register_mock = MagicMock(return_value=_make_register_result()) + mocked = _mocked_modules(register_file_in_place=register_mock) + + output = {"images": [{"filename": "passwd", "subfolder": "../../etc", "type": "output"}]} + + # Do NOT patch os.path.abspath — real resolution is required for the containment check. + with patch.dict("sys.modules", mocked), \ + patch("os.path.isfile", return_value=True): + import importlib + import comfy_execution.asset_enrichment as mod + importlib.reload(mod) + result = mod.enrich_output_with_assets(output) + + self.assertNotIn("id", result["images"][0]) + register_mock.assert_not_called() + + def test_absolute_filename_skipped(self): + register_mock = MagicMock(return_value=_make_register_result()) + mocked = _mocked_modules(register_file_in_place=register_mock) + + # Absolute filename — os.path.join discards earlier components when a later one is absolute. + absolute_filename = os.path.abspath(os.sep + "etc" + os.sep + "passwd") + output = {"images": [{"filename": absolute_filename, "subfolder": "", "type": "output"}]} + + with patch.dict("sys.modules", mocked), \ + patch("os.path.isfile", return_value=True): + import importlib + import comfy_execution.asset_enrichment as mod + importlib.reload(mod) + result = mod.enrich_output_with_assets(output) + + self.assertNotIn("id", result["images"][0]) + register_mock.assert_not_called() + + +if __name__ == "__main__": + unittest.main() From 431a1888d31114ef4959c8a9fb286a5cac8688f0 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Wed, 10 Jun 2026 19:23:01 -0700 Subject: [PATCH 44/45] revert(assets): drop job_ids filter from GET /api/assets (#14408) The job_ids query filter added in #13998 has no live consumer: the frontend Generated tab kept sourcing from GET /jobs, and the cloud side removed its equivalent filter from the shared asset spec. Carrying it on the local server only re-introduces Core<->Cloud drift on the shared contract, so remove it to match. Removed: the job_ids field + validator on ListAssetsQuery, the IN(...) clauses in list_references_page, the service/route passthrough, and the filter-only tests. Kept: the canonical-UUID prompt_id enforcement at job creation (also landed in #13998). It stands on its own -- job ids are matched verbatim by history keys, websocket correlation, and /interrupt -- and cloud inherits it by running core for execution, so no divergence is created. --- app/assets/api/routes.py | 1 - app/assets/api/schemas_in.py | 36 ----------- .../database/queries/asset_reference.py | 6 -- app/assets/services/asset_management.py | 2 - comfy_execution/jobs.py | 7 +-- .../assets_test/queries/test_asset_info.py | 50 ---------------- .../queries/test_list_assets_query.py | 60 ------------------- .../assets_test/test_prompt_id_enforcement.py | 8 +-- tests/execution/test_jobs.py | 4 +- 9 files changed, 9 insertions(+), 165 deletions(-) delete mode 100644 tests-unit/assets_test/queries/test_list_assets_query.py diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 6c9a3200d2ea..7ef462f5cd04 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -219,7 +219,6 @@ async def list_assets_route(request: web.Request) -> web.Response: exclude_tags=q.exclude_tags, name_contains=q.name_contains, metadata_filter=q.metadata_filter, - job_ids=q.job_ids, limit=q.limit, offset=q.offset, sort=sort, diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index 4ae18c65a523..af666746dce2 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -1,5 +1,4 @@ import json -import uuid from dataclasses import dataclass from typing import Any, Literal @@ -54,7 +53,6 @@ class ListAssetsQuery(BaseModel): include_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list) name_contains: str | None = None - job_ids: list[str] = Field(default_factory=list, max_length=500) # Accept either a JSON string (query param) or a dict metadata_filter: dict[str, Any] | None = None @@ -88,40 +86,6 @@ def _split_csv_tags(cls, v): return out return v - @field_validator("job_ids", mode="before") - @classmethod - def _split_and_validate_job_ids(cls, v): - # Accept "uuid1,uuid2" or ["uuid1","uuid2"] or repeated query params. - # Each entry must parse as a UUID; canonicalized to lowercase hyphenated form. - if v is None: - return [] - if isinstance(v, str): - raw = [t.strip() for t in v.split(",") if t.strip()] - elif isinstance(v, list): - raw = [] - for item in v: - if not isinstance(item, str): - raise ValueError( - f"job_ids entries must be strings, got {type(item).__name__}" - ) - raw.extend([t.strip() for t in item.split(",") if t.strip()]) - else: - raise ValueError( - f"job_ids must be a string or list of strings, got {type(v).__name__}" - ) - - out: list[str] = [] - seen: set[str] = set() - for s in raw: - try: - canonical = str(uuid.UUID(s)) - except ValueError as e: - raise ValueError(f"job_ids must be UUIDs: {s!r}") from e - if canonical not in seen: - seen.add(canonical) - out.append(canonical) - return out - @field_validator("metadata_filter", mode="before") @classmethod def _parse_metadata_json(cls, v): diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index 33ded8a1cc8c..792411800e1b 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -264,7 +264,6 @@ def list_references_page( include_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None, metadata_filter: dict | None = None, - job_ids: Sequence[str] | None = None, sort: str | None = None, order: str | None = None, after_cursor_value: object | None = None, @@ -294,9 +293,6 @@ def list_references_page( escaped, esc = escape_sql_like_string(name_contains) base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) - if job_ids: - base = base.where(AssetReference.job_id.in_(list(job_ids))) - base = apply_tag_filters(base, include_tags, exclude_tags) base = apply_metadata_filter(base, metadata_filter) @@ -349,8 +345,6 @@ def list_references_page( count_stmt = count_stmt.where( AssetReference.name.ilike(f"%{escaped}%", escape=esc) ) - if job_ids: - count_stmt = count_stmt.where(AssetReference.job_id.in_(list(job_ids))) count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = apply_metadata_filter(count_stmt, metadata_filter) diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py index 53aec7a15e1f..d4e4fc61c4ad 100644 --- a/app/assets/services/asset_management.py +++ b/app/assets/services/asset_management.py @@ -274,7 +274,6 @@ def list_assets_page( exclude_tags: Sequence[str] | None = None, name_contains: str | None = None, metadata_filter: dict | None = None, - job_ids: Sequence[str] | None = None, limit: int = 20, offset: int = 0, sort: str = "created_at", @@ -320,7 +319,6 @@ def list_assets_page( exclude_tags=exclude_tags, name_contains=name_contains, metadata_filter=metadata_filter, - job_ids=job_ids, limit=fetch_limit, offset=offset, sort=sort, diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 3fbcc3eb01be..20ebae15585d 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -25,10 +25,9 @@ def validate_job_id(value) -> str: Job ids must be UUIDs in the canonical lowercase hyphenated form. The id is stored and compared verbatim everywhere downstream — history keys, - websocket events, /interrupt matching, and the assets ``job_ids`` filter - (a String(36) column matched exactly) — so accepting another spelling - would either rewrite the client's id behind its back or mint a job whose - outputs the filter can never find. Rejecting loudly beats both. + websocket events, and /interrupt matching — so accepting another spelling + would silently rewrite the client's id and then miss every exact-match + lookup. Rejecting loudly beats that. Returns the id unchanged. Raises ValueError when the value is not a string in canonical UUID form. diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py index ba729a2701d5..fe510e3422e6 100644 --- a/tests-unit/assets_test/queries/test_asset_info.py +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -158,56 +158,6 @@ def test_sorting(self, session: Session): refs, _, _ = list_references_page(session, sort="name", order="asc") assert refs[0].name == "large" - def test_job_ids_filter(self, session: Session): - asset = _make_asset(session, "hash1") - job_a = str(uuid.uuid4()) - job_b = str(uuid.uuid4()) - ref_a = _make_reference(session, asset, name="from_job_a") - ref_a.job_id = job_a - ref_b = _make_reference(session, asset, name="from_job_b") - ref_b.job_id = job_b - _make_reference(session, asset, name="no_job") - session.commit() - - # Single job filter - refs, _, total = list_references_page(session, job_ids=[job_a]) - assert total == 1 - assert refs[0].name == "from_job_a" - - # Multi-job filter (IN) - refs, _, total = list_references_page(session, job_ids=[job_a, job_b]) - names = sorted(r.name for r in refs) - assert total == 2 - assert names == ["from_job_a", "from_job_b"] - - # Unknown job id matches nothing - refs, _, total = list_references_page(session, job_ids=[str(uuid.uuid4())]) - assert total == 0 - assert refs == [] - - # Empty/None means no filter -> all three references - refs, _, total = list_references_page(session, job_ids=[]) - assert total == 3 - refs, _, total = list_references_page(session, job_ids=None) - assert total == 3 - - def test_job_ids_combined_with_other_filters(self, session: Session): - asset = _make_asset(session, "hash1") - job_a = str(uuid.uuid4()) - ref_match = _make_reference(session, asset, name="match.bin") - ref_match.job_id = job_a - ref_wrong_name = _make_reference(session, asset, name="other.bin") - ref_wrong_name.job_id = job_a - ref_wrong_job = _make_reference(session, asset, name="match.bin") - ref_wrong_job.job_id = str(uuid.uuid4()) - session.commit() - - refs, _, total = list_references_page( - session, job_ids=[job_a], name_contains="match" - ) - assert total == 1 - assert refs[0].id == ref_match.id - class TestFetchReferenceAssetAndTags: def test_returns_none_for_nonexistent(self, session: Session): diff --git a/tests-unit/assets_test/queries/test_list_assets_query.py b/tests-unit/assets_test/queries/test_list_assets_query.py deleted file mode 100644 index e8d3430e2eb7..000000000000 --- a/tests-unit/assets_test/queries/test_list_assets_query.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Schema-level unit tests for ListAssetsQuery (no DB required).""" -import uuid - -import pytest -from pydantic import ValidationError - -from app.assets.api.schemas_in import ListAssetsQuery - - -class TestJobIdsValidator: - def test_csv_string_parses_and_canonicalizes(self): - a = "AAAAAAAA-BBBB-CCCC-DDDD-EEEEEEEEEEEE" - b = "11111111-2222-3333-4444-555555555555" - q = ListAssetsQuery.model_validate({"job_ids": f"{a},{b}"}) - # Canonicalized to lowercase - assert q.job_ids == [a.lower(), b] - - def test_repeated_query_params_as_list(self): - a = "11111111-1111-1111-1111-111111111111" - b = "22222222-2222-2222-2222-222222222222" - q = ListAssetsQuery.model_validate({"job_ids": [a, b]}) - assert q.job_ids == [a, b] - - def test_dedup_preserves_first_seen_order(self): - a = "11111111-1111-1111-1111-111111111111" - b = "22222222-2222-2222-2222-222222222222" - q = ListAssetsQuery.model_validate({"job_ids": [a, b, a]}) - assert q.job_ids == [a, b] - - def test_default_empty(self): - q = ListAssetsQuery.model_validate({}) - assert q.job_ids == [] - - def test_invalid_uuid_rejected(self): - with pytest.raises(ValidationError) as exc: - ListAssetsQuery.model_validate({"job_ids": "not-a-uuid"}) - assert "must be UUIDs" in str(exc.value) - - def test_non_string_list_item_rejected(self): - with pytest.raises(ValidationError) as exc: - ListAssetsQuery.model_validate( - {"job_ids": ["11111111-1111-1111-1111-111111111111", 42]} - ) - assert "must be strings" in str(exc.value) - - def test_non_string_non_list_value_rejected(self): - with pytest.raises(ValidationError) as exc: - ListAssetsQuery.model_validate({"job_ids": {"bad": "shape"}}) - assert "must be a string or list of strings" in str(exc.value) - - def test_max_length_enforced(self): - too_many = [str(uuid.uuid4()) for _ in range(501)] - with pytest.raises(ValidationError) as exc: - ListAssetsQuery.model_validate({"job_ids": too_many}) - assert exc.value.errors()[0]["type"] == "too_long" - - def test_max_length_boundary_accepted(self): - at_cap = [str(uuid.uuid4()) for _ in range(500)] - q = ListAssetsQuery.model_validate({"job_ids": at_cap}) - assert len(q.job_ids) == 500 diff --git a/tests-unit/assets_test/test_prompt_id_enforcement.py b/tests-unit/assets_test/test_prompt_id_enforcement.py index fb961beae957..86a755c9f000 100644 --- a/tests-unit/assets_test/test_prompt_id_enforcement.py +++ b/tests-unit/assets_test/test_prompt_id_enforcement.py @@ -1,9 +1,9 @@ """POST /prompt enforces canonical-UUID job ids at creation time. -Lives in assets_test because it uses this suite's booted-server fixture and -because the invariant exists for the assets pipeline: the GET /api/assets -``job_ids`` filter matches stored job ids exactly, so a job minted with a -non-canonical id would produce assets the filter can never find. +Lives in assets_test because it uses this suite's booted-server fixture. The +invariant itself is pipeline-wide: a job id is stored and compared verbatim +downstream — history keys, websocket correlation, and /interrupt matching — +so a job minted with a non-canonical id would miss every exact-match lookup. The prompt bodies here are intentionally invalid workflows — prompt_id validation happens before workflow validation, so a rejected id returns diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py index 30e47071d8c8..f7cb612e4f57 100644 --- a/tests/execution/test_jobs.py +++ b/tests/execution/test_jobs.py @@ -35,8 +35,8 @@ def test_canonical_form_passes_through(self): ) def test_non_canonical_spellings_rejected(self, variant): # uuid.UUID parses all of these, but accepting them would silently - # rewrite the client's id (history keys, websocket events, and the - # assets job_ids filter all match the stored form exactly). + # rewrite the client's id (history keys, websocket events, and + # /interrupt matching all match the stored form exactly). with pytest.raises(ValueError): validate_job_id(variant) From 74ee826790035be831c960e4c4bd60051273a99a Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Thu, 11 Jun 2026 12:15:53 +0900 Subject: [PATCH 45/45] chore(openapi): sync shared API contract from cloud@e3c52ad (#14406) --- openapi.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/openapi.yaml b/openapi.yaml index 58614103a756..6e203b1cd9e2 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -1067,6 +1067,9 @@ components: comfyui_version: description: ComfyUI version type: string + deploy_environment: + description: How this ComfyUI instance is deployed (e.g. cloud, local-git, local-portable, local-desktop) + type: string embedded_python: description: Whether using embedded Python type: boolean