@@ -367,8 +367,11 @@ def forward(
367367 """
368368 text_token = batch ['text_token' ].to (device )
369369 text_token_len = batch ['text_token_len' ].to (device )
370- speech_token = batch ['speech_token' ].to (device )
371- speech_token_len = batch ['speech_token_len' ].to (device )
370+ if 'speech_token' not in batch :
371+ speech_token , speech_token_len = self .speech_token_extractor .inference (batch ['whisper_feat' ], batch ['whisper_feat_len' ], device )
372+ else :
373+ speech_token = batch ['speech_token' ].to (device )
374+ speech_token_len = batch ['speech_token_len' ].to (device )
372375
373376 # 1. encode text_token
374377 text_token_emb = self .llm .model .model .embed_tokens (text_token )
@@ -686,8 +689,12 @@ def forward(
686689 """
687690 text_token = batch ['text_token' ].to (device )
688691 text_token_len = batch ['text_token_len' ].to (device )
689- speech_token = batch ['speech_token' ].to (device )
690- speech_token_len = batch ['speech_token_len' ].to (device )
692+ if 'speech_token' not in batch :
693+ speech_token , speech_token_len = self .speech_token_extractor .inference (batch ['whisper_feat' ], batch ['whisper_feat_len' ], device )
694+ else :
695+ speech_token = batch ['speech_token' ].to (device )
696+ speech_token_len = batch ['speech_token_len' ].to (device )
697+
691698 # NOTE should append instruct_token to sequence, not implemented yet
692699 instruct_token = batch ['instruct_token' ].to (device )
693700 instruct_token_len = batch ['instruct_token_len' ].to (device )
0 commit comments