From 2a1f338c66d9a1107de9cab5d2921641ed250991 Mon Sep 17 00:00:00 2001 From: Liu Zhao Date: Thu, 16 Apr 2026 11:51:29 +0800 Subject: [PATCH 1/6] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=B9=B6=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E5=B9=B6=E8=A1=8C=E6=89=A7=E8=A1=8C=E5=B7=A5=E4=BD=9C?= =?UTF-8?q?=E6=B5=81=E7=9A=84=E5=8A=9F=E8=83=BD=EF=BC=8C=E5=85=81=E8=AE=B8?= =?UTF-8?q?=E7=94=A8=E6=88=B7=E9=80=9A=E8=BF=87=20--parallel=20=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E8=AE=BE=E7=BD=AE=E5=B9=B6=E5=8F=91=E5=B7=A5=E4=BD=9C?= =?UTF-8?q?=E6=B5=81=E6=95=B0=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- comfy/cli_args.py | 2 ++ comfy_api/latest/__init__.py | 3 +- comfy_execution/progress.py | 64 ++++++++++++++++++++---------------- execution.py | 48 +++++++++++++-------------- main.py | 24 +++++++++----- 5 files changed, 79 insertions(+), 62 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dbaadf723cf2..53024fe1ad88 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -148,6 +148,8 @@ def from_string(cls, value: str): parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.") +parser.add_argument("--parallel", type=int, default=1, metavar="N", help="Run N workflows concurrently. Useful when workflows are network-bound (e.g. BizyAir nodes) and GPU/VRAM is not the bottleneck. Default: 1 (sequential).") + parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.") diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 04973fea0df7..65341c888f15 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -51,6 +51,7 @@ async def set_progress( Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK """ executing_context = get_executing_context() + prompt_id = executing_context.prompt_id if executing_context is not None else None if node_id is None and executing_context is not None: node_id = executing_context.node_id if node_id is None: @@ -78,7 +79,7 @@ async def set_progress( preview_size = None if ignore_size_limit else args.preview_size to_display = (image_format, to_display, preview_size) - get_progress_state().update_progress( + get_progress_state(prompt_id).update_progress( node_id=node_id, value=value, max_value=max_value, diff --git a/comfy_execution/progress.py b/comfy_execution/progress.py index f951a33507f5..ce871c09683d 100644 --- a/comfy_execution/progress.py +++ b/comfy_execution/progress.py @@ -1,5 +1,6 @@ from __future__ import annotations +import threading from typing import TypedDict, Dict, Optional, Tuple from typing_extensions import override from PIL import Image @@ -152,9 +153,10 @@ class WebUIProgressHandler(ProgressHandler): Handler that sends progress updates to the WebUI via WebSockets. """ - def __init__(self, server_instance): + def __init__(self, server_instance, client_id: Optional[str] = None): super().__init__("webui") self.server_instance = server_instance + self.client_id = client_id def set_registry(self, registry: "ProgressRegistry"): self.registry = registry @@ -183,7 +185,7 @@ def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressStat # Send a combined progress_state message with all node states # Include client_id to ensure message is only sent to the initiating client self.server_instance.send_sync( - "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id + "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.client_id ) @override @@ -206,10 +208,9 @@ def update_handler( if self.registry: self._send_progress_state(prompt_id, self.registry.nodes) if image: - # Only send new format if client supports it - if feature_flags.supports_feature( + if self.client_id is not None and feature_flags.supports_feature( self.server_instance.sockets_metadata, - self.server_instance.client_id, + self.client_id, "supports_preview_metadata", ): metadata = { @@ -226,7 +227,7 @@ def update_handler( self.server_instance.send_sync( BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, (image, metadata), - self.server_instance.client_id, + self.client_id, ) @override @@ -240,9 +241,10 @@ class ProgressRegistry: Registry that maintains node progress state and notifies registered handlers. """ - def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"): + def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt", client_id: Optional[str] = None): self.prompt_id = prompt_id self.dynprompt = dynprompt + self.client_id = client_id self.nodes: Dict[str, NodeProgressState] = {} self.handlers: Dict[str, ProgressHandler] = {} @@ -319,32 +321,38 @@ def reset_handlers(self) -> None: for handler in self.handlers.values(): handler.reset() -# Global registry instance -global_progress_registry: ProgressRegistry | None = None +# Per-prompt progress registries for parallel execution support +_progress_registries: Dict[str, ProgressRegistry] = {} +_progress_registries_lock = threading.Lock() -def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: - global global_progress_registry +def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt", client_id: Optional[str] = None) -> None: + with _progress_registries_lock: + if prompt_id in _progress_registries: + _progress_registries[prompt_id].reset_handlers() + _progress_registries[prompt_id] = ProgressRegistry(prompt_id, dynprompt, client_id=client_id) - # Reset existing handlers if registry exists - if global_progress_registry is not None: - global_progress_registry.reset_handlers() - # Create new registry - global_progress_registry = ProgressRegistry(prompt_id, dynprompt) - - -def add_progress_handler(handler: ProgressHandler) -> None: - registry = get_progress_state() +def add_progress_handler(handler: ProgressHandler, prompt_id: Optional[str] = None) -> None: + registry = get_progress_state(prompt_id) handler.set_registry(registry) registry.register_handler(handler) -def get_progress_state() -> ProgressRegistry: - global global_progress_registry - if global_progress_registry is None: - from comfy_execution.graph import DynamicPrompt +def remove_progress_state(prompt_id: str) -> None: + with _progress_registries_lock: + reg = _progress_registries.pop(prompt_id, None) + if reg is not None: + reg.reset_handlers() - global_progress_registry = ProgressRegistry( - prompt_id="", dynprompt=DynamicPrompt({}) - ) - return global_progress_registry + +def get_progress_state(prompt_id: Optional[str] = None) -> ProgressRegistry: + if prompt_id is not None: + with _progress_registries_lock: + reg = _progress_registries.get(prompt_id) + if reg is not None: + return reg + with _progress_registries_lock: + if _progress_registries: + return next(iter(_progress_registries.values())) + from comfy_execution.graph import DynamicPrompt + return ProgressRegistry(prompt_id="", dynprompt=DynamicPrompt({})) diff --git a/execution.py b/execution.py index 5e02dffb204f..cabed2a56955 100644 --- a/execution.py +++ b/execution.py @@ -36,7 +36,7 @@ ) from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input -from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler +from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, remove_progress_state, WebUIProgressHandler from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io @@ -416,15 +416,15 @@ def _is_intermediate_output(dynprompt, node_id): class_def = nodes.NODE_CLASS_MAPPINGS[class_type] return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False) -def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs): - if server.client_id is None: +def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs, client_id): + if client_id is None: return cached_ui = cached.ui or {} - server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id) + server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, client_id) if cached.ui is not None: ui_outputs[node_id] = cached.ui -async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): +async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs, client_id=None): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -434,8 +434,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, class_def = nodes.NODE_CLASS_MAPPINGS[class_type] cached = await caches.outputs.get(unique_id) if cached is not None: - _send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs) - get_progress_state().finish_progress(unique_id) + _send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs, client_id) + get_progress_state(prompt_id).finish_progress(unique_id) execution_list.cache_update(unique_id, cached) return (ExecutionResult.SUCCESS, None, None) @@ -478,11 +478,11 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, del pending_subgraph_results[unique_id] has_subgraph = False else: - get_progress_state().start_progress(unique_id) + get_progress_state(prompt_id).start_progress(unique_id) input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) - if server.client_id is not None: + if client_id is not None: server.last_node_id = display_node_id - server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, client_id) obj = await caches.objects.get(unique_id) if obj is None: @@ -522,7 +522,7 @@ def execution_block_cb(block): "current_inputs": [], "current_outputs": [], } - server.send_sync("execution_error", mes, server.client_id) + server.send_sync("execution_error", mes, client_id) return ExecutionBlocker(None) else: return block @@ -558,8 +558,8 @@ async def await_completion(): }, "output": output_ui } - if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + if client_id is not None: + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, client_id) if has_subgraph: cached_outputs = [] new_node_ids = [] @@ -640,7 +640,7 @@ async def await_completion(): return (ExecutionResult.FAILURE, error_details, ex) - get_progress_state().finish_progress(unique_id) + get_progress_state(prompt_id).finish_progress(unique_id) executed.add(unique_id) return (ExecutionResult.SUCCESS, None, None) @@ -656,6 +656,7 @@ def reset(self): self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) self.status_messages = [] self.success = True + self.client_id = None def add_message(self, event, data: dict, broadcast: bool): data = { @@ -663,8 +664,8 @@ def add_message(self, event, data: dict, broadcast: bool): "timestamp": int(time.time() * 1000), } self.status_messages.append((event, data)) - if self.server.client_id is not None or broadcast: - self.server.send_sync(event, data, self.server.client_id) + if self.client_id is not None or broadcast: + self.server.send_sync(event, data, self.client_id) def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): node_id = error["node_id"] @@ -715,10 +716,8 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= nodes.interrupt_processing(False) - if "client_id" in extra_data: - self.server.client_id = extra_data["client_id"] - else: - self.server.client_id = None + self.client_id = extra_data.get("client_id", None) + self.server.client_id = self.client_id self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) @@ -731,8 +730,8 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= try: with torch.inference_mode(): dynamic_prompt = DynamicPrompt(prompt) - reset_progress_state(prompt_id, dynamic_prompt) - add_progress_handler(WebUIProgressHandler(self.server)) + reset_progress_state(prompt_id, dynamic_prompt, client_id=self.client_id) + add_progress_handler(WebUIProgressHandler(self.server, client_id=self.client_id), prompt_id=prompt_id) is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) for cache in self.caches.all: await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) @@ -767,7 +766,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= break assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs, client_id=self.client_id) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) @@ -791,7 +790,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= cached = await self.caches.outputs.get(node_id) if cached is not None: display_node_id = dynamic_prompt.get_display_node_id(node_id) - _send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs) + _send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs, client_id=self.client_id) self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) ui_outputs = {} @@ -809,6 +808,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= finally: comfy.memory_management.set_ram_cache_release_state(None, 0) self._notify_prompt_lifecycle("end", prompt_id) + remove_progress_state(prompt_id) async def validate_inputs(prompt_id, prompt, item, validated): diff --git a/main.py b/main.py index 12b04719d572..15fc2157d0c7 100644 --- a/main.py +++ b/main.py @@ -321,8 +321,8 @@ def prompt_worker(q, server_instance): status_str='success' if e.success else 'error', completed=e.success, messages=e.status_messages), process_item=remove_sensitive) - if server_instance.client_id is not None: - server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + if e.client_id is not None: + server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, e.client_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time @@ -385,21 +385,24 @@ def hook(value, total, preview_image, prompt_id=None, node_id=None): prompt_id = server_instance.last_prompt_id if node_id is None: node_id = server_instance.last_node_id + + registry = get_progress_state(prompt_id) + client_id = registry.client_id if registry else None + progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id} - get_progress_state().update_progress(node_id, value, total, preview_image) + registry.update_progress(node_id, value, total, preview_image) - server_instance.send_sync("progress", progress, server_instance.client_id) + server_instance.send_sync("progress", progress, client_id) if preview_image is not None: - # Only send old method if client doesn't support preview metadata - if not feature_flags.supports_feature( + if client_id is None or not feature_flags.supports_feature( server_instance.sockets_metadata, - server_instance.client_id, + client_id, "supports_preview_metadata", ): server_instance.send_sync( BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, - server_instance.client_id, + client_id, ) comfy.utils.set_progress_bar_global_hook(hook) @@ -478,7 +481,10 @@ def start_comfyui(asyncio_loop=None): prompt_server.add_routes() hijack_progress(prompt_server) - threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start() + num_workers = max(1, args.parallel) + logging.info(f"Starting {num_workers} prompt worker(s)") + for i in range(num_workers): + threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,), name=f"prompt_worker-{i}").start() if args.quick_test_for_ci: exit(0) From 81020110f8e915bcda57b6768ea48b7159e8f22e Mon Sep 17 00:00:00 2001 From: Liu Zhao Date: Fri, 17 Apr 2026 17:14:57 +0800 Subject: [PATCH 2/6] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=BD=93=20prompt=5Fid?= =?UTF-8?q?=20=E5=AF=B9=E5=BA=94=E7=9A=84=20registry=20=E5=B7=B2=E8=A2=AB?= =?UTF-8?q?=20remove=5Fprogress=5Fstate()=20=E5=88=A0=E9=99=A4=EF=BC=88?= =?UTF-8?q?=E5=B7=A5=E4=BD=9C=E6=B5=81=E5=AE=8C=E6=88=90=E6=97=B6=E8=B0=83?= =?UTF-8?q?=E7=94=A8=EF=BC=89=EF=BC=8C=E6=88=96=E8=80=85=20prompt=5Fid=20i?= =?UTF-8?q?s=20None=20=E6=97=B6=EF=BC=8C=E4=BC=9A=E8=BF=94=E5=9B=9E?= =?UTF-8?q?=E5=8F=A6=E4=B8=80=E4=B8=AA=E5=B7=A5=E4=BD=9C=E6=B5=81=E7=9A=84?= =?UTF-8?q?=20registry=E3=80=82=E8=BF=99=E5=8F=AF=E8=83=BD=E5=AF=BC?= =?UTF-8?q?=E8=87=B4=EF=BC=9A=20-=20=E8=BF=9B=E5=BA=A6=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E5=86=99=E5=85=A5=E9=94=99=E8=AF=AF=E7=9A=84=E5=B7=A5=E4=BD=9C?= =?UTF-8?q?=E6=B5=81=E7=8A=B6=E6=80=81=20-=20WebSocket=20=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E5=8F=91=E7=BB=99=E9=94=99=E8=AF=AF=E7=9A=84=E5=AE=A2?= =?UTF-8?q?=E6=88=B7=E7=AB=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- comfy_execution/progress.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_execution/progress.py b/comfy_execution/progress.py index ce871c09683d..3f33a314e038 100644 --- a/comfy_execution/progress.py +++ b/comfy_execution/progress.py @@ -351,8 +351,8 @@ def get_progress_state(prompt_id: Optional[str] = None) -> ProgressRegistry: reg = _progress_registries.get(prompt_id) if reg is not None: return reg - with _progress_registries_lock: - if _progress_registries: - return next(iter(_progress_registries.values())) + # with _progress_registries_lock: + # if _progress_registries: + # return next(iter(_progress_registries.values())) from comfy_execution.graph import DynamicPrompt return ProgressRegistry(prompt_id="", dynprompt=DynamicPrompt({})) From 5a97bda3c645e917a8567207b5c724fbff763308 Mon Sep 17 00:00:00 2001 From: Liu Zhao Date: Tue, 21 Apr 2026 16:15:35 +0800 Subject: [PATCH 3/6] =?UTF-8?q?=E5=A2=9E=E5=BC=BA=E5=B7=A5=E4=BD=9C?= =?UTF-8?q?=E6=B5=81=E4=B8=AD=E5=AF=B9=E4=B8=AD=E6=96=AD=E5=A4=84=E7=90=86?= =?UTF-8?q?=E7=9A=84=E6=94=AF=E6=8C=81=EF=BC=8C=E5=85=81=E8=AE=B8=E6=8C=89?= =?UTF-8?q?=E6=8F=90=E7=A4=BA=20ID=20=E6=B3=A8=E5=86=8C=E5=92=8C=E6=B3=A8?= =?UTF-8?q?=E9=94=80=E6=B4=BB=E5=8A=A8=E6=8F=90=E7=A4=BA=EF=BC=8C=E6=94=B9?= =?UTF-8?q?=E8=BF=9B=E5=B9=B6=E8=A1=8C=E6=89=A7=E8=A1=8C=E7=9A=84=E4=B8=AD?= =?UTF-8?q?=E6=96=AD=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- comfy/model_management.py | 72 ++++++++++++++++++++++++++++++++++++--- execution.py | 6 ++-- main.py | 3 +- nodes.py | 4 +-- server.py | 3 +- 5 files changed, 78 insertions(+), 10 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index bcf1399c4691..577377625fd7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1807,22 +1807,86 @@ class InterruptProcessingException(Exception): interrupt_processing_mutex = threading.RLock() interrupt_processing = False -def interrupt_current_processing(value=True): +# Prompt-scoped interrupts prevent one worker from consuming another worker's cancel signal. +interrupted_prompt_ids = set() +active_prompt_ids = set() + + +def _resolve_interrupt_prompt_id(prompt_id=None): + if prompt_id is not None: + return prompt_id + + try: + from comfy_execution.utils import get_executing_context + executing_context = get_executing_context() + except Exception: + executing_context = None + + if executing_context is not None: + return executing_context.prompt_id + return None + + +def register_active_prompt(prompt_id): + global active_prompt_ids + global interrupt_processing_mutex + with interrupt_processing_mutex: + # A restarted prompt_id should begin in a clean state. + active_prompt_ids.add(prompt_id) + interrupted_prompt_ids.discard(prompt_id) + + +def unregister_active_prompt(prompt_id): + global active_prompt_ids + global interrupt_processing_mutex + with interrupt_processing_mutex: + active_prompt_ids.discard(prompt_id) + interrupted_prompt_ids.discard(prompt_id) + + +def interrupt_current_processing(value=True, prompt_id=None): global interrupt_processing global interrupt_processing_mutex with interrupt_processing_mutex: - interrupt_processing = value + if prompt_id is None: + if value: + if active_prompt_ids: + # Global interrupt fan-outs to every prompt that is currently running. + interrupted_prompt_ids.update(active_prompt_ids) + interrupt_processing = False + else: + # Preserve the legacy global flag for code paths that still have no prompt context. + interrupt_processing = True + else: + interrupt_processing = False + interrupted_prompt_ids.clear() + return + + if value: + interrupted_prompt_ids.add(prompt_id) + else: + interrupted_prompt_ids.discard(prompt_id) -def processing_interrupted(): + +def processing_interrupted(prompt_id=None): global interrupt_processing global interrupt_processing_mutex + resolved_prompt_id = _resolve_interrupt_prompt_id(prompt_id) with interrupt_processing_mutex: + if resolved_prompt_id is not None: + return resolved_prompt_id in interrupted_prompt_ids return interrupt_processing -def throw_exception_if_processing_interrupted(): + +def throw_exception_if_processing_interrupted(prompt_id=None): global interrupt_processing global interrupt_processing_mutex + resolved_prompt_id = _resolve_interrupt_prompt_id(prompt_id) with interrupt_processing_mutex: + if resolved_prompt_id is not None: + if resolved_prompt_id in interrupted_prompt_ids: + raise InterruptProcessingException() + return if interrupt_processing: interrupt_processing = False raise InterruptProcessingException() diff --git a/execution.py b/execution.py index cabed2a56955..a390a85d8828 100644 --- a/execution.py +++ b/execution.py @@ -713,8 +713,8 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): set_preview_method(extra_data.get("preview_method")) - - nodes.interrupt_processing(False) + # Register before node execution starts so targeted/global interrupts can see this prompt immediately. + comfy.model_management.register_active_prompt(prompt_id) self.client_id = extra_data.get("client_id", None) self.server.client_id = self.client_id @@ -809,6 +809,8 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= comfy.memory_management.set_ram_cache_release_state(None, 0) self._notify_prompt_lifecycle("end", prompt_id) remove_progress_state(prompt_id) + # Drop prompt-scoped interrupt state once this execution is fully finished. + comfy.model_management.unregister_active_prompt(prompt_id) async def validate_inputs(prompt_id, prompt, item, validated): diff --git a/main.py b/main.py index 15fc2157d0c7..34734780a401 100644 --- a/main.py +++ b/main.py @@ -380,11 +380,12 @@ def hook(value, total, preview_image, prompt_id=None, node_id=None): prompt_id = executing_context.prompt_id if node_id is None and executing_context is not None: node_id = executing_context.node_id - comfy.model_management.throw_exception_if_processing_interrupted() if prompt_id is None: prompt_id = server_instance.last_prompt_id if node_id is None: node_id = server_instance.last_node_id + # Progress callbacks can run outside the main execute() stack, so re-check cancellation by prompt_id here. + comfy.model_management.throw_exception_if_processing_interrupted(prompt_id) registry = get_progress_state(prompt_id) client_id = registry.client_id if registry else None diff --git a/nodes.py b/nodes.py index 299b3d7585fb..e5aafc2e838c 100644 --- a/nodes.py +++ b/nodes.py @@ -51,8 +51,8 @@ def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() -def interrupt_processing(value=True): - comfy.model_management.interrupt_current_processing(value) +def interrupt_processing(value=True, prompt_id=None): + comfy.model_management.interrupt_current_processing(value, prompt_id=prompt_id) MAX_RESOLUTION=16384 diff --git a/server.py b/server.py index 881da8e66ec5..be32b0002e2c 100644 --- a/server.py +++ b/server.py @@ -1003,7 +1003,8 @@ async def post_interrupt(request): break if should_interrupt: - nodes.interrupt_processing() + # Forward the target prompt_id so parallel workers only cancel the requested run. + nodes.interrupt_processing(prompt_id=prompt_id) else: logging.info(f"Prompt {prompt_id} is not currently running, skipping interrupt") else: From e03e9504e98344e1e567482ae31a56be6b959f91 Mon Sep 17 00:00:00 2001 From: Liu Zhao Date: Tue, 21 Apr 2026 16:34:25 +0800 Subject: [PATCH 4/6] =?UTF-8?q?=E5=A2=9E=E5=BC=BA=E7=89=B9=E5=AE=9A?= =?UTF-8?q?=E6=8F=90=E7=A4=BA=E4=B8=AD=E6=96=AD=E7=9A=84=E5=AE=A2=E6=88=B7?= =?UTF-8?q?=E7=AB=AF=E6=89=80=E6=9C=89=E6=9D=83=E9=AA=8C=E8=AF=81=EF=BC=8C?= =?UTF-8?q?=E8=A6=81=E6=B1=82=E6=8F=90=E4=BE=9B=20client=5Fid=EF=BC=8C?= =?UTF-8?q?=E6=8B=92=E7=BB=9D=E6=97=A0=E6=95=88=E7=9A=84=E5=85=A8=E5=B1=80?= =?UTF-8?q?=E4=B8=AD=E6=96=AD=E8=AF=B7=E6=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server.py | 82 +++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 74 insertions(+), 8 deletions(-) diff --git a/server.py b/server.py index be32b0002e2c..7f8522f7d913 100644 --- a/server.py +++ b/server.py @@ -991,26 +991,92 @@ async def post_interrupt(request): # Check if a specific prompt_id was provided for targeted interruption prompt_id = json_data.get('prompt_id') if prompt_id: + request_client_id = json_data.get('client_id') + if request_client_id is None: + return web.json_response( + { + "error": { + "type": "missing_client_id", + "message": "client_id is required when interrupting a specific prompt", + "details": "Provide the client_id that originally submitted the prompt", + "extra_info": {"prompt_id": prompt_id}, + } + }, + status=400, + ) + currently_running, _ = self.prompt_queue.get_current_queue() # Check if the prompt_id matches any currently running prompt - should_interrupt = False + matching_item = None for item in currently_running: # item structure: (number, prompt_id, prompt, extra_data, outputs_to_execute) if item[1] == prompt_id: - logging.info(f"Interrupting prompt {prompt_id}") - should_interrupt = True + matching_item = item break - if should_interrupt: + if matching_item is not None: + prompt_client_id = matching_item[3].get('client_id') + if prompt_client_id is None: + return web.json_response( + { + "error": { + "type": "prompt_owner_unknown", + "message": "Cannot interrupt prompt because it has no recorded client_id owner", + "details": "Only prompts submitted with a client_id can be ownership-validated for interrupt", + "extra_info": {"prompt_id": prompt_id}, + } + }, + status=403, + ) + + if prompt_client_id != request_client_id: + logging.warning( + f"Rejected interrupt for prompt {prompt_id}: requester client_id {request_client_id} does not own prompt" + ) + return web.json_response( + { + "error": { + "type": "prompt_not_owned_by_client", + "message": "client_id does not own the target prompt", + "details": "Interrupt is only allowed for the client_id that originally submitted the prompt", + "extra_info": { + "prompt_id": prompt_id, + "client_id": request_client_id, + }, + } + }, + status=403, + ) + + logging.info(f"Interrupting prompt {prompt_id} for client_id {request_client_id}") # Forward the target prompt_id so parallel workers only cancel the requested run. nodes.interrupt_processing(prompt_id=prompt_id) else: - logging.info(f"Prompt {prompt_id} is not currently running, skipping interrupt") + return web.json_response( + { + "error": { + "type": "prompt_not_running", + "message": "Prompt is not currently running", + "details": "Only actively running prompts can be interrupted", + "extra_info": {"prompt_id": prompt_id}, + } + }, + status=404, + ) else: - # No prompt_id provided, do a global interrupt - logging.info("Global interrupt (no prompt_id specified)") - nodes.interrupt_processing() + logging.warning("Rejected global interrupt without prompt_id") + return web.json_response( + { + "error": { + "type": "global_interrupt_disabled", + "message": "Global interrupt is disabled when ownership validation is enforced", + "details": "Provide both prompt_id and the owning client_id to interrupt a workflow", + "extra_info": {}, + } + }, + status=403, + ) return web.Response(status=200) From de7d0ddfcd1d8396c7ee10b760b314f646e2b995 Mon Sep 17 00:00:00 2001 From: Liu Zhao Date: Tue, 21 Apr 2026 16:37:48 +0800 Subject: [PATCH 5/6] =?UTF-8?q?Revert=20"=E5=A2=9E=E5=BC=BA=E7=89=B9?= =?UTF-8?q?=E5=AE=9A=E6=8F=90=E7=A4=BA=E4=B8=AD=E6=96=AD=E7=9A=84=E5=AE=A2?= =?UTF-8?q?=E6=88=B7=E7=AB=AF=E6=89=80=E6=9C=89=E6=9D=83=E9=AA=8C=E8=AF=81?= =?UTF-8?q?=EF=BC=8C=E8=A6=81=E6=B1=82=E6=8F=90=E4=BE=9B=20client=5Fid?= =?UTF-8?q?=EF=BC=8C=E6=8B=92=E7=BB=9D=E6=97=A0=E6=95=88=E7=9A=84=E5=85=A8?= =?UTF-8?q?=E5=B1=80=E4=B8=AD=E6=96=AD=E8=AF=B7=E6=B1=82"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit e03e9504e98344e1e567482ae31a56be6b959f91. --- server.py | 82 ++++++------------------------------------------------- 1 file changed, 8 insertions(+), 74 deletions(-) diff --git a/server.py b/server.py index 7f8522f7d913..be32b0002e2c 100644 --- a/server.py +++ b/server.py @@ -991,92 +991,26 @@ async def post_interrupt(request): # Check if a specific prompt_id was provided for targeted interruption prompt_id = json_data.get('prompt_id') if prompt_id: - request_client_id = json_data.get('client_id') - if request_client_id is None: - return web.json_response( - { - "error": { - "type": "missing_client_id", - "message": "client_id is required when interrupting a specific prompt", - "details": "Provide the client_id that originally submitted the prompt", - "extra_info": {"prompt_id": prompt_id}, - } - }, - status=400, - ) - currently_running, _ = self.prompt_queue.get_current_queue() # Check if the prompt_id matches any currently running prompt - matching_item = None + should_interrupt = False for item in currently_running: # item structure: (number, prompt_id, prompt, extra_data, outputs_to_execute) if item[1] == prompt_id: - matching_item = item + logging.info(f"Interrupting prompt {prompt_id}") + should_interrupt = True break - if matching_item is not None: - prompt_client_id = matching_item[3].get('client_id') - if prompt_client_id is None: - return web.json_response( - { - "error": { - "type": "prompt_owner_unknown", - "message": "Cannot interrupt prompt because it has no recorded client_id owner", - "details": "Only prompts submitted with a client_id can be ownership-validated for interrupt", - "extra_info": {"prompt_id": prompt_id}, - } - }, - status=403, - ) - - if prompt_client_id != request_client_id: - logging.warning( - f"Rejected interrupt for prompt {prompt_id}: requester client_id {request_client_id} does not own prompt" - ) - return web.json_response( - { - "error": { - "type": "prompt_not_owned_by_client", - "message": "client_id does not own the target prompt", - "details": "Interrupt is only allowed for the client_id that originally submitted the prompt", - "extra_info": { - "prompt_id": prompt_id, - "client_id": request_client_id, - }, - } - }, - status=403, - ) - - logging.info(f"Interrupting prompt {prompt_id} for client_id {request_client_id}") + if should_interrupt: # Forward the target prompt_id so parallel workers only cancel the requested run. nodes.interrupt_processing(prompt_id=prompt_id) else: - return web.json_response( - { - "error": { - "type": "prompt_not_running", - "message": "Prompt is not currently running", - "details": "Only actively running prompts can be interrupted", - "extra_info": {"prompt_id": prompt_id}, - } - }, - status=404, - ) + logging.info(f"Prompt {prompt_id} is not currently running, skipping interrupt") else: - logging.warning("Rejected global interrupt without prompt_id") - return web.json_response( - { - "error": { - "type": "global_interrupt_disabled", - "message": "Global interrupt is disabled when ownership validation is enforced", - "details": "Provide both prompt_id and the owning client_id to interrupt a workflow", - "extra_info": {}, - } - }, - status=403, - ) + # No prompt_id provided, do a global interrupt + logging.info("Global interrupt (no prompt_id specified)") + nodes.interrupt_processing() return web.Response(status=200) From b965a471002597e8ebcf921a4a06ad8fe62f853f Mon Sep 17 00:00:00 2001 From: Liu Zhao Date: Tue, 21 Apr 2026 16:41:39 +0800 Subject: [PATCH 6/6] =?UTF-8?q?=E7=A6=81=E7=94=A8=E5=85=A8=E5=B1=80?= =?UTF-8?q?=E4=B8=AD=E6=96=AD=E4=BB=A5=E9=81=BF=E5=85=8D=E5=B9=B6=E8=A1=8C?= =?UTF-8?q?=E5=85=A8=E5=B1=80=E4=B8=AD=E6=96=AD=E5=AF=BC=E8=87=B4=E5=A4=9A?= =?UTF-8?q?=E4=B8=AA=E8=BF=90=E8=A1=8C=E4=B8=AD=E7=9A=84=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E8=A2=AB=E6=84=8F=E5=A4=96=E4=B8=AD=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/server.py b/server.py index be32b0002e2c..6ba009100959 100644 --- a/server.py +++ b/server.py @@ -1008,9 +1008,11 @@ async def post_interrupt(request): else: logging.info(f"Prompt {prompt_id} is not currently running, skipping interrupt") else: - # No prompt_id provided, do a global interrupt - logging.info("Global interrupt (no prompt_id specified)") - nodes.interrupt_processing() + # # No prompt_id provided, do a global interrupt + # logging.info("Global interrupt (no prompt_id specified)") + # nodes.interrupt_processing() + # 禁用全局中断,避免并行全局中断导致多个运行中的任务被意外中断 + pass return web.Response(status=200)