From dad59c8fa89eebb179bc37cb07585424e0a0685c Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 18 Jun 2026 15:43:52 +0200 Subject: [PATCH 1/2] feat: add MLX builds for LFM2.5 and privacy-filter models Wire HF-hosted MLX variants for LFM2.5 (350M, 1.2B), LFM2.5-VL (450M, 1.6B) and the privacy filters (openai, nemotron), defaulting to MLX on iOS alongside the existing XNNPACK builds. Runner support for the new builds: - vision_encoder reads its declared input dtype from method metadata and converts the fp32 pixels accordingly (fp32 passthrough / bf16 / fp16), instead of hardcoding Float. - multimodal prefiller splice handles fp32<->bf16 vision/text-embed dtype pairs (hybrid fp32 vision + bf16 decoder). - convert_from_float helper in util.h. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../common/runner/encoders/vision_encoder.cpp | 6 ++ .../common/runner/multimodal_prefiller.cpp | 16 ++++ .../common/runner/util.h | 32 ++++++++ .../src/constants/modelRegistry.ts | 75 ++++++++++++++++--- .../src/constants/modelUrls.ts | 8 ++ 5 files changed, 128 insertions(+), 9 deletions(-) diff --git a/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp b/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp index 09fb459661..5c382c5633 100644 --- a/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp +++ b/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -124,8 +125,13 @@ Result VisionEncoder::encode(const MultimodalInput &input) { sizes.insert(sizes.begin(), 1); } + auto vision_meta = ET_UNWRAP(module_->method_meta(kVisionEncoderMethod)); + const auto want_dtype = + ET_UNWRAP(vision_meta.input_tensor_meta(0)).scalar_type(); + auto image_tensor = ::executorch::extension::from_blob( chw.data(), sizes, ::executorch::aten::ScalarType::Float); + image_tensor = ET_UNWRAP(convert_from_float(image_tensor, want_dtype)); auto result = ET_UNWRAP(module_->execute(kVisionEncoderMethod, image_tensor)); auto out_tensor = result[0].toTensor(); diff --git a/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp b/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp index 8b04dc39bf..4a8e1e0be9 100644 --- a/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp +++ b/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp @@ -204,6 +204,22 @@ bool MultimodalPrefiller::get_enable_dynamic_shape() const { for (size_t i = 0; i < visual_elems; ++i) { dst_f[i] = static_cast(src[i]); } + } else if (vision_dtype == ::executorch::aten::ScalarType::Float && + embeds_dtype == ::executorch::aten::ScalarType::BFloat16) { + // Hybrid VLM: fp32 vision encoder (e.g. XNNPACK) + bf16 decoder embeds. + const float *src = vision_tensor.const_data_ptr(); + auto *dst_b = reinterpret_cast<::executorch::aten::BFloat16 *>(dst); + for (size_t i = 0; i < visual_elems; ++i) { + dst_b[i] = ::executorch::aten::BFloat16(src[i]); + } + } else if (vision_dtype == ::executorch::aten::ScalarType::BFloat16 && + embeds_dtype == ::executorch::aten::ScalarType::Float) { + const auto *src = + vision_tensor.const_data_ptr<::executorch::aten::BFloat16>(); + auto *dst_f = reinterpret_cast(dst); + for (size_t i = 0; i < visual_elems; ++i) { + dst_f[i] = static_cast(src[i]); + } } else { ET_CHECK_OR_RETURN_ERROR( false, InvalidState, diff --git a/packages/react-native-executorch/common/runner/util.h b/packages/react-native-executorch/common/runner/util.h index b1e707034b..fae5d8a336 100644 --- a/packages/react-native-executorch/common/runner/util.h +++ b/packages/react-native-executorch/common/runner/util.h @@ -170,6 +170,38 @@ convert_to_bfloat16(const ::executorch::extension::TensorPtr &src_tensor) { return bf16_tensor; } +/** + * Convert a Float tensor to `dtype` (Float passthrough, BFloat16, or Half). + * Used to match an exported method's declared input dtype when preprocessing + * produces fp32 data. Returns InvalidArgument for unsupported targets. + */ +inline ::executorch::runtime::Result<::executorch::extension::TensorPtr> +convert_from_float(const ::executorch::extension::TensorPtr &src_tensor, + ::executorch::aten::ScalarType dtype) { + using ::executorch::aten::ScalarType; + if (dtype == ScalarType::Float) { + return src_tensor; + } + if (dtype == ScalarType::BFloat16) { + return convert_to_bfloat16(src_tensor); + } + ET_CHECK_OR_RETURN_ERROR(src_tensor->scalar_type() == ScalarType::Float, + InvalidArgument, + "convert_from_float only supports a Float source"); + ET_CHECK_OR_RETURN_ERROR(dtype == ScalarType::Half, InvalidArgument, + "Unsupported target dtype %hhd for pixel conversion", + static_cast(dtype)); + const auto num_elements = static_cast(src_tensor->numel()); + const float *float_data = src_tensor->const_data_ptr(); + auto half_tensor = + ::executorch::extension::empty_like(src_tensor, ScalarType::Half); + auto *half_data = half_tensor->mutable_data_ptr<::executorch::aten::Half>(); + for (size_t i = 0; i < num_elements; ++i) { + half_data[i] = ::executorch::aten::Half(float_data[i]); + } + return half_tensor; +} + } // namespace llm } // namespace extension } // namespace executorch diff --git a/packages/react-native-executorch/src/constants/modelRegistry.ts b/packages/react-native-executorch/src/constants/modelRegistry.ts index eb0c98dae7..5a25c16bd6 100644 --- a/packages/react-native-executorch/src/constants/modelRegistry.ts +++ b/packages/react-native-executorch/src/constants/modelRegistry.ts @@ -260,6 +260,64 @@ const GEMMA4_E2B_MM_VARIANTS = { }, }; +const LFM2_5_350M_VARIANTS = { + mlx: { base: { ...M.LFM2_5_350M, modelSource: M.LFM2_5_350M_MLX_MODEL } }, + xnnpack: { base: M.LFM2_5_350M, quant: M.LFM2_5_350M_QUANTIZED }, +}; + +const LFM2_5_1_2B_INSTRUCT_VARIANTS = { + mlx: { + base: { + ...M.LFM2_5_1_2B_INSTRUCT, + modelSource: M.LFM2_5_1_2B_INSTRUCT_MLX_MODEL, + }, + }, + xnnpack: { + base: M.LFM2_5_1_2B_INSTRUCT, + quant: M.LFM2_5_1_2B_INSTRUCT_QUANTIZED, + }, +}; + +const LFM2_5_VL_1_6B_VARIANTS = { + mlx: { + base: { + ...M.LFM2_5_VL_1_6B_QUANTIZED, + modelSource: M.LFM2_5_VL_1_6B_MLX_MODEL, + }, + }, + xnnpack: { base: M.LFM2_5_VL_1_6B_QUANTIZED }, +}; + +const LFM2_5_VL_450M_VARIANTS = { + mlx: { + base: { + ...M.LFM2_5_VL_450M_QUANTIZED, + modelSource: M.LFM2_5_VL_450M_MLX_MODEL, + }, + }, + xnnpack: { base: M.LFM2_5_VL_450M_QUANTIZED }, +}; + +const PRIVACY_FILTER_OPENAI_VARIANTS = { + mlx: { + base: { + ...M.PRIVACY_FILTER_OPENAI, + modelSource: M.PRIVACY_FILTER_OPENAI_MLX_MODEL, + }, + }, + xnnpack: { base: M.PRIVACY_FILTER_OPENAI }, +}; + +const PRIVACY_FILTER_NEMOTRON_VARIANTS = { + mlx: { + base: { + ...M.PRIVACY_FILTER_NEMOTRON, + modelSource: M.PRIVACY_FILTER_NEMOTRON_MLX_MODEL, + }, + }, + xnnpack: { base: M.PRIVACY_FILTER_NEMOTRON }, +}; + const EFFICIENTNET_V2_S_VARIANTS = { xnnpack: { base: { @@ -594,11 +652,10 @@ export const models = { smollm2_1_360m: pair(M.SMOLLM2_1_360M, M.SMOLLM2_1_360M_QUANTIZED), smollm2_1_1_7b: pair(M.SMOLLM2_1_1_7B, M.SMOLLM2_1_1_7B_QUANTIZED), phi_4_mini_4b: pair(M.PHI_4_MINI_4B, M.PHI_4_MINI_4B_QUANTIZED), - lfm2_5_350m: pair(M.LFM2_5_350M, M.LFM2_5_350M_QUANTIZED), - lfm2_5_1_2b_instruct: pair( - M.LFM2_5_1_2B_INSTRUCT, - M.LFM2_5_1_2B_INSTRUCT_QUANTIZED - ), + lfm2_5_350m: variant(LFM2_5_350M_VARIANTS, { ios: 'mlx' }), + lfm2_5_1_2b_instruct: variant(LFM2_5_1_2B_INSTRUCT_VARIANTS, { + ios: 'mlx', + }), bielik_v3_0_1_5b: pair(M.BIELIK_V3_0_1_5B, M.BIELIK_V3_0_1_5B_QUANTIZED), gemma4_e2b: variant(GEMMA4_E2B_VARIANTS, { ios: 'mlx', @@ -606,8 +663,8 @@ export const models = { }), // Multimodal LLMs — same hook/module as plain LLMs, listed here so users // pick a model by capability ("LLM") rather than by modality. - lfm2_5_vl_1_6b: base(M.LFM2_5_VL_1_6B_QUANTIZED), - lfm2_5_vl_450m: base(M.LFM2_5_VL_450M_QUANTIZED), + lfm2_5_vl_1_6b: variant(LFM2_5_VL_1_6B_VARIANTS, { ios: 'mlx' }), + lfm2_5_vl_450m: variant(LFM2_5_VL_450M_VARIANTS, { ios: 'mlx' }), gemma4_e2b_multimodal: variant(GEMMA4_E2B_MM_VARIANTS, { ios: 'mlx', android: 'vulkan', @@ -617,8 +674,8 @@ export const models = { efficientnet_v2_s: variant(EFFICIENTNET_V2_S_VARIANTS), }, privacy_filter: { - openai: base(M.PRIVACY_FILTER_OPENAI), - nemotron: base(M.PRIVACY_FILTER_NEMOTRON), + openai: variant(PRIVACY_FILTER_OPENAI_VARIANTS, { ios: 'mlx' }), + nemotron: variant(PRIVACY_FILTER_NEMOTRON_VARIANTS, { ios: 'mlx' }), }, object_detection: { ssdlite_320_mobilenet_v3_large: variant( diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 0e36f812ff..c74fb5a289 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -450,6 +450,7 @@ export const PHI_4_MINI_4B_QUANTIZED = { // LFM2.5-1.2B-Instruct const LFM2_5_1_2B_INSTRUCT_MODEL = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/1_2b/xnnpack/lfm_2_5_1_2b_xnnpack_fp16.pte`; const LFM2_5_1_2B_INSTRUCT_QUANTIZED_MODEL = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/1_2b/xnnpack/lfm_2_5_1_2b_xnnpack_8da4w.pte`; +export const LFM2_5_1_2B_INSTRUCT_MLX_MODEL = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/1_2b/mlx/lfm_2_5_1_2b_mlx_int4.pte`; const LFM2_5_1_2B_TOKENIZER = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/1_2b/tokenizer.json`; const LFM2_5_1_2B_TOKENIZER_CONFIG = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/1_2b/tokenizer_config.json`; @@ -476,6 +477,7 @@ export const LFM2_5_1_2B_INSTRUCT_QUANTIZED = { // LFM2.5-350M const LFM2_5_350M_MODEL = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/350m/xnnpack/lfm_2_5_350m_xnnpack_fp16.pte`; const LFM2_5_350M_QUANTIZED_MODEL = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/350m/xnnpack/lfm_2_5_350m_xnnpack_8da4w.pte`; +export const LFM2_5_350M_MLX_MODEL = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/350m/mlx/lfm_2_5_350m_mlx_int4.pte`; const LFM2_5_350M_TOKENIZER = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/350m/tokenizer.json`; const LFM2_5_350M_TOKENIZER_CONFIG = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/350m/tokenizer_config.json`; @@ -527,11 +529,13 @@ export const BIELIK_V3_0_1_5B_QUANTIZED = { // LFM2.5-VL-1.6B const LFM2_VL_1_6B_QUANTIZED_MODEL = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/vl_1_6b/xnnpack/lfm_2_5_vl_1_6b_xnnpack_8da4w.pte`; +export const LFM2_5_VL_1_6B_MLX_MODEL = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/vl_1_6b/mlx/lfm_2_5_vl_1_6b_mlx_int4.pte`; const LFM2_VL_1_6B_TOKENIZER = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/vl_1_6b/tokenizer.json`; const LFM2_VL_1_6B_TOKENIZER_CONFIG = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/vl_1_6b/tokenizer_config.json`; // LFM2.5-VL-450M const LFM2_VL_450M_QUANTIZED_MODEL = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/vl_450m/xnnpack/lfm_2_5_vl_450m_xnnpack_8da4w.pte`; +export const LFM2_5_VL_450M_MLX_MODEL = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/vl_450m/mlx/lfm_2_5_vl_450m_mlx_int4.pte`; const LFM2_VL_450M_TOKENIZER = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/vl_450m/tokenizer.json`; const LFM2_VL_450M_TOKENIZER_CONFIG = `${URL_PREFIX}-lfm-2.5/${PREVIOUS_VERSION_TAG}/vl_450m/tokenizer_config.json`; @@ -1281,6 +1285,8 @@ export const PRIVACY_FILTER_OPENAI = { tokenizerSource: `${URL_PREFIX}-privacy-filter-openai/${PREVIOUS_VERSION_TAG}/tokenizer.json`, } as const; +export const PRIVACY_FILTER_OPENAI_MLX_MODEL = `${URL_PREFIX}-privacy-filter-openai/${PREVIOUS_VERSION_TAG}/mlx/privacy_filter_openai_mlx_int4.pte`; + /** * OpenMed/privacy-filter-nemotron — extended PII detector with 55 entity * types (adds medical, financial, identity, technical, demographic, etc.). @@ -1293,6 +1299,8 @@ export const PRIVACY_FILTER_NEMOTRON = { tokenizerSource: `${URL_PREFIX}-privacy-filter-nemotron/${PREVIOUS_VERSION_TAG}/tokenizer.json`, } as const; +export const PRIVACY_FILTER_NEMOTRON_MLX_MODEL = `${URL_PREFIX}-privacy-filter-nemotron/${PREVIOUS_VERSION_TAG}/mlx/privacy_filter_nemotron_mlx_int8.pte`; + // Image generation /** From 2b4dce60a335bb144858ccc8accd3f6f44b4f355 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 18 Jun 2026 16:38:40 +0200 Subject: [PATCH 2/2] refactor: address review on vision dtype conversion - Extract convert_to_float16 (symmetric to convert_to_bfloat16); rewrite convert_from_float as a switch dispatching to the bf16/fp16 helpers with a Float passthrough, rejecting other targets. - Dedupe the image-embed splice's per-pair conversion loops into a single templated castCopy in multimodal_prefiller. - Read the vision input dtype once in getInputShape (ImageShape.dtype) instead of a second method_meta call in encode(). Co-Authored-By: Claude Opus 4.6 (1M context) --- .../common/runner/encoders/vision_encoder.cpp | 10 ++-- .../common/runner/encoders/vision_encoder.h | 1 + .../common/runner/multimodal_prefiller.cpp | 59 +++++++++---------- .../common/runner/util.h | 52 ++++++++++------ 4 files changed, 68 insertions(+), 54 deletions(-) diff --git a/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp b/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp index 5c382c5633..7c9f184d12 100644 --- a/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp +++ b/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp @@ -71,6 +71,7 @@ Result VisionEncoder::getInputShape() const { .height = static_cast(dims[offset + 1]), .width = static_cast(dims[offset + 2]), .with_batch = with_batch, + .dtype = input_meta.scalar_type(), }; } @@ -125,13 +126,12 @@ Result VisionEncoder::encode(const MultimodalInput &input) { sizes.insert(sizes.begin(), 1); } - auto vision_meta = ET_UNWRAP(module_->method_meta(kVisionEncoderMethod)); - const auto want_dtype = - ET_UNWRAP(vision_meta.input_tensor_meta(0)).scalar_type(); - + // Preprocessing produces fp32 pixels; convert to the method's declared + // input dtype (`shape.dtype`, already read in getInputShape). Float is a + // passthrough, so the common path stays copy-free. auto image_tensor = ::executorch::extension::from_blob( chw.data(), sizes, ::executorch::aten::ScalarType::Float); - image_tensor = ET_UNWRAP(convert_from_float(image_tensor, want_dtype)); + image_tensor = ET_UNWRAP(convert_from_float(image_tensor, shape.dtype)); auto result = ET_UNWRAP(module_->execute(kVisionEncoderMethod, image_tensor)); auto out_tensor = result[0].toTensor(); diff --git a/packages/react-native-executorch/common/runner/encoders/vision_encoder.h b/packages/react-native-executorch/common/runner/encoders/vision_encoder.h index 54d43bb869..bdedd2ad11 100644 --- a/packages/react-native-executorch/common/runner/encoders/vision_encoder.h +++ b/packages/react-native-executorch/common/runner/encoders/vision_encoder.h @@ -27,6 +27,7 @@ class VisionEncoder : public IEncoder { struct ImageShape { int32_t channels, height, width; bool with_batch; + ::executorch::aten::ScalarType dtype; }; // The method's output EValue aliases the runtime's reusable output buffer, diff --git a/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp b/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp index 4a8e1e0be9..dd807e27c7 100644 --- a/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp +++ b/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp @@ -24,6 +24,20 @@ using ::executorch::runtime::Error; using ::executorch::runtime::EValue; using ::executorch::runtime::Result; +namespace { +// Element-wise convert `count` values from `src` (Src) into the raw byte +// buffer `dst` (interpreted as Dst). Used to splice an image-embed tensor of +// one dtype into the fused-embeds buffer of another. +template +void castCopy(const void *src, uint8_t *dst, size_t count) { + const auto *s = static_cast(src); + auto *d = reinterpret_cast(dst); + for (size_t i = 0; i < count; ++i) { + d[i] = static_cast(s[i]); + } +} +} // namespace + MultimodalPrefiller::MultimodalPrefiller( Module &module, MultimodalDecoderRunner &decoder_runner, tokenizers::HFTokenizer &tokenizer, @@ -186,40 +200,23 @@ bool MultimodalPrefiller::get_enable_dynamic_shape() const { uint8_t *dst = embeds_buf.data() + static_cast(slot.slot_start) * static_cast(hidden) * embeds_elem_size; + using ::executorch::aten::ScalarType; + const void *src = vision_tensor.const_data_ptr(); if (vision_dtype == embeds_dtype) { - const uint8_t *src = - static_cast(vision_tensor.const_data_ptr()); std::memcpy(dst, src, visual_elems * embeds_elem_size); - } else if (vision_dtype == ::executorch::aten::ScalarType::Float && - embeds_dtype == ::executorch::aten::ScalarType::Half) { - const float *src = vision_tensor.const_data_ptr(); - auto *dst_h = reinterpret_cast<::executorch::aten::Half *>(dst); - for (size_t i = 0; i < visual_elems; ++i) { - dst_h[i] = ::executorch::aten::Half(src[i]); - } - } else if (vision_dtype == ::executorch::aten::ScalarType::Half && - embeds_dtype == ::executorch::aten::ScalarType::Float) { - const auto *src = vision_tensor.const_data_ptr<::executorch::aten::Half>(); - auto *dst_f = reinterpret_cast(dst); - for (size_t i = 0; i < visual_elems; ++i) { - dst_f[i] = static_cast(src[i]); - } - } else if (vision_dtype == ::executorch::aten::ScalarType::Float && - embeds_dtype == ::executorch::aten::ScalarType::BFloat16) { + } else if (vision_dtype == ScalarType::Float && + embeds_dtype == ScalarType::Half) { + castCopy(src, dst, visual_elems); + } else if (vision_dtype == ScalarType::Half && + embeds_dtype == ScalarType::Float) { + castCopy<::executorch::aten::Half, float>(src, dst, visual_elems); + } else if (vision_dtype == ScalarType::Float && + embeds_dtype == ScalarType::BFloat16) { // Hybrid VLM: fp32 vision encoder (e.g. XNNPACK) + bf16 decoder embeds. - const float *src = vision_tensor.const_data_ptr(); - auto *dst_b = reinterpret_cast<::executorch::aten::BFloat16 *>(dst); - for (size_t i = 0; i < visual_elems; ++i) { - dst_b[i] = ::executorch::aten::BFloat16(src[i]); - } - } else if (vision_dtype == ::executorch::aten::ScalarType::BFloat16 && - embeds_dtype == ::executorch::aten::ScalarType::Float) { - const auto *src = - vision_tensor.const_data_ptr<::executorch::aten::BFloat16>(); - auto *dst_f = reinterpret_cast(dst); - for (size_t i = 0; i < visual_elems; ++i) { - dst_f[i] = static_cast(src[i]); - } + castCopy(src, dst, visual_elems); + } else if (vision_dtype == ScalarType::BFloat16 && + embeds_dtype == ScalarType::Float) { + castCopy<::executorch::aten::BFloat16, float>(src, dst, visual_elems); } else { ET_CHECK_OR_RETURN_ERROR( false, InvalidState, diff --git a/packages/react-native-executorch/common/runner/util.h b/packages/react-native-executorch/common/runner/util.h index fae5d8a336..a0258a4bf9 100644 --- a/packages/react-native-executorch/common/runner/util.h +++ b/packages/react-native-executorch/common/runner/util.h @@ -170,6 +170,30 @@ convert_to_bfloat16(const ::executorch::extension::TensorPtr &src_tensor) { return bf16_tensor; } +/** + * Helper function to convert a float tensor to float16 (Half). + * Creates a new tensor with Half dtype and copies/converts the data. + */ +inline ::executorch::runtime::Result<::executorch::extension::TensorPtr> +convert_to_float16(const ::executorch::extension::TensorPtr &src_tensor) { + ET_CHECK_OR_RETURN_ERROR( + src_tensor->scalar_type() == ::executorch::aten::ScalarType::Float, + InvalidArgument, + "Float16 conversion only supported from Float source data"); + + const auto num_elements = static_cast(src_tensor->numel()); + const float *float_data = src_tensor->const_data_ptr(); + + auto half_tensor = ::executorch::extension::empty_like( + src_tensor, ::executorch::aten::ScalarType::Half); + auto *half_data = half_tensor->mutable_data_ptr<::executorch::aten::Half>(); + for (size_t i = 0; i < num_elements; ++i) { + half_data[i] = ::executorch::aten::Half(float_data[i]); + } + + return half_tensor; +} + /** * Convert a Float tensor to `dtype` (Float passthrough, BFloat16, or Half). * Used to match an exported method's declared input dtype when preprocessing @@ -179,27 +203,19 @@ inline ::executorch::runtime::Result<::executorch::extension::TensorPtr> convert_from_float(const ::executorch::extension::TensorPtr &src_tensor, ::executorch::aten::ScalarType dtype) { using ::executorch::aten::ScalarType; - if (dtype == ScalarType::Float) { + switch (dtype) { + case ScalarType::Float: return src_tensor; - } - if (dtype == ScalarType::BFloat16) { + case ScalarType::BFloat16: return convert_to_bfloat16(src_tensor); + case ScalarType::Half: + return convert_to_float16(src_tensor); + default: + ET_CHECK_OR_RETURN_ERROR( + false, InvalidArgument, + "Unsupported target dtype %hhd for float conversion", + static_cast(dtype)); } - ET_CHECK_OR_RETURN_ERROR(src_tensor->scalar_type() == ScalarType::Float, - InvalidArgument, - "convert_from_float only supports a Float source"); - ET_CHECK_OR_RETURN_ERROR(dtype == ScalarType::Half, InvalidArgument, - "Unsupported target dtype %hhd for pixel conversion", - static_cast(dtype)); - const auto num_elements = static_cast(src_tensor->numel()); - const float *float_data = src_tensor->const_data_ptr(); - auto half_tensor = - ::executorch::extension::empty_like(src_tensor, ScalarType::Half); - auto *half_data = half_tensor->mutable_data_ptr<::executorch::aten::Half>(); - for (size_t i = 0; i < num_elements; ++i) { - half_data[i] = ::executorch::aten::Half(float_data[i]); - } - return half_tensor; } } // namespace llm