Skip to content

Commit f26cde5

Browse files
committed
update
1 parent 66b80db commit f26cde5

7 files changed

Lines changed: 90 additions & 73 deletions

File tree

cosyvoice/cli/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ def __init__(self,
256256
self.fp16 = fp16
257257
# NOTE must matching training static_chunk_size
258258
self.token_hop_len = 25
259+
# NOTE increase token_hop_len incrementally to avoid duplicate inference
260+
self.token_max_hop_len = 4 * self.token_hop_len
261+
self.stream_scale_factor = 2
262+
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
259263
# hift cache
260264
self.mel_cache_len = 8
261265
self.source_cache_len = int(self.mel_cache_len * 480)
@@ -353,6 +357,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
353357
stream=stream,
354358
finalize=False)
355359
token_offset += this_token_hop_len
360+
self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor)
356361
yield {'tts_speech': this_tts_speech.cpu()}
357362
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
358363
break
@@ -403,6 +408,10 @@ def __init__(self,
403408
self.fp16 = fp16
404409
# NOTE must matching training static_chunk_size
405410
self.token_hop_len = 25
411+
# NOTE increase token_hop_len incrementally to avoid duplicate inference
412+
self.token_max_hop_len = 4 * self.token_hop_len
413+
self.stream_scale_factor = 2
414+
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
406415
# rtf and decoding related
407416
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
408417
self.lock = threading.Lock()

cosyvoice/dataset/processor.py

Lines changed: 54 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pyarrow.parquet as pq
1818
from io import BytesIO
1919
import numpy as np
20+
import whisper
2021
import torch
2122
import torchaudio
2223
from torch.nn.utils.rnn import pad_sequence
@@ -179,6 +180,23 @@ def compute_fbank(data,
179180
yield sample
180181

181182

183+
def compute_whisper_fbank(data, num_frames=-1, mode='train'):
184+
""" Extract whisper fbank
185+
186+
Args:
187+
data: Iterable[{key, wav, label, sample_rate}]
188+
189+
Returns:
190+
Iterable[{key, feat, label}]
191+
"""
192+
for sample in data:
193+
if num_frames != -1:
194+
assert sample['speech'].shape[1] % num_frames == 0, 'speech length is not aligned with speech_token'
195+
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
196+
sample['whisper_feat'] = whisper.log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1)
197+
yield sample
198+
199+
182200
def compute_f0(data, sample_rate, hop_size, mode='train'):
183201
""" Extract f0
184202
@@ -215,11 +233,12 @@ def parse_embedding(data, normalize, mode='train'):
215233
"""
216234
for sample in data:
217235
if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
218-
speech_16k = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
219-
embedding = embedding_extractor.inference(speech_16k)
236+
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
237+
embedding = embedding_extractor.inference(sample['speech_16k'])
220238
sample['spk_embedding'] = sample['utt_embedding'] = embedding
221-
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
222-
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
239+
else:
240+
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
241+
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
223242
if normalize:
224243
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
225244
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
@@ -242,8 +261,6 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
242261
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
243262
if 'instruct' in sample:
244263
sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
245-
else:
246-
sample['instruct_token'] = tokenizer.encode('', allowed_special=allowed_special)
247264
yield sample
248265

249266

@@ -371,66 +388,42 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
371388
"""
372389
for sample in data:
373390
assert isinstance(sample, list)
374-
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
375-
dtype=torch.int32)
376-
order = torch.argsort(speech_feat_len, descending=True)
377-
378-
utts = [sample[i]['utt'] for i in order]
379-
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
380-
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
381-
speech = pad_sequence(speech, batch_first=True, padding_value=0)
382-
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
383-
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
384-
speech_token = pad_sequence(speech_token,
385-
batch_first=True,
386-
padding_value=0)
387-
speech_feat = [sample[i]['speech_feat'] for i in order]
388-
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
389-
speech_feat = pad_sequence(speech_feat,
390-
batch_first=True,
391-
padding_value=0)
392-
text = [sample[i]['text'] for i in order]
391+
order = torch.argsort(torch.tensor([x['speech'].size(1) for x in sample], dtype=torch.int32), descending=True)
392+
batch = {}
393+
batch['utts'] = [sample[i]['utt'] for i in order]
394+
batch['text'] = [sample[i]['text'] for i in order]
393395
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
394-
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
395-
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
396-
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
397-
instruct_token_len = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
398-
instruct_token = pad_sequence(instruct_token, batch_first=True, padding_value=0)
399-
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
400-
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
401-
batch = {
402-
"utts": utts,
403-
"speech": speech,
404-
"speech_len": speech_len,
405-
"speech_token": speech_token,
406-
"speech_token_len": speech_token_len,
407-
"speech_feat": speech_feat,
408-
"speech_feat_len": speech_feat_len,
409-
"text": text,
410-
"text_token": text_token,
411-
"text_token_len": text_token_len,
412-
"instruct_token": instruct_token,
413-
"instruct_token_len": instruct_token_len,
414-
"utt_embedding": utt_embedding,
415-
"spk_embedding": spk_embedding,
416-
}
396+
batch['text_token_len'] = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
397+
batch['text_token'] = pad_sequence(text_token, batch_first=True, padding_value=0)
398+
speech_feat = [sample[i]['speech_feat'] for i in order]
399+
batch['speech_feat_len'] = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
400+
batch['speech_feat'] = pad_sequence(speech_feat, batch_first=True, padding_value=0)
401+
batch['utt_embedding'] = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
402+
batch['spk_embedding'] = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
403+
if torch.tensor(['instruct_token' in sample[i] for i in order]).all():
404+
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
405+
batch['instruct_token_len'] = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
406+
batch['instruct_token'] = pad_sequence(instruct_token, batch_first=True, padding_value=0)
407+
if torch.tensor(['whisper_feat' in sample[i] for i in order]).all():
408+
whisper_feat = [torch.tensor(sample[i]['whisper_feat']) for i in order]
409+
batch['whisper_feat_len'] = torch.tensor([i.size(0) for i in whisper_feat], dtype=torch.int32)
410+
batch['whisper_feat'] = pad_sequence(whisper_feat, batch_first=True, padding_value=0)
411+
if torch.tensor(['speech_token' in sample[i] for i in order]).all():
412+
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
413+
batch['speech_token_len'] = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
414+
batch['speech_token'] = pad_sequence(speech_token, batch_first=True, padding_value=0)
417415
if gan is True:
418-
# in gan train, we need pitch_feat
416+
# in gan train, we need speech/pitch_feat
417+
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
418+
batch['speech_len'] = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
419+
batch['speech'] = pad_sequence(speech, batch_first=True, padding_value=0)
419420
pitch_feat = [sample[i]['pitch_feat'] for i in order]
420-
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
421-
pitch_feat = pad_sequence(pitch_feat,
422-
batch_first=True,
423-
padding_value=0)
424-
batch["pitch_feat"] = pitch_feat
425-
batch["pitch_feat_len"] = pitch_feat_len
421+
batch['pitch_feat_len'] = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
422+
batch['pitch_feat'] = pad_sequence(pitch_feat, batch_first=True, padding_value=0)
426423
if dpo is True:
427424
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
428-
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
429-
reject_speech_token = pad_sequence(reject_speech_token,
430-
batch_first=True,
431-
padding_value=0)
432-
batch['reject_speech_token'] = reject_speech_token
433-
batch['reject_speech_token_len'] = reject_speech_token_len
425+
batch['reject_speech_token_len'] = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
426+
batch['reject_speech_token'] = pad_sequence(reject_speech_token, batch_first=True, padding_value=0)
434427
if use_spk_embedding is True:
435428
batch["embedding"] = batch["spk_embedding"]
436429
else:

cosyvoice/flow/flow.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import logging
14+
import os, logging
1515
import random
1616
from typing import Dict, Optional
1717
import torch
1818
import torch.nn as nn
1919
from torch.nn import functional as F
2020
from omegaconf import DictConfig
2121
from cosyvoice.utils.mask import make_pad_mask
22-
from cosyvoice.utils.onnx import SpeechTokenExtractor
22+
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
2323

2424

2525
class MaskedDiffWithXvec(torch.nn.Module):
@@ -180,14 +180,19 @@ def __init__(self,
180180
self.only_mask_loss = only_mask_loss
181181
self.token_mel_ratio = token_mel_ratio
182182
self.pre_lookahead_len = pre_lookahead_len
183+
if online_feature is True:
184+
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
183185

184186
def forward(
185187
self,
186188
batch: dict,
187189
device: torch.device,
188190
) -> Dict[str, Optional[torch.Tensor]]:
189-
token = batch['speech_token'].to(device)
190-
token_len = batch['speech_token_len'].to(device)
191+
if 'speech_token' not in batch:
192+
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'])
193+
else:
194+
token = batch['speech_token'].to(device)
195+
token_len = batch['speech_token_len'].to(device)
191196
feat = batch['speech_feat'].to(device)
192197
feat_len = batch['speech_feat_len'].to(device)
193198
embedding = batch['embedding'].to(device)
@@ -309,6 +314,8 @@ def __init__(self,
309314
self.decoder = decoder
310315
self.only_mask_loss = only_mask_loss
311316
self.token_mel_ratio = token_mel_ratio
317+
if online_feature is True:
318+
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
312319

313320
def forward(
314321
self,

cosyvoice/llm/llm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import queue
15+
import os, queue
1616
import random
1717
import time
1818
import threading
@@ -28,7 +28,7 @@
2828
from cosyvoice.utils.common import th_accuracy
2929
from cosyvoice.utils.file_utils import logging
3030
from cosyvoice.utils.mask import make_pad_mask
31-
from cosyvoice.utils.onnx import SpeechTokenExtractor
31+
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
3232

3333

3434
class TransformerLM(torch.nn.Module):
@@ -301,6 +301,8 @@ def __init__(
301301
# 5. vllm related
302302
self.stop_token_ids = [speech_token_size + i for i in range(3)]
303303
self.vllm_output_queue = {}
304+
if online_feature is True:
305+
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
304306

305307
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
306308
lm_target, lm_input = [], []
@@ -667,6 +669,8 @@ def __init__(
667669
# 5. vllm related
668670
self.stop_token_ids = [speech_token_size + i for i in range(200)]
669671
self.vllm_output_queue = {}
672+
if online_feature is True:
673+
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
670674

671675
def forward(
672676
self,

cosyvoice/utils/onnx.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@ def __init__(self, model_path):
1818
sess_options=option,
1919
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
2020

21-
def inference(self, feat, feat_lengths, device):
22-
ort_out = self.speech_tokenizer_session.run(None,
21+
def inference(self, feat, feat_lengths):
22+
speech_token = self.speech_tokenizer_session.run(None,
2323
{self.speech_tokenizer_session.get_inputs()[0].name:
24-
feat.detach().cpu().numpy(),
24+
feat.transpose(1, 2).detach().cpu().numpy(),
2525
self.speech_tokenizer_session.get_inputs()[1].name:
26-
feat_lengths.detach().cpu().numpy()})
27-
speech_token, speech_token_embedding = ort_out[0], ort_out[1]
28-
return torch.tensor(speech_token).to(device), (feat_lengths / 2).to(torch.int32).to(device)
26+
feat_lengths.detach().cpu().numpy()})[0]
27+
return torch.tensor(speech_token).to(feat), (feat_lengths / 4).to(torch.int32).to(feat.device)
2928

3029

3130
class EmbeddingExtractor():

examples/libritts/cosyvoice2/conf/cosyvoice2.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
159159
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
160160
feat_extractor: !ref <feat_extractor>
161161
num_frames: 960
162+
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
163+
num_frames: 960
162164
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
163165
sample_rate: !ref <sample_rate>
164166
hop_size: 480
@@ -183,6 +185,7 @@ data_pipeline: [
183185
!ref <resample>,
184186
!ref <compute_fbank>,
185187
!ref <parse_embedding>,
188+
!ref <compute_whisper_fbank>,
186189
!ref <shuffle>,
187190
!ref <sort>,
188191
!ref <batch>,

examples/libritts/cosyvoice3/conf/cosyvoice3.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
149149
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
150150
feat_extractor: !ref <feat_extractor>
151151
num_frames: 960
152+
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
152153
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
153154
sample_rate: !ref <sample_rate>
154155
hop_size: 480
@@ -173,6 +174,7 @@ data_pipeline: [
173174
!ref <resample>,
174175
!ref <compute_fbank>,
175176
!ref <parse_embedding>,
177+
!ref <compute_whisper_fbank>,
176178
!ref <shuffle>,
177179
!ref <sort>,
178180
!ref <batch>,

0 commit comments

Comments
 (0)