diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index d678461e3a..32f3498091 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -80,12 +80,10 @@ std::pair load_library_from_path( } #ifdef SWIFTPM_BUNDLE -MTL::Library* try_load_bundle( +MTL::Library* try_load_bundle_path( MTL::Device* device, - NS::URL* url, + const std::string& bundle_path, const std::string& lib_name) { - std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" + - SWIFTPM_BUNDLE + ".bundle"; auto bundle = NS::Bundle::alloc()->init( NS::String::string(bundle_path.c_str(), NS::UTF8StringEncoding)); if (bundle != nullptr) { @@ -100,6 +98,15 @@ MTL::Library* try_load_bundle( return nullptr; } +MTL::Library* try_load_bundle( + MTL::Device* device, + NS::URL* url, + const std::string& lib_name) { + std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" + + SWIFTPM_BUNDLE + ".bundle"; + return try_load_bundle_path(device, bundle_path, lib_name); +} + MTL::Library* try_load_framework( MTL::Device* device, NS::URL* url, @@ -130,6 +137,19 @@ std::pair load_swiftpm_library( MTL::Device* device, const std::string& lib_name) { #ifdef SWIFTPM_BUNDLE + const std::string swiftpm_bundle_name = + std::string(SWIFTPM_BUNDLE) + ".bundle"; + + auto binary_dir = current_binary_dir(); + for (int i = 0; i < 4 && !binary_dir.empty(); ++i) { + MTL::Library* library = try_load_bundle_path( + device, (binary_dir / swiftpm_bundle_name).string(), lib_name); + if (library != nullptr) { + return {library, nullptr}; + } + binary_dir = binary_dir.parent_path(); + } + MTL::Library* library = try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name); if (library != nullptr) { @@ -138,6 +158,16 @@ std::pair load_swiftpm_library( auto bundles = NS::Bundle::allBundles(); for (int i = 0, c = (int)bundles->count(); i < c; i++) { auto bundle = reinterpret_cast(bundles->object(i)); + const auto bundle_url = bundle->bundleURL(); + const std::string bundle_path = bundle_url->fileSystemRepresentation(); + if (bundle_path.size() >= swiftpm_bundle_name.size() && + bundle_path.rfind(swiftpm_bundle_name) == + bundle_path.size() - swiftpm_bundle_name.size()) { + library = try_load_framework(device, bundle->resourceURL(), lib_name); + if (library != nullptr) { + return {library, nullptr}; + } + } library = try_load_bundle(device, bundle->resourceURL(), lib_name); if (library != nullptr) { return {library, nullptr};