Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 9 additions & 21 deletions src/runtime/extra/contrib/coreml/coreml_runtime.mm
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>

#include "../../../../support/bytes_io.h"
#include "coreml_runtime.h"
Expand Down Expand Up @@ -136,12 +137,12 @@
return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { model_->Invoke(); });
} else if (name == "set_input") {
return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) {
const auto& input_name = args[0].operator std::string();
model_->SetInput(input_name, args[1]);
model_->SetInput(args[0].cast<std::string>(), args[1].cast<DLTensor*>());
});
Comment thread
tlopex marked this conversation as resolved.
} else if (name == "get_output") {
return ffi::Function(
[this](ffi::PackedArgs args, ffi::Any* rv) { *rv = model_->GetOutput(args[0]); });
return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) {
*rv = model_->GetOutput(args[0].cast<int>());
});
Comment thread
tlopex marked this conversation as resolved.
} else if (name == "get_num_outputs") {
return ffi::Function(
[this](ffi::PackedArgs args, ffi::Any* rv) { *rv = model_->GetNumOutputs(); });
Expand All @@ -154,32 +155,19 @@
NSDictionary* json = [NSJSONSerialization JSONObjectWithData:data
options:NSJSONReadingAllowFragments
error:nil];
NSffi::Array<NSString*>* input_names = json[@"inputs"];
NSArray<NSString*>* input_names = json[@"inputs"];

// Copy input tensors to corresponding data entries.
for (auto i = 0; i < args.size() - 1; ++i) {
TVM_FFI_ICHECK(args[i].type_code() == kTVMDLTensorHandle ||
args[i].type_code() == kTVMTensorHandle)
<< "Expect Tensor or DLTensor as inputs\n";
if (args[i].type_code() == kTVMDLTensorHandle || args[i].type_code() == kTVMTensorHandle) {
model_->SetInput([input_names[i] UTF8String], args[i]);
} else {
LOG(FATAL) << "Not implemented";
}
model_->SetInput([input_names[i] UTF8String], args[i].cast<DLTensor*>());
}
Comment thread
tlopex marked this conversation as resolved.

// Execute the subgraph.
model_->Invoke();

// TODO: Support multiple outputs.
Tensor out = model_->GetOutput(0);
if (args[args.size() - 1].type_code() == kTVMDLTensorHandle) {
DLTensor* arg = args[args.size() - 1];
out.CopyTo(arg);
} else {
Tensor arg = args[args.size() - 1];
out.CopyTo(arg);
}
out.CopyTo(args[args.size() - 1].cast<DLTensor*>());
*rv = out;
});
} else {
Expand All @@ -196,7 +184,7 @@
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("tvm.coreml_runtime.create", [](ffi::PackedArgs args, ffi::Any* rv) {
*rv = CoreMLRuntimeCreate(args[0], args[1]);
*rv = CoreMLRuntimeCreate(args[0].cast<std::string>(), args[1].cast<std::string>());
});
Comment thread
tlopex marked this conversation as resolved.
}

Expand Down
Loading