Skip to content

Segfualt with prefetching and MLX arrays in key transform #47

@awni

Description

@awni

The following code segfaults on my machine (M1 Max, OS 14.2)

Some observations:

  • Using NumPy in place of MLX works fine
  • Only segfaults with prefetching
import mlx.core as mx
from mlx.data.datasets import load_cifar10

def get_cifar10(batch_size, root=None):
    tr = load_cifar10(root=root)

    mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
    std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))

    def normalize(x):
        x = x.astype("float32") / 255.0
        return (x - mean) / std

    tr_iter = (
        tr.shuffle()
        .to_stream()
        .image_random_h_flip("image", prob=0.5)
        .pad("image", 0, 4, 4, 0.0)
        .pad("image", 1, 4, 4, 0.0)
        .image_random_crop("image", 32, 32)
        .key_transform("image", normalize)
        .batch(batch_size)
        .prefetch(4, 4)
    )

    return tr_iter

if __name__ == "__main__":
    tr_iter = get_cifar10(256)
    for batch_counter, batch in enumerate(tr_iter):
        print(batch)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions