Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def from_string(cls, value: str):

parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")

parser.add_argument("--parallel", type=int, default=1, metavar="N", help="Run N workflows concurrently. Useful when workflows are network-bound (e.g. BizyAir nodes) and GPU/VRAM is not the bottleneck. Default: 1 (sequential).")

parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
Expand Down
72 changes: 68 additions & 4 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,22 +1807,86 @@ class InterruptProcessingException(Exception):
interrupt_processing_mutex = threading.RLock()

interrupt_processing = False
def interrupt_current_processing(value=True):
# Prompt-scoped interrupts prevent one worker from consuming another worker's cancel signal.
interrupted_prompt_ids = set()
active_prompt_ids = set()


def _resolve_interrupt_prompt_id(prompt_id=None):
if prompt_id is not None:
return prompt_id

try:
from comfy_execution.utils import get_executing_context
executing_context = get_executing_context()
except Exception:
executing_context = None

if executing_context is not None:
return executing_context.prompt_id
return None


def register_active_prompt(prompt_id):
global active_prompt_ids
global interrupt_processing_mutex
with interrupt_processing_mutex:
# A restarted prompt_id should begin in a clean state.
active_prompt_ids.add(prompt_id)
interrupted_prompt_ids.discard(prompt_id)


def unregister_active_prompt(prompt_id):
global active_prompt_ids
global interrupt_processing_mutex
with interrupt_processing_mutex:
active_prompt_ids.discard(prompt_id)
interrupted_prompt_ids.discard(prompt_id)


def interrupt_current_processing(value=True, prompt_id=None):
global interrupt_processing
global interrupt_processing_mutex
with interrupt_processing_mutex:
interrupt_processing = value
if prompt_id is None:
if value:
if active_prompt_ids:
# Global interrupt fan-outs to every prompt that is currently running.
interrupted_prompt_ids.update(active_prompt_ids)
interrupt_processing = False
else:
# Preserve the legacy global flag for code paths that still have no prompt context.
interrupt_processing = True
else:
interrupt_processing = False
interrupted_prompt_ids.clear()
return

if value:
interrupted_prompt_ids.add(prompt_id)
else:
interrupted_prompt_ids.discard(prompt_id)

def processing_interrupted():

def processing_interrupted(prompt_id=None):
global interrupt_processing
global interrupt_processing_mutex
resolved_prompt_id = _resolve_interrupt_prompt_id(prompt_id)
with interrupt_processing_mutex:
if resolved_prompt_id is not None:
return resolved_prompt_id in interrupted_prompt_ids
return interrupt_processing

def throw_exception_if_processing_interrupted():

def throw_exception_if_processing_interrupted(prompt_id=None):
global interrupt_processing
global interrupt_processing_mutex
resolved_prompt_id = _resolve_interrupt_prompt_id(prompt_id)
with interrupt_processing_mutex:
if resolved_prompt_id is not None:
if resolved_prompt_id in interrupted_prompt_ids:
raise InterruptProcessingException()
return
if interrupt_processing:
interrupt_processing = False
raise InterruptProcessingException()
3 changes: 2 additions & 1 deletion comfy_api/latest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ async def set_progress(
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
"""
executing_context = get_executing_context()
prompt_id = executing_context.prompt_id if executing_context is not None else None
if node_id is None and executing_context is not None:
node_id = executing_context.node_id
if node_id is None:
Expand Down Expand Up @@ -78,7 +79,7 @@ async def set_progress(
preview_size = None if ignore_size_limit else args.preview_size
to_display = (image_format, to_display, preview_size)

get_progress_state().update_progress(
get_progress_state(prompt_id).update_progress(
node_id=node_id,
Comment on lines 53 to 83

Copilot AI Apr 17, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_progress() derives prompt_id from executing_context, but when there is no executing context it becomes None and the code still calls get_progress_state(prompt_id). With the new per-prompt registry design this risks updating the wrong prompt’s registry (or a dummy/no-op registry) and makes behavior dependent on internal fallback logic. Consider requiring an executing context (raise if missing), or extending the API to accept an explicit prompt_id and using that instead of None.

Copilot uses AI. Check for mistakes.
value=value,
max_value=max_value,
Expand Down
64 changes: 36 additions & 28 deletions comfy_execution/progress.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import threading
from typing import TypedDict, Dict, Optional, Tuple
from typing_extensions import override
from PIL import Image
Expand Down Expand Up @@ -152,9 +153,10 @@ class WebUIProgressHandler(ProgressHandler):
Handler that sends progress updates to the WebUI via WebSockets.
"""

def __init__(self, server_instance):
def __init__(self, server_instance, client_id: Optional[str] = None):
super().__init__("webui")
self.server_instance = server_instance
self.client_id = client_id

def set_registry(self, registry: "ProgressRegistry"):
self.registry = registry
Expand Down Expand Up @@ -183,7 +185,7 @@ def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressStat
# Send a combined progress_state message with all node states
# Include client_id to ensure message is only sent to the initiating client
self.server_instance.send_sync(
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.client_id
)

@override
Expand All @@ -206,10 +208,9 @@ def update_handler(
if self.registry:
self._send_progress_state(prompt_id, self.registry.nodes)
if image:
# Only send new format if client supports it
if feature_flags.supports_feature(
if self.client_id is not None and feature_flags.supports_feature(
self.server_instance.sockets_metadata,
self.server_instance.client_id,
self.client_id,
"supports_preview_metadata",
):
metadata = {
Expand All @@ -226,7 +227,7 @@ def update_handler(
self.server_instance.send_sync(
BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA,
(image, metadata),
self.server_instance.client_id,
self.client_id,
)

@override
Expand All @@ -240,9 +241,10 @@ class ProgressRegistry:
Registry that maintains node progress state and notifies registered handlers.
"""

def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"):
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt", client_id: Optional[str] = None):
self.prompt_id = prompt_id
self.dynprompt = dynprompt
self.client_id = client_id
self.nodes: Dict[str, NodeProgressState] = {}
self.handlers: Dict[str, ProgressHandler] = {}

Expand Down Expand Up @@ -319,32 +321,38 @@ def reset_handlers(self) -> None:
for handler in self.handlers.values():
handler.reset()

# Global registry instance
global_progress_registry: ProgressRegistry | None = None
# Per-prompt progress registries for parallel execution support
_progress_registries: Dict[str, ProgressRegistry] = {}
_progress_registries_lock = threading.Lock()

def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
global global_progress_registry
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt", client_id: Optional[str] = None) -> None:
with _progress_registries_lock:
if prompt_id in _progress_registries:
_progress_registries[prompt_id].reset_handlers()
_progress_registries[prompt_id] = ProgressRegistry(prompt_id, dynprompt, client_id=client_id)

# Reset existing handlers if registry exists
if global_progress_registry is not None:
global_progress_registry.reset_handlers()

# Create new registry
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)


def add_progress_handler(handler: ProgressHandler) -> None:
registry = get_progress_state()
def add_progress_handler(handler: ProgressHandler, prompt_id: Optional[str] = None) -> None:
registry = get_progress_state(prompt_id)
handler.set_registry(registry)
registry.register_handler(handler)


def get_progress_state() -> ProgressRegistry:
global global_progress_registry
if global_progress_registry is None:
from comfy_execution.graph import DynamicPrompt
def remove_progress_state(prompt_id: str) -> None:
with _progress_registries_lock:
reg = _progress_registries.pop(prompt_id, None)
if reg is not None:
reg.reset_handlers()

global_progress_registry = ProgressRegistry(
prompt_id="", dynprompt=DynamicPrompt({})
)
return global_progress_registry

def get_progress_state(prompt_id: Optional[str] = None) -> ProgressRegistry:
if prompt_id is not None:
with _progress_registries_lock:
reg = _progress_registries.get(prompt_id)
if reg is not None:
return reg
# with _progress_registries_lock:
# if _progress_registries:
# return next(iter(_progress_registries.values()))
from comfy_execution.graph import DynamicPrompt
return ProgressRegistry(prompt_id="", dynprompt=DynamicPrompt({}))
Comment thread
ruabbit233 marked this conversation as resolved.
54 changes: 28 additions & 26 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, remove_progress_state, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io
Expand Down Expand Up @@ -416,15 +416,15 @@ def _is_intermediate_output(dynprompt, node_id):
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)

def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
if server.client_id is None:
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs, client_id):
if client_id is None:
return
cached_ui = cached.ui or {}
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, client_id)
if cached.ui is not None:
ui_outputs[node_id] = cached.ui

async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs, client_id=None):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
Expand All @@ -434,8 +434,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
cached = await caches.outputs.get(unique_id)
if cached is not None:
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs)
get_progress_state().finish_progress(unique_id)
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs, client_id)
get_progress_state(prompt_id).finish_progress(unique_id)
execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None)

Expand Down Expand Up @@ -478,11 +478,11 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
del pending_subgraph_results[unique_id]
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
get_progress_state(prompt_id).start_progress(unique_id)
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
if client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, client_id)

obj = await caches.objects.get(unique_id)
if obj is None:
Expand Down Expand Up @@ -522,7 +522,7 @@ def execution_block_cb(block):
"current_inputs": [],
"current_outputs": [],
}
server.send_sync("execution_error", mes, server.client_id)
server.send_sync("execution_error", mes, client_id)
return ExecutionBlocker(None)
else:
return block
Expand Down Expand Up @@ -558,8 +558,8 @@ async def await_completion():
},
"output": output_ui
}
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, client_id)
if has_subgraph:
cached_outputs = []
new_node_ids = []
Expand Down Expand Up @@ -640,7 +640,7 @@ async def await_completion():

return (ExecutionResult.FAILURE, error_details, ex)

get_progress_state().finish_progress(unique_id)
get_progress_state(prompt_id).finish_progress(unique_id)
executed.add(unique_id)

return (ExecutionResult.SUCCESS, None, None)
Expand All @@ -656,15 +656,16 @@ def reset(self):
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
self.status_messages = []
self.success = True
self.client_id = None

def add_message(self, event, data: dict, broadcast: bool):
data = {
**data,
"timestamp": int(time.time() * 1000),
}
self.status_messages.append((event, data))
if self.server.client_id is not None or broadcast:
self.server.send_sync(event, data, self.server.client_id)
if self.client_id is not None or broadcast:
self.server.send_sync(event, data, self.client_id)

def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
node_id = error["node_id"]
Expand Down Expand Up @@ -712,13 +713,11 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):

async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
set_preview_method(extra_data.get("preview_method"))
# Register before node execution starts so targeted/global interrupts can see this prompt immediately.
comfy.model_management.register_active_prompt(prompt_id)

nodes.interrupt_processing(False)

if "client_id" in extra_data:
self.server.client_id = extra_data["client_id"]
else:
self.server.client_id = None
self.client_id = extra_data.get("client_id", None)
self.server.client_id = self.client_id
Comment on lines +719 to +720

Copilot AI Apr 17, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PromptExecutor.execute_async() still writes self.server.client_id = self.client_id. With multiple prompt workers this global field will be raced/overwritten by concurrent executions, which can break any server logic that relies on a single “currently executing client” (e.g., WebSocket reconnect sending the current node). Consider removing this global mutation in parallel mode and instead tracking executing state per prompt/client (e.g., a mapping on the server keyed by client_id/prompt_id).

Copilot uses AI. Check for mistakes.

self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
Expand All @@ -731,8 +730,8 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
try:
with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
reset_progress_state(prompt_id, dynamic_prompt, client_id=self.client_id)
add_progress_handler(WebUIProgressHandler(self.server, client_id=self.client_id), prompt_id=prompt_id)
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
Expand Down Expand Up @@ -767,7 +766,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
break

assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs, client_id=self.client_id)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
Expand All @@ -791,7 +790,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
cached = await self.caches.outputs.get(node_id)
if cached is not None:
display_node_id = dynamic_prompt.get_display_node_id(node_id)
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs)
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs, client_id=self.client_id)
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)

ui_outputs = {}
Expand All @@ -809,6 +808,9 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
finally:
comfy.memory_management.set_ram_cache_release_state(None, 0)
self._notify_prompt_lifecycle("end", prompt_id)
remove_progress_state(prompt_id)
# Drop prompt-scoped interrupt state once this execution is fully finished.
comfy.model_management.unregister_active_prompt(prompt_id)


async def validate_inputs(prompt_id, prompt, item, validated):
Expand Down
Loading
Loading