Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion comfy_api/latest/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,8 @@ class V3Data(TypedDict):
class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any,
extra_pnginfo: Any, dynprompt: Any,
auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs):
auth_token_comfy_org: str, api_key_comfy_org: str,
comfy_usage_source: str = None, **kwargs):
self.unique_id = unique_id
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
self.prompt = prompt
Expand All @@ -1413,6 +1414,8 @@ def __init__(self, unique_id: str, prompt: Any,
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
self.api_key_comfy_org = api_key_comfy_org
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
self.comfy_usage_source = comfy_usage_source
"""COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header."""

def __getattr__(self, key: str):
'''If hidden variable not found, return None.'''
Expand All @@ -1429,6 +1432,7 @@ def from_dict(cls, d: dict | None):
dynprompt=d.get(Hidden.dynprompt, None),
auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None),
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
comfy_usage_source=d.get(Hidden.comfy_usage_source, None),
)

@classmethod
Expand All @@ -1451,6 +1455,8 @@ class Hidden(str, Enum):
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
api_key_comfy_org = "API_KEY_COMFY_ORG"
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
comfy_usage_source = "COMFY_USAGE_SOURCE"
"""COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header."""


@dataclass
Expand Down Expand Up @@ -1654,6 +1660,8 @@ def finalize(self):
self.hidden.append(Hidden.auth_token_comfy_org)
if Hidden.api_key_comfy_org not in self.hidden:
self.hidden.append(Hidden.api_key_comfy_org)
if Hidden.comfy_usage_source not in self.hidden:
self.hidden.append(Hidden.comfy_usage_source)
# if is an output_node, will need prompt and extra_pnginfo
if self.is_output_node:
if Hidden.prompt not in self.hidden:
Expand Down
5 changes: 2 additions & 3 deletions comfy_api_nodes/nodes_sonilo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from comfy_api_nodes.util._helpers import (
default_base_url,
get_auth_header,
get_comfy_api_headers,
get_node_id,
is_processing_interrupted,
)
Expand Down Expand Up @@ -174,8 +174,7 @@ async def _stream_sonilo_music(
"""POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes."""
url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/"))

headers: dict[str, str] = {}
headers.update(get_auth_header(cls))
headers = get_comfy_api_headers(cls)
headers.update(endpoint.headers)

node_id = get_node_id(cls)
Expand Down
25 changes: 25 additions & 0 deletions comfy_api_nodes/util/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from yarl import URL

from comfy.cli_args import args
from comfy.deploy_environment import get_deploy_environment
from comfy.model_management import processing_interrupted
from comfy_api.latest import IO

Expand All @@ -35,6 +36,30 @@ def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
return {}


def get_usage_source(node_cls: type[IO.ComfyNode]) -> str:
"""Source of the prompt that triggered this API node.

Defaults to "comfyui-api" when the submitting client didn't identify itself,
i.e. a direct API call to this server.
"""
return node_cls.hidden.comfy_usage_source or "comfyui-api"


def get_comfy_api_headers(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
"""Common headers (auth, deploy environment, usage source) for Comfy API requests.

Centralizes the shared header set so every Comfy API request sends a consistent
set and new shared headers only need to be added in one place. Intended for
relative/cloud URLs resolved against ``default_base_url()``; because the result
includes auth, callers must not attach it to arbitrary absolute/presigned URLs.
"""
return {
**get_auth_header(node_cls),
"Comfy-Env": get_deploy_environment(),
"Comfy-Usage-Source": get_usage_source(node_cls),
}


def default_base_url() -> str:
return getattr(args, "comfy_api_base", "https://api.comfy.org")

Expand Down
7 changes: 2 additions & 5 deletions comfy_api_nodes/util/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
from comfy_api.latest import IO
from server import PromptServer

from comfy.deploy_environment import get_deploy_environment

from . import request_logger
from ._helpers import (
default_base_url,
get_auth_header,
get_comfy_api_headers,
get_node_id,
is_processing_interrupted,
sleep_with_interrupt,
Expand Down Expand Up @@ -645,8 +643,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float):

payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls))
payload_headers["Comfy-Env"] = get_deploy_environment()
payload_headers.update(get_comfy_api_headers(cfg.node_cls))
if cfg.endpoint.headers:
payload_headers.update(cfg.endpoint.headers)

Expand Down
4 changes: 2 additions & 2 deletions comfy_api_nodes/util/download_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from . import request_logger
from ._helpers import (
default_base_url,
get_auth_header,
get_comfy_api_headers,
is_processing_interrupted,
sleep_with_interrupt,
to_aiohttp_url,
Expand Down Expand Up @@ -64,7 +64,7 @@ async def download_url_to_bytesio(
if cls is None:
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
headers = get_auth_header(cls)
headers = get_comfy_api_headers(cls)

while True:
attempt += 1
Expand Down
4 changes: 4 additions & 0 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def mark_missing():
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
if io.Hidden.api_key_comfy_org.name in hidden:
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
if io.Hidden.comfy_usage_source.name in hidden:
hidden_inputs_v3[io.Hidden.comfy_usage_source] = extra_data.get("comfy_usage_source", None)
else:
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
Expand All @@ -216,6 +218,8 @@ def mark_missing():
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
if h[x] == "COMFY_USAGE_SOURCE":
input_data_all[x] = [extra_data.get("comfy_usage_source", None)]
v3_data["hidden_inputs"] = hidden_inputs_v3
return input_data_all, missing_keys, v3_data

Expand Down
5 changes: 5 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,11 @@ async def post_prompt(request):

if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]

if "comfy_usage_source" not in extra_data:
usage_source = request.headers.get("Comfy-Usage-Source")
if usage_source:
extra_data["comfy_usage_source"] = usage_source
if valid[0]:
outputs_to_execute = valid[2]
sensitive = {}
Expand Down
Loading