diff --git a/src/runtime/extra/contrib/coreml/coreml_runtime.mm b/src/runtime/extra/contrib/coreml/coreml_runtime.mm index 82f7c51c9ea7..a72948b250a7 100644 --- a/src/runtime/extra/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/extra/contrib/coreml/coreml_runtime.mm @@ -22,6 +22,7 @@ */ #include #include +#include #include "../../../../support/bytes_io.h" #include "coreml_runtime.h" @@ -136,12 +137,12 @@ return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { model_->Invoke(); }); } else if (name == "set_input") { return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { - const auto& input_name = args[0].operator std::string(); - model_->SetInput(input_name, args[1]); + model_->SetInput(args[0].cast(), args[1].cast()); }); } else if (name == "get_output") { - return ffi::Function( - [this](ffi::PackedArgs args, ffi::Any* rv) { *rv = model_->GetOutput(args[0]); }); + return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { + *rv = model_->GetOutput(args[0].cast()); + }); } else if (name == "get_num_outputs") { return ffi::Function( [this](ffi::PackedArgs args, ffi::Any* rv) { *rv = model_->GetNumOutputs(); }); @@ -154,18 +155,11 @@ NSDictionary* json = [NSJSONSerialization JSONObjectWithData:data options:NSJSONReadingAllowFragments error:nil]; - NSffi::Array* input_names = json[@"inputs"]; + NSArray* input_names = json[@"inputs"]; // Copy input tensors to corresponding data entries. for (auto i = 0; i < args.size() - 1; ++i) { - TVM_FFI_ICHECK(args[i].type_code() == kTVMDLTensorHandle || - args[i].type_code() == kTVMTensorHandle) - << "Expect Tensor or DLTensor as inputs\n"; - if (args[i].type_code() == kTVMDLTensorHandle || args[i].type_code() == kTVMTensorHandle) { - model_->SetInput([input_names[i] UTF8String], args[i]); - } else { - LOG(FATAL) << "Not implemented"; - } + model_->SetInput([input_names[i] UTF8String], args[i].cast()); } // Execute the subgraph. @@ -173,13 +167,7 @@ // TODO: Support multiple outputs. Tensor out = model_->GetOutput(0); - if (args[args.size() - 1].type_code() == kTVMDLTensorHandle) { - DLTensor* arg = args[args.size() - 1]; - out.CopyTo(arg); - } else { - Tensor arg = args[args.size() - 1]; - out.CopyTo(arg); - } + out.CopyTo(args[args.size() - 1].cast()); *rv = out; }); } else { @@ -196,7 +184,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.coreml_runtime.create", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = CoreMLRuntimeCreate(args[0], args[1]); + *rv = CoreMLRuntimeCreate(args[0].cast(), args[1].cast()); }); }