Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,15 +1202,16 @@ def put(self, item):
self.server.queue_updated()
self.not_empty.notify()

def get(self, timeout=None):
def get(self, timeout=None, worker_id=0):
with self.not_empty:
while len(self.queue) == 0:
self.not_empty.wait(timeout=timeout)
if timeout is not None and len(self.queue) == 0:
return None
item = heapq.heappop(self.queue)
i = self.task_counter
self.currently_running[i] = copy.deepcopy(item)
# Store with worker_id to support multiple workers
self.currently_running[i] = {"worker_id": worker_id, "item": copy.deepcopy(item)}
self.task_counter += 1
self.server.queue_updated()
return (item, i)
Expand All @@ -1223,7 +1224,12 @@ class ExecutionStatus(NamedTuple):
def task_done(self, item_id, history_result,
status: Optional['PromptQueue.ExecutionStatus'], process_item=None):
with self.mutex:
prompt = self.currently_running.pop(item_id)
running_entry = self.currently_running.pop(item_id)
# Support both old format (direct item) and new format (dict with worker_id and item)
if isinstance(running_entry, dict) and "item" in running_entry:
prompt = running_entry["item"]
else:
prompt = running_entry
if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history)))

Expand All @@ -1247,13 +1253,23 @@ def get_current_queue(self):
with self.mutex:
out = []
for x in self.currently_running.values():
out += [x]
# Support both old format (direct item) and new format (dict with worker_id and item)
if isinstance(x, dict) and "item" in x:
out += [x["item"]]
else:
out += [x]
return (out, copy.deepcopy(self.queue))

# read-safe as long as queue items are immutable
def get_current_queue_volatile(self):
with self.mutex:
running = [x for x in self.currently_running.values()]
running = []
for x in self.currently_running.values():
# Support both old format (direct item) and new format (dict with worker_id and item)
if isinstance(x, dict) and "item" in x:
running.append(x["item"])
else:
running.append(x)
queued = copy.copy(self.queue)
return (running, queued)

Expand Down
12 changes: 9 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]:
return paths


def prompt_worker(q, server_instance):
def prompt_worker(q, server_instance, worker_id=0):
current_time: float = 0.0
cache_ram = args.cache_ram
if cache_ram < 0:
Expand All @@ -297,7 +297,7 @@ def prompt_worker(q, server_instance):
if need_gc:
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)

queue_item = q.get(timeout=timeout)
queue_item = q.get(timeout=timeout, worker_id=worker_id)
if queue_item is not None:
item, item_id = queue_item
execution_start_time = time.perf_counter()
Expand Down Expand Up @@ -478,7 +478,13 @@ def start_comfyui(asyncio_loop=None):
prompt_server.add_routes()
hijack_progress(prompt_server)

threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
# Number of parallel workers - can be adjusted based on GPU memory and workload
# WARNING: Multiple workers will increase GPU memory usage significantly
NUM_WORKERS = int(os.environ.get("COMFYUI_NUM_WORKERS", "1"))
if NUM_WORKERS > 1:
logging.info(f"Starting {NUM_WORKERS} parallel prompt workers")
for worker_id in range(NUM_WORKERS):
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server, worker_id)).start()

if args.quick_test_for_ci:
exit(0)
Expand Down
Loading