Add Metal DLPack zero-copy sharing#3531
Conversation
|
Hi @XXXXRT666 — read through this PR after @awni redirected us here from #3548. The Wanted to offer some testing help that complements the PyTorch MPS bring-up you have: We maintain a downstream TileLang fork (https://github.com/DatasunriseOU/tilelang) whose TVM-FFI bridge exports
If useful, once the PR converges I can:
Tag me here when you'd like input — no rush, just don't want this to slip past once it's review-ready. (For the orthogonal |
Required by ml-explore/mlx PR ml-explore#3531 (Metal DLPack zero-copy sharing). SHA: 33f52e635db5e6229060481d16a167230a1a474b PR: wjakob/nanobind#1338 Branch: metal-dlpack-cast
002360f to
4e16f1d
Compare
|
This would be super cool if it landed for end to end "0-copy" support in safetensors! I'm working (safetensors/safetensors#767) on adding reading bytes from disk in raw Also, support for Quick question on the |
Required by ml-explore/mlx PR ml-explore#3531 (Metal DLPack zero-copy sharing). SHA: 33f52e635db5e6229060481d16a167230a1a474b PR: wjakob/nanobind#1338 Branch: metal-dlpack-cast
https://dmlc.github.io/dlpack/latest/c_api.html#c.DLTensor.data The data pointer points to the allocated data. This will be CUDA device pointer, |
4e16f1d to
a17cd99
Compare
|
One API question: should That would match the mental model used by NumPy/PyTorch more closely: |
| auto copied_shape = a.shape(); // |a| will be moved | ||
| auto dtype = a.dtype(); | ||
| return array( | ||
| std::move(copied_shape), |
There was a problem hiding this comment.
Also there is no need to use std::move since shapes are no longer heap allocated, we don't migrate the old code but try to use the new pattern in new code.
There was a problem hiding this comment.
I think this should be handled explicitly rather than relying on argument evaluation order. Since a.shape() and {std::move(a)} are evaluated as separate function arguments, C++ may evaluate {std::move(a)} first. In that case, a.shape() would be called after a has been moved from.
Linux and Windows CPU CI failed because of this
I think this is very good design. |
mlx 0.31.2mlx 0.32.0.dev20260523+4e8decde9benchmark function# --- mlx -> torch candidates ---------------------------------------------------
def mlx_to_torch_current(arr: mx.array, device: torch.device) -> torch.Tensor:
arr = mx.contiguous(arr)
mx.eval(arr)
buf = memoryview(arr)
dtype_map = {
mx.float32: torch.float32,
mx.float16: torch.float16,
mx.bfloat16: torch.bfloat16,
}
t = torch.frombuffer(buf, dtype=dtype_map[arr.dtype]).reshape(arr.shape)
if device.type == "mps":
t = t.to(device)
return t
def mlx_to_torch_dlpack_mps(arr: mx.array, device: torch.device) -> torch.Tensor:
mx.eval(arr)
t = torch.from_dlpack(arr)
if device.type == "mps":
t = t.to(device)
return t
def mlx_to_torch_dlpack_cpu(arr: mx.array, device: torch.device) -> torch.Tensor:
"""Force a CPU-typed capsule via `dl_device=(kDLCPU, 0)` (Phase 2+).
Falls back to the no-kwarg form for builds that don't accept it."""
mx.eval(arr)
try:
cap = arr.__dlpack__(dl_device=(1, 0))
except TypeError:
# Older builds: zero-arg lambda. Capsule is already kDLCPU there.
cap = arr.__dlpack__()
return torch.from_dlpack(cap)
# --- torch -> mlx candidates ---------------------------------------------------
def torch_to_mlx_current(t: torch.Tensor) -> mx.array:
if t.device.type != "cpu":
t = t.cpu()
t = t.detach()
if t.dtype == torch.bfloat16:
return mx.array(t)
return mx.array(t.numpy())
def torch_to_mlx_dlpack(t: torch.Tensor) -> mx.array:
"""Use mx.from_dlpack when the API exists; old MLX falls back to CPU copy."""
if hasattr(mx, "from_dlpack"):
return mx.from_dlpack(t)
return mx.array(t.detach().cpu()) |
I updated the PR to follow this design: |
This reverts commit eb74695.
0607c24 to
a4378cf
Compare
zcbenz
left a comment
There was a problem hiding this comment.
This basically looks good to me, thanks for the nice work!
|
I just realized that |
47326da to
4eaea96
Compare
4eaea96 to
a7c4a44
Compare
| auto src = static_cast<const SrcT*>(nd_array.data()); | ||
| auto dst = out.data<DstT>(); | ||
| for (size_t i = 0; i < out.size(); ++i) { | ||
| dst[i] = static_cast<DstT>(src[strided_offset(i, shape, strides)]); |
There was a problem hiding this comment.
This is going to be really slow, would it be practical to preserve the original strides and and just do memcpy?
out.set_data(
mx::allocator::malloc(data_size * itemsize),
data_size,
strides);
std::copy(src, src + data_size, dst);There was a problem hiding this comment.
Yes, I checked that PyTorch preserves the strides provided by DLPack on import, including for tensors that are not non_overlapping_and_dense. Strides only change if PyTorch performs a copy afterwards, for example when torch.as_tensor needs to change device or dtype, or when using torch.tensor, which always copies. In those cases, tensors that are not non_overlapping_and_dense are materialized into a compact layout.
I think we do not need to do that in MLX, and can simply preserve the original strides.
There was a problem hiding this comment.
A non_overlapping_and_dense tensor refers to a tensor that is not like x[::2] or torch.tensor([1]).expand(3)
| if (copy && !import_flags.row_contiguous) { | ||
| // Force the copy primitive to materialize the virtual strided input into a | ||
| // row-contiguous output instead of preserving a dense non-row layout. | ||
| import_flags.contiguous = false; |
There was a problem hiding this comment.
Can you collaborate on this? If the input is truly contiguous we shouldn't need this for copying, otherwise the flag was set for non-contiguous input.
There was a problem hiding this comment.
This makes the astype copy force the array to be row_contiguous. However, I think this is no longer necessary if we preserve the strides and do not require imported arrays to always be row_contiguous
210c578 to
f733754
Compare
|
CPU imports now copy the underlying storage span and preserve the For Metal:
|
|
The earlier CI failures came from three independent issues. First, the new Second, Third, the macOS Metal validation failure was from a test mutating the original PyTorch MPS tensor after exporting/importing it through DLPack; I think it's a problem with PyTorch METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -c 'import torch; x=torch.tensor([1.0,2.0,3.0], device="mps"); x+=1; torch.mps.synchronize(); print(x.cpu())'will cause the same error without |
Proposed changes
This draft adds zero-copy Metal DLPack sharing for MLX arrays and PyTorch MPS tensors.
This PR builds on the merged DLPack import PR #3495 and requires nanobind support.
The main changes are:
byte_offset.mx.from_dlpack(..., copy=...)controls for Metal DLPack inputs.mx.array(...)zero-copy for Metal DLPack inputs unless an explicit different dtype is requested.The shared lifetime is tied to the exported or imported buffer. Synchronization remains explicit: PyTorch writes require
torch.mps.synchronize()before MLX reads, and MLX writes requiremx.eval(...)before PyTorch reads.For MLX arrays exported to PyTorch, later MLX updates may rebind the MLX array to a new buffer while the PyTorch tensor continues to reference the exported buffer.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes