diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 0bd157b3..f9c2d7e2 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -35,6 +35,8 @@ destroy_distributed, get_dp_group, get_draft_dp_group, + get_draft_sp_group, + get_sp_ring_group, get_tp_group, init_distributed, ) @@ -96,6 +98,12 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]: choices=["sglang", "hf", "custom"], help="The backend of the target model", ) + model_group.add_argument( + "--shard-target-logits", + action="store_true", + default=False, + help="Shard target logits across TP ranks instead of materializing full target logits on every rank.", + ) # dataset arguments dataset_group = parser.add_argument_group("dataset") @@ -291,6 +299,7 @@ def build_target_model( torch_dtype=torch.bfloat16, device="cuda", cache_dir=args.model_download_dir, + shard_target_logits=args.shard_target_logits, **target_model_kwargs, trust_remote_code=args.trust_remote_code, ) @@ -339,8 +348,16 @@ def sanity_check(args: Namespace) -> None: """ args.dp_size = dist.get_world_size() // args.tp_size args.target_batch_size = args.tp_size * args.batch_size + if args.shard_target_logits: + assert ( + args.target_model_backend == "sglang" + ), "--shard-target-logits is only supported for the SGLang backend" + if args.attention_backend == "usp": sp_sanity_check(args) + if args.train_data_path is not None and args.train_hidden_states_path is None: + sp_size = args.sp_ring_size * args.sp_ulysses_size + args.target_batch_size = (args.tp_size // sp_size) * args.batch_size def sp_sanity_check(args: Namespace) -> None: @@ -356,7 +373,16 @@ def sp_sanity_check(args: Namespace) -> None: f"Got sp_ring_size={args.sp_ring_size}, sp_ulysses_size={args.sp_ulysses_size}." ) - assert args.train_hidden_states_path is not None, f"USP only support offline mode" + is_online = args.train_data_path is not None and args.train_hidden_states_path is None + if is_online: + sp_size = args.sp_ring_size * args.sp_ulysses_size + assert args.shard_target_logits, "Online USP requires --shard-target-logits" + assert not args.is_vlm, "Online USP with sharded target logits does not support VLM yet" + assert ( + args.tp_size % sp_size == 0 + ), f"Online USP with sharded target logits requires tp_size ({args.tp_size}) to be divisible by SP size ({sp_size})" + else: + assert args.train_hidden_states_path is not None, f"USP only support offline mode" if args.eval_data_path is not None and args.eval_hidden_states_path is not None: raise ValueError( @@ -610,6 +636,18 @@ def run_forward( image_grid_thw = None if is_online: # we generate the eagle3 using the target model in an online fashion + tp_size = dist.get_world_size(get_tp_group()) + tp_rank = dist.get_rank(get_tp_group()) + sequence_parallel = args.attention_backend == "usp" + sp_group = get_draft_sp_group() if sequence_parallel else None + sp_rank = dist.get_rank(sp_group) if sequence_parallel else 0 + sp_size = dist.get_world_size(sp_group) if sequence_parallel else 1 + target_dp_rank = tp_rank // sp_size if sequence_parallel else tp_rank + target_dp_size = tp_size // sp_size if sequence_parallel else tp_size + ring_group = get_sp_ring_group() if sequence_parallel else None + sp_ring_rank = dist.get_rank(ring_group) if sequence_parallel else 0 + sp_ring_size = dist.get_world_size(ring_group) if sequence_parallel else 1 + # Handle VLM data: pixel_values and image_grid_thw are lists # pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None if args.is_vlm: @@ -626,19 +664,43 @@ def run_forward( is_vlm=args.is_vlm, pixel_values=pixel_values, image_grid_thw=image_grid_thw, + dp_rank=target_dp_rank, + dp_size=target_dp_size, + sequence_parallel=sequence_parallel, + sp_rank=sp_rank, + sp_size=sp_size, + sp_ring_rank=sp_ring_rank, + sp_ring_size=sp_ring_size, + ttt_length=args.ttt_length, ) else: eagle3_data = target_model.generate_eagle3_data( input_ids=data["input_ids"].cuda(), attention_mask=data["attention_mask"].cuda(), loss_mask=data["loss_mask"].cuda(), + dp_rank=target_dp_rank, + dp_size=target_dp_size, + sequence_parallel=sequence_parallel, + sp_rank=sp_rank, + sp_size=sp_size, + sp_ring_rank=sp_ring_rank, + sp_ring_size=sp_ring_size, + ttt_length=args.ttt_length, ) - input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids) - attention_mask = get_dp_data_shard_from_tp(eagle3_data.attention_mask) - loss_mask = get_dp_data_shard_from_tp(eagle3_data.loss_mask) - target = get_dp_data_shard_from_tp(eagle3_data.target) - hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states) + if sequence_parallel or args.shard_target_logits: + input_ids = eagle3_data.input_ids + attention_mask = eagle3_data.attention_mask + loss_mask = eagle3_data.loss_mask + target = eagle3_data.target + hidden_states = eagle3_data.hidden_states + else: + input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids) + attention_mask = get_dp_data_shard_from_tp(eagle3_data.attention_mask) + loss_mask = get_dp_data_shard_from_tp(eagle3_data.loss_mask) + target = get_dp_data_shard_from_tp(eagle3_data.target) + hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states) + position_ids = eagle3_data.position_ids else: # we generate the logits using the hidden states loaded from disk attention_mask = data["attention_mask"].cuda() @@ -651,15 +713,14 @@ def run_forward( target.cuda() ) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU. loss_mask = loss_mask.cuda() + position_ids = data["position_ids"].cuda() if "position_ids" in data else None plosses, _, acces = eagle3_model( input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask, target=target, hidden_states=hidden_states, - position_ids=( - data["position_ids"].cuda() if "position_ids" in data else None - ), + position_ids=position_ids, image_grid_thw=image_grid_thw, is_vlm=args.is_vlm, ) diff --git a/specforge/modeling/target/eagle3_target_model.py b/specforge/modeling/target/eagle3_target_model.py index 2acf50ba..2d2fcf49 100644 --- a/specforge/modeling/target/eagle3_target_model.py +++ b/specforge/modeling/target/eagle3_target_model.py @@ -47,6 +47,33 @@ class Eagle3TargetOutput: input_ids: torch.Tensor attention_mask: torch.Tensor last_hidden_states: Optional[torch.Tensor] = None + position_ids: Optional[torch.Tensor] = None + + +def _sp_chunk_bounds( + seq_len: int, sp_rank: int, sp_size: int, ttt_length: int +) -> Tuple[int, int]: + chunk_size = (seq_len + sp_size - 1) // sp_size + start = sp_rank * chunk_size + end = min(start + chunk_size + ttt_length, seq_len) + return start, end + + +def _slice_sequence_for_sp( + tensor: torch.Tensor, sp_rank: int, sp_size: int, ttt_length: int +) -> torch.Tensor: + seq_len = tensor.shape[1] + chunk_size = (seq_len + sp_size - 1) // sp_size + local_len = chunk_size + ttt_length + start, end = _sp_chunk_bounds(seq_len, sp_rank, sp_size, ttt_length) + sliced = tensor[:, start:end].contiguous() + if sliced.shape[1] == local_len: + return sliced + padded_shape = list(sliced.shape) + padded_shape[1] = local_len + padded = torch.zeros(padded_shape, dtype=sliced.dtype, device=sliced.device) + padded[:, : sliced.shape[1]] = sliced + return padded class Eagle3TargetModel(ABC): @@ -301,6 +328,7 @@ def from_pretrained( device: str = None, cache_dir: Optional[str] = None, trust_remote_code: bool = False, + shard_target_logits: bool = False, **kwargs, ) -> "SGLangEagle3TargetModel": tp_size = dist.get_world_size(get_tp_group()) @@ -337,14 +365,21 @@ def from_pretrained( nccl_port=None, is_draft_worker=False, ) + tp_group = get_tp_group() wrap_eagle3_logits_processors_in_module( - model_runner.model, return_full_logits=False + model_runner.model, + return_full_logits=False, + shard_target_logits=shard_target_logits, + tp_group=tp_group, ) # Get hf_config from model_config for VLM attributes hf_config = getattr(model_config, "hf_config", None) - return cls(model_runner, hf_config=hf_config) + instance = cls(model_runner, hf_config=hf_config) + instance.shard_target_logits = shard_target_logits + instance.tp_group = tp_group + return instance def set_aux_hidden_states_layers( self, aux_hidden_states_layers: Optional[List[int]] = None @@ -358,12 +393,23 @@ def _extend( capture_aux_hidden_states: bool = True, return_last_hidden_states: bool = False, return_logits: bool = False, + shard_target_logits: bool = False, + sequence_parallel: bool = False, + sequence_rank: int = 0, + sequence_size: int = 1, + sp_ttt_length: int = 0, ): # set the logits processor for the model runner for name, module in self.model_runner.model.named_modules(): if isinstance(module, LogitsProcessorForEAGLE3): module.return_last_hidden_states = return_last_hidden_states module.return_logits = return_logits + module.shard_target_logits = shard_target_logits + module.sequence_parallel = sequence_parallel + module.sequence_rank = sequence_rank + module.sequence_size = sequence_size + module.sp_ttt_length = sp_ttt_length + module.tp_group = getattr(self, "tp_group", None) cache_params = CacheInitParams( disable=False, @@ -391,36 +437,69 @@ def _extend( aux_hidden_states_list = None input_lens = [len(req.origin_input_ids) for req in reqs] + local_input_lens = getattr( + eagle3_output.logits_output, "local_input_lens", None + ) + local_sample_indices = getattr( + eagle3_output.logits_output, "local_sample_indices", None + ) + if local_input_lens is None: + local_input_lens = input_lens + if local_sample_indices is None: + local_sample_indices = list(range(len(reqs))) + local_num_samples = len(local_input_lens) + if return_logits: - if hasattr(eagle3_output, "logits_output"): - raw_logits = eagle3_output.logits_output.logits + raw_logits = ( + eagle3_output.logits_output.logits + if hasattr(eagle3_output, "logits_output") + else eagle3_output.logits + ) + if raw_logits is not None and len(local_input_lens) > 0: + logits = torch.split(raw_logits, local_input_lens, dim=0) + elif raw_logits is not None: + logits = [] else: - raw_logits = eagle3_output.logits - logits = torch.split(raw_logits, input_lens, dim=0) + logits = [None] * local_num_samples else: - logits = [None] * len(reqs) + logits = [None] * local_num_samples if capture_aux_hidden_states: - raw_aux_hidden_states = ( - eagle3_output.logits_output.aux_hidden_states - ) # concat hidden shape: (total_tokens, H*3) - aux_hidden_states_list = torch.split( - raw_aux_hidden_states, input_lens, dim=0 - ) + raw_aux_hidden_states = eagle3_output.logits_output.aux_hidden_states + if raw_aux_hidden_states is not None and len(local_input_lens) > 0: + aux_hidden_states_list = torch.split( + raw_aux_hidden_states, local_input_lens, dim=0 + ) + elif raw_aux_hidden_states is not None: + aux_hidden_states_list = [] + else: + aux_hidden_states_list = [None] * local_num_samples else: - aux_hidden_states_list = [None] * len(reqs) + aux_hidden_states_list = [None] * local_num_samples if return_last_hidden_states: - last_hidden_states = torch.split( - eagle3_output.logits_output.last_hidden_states, input_lens, dim=0 - ) + raw_last_hidden_states = eagle3_output.logits_output.last_hidden_states + if raw_last_hidden_states is not None and len(local_input_lens) > 0: + last_hidden_states = torch.split( + raw_last_hidden_states, local_input_lens, dim=0 + ) + elif raw_last_hidden_states is not None: + last_hidden_states = [] + else: + last_hidden_states = [None] * local_num_samples else: - last_hidden_states = [None] * len(reqs) + last_hidden_states = [None] * local_num_samples # TODO: can we not clear? self.model_runner.req_to_token_pool.clear() self.model_runner.token_to_kv_pool_allocator.clear() - return logits, aux_hidden_states_list, last_hidden_states + return ( + logits, + aux_hidden_states_list, + last_hidden_states, + local_sample_indices, + local_input_lens, + ) def _maybe_prepare_mlp_sync_batch(self, batch: ScheduleBatch): if require_mlp_sync(self.model_runner.server_args): @@ -449,6 +528,11 @@ def extend( loss_mask: torch.Tensor, return_last_hidden_states: bool = False, return_logits: bool = True, + shard_target_logits: bool = False, + sequence_parallel: bool = False, + sequence_rank: int = 0, + sequence_size: int = 1, + sp_ttt_length: int = 0, ): sampling_params = SamplingParams(temperature=0, max_new_tokens=1, top_k=1) reqs, data_cache = [], [] @@ -477,13 +561,31 @@ def extend( data_cache.append([input_id_, attention_mask_, loss_mask_]) reqs.append(req) - logits_list, aux_hidden_states_list, last_hidden_states_list = self._extend( + ( + logits_list, + aux_hidden_states_list, + last_hidden_states_list, + local_sample_indices, + local_input_lens, + ) = self._extend( reqs, capture_aux_hidden_states=True, return_last_hidden_states=return_last_hidden_states, return_logits=return_logits, + shard_target_logits=shard_target_logits, + sequence_parallel=sequence_parallel, + sequence_rank=sequence_rank, + sequence_size=sequence_size, + sp_ttt_length=sp_ttt_length, ) + if ( + shard_target_logits + and not sequence_parallel + and local_sample_indices != list(range(len(data_cache))) + ): + data_cache = [data_cache[i] for i in local_sample_indices] + return data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list def get_rope_index( @@ -541,6 +643,11 @@ def extend_vlm( return_logits: bool = True, pixel_values: Optional[List[torch.Tensor]] = None, image_grid_thw: Optional[List[torch.Tensor]] = None, + shard_target_logits: bool = False, + sequence_parallel: bool = False, + sequence_rank: int = 0, + sequence_size: int = 1, + sp_ttt_length: int = 0, ): """ Args: @@ -641,7 +748,8 @@ def extend_vlm( pad_value=self.image_token_id, # Required for placeholder tensor creation offsets=offset, # List of (start, end) tuples ) - mm_item.set("image_grid_thw", image_grid_thw_.cpu()) + if image_grid_thw_ is not None: + mm_item.set("image_grid_thw", image_grid_thw_.cpu()) mm_item.set_pad_value() mm_inputs = MultimodalInputs( mm_items=[mm_item], @@ -670,13 +778,31 @@ def extend_vlm( data_cache.append([input_id_, attention_mask_, loss_mask_]) reqs.append(req) - logits_list, aux_hidden_states_list, last_hidden_states_list = self._extend( + ( + logits_list, + aux_hidden_states_list, + last_hidden_states_list, + local_sample_indices, + local_input_lens, + ) = self._extend( reqs, capture_aux_hidden_states=True, return_last_hidden_states=return_last_hidden_states, return_logits=return_logits, + shard_target_logits=shard_target_logits, + sequence_parallel=sequence_parallel, + sequence_rank=sequence_rank, + sequence_size=sequence_size, + sp_ttt_length=sp_ttt_length, ) + if ( + shard_target_logits + and not sequence_parallel + and local_sample_indices != list(range(len(data_cache))) + ): + data_cache = [data_cache[i] for i in local_sample_indices] + return data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list @torch.no_grad() @@ -688,93 +814,119 @@ def generate_eagle3_data( pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.Tensor] = None, is_vlm: bool = False, + dp_rank: Optional[int] = None, + dp_size: int = 1, + sequence_parallel: bool = False, + sp_rank: int = 0, + sp_size: int = 1, + sp_ring_rank: int = 0, + sp_ring_size: int = 1, + ttt_length: int = 0, ) -> Eagle3TargetOutput: - """ - return: - data_for_draft: List[Dict[str, torch.Tensor]] of draft_batch_size, draft_micro_batch_size = 1 - - input_ids: (1, seq_len) - - attention_mask: (1, seq_len) - - loss_mask: (1, seq_len) - - target: (1, seq_len, vocab_size) or (1, seq_len, hidden_size) - - hidden_states: (1, seq_len, hidden_size) - - pixel_values: (patch_len, patch_width) - - image_grid_thw (batch_size, 3) - """ + shard_target_logits = getattr(self, "shard_target_logits", False) if is_vlm: - data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list = ( - self.extend_vlm( - input_ids, - attention_mask, - loss_mask, - return_last_hidden_states=False, - return_logits=True, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - ) + data_cache, logits_list, aux_hidden_states_list, _ = self.extend_vlm( + input_ids, + attention_mask, + loss_mask, + return_last_hidden_states=False, + return_logits=True, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + shard_target_logits=shard_target_logits, + sequence_parallel=sequence_parallel, + sequence_rank=sp_rank, + sequence_size=sp_size, + sp_ttt_length=ttt_length, ) else: - data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list = ( - self.extend( - input_ids, - attention_mask, - loss_mask, - return_last_hidden_states=False, - return_logits=True, - ) + data_cache, logits_list, aux_hidden_states_list, _ = self.extend( + input_ids, + attention_mask, + loss_mask, + return_last_hidden_states=False, + return_logits=True, + shard_target_logits=shard_target_logits, + sequence_parallel=sequence_parallel, + sequence_rank=sp_rank, + sequence_size=sp_size, + sp_ttt_length=ttt_length, ) - aux_hidden_states_out = [] - target_out = [] - loss_mask_out = [] - input_ids_out = [] - last_hidden_states_out = [] - for idx, (data, logits, aux_hidden_states, last_hidden_states) in enumerate( - zip( - data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list - ) - ): - aux_hidden_states_out.append(aux_hidden_states.unsqueeze(0)) - loss_mask_out.append(data[2]) - input_ids_out.append(data[0]) - - # when generating hidden states for offline training, we don't compute logits and only keep the last_hidden_states - # when training online, we don't keep the last_hidden_states and only keep the logits - if logits is not None: - target_out.append(logits.unsqueeze(0)) - else: - target_out.append(None) + kept_aux_hidden_states = [] + kept_targets = [] + kept_loss_masks = [] + kept_input_ids = [] + kept_attention_masks = [] + kept_position_ids = [] - if last_hidden_states is not None: - last_hidden_states_out.append(last_hidden_states.unsqueeze(0)) - else: - last_hidden_states_out.append(None) + for sample_idx, (data, logits, aux_hidden_states) in enumerate( + zip(data_cache, logits_list, aux_hidden_states_list) + ): + should_keep = True + if sequence_parallel and dp_rank is not None and dp_size > 1: + should_keep = (sample_idx % dp_size) == dp_rank + elif not shard_target_logits and dp_rank is not None and dp_size > 1: + should_keep = (sample_idx % dp_size) == dp_rank + + if should_keep: + input_id, attention_mask_, loss_mask_ = data + if sequence_parallel: + input_id = _slice_sequence_for_sp( + input_id, sp_rank, sp_size, ttt_length + ) + attention_mask_ = _slice_sequence_for_sp( + attention_mask_, sp_rank, sp_size, ttt_length + ) + loss_mask_ = _slice_sequence_for_sp( + loss_mask_, sp_rank, sp_size, ttt_length + ) + seq_len = input_id.shape[1] + sp_ulysses_size = max(1, sp_size // sp_ring_size) + usp_chunk_size = max(seq_len - ttt_length, 0) + ring_chunk = usp_chunk_size * sp_ulysses_size + ring_start = sp_ring_rank * ring_chunk + kept_position_ids.append( + torch.arange( + ring_start, + ring_start + ring_chunk, + dtype=torch.long, + device=input_id.device, + ).unsqueeze(0) + ) - aux_hidden_states_out = torch.cat(aux_hidden_states_out, dim=0) + kept_aux_hidden_states.append(aux_hidden_states.unsqueeze(0)) + kept_loss_masks.append(loss_mask_) + kept_input_ids.append(input_id) + kept_attention_masks.append(attention_mask_) + if logits is not None: + kept_targets.append(logits.unsqueeze(0)) - loss_mask_out = torch.cat(loss_mask_out, dim=0) - input_ids_out = torch.cat(input_ids_out, dim=0) + aux_hidden_states_out = torch.cat(kept_aux_hidden_states, dim=0) + loss_mask_out = torch.cat(kept_loss_masks, dim=0) + input_ids_out = torch.cat(kept_input_ids, dim=0) + attention_mask_out = torch.cat(kept_attention_masks, dim=0) - if target_out[0] is not None: - target_out = torch.cat(target_out, dim=0) + if kept_targets: + target_out = torch.cat(kept_targets, dim=0) else: target_out = None - if last_hidden_states_out[0] is not None: - last_hidden_states_out = torch.cat(last_hidden_states_out, dim=0) - else: - last_hidden_states_out = None - target_out = padding(target_out, left=False) input_ids_out = padding(input_ids_out, left=False) loss_mask_out = loss_mask_out[..., None] + position_ids_out = ( + torch.cat(kept_position_ids, dim=0) if kept_position_ids else None + ) return Eagle3TargetOutput( hidden_states=aux_hidden_states_out, target=target_out, loss_mask=loss_mask_out, input_ids=input_ids_out, - attention_mask=attention_mask, - last_hidden_states=last_hidden_states_out, + attention_mask=attention_mask_out, + last_hidden_states=None, + position_ids=position_ids_out, ) diff --git a/specforge/modeling/target/sglang_backend/utils.py b/specforge/modeling/target/sglang_backend/utils.py index 87d384bf..7fc07628 100644 --- a/specforge/modeling/target/sglang_backend/utils.py +++ b/specforge/modeling/target/sglang_backend/utils.py @@ -6,6 +6,7 @@ from typing import List, Optional, Union import torch +import torch.distributed as dist import torch.nn as nn from sglang.srt.layers.logits_processor import ( LogitsMetadata, @@ -25,6 +26,219 @@ class ReplacedLogitsProcessorEagle3Output: logits: torch.Tensor aux_hidden_states: torch.Tensor last_hidden_states: Optional[torch.Tensor] = None + local_sample_indices: Optional[list] = None + local_input_lens: Optional[list] = None + + +def all_to_all_batch_sharded_logits( + logits: torch.Tensor, + group: dist.ProcessGroup, + input_lens: list, +) -> tuple: + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + num_samples = len(input_lens) + local_vocab_size = logits.shape[1] + max_seq_len = max(input_lens) + + if world_size == 1: + return logits, input_lens, list(range(num_samples)) + + padded_batch_size = ((num_samples + world_size - 1) // world_size) * world_size + num_pad_samples = padded_batch_size - num_samples + padded_logits = torch.zeros( + (padded_batch_size, max_seq_len, local_vocab_size), + dtype=logits.dtype, + device=logits.device, + ) + + token_offset = 0 + for sample_idx, seq_len in enumerate(input_lens): + padded_logits[sample_idx, :seq_len] = logits[ + token_offset : token_offset + seq_len + ] + token_offset += seq_len + + padded_input_lens = input_lens + [0] * num_pad_samples + samples_per_rank = padded_batch_size // world_size + input_list = list(padded_logits.chunk(world_size, dim=0)) + output_buf = torch.empty( + (world_size, samples_per_rank, max_seq_len, local_vocab_size), + dtype=logits.dtype, + device=logits.device, + ) + output_list = list(output_buf.unbind(0)) + dist.all_to_all(output_list, input_list, group=group) + local_logits_padded = torch.cat(output_list, dim=-1) + + start_sample = rank * samples_per_rank + end_sample = min(start_sample + samples_per_rank, num_samples) + local_sample_indices = list(range(start_sample, end_sample)) + local_input_lens = padded_input_lens[start_sample:end_sample] + + local_tokens = [] + for sample_idx, seq_len in enumerate(local_input_lens): + if seq_len > 0: + local_tokens.append(local_logits_padded[sample_idx, :seq_len]) + + if not local_tokens: + return ( + torch.empty( + (0, local_logits_padded.shape[-1]), + dtype=logits.dtype, + device=logits.device, + ), + [], + [], + ) + + return ( + torch.cat(local_tokens, dim=0), + [seq_len for seq_len in local_input_lens if seq_len > 0], + local_sample_indices, + ) + + +def all_to_all_sequence_sharded_logits( + logits: torch.Tensor, + group: dist.ProcessGroup, + input_lens: list, + ttt_length: int, + sequence_rank: int, + sequence_size: int, +) -> tuple: + world_size = dist.get_world_size(group) + num_samples = len(input_lens) + local_vocab_size = logits.shape[1] + + if world_size == 1: + return logits, input_lens, list(range(num_samples)) + + if world_size % sequence_size != 0: + raise ValueError( + f"TP size ({world_size}) must be divisible by SP size ({sequence_size}) " + "when sharding target logits for sequence parallel training." + ) + if sequence_rank < 0 or sequence_rank >= sequence_size: + raise ValueError( + f"sequence_rank must be in [0, {sequence_size}), got {sequence_rank}." + ) + + max_seq_len = max(input_lens) + chunk_size = (max_seq_len + sequence_size - 1) // sequence_size + local_len = chunk_size + ttt_length + padded_logits = torch.zeros( + (num_samples, max_seq_len, local_vocab_size), + dtype=logits.dtype, + device=logits.device, + ) + + token_offset = 0 + for sample_idx, seq_len in enumerate(input_lens): + padded_logits[sample_idx, :seq_len] = logits[ + token_offset : token_offset + seq_len + ] + token_offset += seq_len + + sequence_rank_tensor = torch.tensor( + [sequence_rank], dtype=torch.int64, device=logits.device + ) + gathered_sequence_ranks = [ + torch.empty_like(sequence_rank_tensor) for _ in range(world_size) + ] + dist.all_gather(gathered_sequence_ranks, sequence_rank_tensor, group=group) + + send_buf = torch.zeros( + (world_size, num_samples, local_len, local_vocab_size), + dtype=logits.dtype, + device=logits.device, + ) + for dst_rank in range(world_size): + dst_sequence_rank = int(gathered_sequence_ranks[dst_rank].item()) + start = dst_sequence_rank * chunk_size + for sample_idx, seq_len in enumerate(input_lens): + end = min(start + local_len, seq_len) + valid_len = max(0, end - start) + if valid_len > 0: + send_buf[dst_rank, sample_idx, :valid_len] = padded_logits[ + sample_idx, start:end + ] + + input_list = list(send_buf.unbind(0)) + output_buf = torch.empty( + (world_size, num_samples, local_len, local_vocab_size), + dtype=logits.dtype, + device=logits.device, + ) + output_list = list(output_buf.unbind(0)) + dist.all_to_all(output_list, input_list, group=group) + local_logits_padded = torch.cat(output_list, dim=-1) + return ( + local_logits_padded.reshape( + num_samples * local_len, local_logits_padded.shape[-1] + ), + [local_len] * num_samples, + list(range(num_samples)), + ) + + +def slice_hidden_states_by_samples( + hidden_states: torch.Tensor, + input_lens: list, + local_sample_indices: list, +) -> torch.Tensor: + if not local_sample_indices: + return torch.empty( + (0, hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + sample_offsets = [0] + for seq_len in input_lens: + sample_offsets.append(sample_offsets[-1] + seq_len) + + local_tokens = [] + for sample_idx in local_sample_indices: + start = sample_offsets[sample_idx] + end = sample_offsets[sample_idx + 1] + local_tokens.append(hidden_states[start:end]) + return torch.cat(local_tokens, dim=0) + + +def slice_hidden_states_by_sequence_chunks( + hidden_states: torch.Tensor, + input_lens: list, + sequence_rank: int, + sequence_size: int, + ttt_length: int, +) -> torch.Tensor: + if sequence_size == 1: + return hidden_states + + max_seq_len = max(input_lens) + chunk_size = (max_seq_len + sequence_size - 1) // sequence_size + local_len = chunk_size + ttt_length + sample_offsets = [0] + for seq_len in input_lens: + sample_offsets.append(sample_offsets[-1] + seq_len) + + local_tokens = torch.zeros( + (len(input_lens), local_len, hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + start = sequence_rank * chunk_size + for sample_idx, seq_len in enumerate(input_lens): + end = min(start + local_len, seq_len) + valid_len = max(0, end - start) + if valid_len > 0: + sample_start = sample_offsets[sample_idx] + local_tokens[sample_idx, :valid_len] = hidden_states[ + sample_start + start : sample_start + end + ] + + return local_tokens.reshape(len(input_lens) * local_len, hidden_states.shape[1]) def replaced_logits_processor_forward_for_eagle3( @@ -37,6 +251,12 @@ def replaced_logits_processor_forward_for_eagle3( hidden_states_before_norm: Optional[torch.Tensor] = None, return_last_hidden_states: bool = False, return_logits: bool = False, + shard_target_logits: bool = False, + sequence_parallel: bool = False, + sequence_rank: int = 0, + sequence_size: int = 1, + sp_ttt_length: int = 0, + tp_group: Optional[dist.ProcessGroup] = None, ) -> LogitsProcessorOutput: """ This is a modified forward function for the SGLang's logits processor, adapted from https://github.com/sgl-project/sglang/blob/v0.5.4/python/sglang/srt/layers/logits_processor.py. @@ -79,11 +299,58 @@ def replaced_logits_processor_forward_for_eagle3( else: last_hidden_states = None + if hasattr(logits_metadata, "seq_lens"): + all_input_lens = list(logits_metadata.seq_lens) + elif hasattr(logits_metadata, "extend_seq_lens"): + all_input_lens = list(logits_metadata.extend_seq_lens) + else: + all_input_lens = [hidden_states.shape[0]] + + local_sample_indices = None + local_input_lens = None + if return_logits: - # Compute logits for both input and sampled tokens. + original_do_all_gather = self.do_tensor_parallel_all_gather + if ( + shard_target_logits + and tp_group is not None + and dist.get_world_size(tp_group) > 1 + ): + self.do_tensor_parallel_all_gather = False + logits = self._get_logits(pruned_states, lm_head, logits_metadata) + if ( + shard_target_logits + and tp_group is not None + and dist.get_world_size(tp_group) > 1 + ): + if sequence_parallel: + # Convert sequence-major local-vocab logits into local-sequence full-vocab logits. + logits, local_input_lens, local_sample_indices = ( + all_to_all_sequence_sharded_logits( + logits, + tp_group, + all_input_lens, + sp_ttt_length, + sequence_rank, + sequence_size, + ) + ) + else: + # Convert batch-major local-vocab logits into local-batch full-vocab logits. + logits, local_input_lens, local_sample_indices = ( + all_to_all_batch_sharded_logits( + logits, tp_group, all_input_lens + ) + ) + else: + local_input_lens = all_input_lens + local_sample_indices = list(range(len(all_input_lens))) + self.do_tensor_parallel_all_gather = original_do_all_gather else: logits = None + local_input_lens = all_input_lens + local_sample_indices = list(range(len(all_input_lens))) # get the aux hidden states hidden_states_to_store: Optional[torch.Tensor] = None @@ -113,6 +380,44 @@ def replaced_logits_processor_forward_for_eagle3( else: assert False, "Should never reach" + if ( + shard_target_logits + and hidden_states_to_store is not None + and tp_group is not None + and dist.get_world_size(tp_group) > 1 + ): + if sequence_parallel: + hidden_states_to_store = slice_hidden_states_by_sequence_chunks( + hidden_states_to_store, + all_input_lens, + sequence_rank, + sequence_size, + sp_ttt_length, + ) + else: + hidden_states_to_store = slice_hidden_states_by_samples( + hidden_states_to_store, all_input_lens, local_sample_indices + ) + + if ( + shard_target_logits + and last_hidden_states is not None + and tp_group is not None + and dist.get_world_size(tp_group) > 1 + ): + if sequence_parallel: + last_hidden_states = slice_hidden_states_by_sequence_chunks( + last_hidden_states, + all_input_lens, + sequence_rank, + sequence_size, + sp_ttt_length, + ) + else: + last_hidden_states = slice_hidden_states_by_samples( + last_hidden_states, all_input_lens, local_sample_indices + ) + assert ( not logits_metadata.extend_return_logprob ), "extend_return_logprob is not supported" @@ -121,6 +426,8 @@ def replaced_logits_processor_forward_for_eagle3( logits=logits, aux_hidden_states=hidden_states_to_store, last_hidden_states=last_hidden_states, + local_sample_indices=local_sample_indices, + local_input_lens=local_input_lens, ) @@ -130,11 +437,23 @@ def __init__( logits_processor: LogitsProcessor, return_last_hidden_states: bool = False, return_logits: bool = False, + shard_target_logits: bool = False, + sequence_parallel: bool = False, + sequence_rank: int = 0, + sequence_size: int = 1, + sp_ttt_length: int = 0, + tp_group: Optional[dist.ProcessGroup] = None, ): super().__init__() self.logits_processor = logits_processor self.return_last_hidden_states = return_last_hidden_states self.return_logits = return_logits + self.shard_target_logits = shard_target_logits + self.sequence_parallel = sequence_parallel + self.sequence_rank = sequence_rank + self.sequence_size = sequence_size + self.sp_ttt_length = sp_ttt_length + self.tp_group = tp_group def forward( self, @@ -156,18 +475,36 @@ def forward( hidden_states_before_norm, self.return_last_hidden_states, self.return_logits, + self.shard_target_logits, + self.sequence_parallel, + self.sequence_rank, + self.sequence_size, + self.sp_ttt_length, + self.tp_group, ) return ret def wrap_eagle3_logits_processors_in_module( - module: nn.Module, return_full_logits: bool = False + module: nn.Module, + return_full_logits: bool = False, + shard_target_logits: bool = False, + tp_group: Optional[dist.ProcessGroup] = None, ): """ This function will wrap the SGLang's original logits processor with the modified one for EAGLE3. """ for name, submodule in module.named_modules(): if isinstance(submodule, LogitsProcessor): - wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits) + wrapped = LogitsProcessorForEAGLE3( + submodule, + return_last_hidden_states=False, + return_logits=return_full_logits, + shard_target_logits=shard_target_logits, + tp_group=tp_group, + ) setattr(module, name, wrapped) - print(f"wrapped {name} with LogitsProcessorForEAGLE3") + print( + f"wrapped {name} with LogitsProcessorForEAGLE3 " + f"(shard_target_logits={shard_target_logits})" + )