From 2b9f35013d5e1c3bfb1216299598ab0b4febeef3 Mon Sep 17 00:00:00 2001 From: Curnane Date: Mon, 25 May 2026 11:24:46 +0800 Subject: [PATCH] feat: add automatic device detection for non-CUDA backends - Add get_device_type() to auto-detect npu/cuda/cpu via runtime check - Add get_local_device() to return torch.device for current LOCAL_RANK - Replace hardcoded .cuda() and device='cuda' in train_dflash.py with dynamic device selection - Use .to(device, non_blocking=True) for tensor movement to support both CUDA and Ascend NPU without code changes - Maintain backward compatibility: CUDA remains default when available --- scripts/train_dflash.py | 19 +++++++++++-------- specforge/utils.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index 808e928c..461cfc53 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -33,7 +33,7 @@ from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead from specforge.optimizer import BF16Optimizer from specforge.tracker import create_tracker -from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank +from specforge.utils import get_device_type, get_last_checkpoint, get_local_device, print_on_rank0, print_with_rank def parse_args(): @@ -156,11 +156,14 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: if args.target_model_backend == "sglang": target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() + device = get_local_device() + device_type = device.type + target_model = get_dflash_target_model( pretrained_model_name_or_path=args.target_model_path, backend=args.target_model_backend, torch_dtype=torch.bfloat16, - device="cuda" if args.target_model_backend == "hf" else None, + device=device_type if args.target_model_backend == "hf" else None, trust_remote_code=args.trust_remote_code, **target_model_kwargs, ) @@ -191,7 +194,7 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: draft_config._attn_implementation = args.attention_backend print_on_rank0(f"Using attention backend: {args.attention_backend}") - draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16) + draft_model = DFlashDraftModel(draft_config).to(device=device, dtype=torch.bfloat16) target_model.set_capture_layers(draft_model.target_layer_ids) @@ -423,7 +426,7 @@ def main(): args.target_model_path, embed_key=args.embedding_key, lm_head_key=args.lm_head_key, - device="cuda", + device=device_type, trust_remote_code=args.trust_remote_code, ) @@ -496,13 +499,13 @@ def main(): continue global_step += 1 - input_ids = data["input_ids"].cuda() - attention_mask = data["attention_mask"].cuda() - loss_mask = data["loss_mask"].cuda() + input_ids = data["input_ids"].to(device, non_blocking=True) + attention_mask = data["attention_mask"].to(device, non_blocking=True) + loss_mask = data["loss_mask"].to(device, non_blocking=True) target_output = target_model.generate_dflash_data( input_ids, attention_mask, loss_mask ) - hidden_states = target_output.hidden_states.cuda() # Ensure on GPU + hidden_states = target_output.hidden_states.to(device, non_blocking=True) loss, accuracy = dflash_model( input_ids=input_ids, diff --git a/specforge/utils.py b/specforge/utils.py index af4d627c..21e07403 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -49,6 +49,36 @@ def load_config_from_file(config_path: str): return PretrainedConfig.from_dict(config) +def get_device_type() -> str: + """Auto-detect the available accelerator type. + + Priority: + 1. SPECFORGE_DEVICE environment variable + 2. NVIDIA CUDA (torch.cuda) + 3. Ascend NPU (torch.npu) + 4. CPU fallback + """ + dt = os.environ.get("SPECFORGE_DEVICE", None) + if dt: + return dt + if torch.cuda.is_available(): + return "cuda" + if hasattr(torch, "npu") and torch.npu.is_available(): + return "npu" + return "cpu" + + +def get_local_device() -> torch.device: + """Return the local torch.device for the current process rank.""" + device_type = get_device_type() + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if device_type == "cuda": + return torch.device("cuda", local_rank) + if device_type == "npu": + return torch.device("npu", local_rank) + return torch.device("cpu") + + def print_with_rank(message): if dist.is_available() and dist.is_initialized(): logger.info(f"rank {dist.get_rank()}: {message}")