1717import pyarrow .parquet as pq
1818from io import BytesIO
1919import numpy as np
20+ import whisper
2021import torch
2122import torchaudio
2223from torch .nn .utils .rnn import pad_sequence
@@ -179,6 +180,23 @@ def compute_fbank(data,
179180 yield sample
180181
181182
183+ def compute_whisper_fbank (data , num_frames = - 1 , mode = 'train' ):
184+ """ Extract whisper fbank
185+
186+ Args:
187+ data: Iterable[{key, wav, label, sample_rate}]
188+
189+ Returns:
190+ Iterable[{key, feat, label}]
191+ """
192+ for sample in data :
193+ if num_frames != - 1 :
194+ assert sample ['speech' ].shape [1 ] % num_frames == 0 , 'speech length is not aligned with speech_token'
195+ sample ['speech_16k' ] = torchaudio .transforms .Resample (orig_freq = sample ['sample_rate' ], new_freq = 16000 )(sample ['speech' ])
196+ sample ['whisper_feat' ] = whisper .log_mel_spectrogram (sample ['speech_16k' ], n_mels = 128 ).squeeze (dim = 0 ).transpose (0 , 1 )
197+ yield sample
198+
199+
182200def compute_f0 (data , sample_rate , hop_size , mode = 'train' ):
183201 """ Extract f0
184202
@@ -215,11 +233,12 @@ def parse_embedding(data, normalize, mode='train'):
215233 """
216234 for sample in data :
217235 if 'utt_embedding' not in sample and 'spk_embedding' not in sample :
218- speech_16k = torchaudio .transforms .Resample (orig_freq = sample ['sample_rate' ], new_freq = 16000 )(sample ['speech' ])
219- embedding = embedding_extractor .inference (speech_16k )
236+ sample [ ' speech_16k' ] = torchaudio .transforms .Resample (orig_freq = sample ['sample_rate' ], new_freq = 16000 )(sample ['speech' ])
237+ embedding = embedding_extractor .inference (sample [ ' speech_16k' ] )
220238 sample ['spk_embedding' ] = sample ['utt_embedding' ] = embedding
221- sample ['utt_embedding' ] = torch .tensor (sample ['utt_embedding' ], dtype = torch .float32 )
222- sample ['spk_embedding' ] = torch .tensor (sample ['spk_embedding' ], dtype = torch .float32 )
239+ else :
240+ sample ['utt_embedding' ] = torch .tensor (sample ['utt_embedding' ], dtype = torch .float32 )
241+ sample ['spk_embedding' ] = torch .tensor (sample ['spk_embedding' ], dtype = torch .float32 )
223242 if normalize :
224243 sample ['utt_embedding' ] = F .normalize (sample ['utt_embedding' ], dim = 0 )
225244 sample ['spk_embedding' ] = F .normalize (sample ['spk_embedding' ], dim = 0 )
@@ -242,8 +261,6 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
242261 sample ['text_token' ] = tokenizer .encode (sample ['text' ], allowed_special = allowed_special )
243262 if 'instruct' in sample :
244263 sample ['instruct_token' ] = tokenizer .encode (sample ['instruct' ], allowed_special = allowed_special )
245- else :
246- sample ['instruct_token' ] = tokenizer .encode ('' , allowed_special = allowed_special )
247264 yield sample
248265
249266
@@ -371,66 +388,42 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
371388 """
372389 for sample in data :
373390 assert isinstance (sample , list )
374- speech_feat_len = torch .tensor ([x ['speech_feat' ].size (1 ) for x in sample ],
375- dtype = torch .int32 )
376- order = torch .argsort (speech_feat_len , descending = True )
377-
378- utts = [sample [i ]['utt' ] for i in order ]
379- speech = [sample [i ]['speech' ].squeeze (dim = 0 ) for i in order ]
380- speech_len = torch .tensor ([i .size (0 ) for i in speech ], dtype = torch .int32 )
381- speech = pad_sequence (speech , batch_first = True , padding_value = 0 )
382- speech_token = [torch .tensor (sample [i ]['speech_token' ]) for i in order ]
383- speech_token_len = torch .tensor ([i .size (0 ) for i in speech_token ], dtype = torch .int32 )
384- speech_token = pad_sequence (speech_token ,
385- batch_first = True ,
386- padding_value = 0 )
387- speech_feat = [sample [i ]['speech_feat' ] for i in order ]
388- speech_feat_len = torch .tensor ([i .size (0 ) for i in speech_feat ], dtype = torch .int32 )
389- speech_feat = pad_sequence (speech_feat ,
390- batch_first = True ,
391- padding_value = 0 )
392- text = [sample [i ]['text' ] for i in order ]
391+ order = torch .argsort (torch .tensor ([x ['speech' ].size (1 ) for x in sample ], dtype = torch .int32 ), descending = True )
392+ batch = {}
393+ batch ['utts' ] = [sample [i ]['utt' ] for i in order ]
394+ batch ['text' ] = [sample [i ]['text' ] for i in order ]
393395 text_token = [torch .tensor (sample [i ]['text_token' ]) for i in order ]
394- text_token_len = torch .tensor ([i .size (0 ) for i in text_token ], dtype = torch .int32 )
395- text_token = pad_sequence (text_token , batch_first = True , padding_value = 0 )
396- instruct_token = [torch .tensor (sample [i ]['instruct_token' ]) for i in order ]
397- instruct_token_len = torch .tensor ([i .size (0 ) for i in instruct_token ], dtype = torch .int32 )
398- instruct_token = pad_sequence (instruct_token , batch_first = True , padding_value = 0 )
399- utt_embedding = torch .stack ([sample [i ]['utt_embedding' ] for i in order ], dim = 0 )
400- spk_embedding = torch .stack ([sample [i ]['spk_embedding' ] for i in order ], dim = 0 )
401- batch = {
402- "utts" : utts ,
403- "speech" : speech ,
404- "speech_len" : speech_len ,
405- "speech_token" : speech_token ,
406- "speech_token_len" : speech_token_len ,
407- "speech_feat" : speech_feat ,
408- "speech_feat_len" : speech_feat_len ,
409- "text" : text ,
410- "text_token" : text_token ,
411- "text_token_len" : text_token_len ,
412- "instruct_token" : instruct_token ,
413- "instruct_token_len" : instruct_token_len ,
414- "utt_embedding" : utt_embedding ,
415- "spk_embedding" : spk_embedding ,
416- }
396+ batch ['text_token_len' ] = torch .tensor ([i .size (0 ) for i in text_token ], dtype = torch .int32 )
397+ batch ['text_token' ] = pad_sequence (text_token , batch_first = True , padding_value = 0 )
398+ speech_feat = [sample [i ]['speech_feat' ] for i in order ]
399+ batch ['speech_feat_len' ] = torch .tensor ([i .size (0 ) for i in speech_feat ], dtype = torch .int32 )
400+ batch ['speech_feat' ] = pad_sequence (speech_feat , batch_first = True , padding_value = 0 )
401+ batch ['utt_embedding' ] = torch .stack ([sample [i ]['utt_embedding' ] for i in order ], dim = 0 )
402+ batch ['spk_embedding' ] = torch .stack ([sample [i ]['spk_embedding' ] for i in order ], dim = 0 )
403+ if torch .tensor (['instruct_token' in sample [i ] for i in order ]).all ():
404+ instruct_token = [torch .tensor (sample [i ]['instruct_token' ]) for i in order ]
405+ batch ['instruct_token_len' ] = torch .tensor ([i .size (0 ) for i in instruct_token ], dtype = torch .int32 )
406+ batch ['instruct_token' ] = pad_sequence (instruct_token , batch_first = True , padding_value = 0 )
407+ if torch .tensor (['whisper_feat' in sample [i ] for i in order ]).all ():
408+ whisper_feat = [torch .tensor (sample [i ]['whisper_feat' ]) for i in order ]
409+ batch ['whisper_feat_len' ] = torch .tensor ([i .size (0 ) for i in whisper_feat ], dtype = torch .int32 )
410+ batch ['whisper_feat' ] = pad_sequence (whisper_feat , batch_first = True , padding_value = 0 )
411+ if torch .tensor (['speech_token' in sample [i ] for i in order ]).all ():
412+ speech_token = [torch .tensor (sample [i ]['speech_token' ]) for i in order ]
413+ batch ['speech_token_len' ] = torch .tensor ([i .size (0 ) for i in speech_token ], dtype = torch .int32 )
414+ batch ['speech_token' ] = pad_sequence (speech_token , batch_first = True , padding_value = 0 )
417415 if gan is True :
418- # in gan train, we need pitch_feat
416+ # in gan train, we need speech/pitch_feat
417+ speech = [sample [i ]['speech' ].squeeze (dim = 0 ) for i in order ]
418+ batch ['speech_len' ] = torch .tensor ([i .size (0 ) for i in speech ], dtype = torch .int32 )
419+ batch ['speech' ] = pad_sequence (speech , batch_first = True , padding_value = 0 )
419420 pitch_feat = [sample [i ]['pitch_feat' ] for i in order ]
420- pitch_feat_len = torch .tensor ([i .size (0 ) for i in pitch_feat ], dtype = torch .int32 )
421- pitch_feat = pad_sequence (pitch_feat ,
422- batch_first = True ,
423- padding_value = 0 )
424- batch ["pitch_feat" ] = pitch_feat
425- batch ["pitch_feat_len" ] = pitch_feat_len
421+ batch ['pitch_feat_len' ] = torch .tensor ([i .size (0 ) for i in pitch_feat ], dtype = torch .int32 )
422+ batch ['pitch_feat' ] = pad_sequence (pitch_feat , batch_first = True , padding_value = 0 )
426423 if dpo is True :
427424 reject_speech_token = [torch .tensor (sample [i ]['reject_speech_token' ]) for i in order ]
428- reject_speech_token_len = torch .tensor ([i .size (0 ) for i in reject_speech_token ], dtype = torch .int32 )
429- reject_speech_token = pad_sequence (reject_speech_token ,
430- batch_first = True ,
431- padding_value = 0 )
432- batch ['reject_speech_token' ] = reject_speech_token
433- batch ['reject_speech_token_len' ] = reject_speech_token_len
425+ batch ['reject_speech_token_len' ] = torch .tensor ([i .size (0 ) for i in reject_speech_token ], dtype = torch .int32 )
426+ batch ['reject_speech_token' ] = pad_sequence (reject_speech_token , batch_first = True , padding_value = 0 )
434427 if use_spk_embedding is True :
435428 batch ["embedding" ] = batch ["spk_embedding" ]
436429 else :
0 commit comments