diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index d3beb9ec2..fe6a63b31 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -92,20 +92,23 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False): # Do not use concat, it may cause memory format changed and trt infer with wrong results! # NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype - x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype) - mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype) - mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype) - t_in = torch.zeros([2], device=x.device, dtype=spks.dtype) - spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype) - cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype) + bsz = x.size(0) + x_in = torch.zeros([2 * bsz, 80, x.size(2)], device=x.device, dtype=spks.dtype) + mask_in = torch.zeros([2 * bsz, 1, x.size(2)], device=x.device, dtype=spks.dtype) + mu_in = torch.zeros([2 * bsz, 80, x.size(2)], device=x.device, dtype=spks.dtype) + t_in = torch.zeros([2 * bsz], device=x.device, dtype=spks.dtype) + spks_in = torch.zeros([2 * bsz, 80], device=x.device, dtype=spks.dtype) + cond_in = torch.zeros([2 * bsz, 80, x.size(2)], device=x.device, dtype=spks.dtype) for step in range(1, len(t_span)): # Classifier-Free Guidance inference introduced in VoiceBox - x_in[:] = x - mask_in[:] = mask - mu_in[0] = mu + x_in[:bsz] = x + x_in[bsz:] = x + mask_in[:bsz] = mask + mask_in[bsz:] = mask + mu_in[:bsz] = mu t_in[:] = t.unsqueeze(0) - spks_in[0] = spks - cond_in[0] = cond + spks_in[:bsz] = spks + cond_in[:bsz] = cond dphi_dt = self.forward_estimator( x_in, mask_in, mu_in, t_in,