Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 70 additions & 9 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
destroy_distributed,
get_dp_group,
get_draft_dp_group,
get_draft_sp_group,
get_sp_ring_group,
get_tp_group,
init_distributed,
)
Expand Down Expand Up @@ -96,6 +98,12 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]:
choices=["sglang", "hf", "custom"],
help="The backend of the target model",
)
model_group.add_argument(
"--shard-target-logits",
action="store_true",
default=False,
help="Shard target logits across TP ranks instead of materializing full target logits on every rank.",
)

# dataset arguments
dataset_group = parser.add_argument_group("dataset")
Expand Down Expand Up @@ -291,6 +299,7 @@ def build_target_model(
torch_dtype=torch.bfloat16,
device="cuda",
cache_dir=args.model_download_dir,
shard_target_logits=args.shard_target_logits,
**target_model_kwargs,
trust_remote_code=args.trust_remote_code,
)
Expand Down Expand Up @@ -339,8 +348,16 @@ def sanity_check(args: Namespace) -> None:
"""
args.dp_size = dist.get_world_size() // args.tp_size
args.target_batch_size = args.tp_size * args.batch_size
if args.shard_target_logits:
assert (
args.target_model_backend == "sglang"
), "--shard-target-logits is only supported for the SGLang backend"

if args.attention_backend == "usp":
sp_sanity_check(args)
if args.train_data_path is not None and args.train_hidden_states_path is None:
sp_size = args.sp_ring_size * args.sp_ulysses_size
args.target_batch_size = (args.tp_size // sp_size) * args.batch_size


def sp_sanity_check(args: Namespace) -> None:
Expand All @@ -356,7 +373,16 @@ def sp_sanity_check(args: Namespace) -> None:
f"Got sp_ring_size={args.sp_ring_size}, sp_ulysses_size={args.sp_ulysses_size}."
)

assert args.train_hidden_states_path is not None, f"USP only support offline mode"
is_online = args.train_data_path is not None and args.train_hidden_states_path is None
if is_online:
sp_size = args.sp_ring_size * args.sp_ulysses_size
assert args.shard_target_logits, "Online USP requires --shard-target-logits"
assert not args.is_vlm, "Online USP with sharded target logits does not support VLM yet"
assert (
args.tp_size % sp_size == 0
), f"Online USP with sharded target logits requires tp_size ({args.tp_size}) to be divisible by SP size ({sp_size})"
else:
assert args.train_hidden_states_path is not None, f"USP only support offline mode"

if args.eval_data_path is not None and args.eval_hidden_states_path is not None:
raise ValueError(
Expand Down Expand Up @@ -610,6 +636,18 @@ def run_forward(
image_grid_thw = None
if is_online:
# we generate the eagle3 using the target model in an online fashion
tp_size = dist.get_world_size(get_tp_group())
tp_rank = dist.get_rank(get_tp_group())
sequence_parallel = args.attention_backend == "usp"
sp_group = get_draft_sp_group() if sequence_parallel else None
sp_rank = dist.get_rank(sp_group) if sequence_parallel else 0
sp_size = dist.get_world_size(sp_group) if sequence_parallel else 1
target_dp_rank = tp_rank // sp_size if sequence_parallel else tp_rank
target_dp_size = tp_size // sp_size if sequence_parallel else tp_size
ring_group = get_sp_ring_group() if sequence_parallel else None
sp_ring_rank = dist.get_rank(ring_group) if sequence_parallel else 0
sp_ring_size = dist.get_world_size(ring_group) if sequence_parallel else 1

# Handle VLM data: pixel_values and image_grid_thw are lists
# pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None
if args.is_vlm:
Expand All @@ -626,19 +664,43 @@ def run_forward(
is_vlm=args.is_vlm,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
dp_rank=target_dp_rank,
dp_size=target_dp_size,
sequence_parallel=sequence_parallel,
sp_rank=sp_rank,
sp_size=sp_size,
sp_ring_rank=sp_ring_rank,
sp_ring_size=sp_ring_size,
ttt_length=args.ttt_length,
)
else:
eagle3_data = target_model.generate_eagle3_data(
input_ids=data["input_ids"].cuda(),
attention_mask=data["attention_mask"].cuda(),
loss_mask=data["loss_mask"].cuda(),
dp_rank=target_dp_rank,
dp_size=target_dp_size,
sequence_parallel=sequence_parallel,
sp_rank=sp_rank,
sp_size=sp_size,
sp_ring_rank=sp_ring_rank,
sp_ring_size=sp_ring_size,
ttt_length=args.ttt_length,
)

input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids)
attention_mask = get_dp_data_shard_from_tp(eagle3_data.attention_mask)
loss_mask = get_dp_data_shard_from_tp(eagle3_data.loss_mask)
target = get_dp_data_shard_from_tp(eagle3_data.target)
hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states)
if sequence_parallel or args.shard_target_logits:
input_ids = eagle3_data.input_ids
attention_mask = eagle3_data.attention_mask
loss_mask = eagle3_data.loss_mask
target = eagle3_data.target
hidden_states = eagle3_data.hidden_states
else:
input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids)
attention_mask = get_dp_data_shard_from_tp(eagle3_data.attention_mask)
loss_mask = get_dp_data_shard_from_tp(eagle3_data.loss_mask)
target = get_dp_data_shard_from_tp(eagle3_data.target)
hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states)
position_ids = eagle3_data.position_ids
else:
# we generate the logits using the hidden states loaded from disk
attention_mask = data["attention_mask"].cuda()
Expand All @@ -651,15 +713,14 @@ def run_forward(
target.cuda()
) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU.
loss_mask = loss_mask.cuda()
position_ids = data["position_ids"].cuda() if "position_ids" in data else None
plosses, _, acces = eagle3_model(
input_ids=input_ids,
attention_mask=attention_mask,
loss_mask=loss_mask,
target=target,
hidden_states=hidden_states,
position_ids=(
data["position_ids"].cuda() if "position_ids" in data else None
),
position_ids=position_ids,
image_grid_thw=image_grid_thw,
is_vlm=args.is_vlm,
)
Expand Down
Loading