1616
1717import pyarrow .parquet as pq
1818from io import BytesIO
19+ import numpy as np
1920import torch
2021import torchaudio
2122from torch .nn .utils .rnn import pad_sequence
2223import torch .nn .functional as F
2324import pyworld as pw
24-
25+ from cosyvoice . utils . onnx import embedding_extractor , online_feature
2526
2627AUDIO_FORMAT_SETS = {'flac' , 'mp3' , 'm4a' , 'ogg' , 'opus' , 'wav' , 'wma' }
2728
@@ -92,9 +93,9 @@ def filter(data,
9293 continue
9394 if len (sample ['text_token' ]) > token_max_length :
9495 continue
95- if len (sample ['speech_token' ]) == 0 :
96+ if online_feature is False and len (sample ['speech_token' ]) == 0 :
9697 continue
97- if 'reject_speech_token' in sample and len (sample ['reject_speech_token' ]) == 0 :
98+ if online_feature is False and 'reject_speech_token' in sample and len (sample ['reject_speech_token' ]) == 0 :
9899 continue
99100 if num_frames != 0 :
100101 if len (sample ['text_token' ]) / num_frames < min_output_input_ratio :
@@ -155,7 +156,7 @@ def truncate(data, truncate_length=24576, mode='train'):
155156
156157def compute_fbank (data ,
157158 feat_extractor ,
158- token_mel_ratio = 0 ,
159+ num_frames = - 1 ,
159160 mode = 'train' ):
160161 """ Extract fbank
161162
@@ -170,14 +171,11 @@ def compute_fbank(data,
170171 assert 'speech' in sample
171172 assert 'utt' in sample
172173 assert 'text_token' in sample
173- waveform = sample ['speech' ]
174- feat = feat_extractor (waveform ).squeeze (dim = 0 ).transpose (0 , 1 )
175- if token_mel_ratio != 0 :
176- # trim to align speech_token and speech_feat
177- token_len = int (min (feat .shape [0 ] / token_mel_ratio , sample ["speech_token" ].shape [0 ]))
178- feat = feat [:token_mel_ratio * token_len ]
179- sample ["speech_token" ] = sample ["speech_token" ][:token_len ]
180- sample ['speech_feat' ] = feat
174+ # NOTE in cosyvoice2/3, we support online token extraction, so we need to align speech to 25hz first
175+ if num_frames != - 1 :
176+ index = int (np .ceil (sample ['speech' ].shape [1 ] / num_frames ))
177+ sample ['speech' ] = torch .concat ([sample ['speech' ], torch .zeros (1 , index * num_frames - sample ['speech' ].shape [1 ])], dim = 1 )
178+ sample ['speech_feat' ] = feat_extractor (sample ['speech' ]).squeeze (dim = 0 ).transpose (0 , 1 )
181179 yield sample
182180
183181
@@ -216,6 +214,10 @@ def parse_embedding(data, normalize, mode='train'):
216214 Iterable[{key, feat, label}]
217215 """
218216 for sample in data :
217+ if 'utt_embedding' not in sample and 'spk_embedding' not in sample :
218+ speech_16k = torchaudio .transforms .Resample (orig_freq = sample ['sample_rate' ], new_freq = 16000 )(sample ['speech' ])
219+ embedding = embedding_extractor .inference (speech_16k )
220+ sample ['spk_embedding' ] = sample ['utt_embedding' ] = embedding
219221 sample ['utt_embedding' ] = torch .tensor (sample ['utt_embedding' ], dtype = torch .float32 )
220222 sample ['spk_embedding' ] = torch .tensor (sample ['spk_embedding' ], dtype = torch .float32 )
221223 if normalize :
@@ -256,13 +258,14 @@ def shuffle(data, shuffle_size=10000, mode='train'):
256258 Iterable[{key, feat, label}]
257259 """
258260 buf = []
261+ yield_size = int (shuffle_size / 2 )
259262 for sample in data :
260263 buf .append (sample )
261264 if len (buf ) >= shuffle_size :
262265 random .shuffle (buf )
263- for x in buf :
266+ for x in buf [: yield_size ] :
264267 yield x
265- buf = [ ]
268+ buf = buf [ yield_size : ]
266269 # The sample left over
267270 random .shuffle (buf )
268271 for x in buf :
@@ -420,10 +423,6 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
420423 padding_value = 0 )
421424 batch ["pitch_feat" ] = pitch_feat
422425 batch ["pitch_feat_len" ] = pitch_feat_len
423- else :
424- # only gan train needs speech, delete it to save memory
425- del batch ["speech" ]
426- del batch ["speech_len" ]
427426 if dpo is True :
428427 reject_speech_token = [torch .tensor (sample [i ]['reject_speech_token' ]) for i in order ]
429428 reject_speech_token_len = torch .tensor ([i .size (0 ) for i in reject_speech_token ], dtype = torch .int32 )
0 commit comments