Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions csrc/include/aiter_opus_plus.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,19 @@ OPUS_D decltype(auto) fp32_to_fp8_scaled_x2(const S& s, float inverted_scale)
constexpr float hi = 448.0f, lo = -448.0f;
#endif
float a = tmp[0], b = tmp[1];
#if defined(__gfx942__) || defined(__gfx950__)
int w;
asm volatile("v_med3_f32 %1, %1, %3, %4\n"
"v_med3_f32 %2, %2, %3, %4\n"
"v_cvt_pk_fp8_f32 %0, %1, %2"
: "=v"(w), "+v"(a), "+v"(b)
: "v"(lo), "v"(hi));
return __builtin_bit_cast(fp8x2_t, static_cast<int16_t>(w));
#else
// Arches without packed fp8-cvt (RDNA3/3.5, host): compile-only stub.
// fp8 KV-cache is unused on these arches; never executed at runtime.
(void)a; (void)b; (void)lo; (void)hi; return fp8x2_t{};
#endif
}

template <typename S, std::enable_if_t<std::is_same_v<S, fp32x4_t>, bool> = true>
Expand All @@ -76,13 +82,17 @@ OPUS_D decltype(auto) fp32_to_bf8_scaled_x2(const S& s, float inverted_scale)
fp32x2_t tmp = pk_mul_f32(s, fp32x2_t{inverted_scale, inverted_scale});
constexpr float hi = 57344.0f, lo = -57344.0f;
float a = tmp[0], b = tmp[1];
#if defined(__gfx942__) || defined(__gfx950__)
int w;
asm volatile("v_med3_f32 %1, %1, %3, %4\n"
"v_med3_f32 %2, %2, %3, %4\n"
"v_cvt_pk_bf8_f32 %0, %1, %2"
: "=v"(w), "+v"(a), "+v"(b)
: "v"(lo), "v"(hi));
return __builtin_bit_cast(bf8x2_t, static_cast<int16_t>(w));
#else
(void)a; (void)b; (void)lo; (void)hi; return bf8x2_t{};
#endif
}

template <typename S, std::enable_if_t<std::is_same_v<S, fp32x4_t>, bool> = true>
Expand Down
10 changes: 10 additions & 0 deletions csrc/include/opus/opus.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1125,12 +1125,22 @@ OPUS_D constexpr auto fp32_to_bf16(const fp32_t& x, number<rm> = {}) {
// Template constexpr (packed variants, OPUS_CAST_DEFINE) survives because the check is deferred to instantiation.
// TODO: we may remove constexpr from cast in the future
OPUS_D auto fp32_to_fp8(const fp32_t& x) {
#if defined(__HIP_DEVICE_COMPILE__) && !(defined(__gfx942__) || defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx1250__))
// RDNA3/3.5 (gfx1100/gfx115x) lack fp8-conversion-insts; compile-only
// stub so headers build. BF16 code paths never invoke fp8 conversion.
(void)x; return __builtin_bit_cast(fp8_t, static_cast<signed char>(0));
#else
int w; w = __builtin_amdgcn_cvt_pk_fp8_f32(x, 0.0f, w, /*sel=lo*/0);
return __builtin_bit_cast(fp8_t, static_cast<signed char>(w));
#endif
}
OPUS_D auto fp8_to_fp32(const fp8_t& x) {
#if defined(__HIP_DEVICE_COMPILE__) && !(defined(__gfx942__) || defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx1250__))
(void)x; return fp32_t(0.0f);
#else
int w = static_cast<int>(__builtin_bit_cast(unsigned char, x));
return __builtin_amdgcn_cvt_f32_fp8(w, /*byte=*/0);
#endif
}
OPUS_D constexpr auto fp32_to_fp32(const fp32_t& x) { return x; }
OPUS_D constexpr auto fp32_to_i8(const fp32_t& x) { return static_cast<i8_t>(x); }
Expand Down
Loading