[REFACTOR][CUDA] Phase out cuda_common.h#19770
Conversation
The CUDA runtime can rely on the shared tvm-ffi CUDA error helper instead of carrying a TVM-local common header. This PR removes the legacy header by moving the CUDA workspace state into cuda_device_api.cc, replacing CUDA_CALL users with TVM_FFI_CHECK_CUDA_ERROR, and inlining the remaining CUDA driver checks at their call sites.
There was a problem hiding this comment.
Code Review
This pull request migrates the CUDA runtime and contrib modules to use the new FFI-based error checking (TVM_FFI_CHECK_CUDA_ERROR) and removes the legacy cuda_common.h header. The review feedback highlights several critical issues regarding exception safety and error handling: potential memory and resource leaks on exception paths in cuda_ipc_memory.cc, tensorrt_calibrator.h, and cublas_utils.cc; potential corruption of the thread-local stream environment in cuda_graph_builtin.cc if stream capture fails; and potential undefined behavior or crashes in cuda_device_api.cc and cuda_module.cc if cuGetErrorName returns a null pointer.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| TVM_FFI_CHECK_CUDA_ERROR(cudaMalloc(&d_src, CUDA_IPC_HANDLE_SIZE)); | ||
| TVM_FFI_CHECK_CUDA_ERROR(cudaMalloc(&d_dst, CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers)); | ||
| TVM_FFI_CHECK_CUDA_ERROR( | ||
| cudaMemcpy(d_src, &local_handle, CUDA_IPC_HANDLE_SIZE, cudaMemcpyHostToDevice)); | ||
| NCCL_CALL(ncclAllGather(d_src, d_dst, CUDA_IPC_HANDLE_SIZE, ncclChar, ctx->global_comm, | ||
| /*stream=*/nullptr)); | ||
| std::vector<char> serial_handles(CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers, 0); | ||
| CUDA_CALL(cudaMemcpy(serial_handles.data(), d_dst, | ||
| CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers, cudaMemcpyDefault)); | ||
| TVM_FFI_CHECK_CUDA_ERROR(cudaMemcpy(serial_handles.data(), d_dst, | ||
| CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers, | ||
| cudaMemcpyDefault)); | ||
| std::vector<cudaIpcMemHandle_t> handles(ctx->worker->num_workers); | ||
| for (int i = 0; i < ctx->worker->num_workers; ++i) { | ||
| memcpy(handles[i].reserved, &serial_handles[i * CUDA_IPC_HANDLE_SIZE], CUDA_IPC_HANDLE_SIZE); | ||
| } | ||
| CUDA_CALL(cudaFree(d_src)); | ||
| CUDA_CALL(cudaFree(d_dst)); | ||
| TVM_FFI_CHECK_CUDA_ERROR(cudaFree(d_src)); | ||
| TVM_FFI_CHECK_CUDA_ERROR(cudaFree(d_dst)); |
There was a problem hiding this comment.
If cudaMalloc for d_dst, cudaMemcpy, or ncclAllGather throws an exception, the allocated GPU memory d_src and d_dst will be leaked. We should use a custom deleter with std::unique_ptr to ensure cudaFree is called on exception paths.
void *d_src = nullptr, *d_dst = nullptr;
struct CUDAFreeDeleter {
void operator()(void* ptr) const { if (ptr) cudaFree(ptr); }
};
TVM_FFI_CHECK_CUDA_ERROR(cudaMalloc(&d_src, CUDA_IPC_HANDLE_SIZE));
std::unique_ptr<void, CUDAFreeDeleter> src_guard(d_src);
TVM_FFI_CHECK_CUDA_ERROR(cudaMalloc(&d_dst, CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers));
std::unique_ptr<void, CUDAFreeDeleter> dst_guard(d_dst);
TVM_FFI_CHECK_CUDA_ERROR(
cudaMemcpy(d_src, &local_handle, CUDA_IPC_HANDLE_SIZE, cudaMemcpyHostToDevice));
NCCL_CALL(ncclAllGather(d_src, d_dst, CUDA_IPC_HANDLE_SIZE, ncclChar, ctx->global_comm,
/*stream=*/nullptr));
std::vector<char> serial_handles(CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers, 0);
TVM_FFI_CHECK_CUDA_ERROR(cudaMemcpy(serial_handles.data(), d_dst,
CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers,
cudaMemcpyDefault));
std::vector<cudaIpcMemHandle_t> handles(ctx->worker->num_workers);
for (int i = 0; i < ctx->worker->num_workers; ++i) {
memcpy(handles[i].reserved, &serial_handles[i * CUDA_IPC_HANDLE_SIZE], CUDA_IPC_HANDLE_SIZE);
}
return handles;| TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id_)); | ||
| TVM_FFI_CHECK_SAFE_CALL( | ||
| TVMFFIEnvSetStream(kDLCUDA, device_id_, capture_stream_, | ||
| reinterpret_cast<TVMFFIStreamHandle*>(&prev_default_stream_))); | ||
| CUDA_CALL(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); | ||
| TVM_FFI_CHECK_CUDA_ERROR(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); |
There was a problem hiding this comment.
If cudaStreamBeginCapture throws an exception, the constructor of CUDACaptureStream fails and the destructor is never called. This leaves the thread-local stream environment pointing to the destroyed capture_stream_ member, causing subsequent operations to crash or exhibit undefined behavior. We should catch any exception and restore the stream environment before rethrowing.
TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id_));
TVM_FFI_CHECK_SAFE_CALL(
TVMFFIEnvSetStream(kDLCUDA, device_id_, capture_stream_,
reinterpret_cast<TVMFFIStreamHandle*>(&prev_default_stream_)));
try {
TVM_FFI_CHECK_CUDA_ERROR(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal));
} catch (...) {
TVMFFIEnvSetStream(kDLCUDA, device_id_, prev_default_stream_, nullptr);
throw;
}| CUresult result = cuDeviceGetName(&name[0], name.size(), dev.device_id); | ||
| if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { | ||
| const char* msg; | ||
| cuGetErrorName(result, &msg); | ||
| TVM_FFI_THROW(CUDAError) << "cuDeviceGetName failed with error: " << msg; | ||
| } |
There was a problem hiding this comment.
The cuGetErrorName function can return CUDA_ERROR_INVALID_VALUE if the error code is unrecognized, which sets msg to nullptr. Streaming a null pointer to TVM_FFI_THROW results in undefined behavior (and potential crashes). It is safer to check if msg is valid or use a fallback string.
CUresult result = cuDeviceGetName(&name[0], name.size(), dev.device_id);
if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) {
const char* msg = nullptr;
cuGetErrorName(result, &msg);
TVM_FFI_THROW(CUDAError) << "cuDeviceGetName failed with error: " << (msg ? msg : "unknown error");
}| CUresult result = cuModuleLoadData(&(module_[device_id]), code_.data()); | ||
| if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { | ||
| const char* msg; | ||
| cuGetErrorName(result, &msg); | ||
| TVM_FFI_THROW(CUDAError) << "cuModuleLoadData failed with error: " << msg; | ||
| } |
There was a problem hiding this comment.
The cuGetErrorName function can return CUDA_ERROR_INVALID_VALUE if the error code is unrecognized, which sets msg to nullptr. Streaming a null pointer to TVM_FFI_THROW results in undefined behavior (and potential crashes). It is safer to check if msg is valid or use a fallback string.
| CUresult result = cuModuleLoadData(&(module_[device_id]), code_.data()); | |
| if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { | |
| const char* msg; | |
| cuGetErrorName(result, &msg); | |
| TVM_FFI_THROW(CUDAError) << "cuModuleLoadData failed with error: " << msg; | |
| } | |
| CUresult result = cuModuleLoadData(&(module_[device_id]), code_.data()); | |
| if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { | |
| const char* msg = nullptr; | |
| cuGetErrorName(result, &msg); | |
| TVM_FFI_THROW(CUDAError) << "cuModuleLoadData failed with error: " << (msg ? msg : "unknown error"); | |
| } |
| std::vector<float*> data_host(bindings.size(), nullptr); | ||
| for (size_t i = 0; i < bindings.size(); ++i) { | ||
| data_host[i] = new float[batch_size_ * binding_sizes[i]]; | ||
| CUDA_CALL(cudaMemcpy(static_cast<void*>(data_host[i]), bindings[i], | ||
| batch_size_ * binding_sizes[i] * sizeof(float), cudaMemcpyDeviceToHost)); | ||
| TVM_FFI_CHECK_CUDA_ERROR(cudaMemcpy(static_cast<void*>(data_host[i]), bindings[i], | ||
| batch_size_ * binding_sizes[i] * sizeof(float), | ||
| cudaMemcpyDeviceToHost)); | ||
| } |
There was a problem hiding this comment.
If cudaMemcpy throws an exception, any previously allocated data_host[i] in the loop will be leaked because they are not yet pushed to data_ and the destructor only cleans up data_. We should use a try-catch block to free them on failure.
std::vector<float*> data_host(bindings.size(), nullptr);
try {
for (size_t i = 0; i < bindings.size(); ++i) {
data_host[i] = new float[batch_size_ * binding_sizes[i]];
TVM_FFI_CHECK_CUDA_ERROR(cudaMemcpy(static_cast<void*>(data_host[i]), bindings[i],
batch_size_ * binding_sizes[i] * sizeof(float),
cudaMemcpyDeviceToHost));
}
} catch (...) {
for (float* ptr : data_host) {
delete[] ptr;
}
throw;
}| CuBlasLtThreadEntry::CuBlasLtThreadEntry() { | ||
| CHECK_CUBLAS_ERROR(cublasLtCreate(&handle)); | ||
| CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&matmul_pref_desc)); | ||
| CUDA_CALL(cudaMalloc(&workspace_ptr, workspace_size)); | ||
| TVM_FFI_CHECK_CUDA_ERROR(cudaMalloc(&workspace_ptr, workspace_size)); | ||
| } |
There was a problem hiding this comment.
If cudaMalloc throws an exception, the constructor fails and the destructor is not called. This causes handle and matmul_pref_desc to be leaked. We should wrap the constructor body in a try-catch block to perform cleanup on failure.
CuBlasLtThreadEntry::CuBlasLtThreadEntry() {
CHECK_CUBLAS_ERROR(cublasLtCreate(&handle));
try {
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&matmul_pref_desc));
TVM_FFI_CHECK_CUDA_ERROR(cudaMalloc(&workspace_ptr, workspace_size));
} catch (...) {
if (matmul_pref_desc) {
cublasLtMatmulPreferenceDestroy(matmul_pref_desc);
}
if (handle) {
cublasLtDestroy(handle);
}
throw;
}
}
Summary
The CUDA runtime can rely on the shared tvm-ffi CUDA error helper instead of carrying a TVM-local common header. This keeps CUDA error handling aligned with the FFI CUDA support and removes the remaining reason for src/backend/cuda/runtime/cuda_common.h.
Main changes: