Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 61 additions & 22 deletions graph_net/agent/model_fetcher/huggingface_fetcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""HuggingFace model fetcher implementation"""

import os
import signal
import time
from pathlib import Path
from typing import Optional
Expand All @@ -17,6 +18,13 @@
ModelFetchError,
)


class _DownloadTimeoutError(Exception):
"""Raised when snapshot_download exceeds the overall time budget."""

pass


# Network-related exceptions that are worth retrying
_RETRYABLE_ERRORS = (
ConnectionError,
Expand Down Expand Up @@ -133,6 +141,10 @@ def download(self, model_id: str) -> Path:
"""
Download model from HuggingFace Hub with retry on network errors.

A **hard overall timeout** (signal.alarm) guards `snapshot_download` so that
if a single file hangs indefinitely (e.g. TCP connection stays open but no
data arrives), the call is aborted instead of blocking forever.

Args:
model_id: HuggingFace model ID (e.g., "bert-base-uncased")

Expand All @@ -148,35 +160,58 @@ def download(self, model_id: str) -> Path:
"Please install it with: pip install huggingface_hub"
)

# Set a stricter download timeout to avoid getting stuck on large/slow files
# (default is 10s; we bump to 30s to accommodate slow networks while still
# preventing indefinite hangs).
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = os.environ.get(
"HF_HUB_DOWNLOAD_TIMEOUT", "30"
)

last_err = None
for attempt in range(1, self.max_retries + 1):
try:
# Set endpoint for this call if configured
if self.endpoint:
os.environ["HF_ENDPOINT"] = self.endpoint

local_dir = snapshot_download(
repo_id=model_id,
cache_dir=str(self.cache_dir) if self.cache_dir else None,
token=self.token,
ignore_patterns=[
"*.bin",
"*.safetensors",
"*.pt",
"*.pth",
"*.gguf",
"*.ot",
"*.zip",
"*.tflite",
"*.mlmodel",
"*.onnx",
"*.msgpack",
"flax_model*",
"tf_model*",
"rust_model*",
],
)
return Path(local_dir)
# Hard overall timeout for the entire snapshot_download call.
# HF_HUB_DOWNLOAD_TIMEOUT=30 only controls individual HTTP requests;
# huggingface_hub's internal retry/resume logic can still loop forever
# when a connection stalls without raising an exception. We therefore
# enforce a 120-second wall-clock ceiling on the whole operation.
def _alarm_handler(signum, frame):
raise _DownloadTimeoutError(
f"Overall download timeout (120s) exceeded for {model_id}"
)

old_handler = signal.signal(signal.SIGALRM, _alarm_handler)
signal.alarm(120)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subprocess_graph_extractor.py 为了加timeout控制,用了subprocess.Popen,subprocess可以控制真正杀死子进程。不过这个任务轻量一些,应该也还好

try:
local_dir = snapshot_download(
repo_id=model_id,
cache_dir=str(self.cache_dir) if self.cache_dir else None,
token=self.token,
ignore_patterns=[
"*.bin",
"*.safetensors",
"*.pt",
"*.pth",
"*.gguf",
"*.ot",
"*.zip",
"*.tflite",
"*.mlmodel",
"*.onnx",
"*.msgpack",
"flax_model*",
"tf_model*",
"rust_model*",
],
)
return Path(local_dir)
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)

except _RETRYABLE_ERRORS as e:
last_err = e
Expand All @@ -203,6 +238,10 @@ def download(self, model_id: str) -> Path:
raise ModelFetchError(
f"Failed to download model {model_id} after {self.max_retries} retries: {e}"
) from e
except _DownloadTimeoutError as e:
raise ModelFetchError(
f"Failed to download model {model_id}: {e}"
) from e
except Exception as e:
raise ModelFetchError(
f"Failed to download model {model_id}: {e}",
Expand Down
Loading