Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions scripts/train_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead
from specforge.optimizer import BF16Optimizer
from specforge.tracker import create_tracker
from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank
from specforge.utils import get_device_type, get_last_checkpoint, get_local_device, print_on_rank0, print_with_rank


def parse_args():
Expand Down Expand Up @@ -156,11 +156,14 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]:
if args.target_model_backend == "sglang":
target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs()

device = get_local_device()
device_type = device.type

target_model = get_dflash_target_model(
pretrained_model_name_or_path=args.target_model_path,
backend=args.target_model_backend,
torch_dtype=torch.bfloat16,
device="cuda" if args.target_model_backend == "hf" else None,
device=device_type if args.target_model_backend == "hf" else None,
trust_remote_code=args.trust_remote_code,
**target_model_kwargs,
)
Expand Down Expand Up @@ -191,7 +194,7 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]:
draft_config._attn_implementation = args.attention_backend
print_on_rank0(f"Using attention backend: {args.attention_backend}")

draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16)
draft_model = DFlashDraftModel(draft_config).to(device=device, dtype=torch.bfloat16)

target_model.set_capture_layers(draft_model.target_layer_ids)

Expand Down Expand Up @@ -423,7 +426,7 @@ def main():
args.target_model_path,
embed_key=args.embedding_key,
lm_head_key=args.lm_head_key,
device="cuda",
device=device_type,
trust_remote_code=args.trust_remote_code,
)

Expand Down Expand Up @@ -496,13 +499,13 @@ def main():
continue
global_step += 1

input_ids = data["input_ids"].cuda()
attention_mask = data["attention_mask"].cuda()
loss_mask = data["loss_mask"].cuda()
input_ids = data["input_ids"].to(device, non_blocking=True)
attention_mask = data["attention_mask"].to(device, non_blocking=True)
loss_mask = data["loss_mask"].to(device, non_blocking=True)
target_output = target_model.generate_dflash_data(
input_ids, attention_mask, loss_mask
)
hidden_states = target_output.hidden_states.cuda() # Ensure on GPU
hidden_states = target_output.hidden_states.to(device, non_blocking=True)

loss, accuracy = dflash_model(
input_ids=input_ids,
Expand Down
30 changes: 30 additions & 0 deletions specforge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,36 @@ def load_config_from_file(config_path: str):
return PretrainedConfig.from_dict(config)


def get_device_type() -> str:
"""Auto-detect the available accelerator type.

Priority:
1. SPECFORGE_DEVICE environment variable
2. NVIDIA CUDA (torch.cuda)
3. Ascend NPU (torch.npu)
4. CPU fallback
"""
dt = os.environ.get("SPECFORGE_DEVICE", None)
if dt:
return dt
if torch.cuda.is_available():
return "cuda"
if hasattr(torch, "npu") and torch.npu.is_available():
return "npu"
return "cpu"


def get_local_device() -> torch.device:
"""Return the local torch.device for the current process rank."""
device_type = get_device_type()
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if device_type == "cuda":
return torch.device("cuda", local_rank)
if device_type == "npu":
return torch.device("npu", local_rank)
return torch.device("cpu")


def print_with_rank(message):
if dist.is_available() and dist.is_initialized():
logger.info(f"rank {dist.get_rank()}: {message}")
Expand Down