Skip to content

Commit 84e4172

Browse files
committed
update
1 parent f26cde5 commit 84e4172

4 files changed

Lines changed: 20 additions & 13 deletions

File tree

cosyvoice/flow/flow.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def forward(
189189
device: torch.device,
190190
) -> Dict[str, Optional[torch.Tensor]]:
191191
if 'speech_token' not in batch:
192-
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'])
192+
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
193193
else:
194194
token = batch['speech_token'].to(device)
195195
token_len = batch['speech_token_len'].to(device)
@@ -322,8 +322,11 @@ def forward(
322322
batch: dict,
323323
device: torch.device,
324324
) -> Dict[str, Optional[torch.Tensor]]:
325-
token = batch['speech_token'].to(device)
326-
token_len = batch['speech_token_len'].to(device)
325+
if 'speech_token' not in batch:
326+
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
327+
else:
328+
token = batch['speech_token'].to(device)
329+
token_len = batch['speech_token_len'].to(device)
327330
feat = batch['speech_feat'].to(device)
328331
feat_len = batch['speech_feat_len'].to(device)
329332
embedding = batch['embedding'].to(device)

cosyvoice/llm/llm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,11 @@ def forward(
367367
"""
368368
text_token = batch['text_token'].to(device)
369369
text_token_len = batch['text_token_len'].to(device)
370-
speech_token = batch['speech_token'].to(device)
371-
speech_token_len = batch['speech_token_len'].to(device)
370+
if 'speech_token' not in batch:
371+
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
372+
else:
373+
speech_token = batch['speech_token'].to(device)
374+
speech_token_len = batch['speech_token_len'].to(device)
372375

373376
# 1. encode text_token
374377
text_token_emb = self.llm.model.model.embed_tokens(text_token)
@@ -686,8 +689,12 @@ def forward(
686689
"""
687690
text_token = batch['text_token'].to(device)
688691
text_token_len = batch['text_token_len'].to(device)
689-
speech_token = batch['speech_token'].to(device)
690-
speech_token_len = batch['speech_token_len'].to(device)
692+
if 'speech_token' not in batch:
693+
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
694+
else:
695+
speech_token = batch['speech_token'].to(device)
696+
speech_token_len = batch['speech_token_len'].to(device)
697+
691698
# NOTE should append instruct_token to sequence, not implemented yet
692699
instruct_token = batch['instruct_token'].to(device)
693700
instruct_token_len = batch['instruct_token_len'].to(device)

cosyvoice/utils/onnx.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import onnxruntime
22
import torch, random
3-
from torch import nn
43
import os
5-
import whisper
6-
import numpy as np
74
import torchaudio.compliance.kaldi as kaldi
8-
import torch.nn.functional as F
95

106

117
class SpeechTokenExtractor():
@@ -18,13 +14,13 @@ def __init__(self, model_path):
1814
sess_options=option,
1915
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
2016

21-
def inference(self, feat, feat_lengths):
17+
def inference(self, feat, feat_lengths, device):
2218
speech_token = self.speech_tokenizer_session.run(None,
2319
{self.speech_tokenizer_session.get_inputs()[0].name:
2420
feat.transpose(1, 2).detach().cpu().numpy(),
2521
self.speech_tokenizer_session.get_inputs()[1].name:
2622
feat_lengths.detach().cpu().numpy()})[0]
27-
return torch.tensor(speech_token).to(feat), (feat_lengths / 4).to(torch.int32).to(feat.device)
23+
return torch.tensor(speech_token).to(torch.int32).to(device), (feat_lengths / 4).to(torch.int32).to(device)
2824

2925

3026
class EmbeddingExtractor():

examples/libritts/cosyvoice3/conf/cosyvoice3.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
150150
feat_extractor: !ref <feat_extractor>
151151
num_frames: 960
152152
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
153+
num_frames: 960
153154
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
154155
sample_rate: !ref <sample_rate>
155156
hop_size: 480

0 commit comments

Comments
 (0)