2424from cosyvoice .utils .common import fade_in_out
2525from cosyvoice .utils .file_utils import convert_onnx_to_trt , export_cosyvoice2_vllm
2626from cosyvoice .utils .common import TrtContextWrapper
27+ from cosyvoice .utils .device import get_device , get_stream_context , get_autocast_context , empty_cache
2728
2829
2930class CosyVoiceModel :
@@ -33,7 +34,7 @@ def __init__(self,
3334 flow : torch .nn .Module ,
3435 hift : torch .nn .Module ,
3536 fp16 : bool = False ):
36- self .device = torch . device ( 'cuda' if torch . cuda . is_available () else 'cpu' )
37+ self .device = get_device ( )
3738 self .llm = llm
3839 self .flow = flow
3940 self .hift = hift
@@ -52,7 +53,7 @@ def __init__(self,
5253 # rtf and decoding related
5354 self .stream_scale_factor = 1
5455 assert self .stream_scale_factor >= 1 , 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
55- self .llm_context = torch . cuda . stream ( torch . cuda . Stream ( self .device )) if torch . cuda . is_available () else nullcontext ( )
56+ self .llm_context = get_stream_context ( self .device )
5657 self .lock = threading .Lock ()
5758 # dict used to store session related variable
5859 self .tts_speech_token_dict = {}
@@ -100,7 +101,7 @@ def get_trt_kwargs(self):
100101
101102 def llm_job (self , text , prompt_text , llm_prompt_speech_token , llm_embedding , uuid ):
102103 cur_silent_token_num , max_silent_token_num = 0 , 5
103- with self .llm_context , torch . cuda . amp . autocast (self .fp16 is True and hasattr (self .llm , 'vllm' ) is False ):
104+ with self .llm_context , get_autocast_context (self .fp16 is True and hasattr (self .llm , 'vllm' ) is False , self . device ):
104105 if isinstance (text , Generator ):
105106 assert (self .__class__ .__name__ != 'CosyVoiceModel' ) and not hasattr (self .llm , 'vllm' ), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
106107 token_generator = self .llm .inference_bistream (text = text ,
@@ -133,7 +134,7 @@ def vc_job(self, source_speech_token, uuid):
133134 self .llm_end_dict [uuid ] = True
134135
135136 def token2wav (self , token , prompt_token , prompt_feat , embedding , uuid , finalize = False , speed = 1.0 ):
136- with torch . cuda . amp . autocast (self .fp16 ):
137+ with get_autocast_context (self .fp16 , self . device ):
137138 tts_mel , self .flow_cache_dict [uuid ] = self .flow .inference (token = token .to (self .device , dtype = torch .int32 ),
138139 token_len = torch .tensor ([token .shape [1 ]], dtype = torch .int32 ).to (self .device ),
139140 prompt_token = prompt_token .to (self .device ),
@@ -237,9 +238,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
237238 self .mel_overlap_dict .pop (this_uuid )
238239 self .hift_cache_dict .pop (this_uuid )
239240 self .flow_cache_dict .pop (this_uuid )
240- if torch .cuda .is_available ():
241- torch .cuda .empty_cache ()
242- torch .cuda .current_stream ().synchronize ()
241+ empty_cache (self .device )
243242
244243
245244class CosyVoice2Model (CosyVoiceModel ):
@@ -249,7 +248,7 @@ def __init__(self,
249248 flow : torch .nn .Module ,
250249 hift : torch .nn .Module ,
251250 fp16 : bool = False ):
252- self .device = torch . device ( 'cuda' if torch . cuda . is_available () else 'cpu' )
251+ self .device = get_device ( )
253252 self .llm = llm
254253 self .flow = flow
255254 self .hift = hift
@@ -266,7 +265,7 @@ def __init__(self,
266265 # speech fade in out
267266 self .speech_window = np .hamming (2 * self .source_cache_len )
268267 # rtf and decoding related
269- self .llm_context = torch . cuda . stream ( torch . cuda . Stream ( self .device )) if torch . cuda . is_available () else nullcontext ( )
268+ self .llm_context = get_stream_context ( self .device )
270269 self .lock = threading .Lock ()
271270 # dict used to store session related variable
272271 self .tts_speech_token_dict = {}
@@ -290,7 +289,7 @@ def load_vllm(self, model_dir):
290289 del self .llm .llm .model .model .layers
291290
292291 def token2wav (self , token , prompt_token , prompt_feat , embedding , token_offset , uuid , stream = False , finalize = False , speed = 1.0 ):
293- with torch . cuda . amp . autocast (self .fp16 ):
292+ with get_autocast_context (self .fp16 , self . device ):
294293 tts_mel , _ = self .flow .inference (token = token .to (self .device , dtype = torch .int32 ),
295294 token_len = torch .tensor ([token .shape [1 ]], dtype = torch .int32 ).to (self .device ),
296295 prompt_token = prompt_token .to (self .device ),
@@ -389,9 +388,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
389388 self .tts_speech_token_dict .pop (this_uuid )
390389 self .llm_end_dict .pop (this_uuid )
391390 self .hift_cache_dict .pop (this_uuid )
392- if torch .cuda .is_available ():
393- torch .cuda .empty_cache ()
394- torch .cuda .current_stream ().synchronize ()
391+ empty_cache (self .device )
395392
396393
397394class CosyVoice3Model (CosyVoice2Model ):
@@ -401,7 +398,7 @@ def __init__(self,
401398 flow : torch .nn .Module ,
402399 hift : torch .nn .Module ,
403400 fp16 : bool = False ):
404- self .device = torch . device ( 'cuda' if torch . cuda . is_available () else 'cpu' )
401+ self .device = get_device ( )
405402 self .llm = llm
406403 self .flow = flow
407404 self .hift = hift
@@ -413,7 +410,7 @@ def __init__(self,
413410 self .stream_scale_factor = 2
414411 assert self .stream_scale_factor >= 1 , 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
415412 # rtf and decoding related
416- self .llm_context = torch . cuda . stream ( torch . cuda . Stream ( self .device )) if torch . cuda . is_available () else nullcontext ( )
413+ self .llm_context = get_stream_context ( self .device )
417414 self .lock = threading .Lock ()
418415 # dict used to store session related variable
419416 self .tts_speech_token_dict = {}
@@ -423,7 +420,7 @@ def __init__(self,
423420 self .silent_tokens = [1 , 2 , 28 , 29 , 55 , 248 , 494 , 2241 , 2242 , 2322 , 2323 ]
424421
425422 def token2wav (self , token , prompt_token , prompt_feat , embedding , token_offset , uuid , stream = False , finalize = False , speed = 1.0 ):
426- with torch . cuda . amp . autocast (self .fp16 ):
423+ with get_autocast_context (self .fp16 , self . device ):
427424 tts_mel , _ = self .flow .inference (token = token .to (self .device , dtype = torch .int32 ),
428425 token_len = torch .tensor ([token .shape [1 ]], dtype = torch .int32 ).to (self .device ),
429426 prompt_token = prompt_token .to (self .device ),
0 commit comments