Skip to content

Commit 66b80db

Browse files
committed
online feature
1 parent 1822c5c commit 66b80db

22 files changed

Lines changed: 133 additions & 116 deletions

File tree

cosyvoice/bin/train.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def get_args():
4949
parser.add_argument('--train_data', required=True, help='train data file')
5050
parser.add_argument('--cv_data', required=True, help='cv data file')
5151
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
52+
parser.add_argument('--onnx_path', required=False, help='onnx path, which is required for online feature extraction')
5253
parser.add_argument('--checkpoint', help='checkpoint model')
5354
parser.add_argument('--model_dir', required=True, help='save model dir')
5455
parser.add_argument('--tensorboard_dir',
@@ -96,6 +97,7 @@ def get_args():
9697
@record
9798
def main():
9899
args = get_args()
100+
os.environ['onnx_path'] = args.onnx_path
99101
logging.basicConfig(level=logging.DEBUG,
100102
format='%(asctime)s %(levelname)s %(message)s')
101103
# gan train has some special initialization logic
@@ -104,12 +106,10 @@ def main():
104106
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
105107
if gan is True:
106108
override_dict.pop('hift')
107-
try:
108-
with open(args.config, 'r') as f:
109-
configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
110-
except Exception:
111-
with open(args.config, 'r') as f:
112-
configs = load_hyperpyyaml(f, overrides=override_dict)
109+
if args.qwen_pretrain_path is not None:
110+
override_dict['qwen_pretrain_path'] = args.qwen_pretrain_path
111+
with open(args.config, 'r') as f:
112+
configs = load_hyperpyyaml(f, overrides=override_dict)
113113
if gan is True:
114114
configs['train_conf'] = configs['train_conf_gan']
115115
configs['train_conf'].update(vars(args))

cosyvoice/dataset/processor.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616

1717
import pyarrow.parquet as pq
1818
from io import BytesIO
19+
import numpy as np
1920
import torch
2021
import torchaudio
2122
from torch.nn.utils.rnn import pad_sequence
2223
import torch.nn.functional as F
2324
import pyworld as pw
24-
25+
from cosyvoice.utils.onnx import embedding_extractor, online_feature
2526

2627
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
2728

@@ -92,9 +93,9 @@ def filter(data,
9293
continue
9394
if len(sample['text_token']) > token_max_length:
9495
continue
95-
if len(sample['speech_token']) == 0:
96+
if online_feature is False and len(sample['speech_token']) == 0:
9697
continue
97-
if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
98+
if online_feature is False and 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
9899
continue
99100
if num_frames != 0:
100101
if len(sample['text_token']) / num_frames < min_output_input_ratio:
@@ -155,7 +156,7 @@ def truncate(data, truncate_length=24576, mode='train'):
155156

156157
def compute_fbank(data,
157158
feat_extractor,
158-
token_mel_ratio=0,
159+
num_frames=-1,
159160
mode='train'):
160161
""" Extract fbank
161162
@@ -170,14 +171,11 @@ def compute_fbank(data,
170171
assert 'speech' in sample
171172
assert 'utt' in sample
172173
assert 'text_token' in sample
173-
waveform = sample['speech']
174-
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
175-
if token_mel_ratio != 0:
176-
# trim to align speech_token and speech_feat
177-
token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
178-
feat = feat[:token_mel_ratio * token_len]
179-
sample["speech_token"] = sample["speech_token"][:token_len]
180-
sample['speech_feat'] = feat
174+
# NOTE in cosyvoice2/3, we support online token extraction, so we need to align speech to 25hz first
175+
if num_frames != -1:
176+
index = int(np.ceil(sample['speech'].shape[1] / num_frames))
177+
sample['speech'] = torch.concat([sample['speech'], torch.zeros(1, index * num_frames - sample['speech'].shape[1])], dim=1)
178+
sample['speech_feat'] = feat_extractor(sample['speech']).squeeze(dim=0).transpose(0, 1)
181179
yield sample
182180

183181

@@ -216,6 +214,10 @@ def parse_embedding(data, normalize, mode='train'):
216214
Iterable[{key, feat, label}]
217215
"""
218216
for sample in data:
217+
if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
218+
speech_16k = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
219+
embedding = embedding_extractor.inference(speech_16k)
220+
sample['spk_embedding'] = sample['utt_embedding'] = embedding
219221
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
220222
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
221223
if normalize:
@@ -256,13 +258,14 @@ def shuffle(data, shuffle_size=10000, mode='train'):
256258
Iterable[{key, feat, label}]
257259
"""
258260
buf = []
261+
yield_size = int(shuffle_size / 2)
259262
for sample in data:
260263
buf.append(sample)
261264
if len(buf) >= shuffle_size:
262265
random.shuffle(buf)
263-
for x in buf:
266+
for x in buf[:yield_size]:
264267
yield x
265-
buf = []
268+
buf = buf[yield_size:]
266269
# The sample left over
267270
random.shuffle(buf)
268271
for x in buf:
@@ -420,10 +423,6 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
420423
padding_value=0)
421424
batch["pitch_feat"] = pitch_feat
422425
batch["pitch_feat_len"] = pitch_feat_len
423-
else:
424-
# only gan train needs speech, delete it to save memory
425-
del batch["speech"]
426-
del batch["speech_len"]
427426
if dpo is True:
428427
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
429428
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)

cosyvoice/flow/flow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.nn import functional as F
2020
from omegaconf import DictConfig
2121
from cosyvoice.utils.mask import make_pad_mask
22+
from cosyvoice.utils.onnx import SpeechTokenExtractor
2223

2324

2425
class MaskedDiffWithXvec(torch.nn.Module):

cosyvoice/llm/llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from cosyvoice.utils.common import th_accuracy
2929
from cosyvoice.utils.file_utils import logging
3030
from cosyvoice.utils.mask import make_pad_mask
31+
from cosyvoice.utils.onnx import SpeechTokenExtractor
3132

3233

3334
class TransformerLM(torch.nn.Module):

cosyvoice/utils/onnx.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import onnxruntime
2+
import torch, random
3+
from torch import nn
4+
import os
5+
import whisper
6+
import numpy as np
7+
import torchaudio.compliance.kaldi as kaldi
8+
import torch.nn.functional as F
9+
10+
11+
class SpeechTokenExtractor():
12+
def __init__(self, model_path):
13+
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
14+
option = onnxruntime.SessionOptions()
15+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
16+
option.intra_op_num_threads = 1
17+
self.speech_tokenizer_session = onnxruntime.InferenceSession(model_path,
18+
sess_options=option,
19+
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
20+
21+
def inference(self, feat, feat_lengths, device):
22+
ort_out = self.speech_tokenizer_session.run(None,
23+
{self.speech_tokenizer_session.get_inputs()[0].name:
24+
feat.detach().cpu().numpy(),
25+
self.speech_tokenizer_session.get_inputs()[1].name:
26+
feat_lengths.detach().cpu().numpy()})
27+
speech_token, speech_token_embedding = ort_out[0], ort_out[1]
28+
return torch.tensor(speech_token).to(device), (feat_lengths / 2).to(torch.int32).to(device)
29+
30+
31+
class EmbeddingExtractor():
32+
def __init__(self, model_path):
33+
option = onnxruntime.SessionOptions()
34+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
35+
option.intra_op_num_threads = 1
36+
self.max_len = 10 * 16000
37+
self.campplus_session = onnxruntime.InferenceSession(model_path,
38+
sess_options=option,
39+
providers=["CPUExecutionProvider"])
40+
41+
def inference(self, speech):
42+
if speech.shape[1] > self.max_len:
43+
start_index = random.randint(0, speech.shape[1] - self.max_len)
44+
speech = speech[:, start_index: start_index + self.max_len]
45+
feat = kaldi.fbank(speech,
46+
num_mel_bins=80,
47+
dither=0,
48+
sample_frequency=16000)
49+
feat = feat - feat.mean(dim=0, keepdim=True)
50+
embedding = self.campplus_session.run(None,
51+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
52+
return torch.tensor(embedding).to(speech.device)
53+
54+
# singleton mode, only initialized once
55+
onnx_path = os.environ.get('onnx_path')
56+
if onnx_path is not None:
57+
embedding_extractor, online_feature = EmbeddingExtractor(model_path=os.path.join(onnx_path, 'campplus.onnx')), True
58+
else:
59+
embedding_extractor, online_feature = None, False

examples/libritts/cosyvoice/cosyvoice

Lines changed: 0 additions & 1 deletion
This file was deleted.

examples/libritts/cosyvoice/local/prepare_data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def main():
5454
parser.add_argument('--des_dir',
5555
type=str)
5656
parser.add_argument('--instruct',
57-
type=str)
57+
type=str,
58+
default='')
5859
args = parser.parse_args()
5960
main()

examples/libritts/cosyvoice/run.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ fi
2727
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
2828
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
2929
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
30-
tools/extract_embedding.py --dir data/$x \
30+
../../../tools/extract_embedding.py --dir data/$x \
3131
--onnx_path $pretrained_model_dir/campplus.onnx
3232
done
3333
fi
3434

3535
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
3636
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
3737
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
38-
tools/extract_speech_token.py --dir data/$x \
38+
../../../tools/extract_speech_token.py --dir data/$x \
3939
--onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
4040
done
4141
fi
@@ -44,7 +44,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
4444
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
4545
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
4646
mkdir -p data/$x/parquet
47-
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
47+
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
4848
--num_processes 10 \
4949
--src_dir data/$x \
5050
--des_dir data/$x/parquet
@@ -69,7 +69,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
6969
for model in llm flow hifigan; do
7070
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
7171
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
72-
cosyvoice/bin/train.py \
72+
../../../cosyvoice/bin/train.py \
7373
--train_engine $train_engine \
7474
--config conf/cosyvoice.yaml \
7575
--train_data data/train.data.list \

examples/libritts/cosyvoice/tools

Lines changed: 0 additions & 1 deletion
This file was deleted.

examples/libritts/cosyvoice2/conf/cosyvoice2.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ tokenize: !name:cosyvoice.dataset.processor.tokenize
139139
get_tokenizer: !ref <get_tokenizer>
140140
allowed_special: !ref <allowed_special>
141141
filter: !name:cosyvoice.dataset.processor.filter
142-
max_length: 40960
142+
max_length: 6000
143143
min_length: 100
144144
token_max_length: 200
145145
token_min_length: 1
@@ -158,7 +158,7 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
158158
center: False
159159
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
160160
feat_extractor: !ref <feat_extractor>
161-
token_mel_ratio: 2
161+
num_frames: 960
162162
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
163163
sample_rate: !ref <sample_rate>
164164
hop_size: 480

0 commit comments

Comments
 (0)