diff --git a/gemma/configs.cc b/gemma/configs.cc index 271ca8fc..9cc33a00 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -736,4 +736,23 @@ AttentionImpl GetAttentionImpl(const std::string& impl_name) { return AttentionImpl::kFlash; } +std::string KVEncodingToString(KVEncoding encoding) { + switch (encoding) { + case KVEncoding::kF32: + return "F32"; + case KVEncoding::kBF16: + return "BF16"; + case KVEncoding::kF32TwoTranspositions: + return "F32TwoTranspositions"; + case KVEncoding::kBF16TwoTranspositions: + return "BF16TwoTranspositions"; + case KVEncoding::kInt8: + return "Int8"; + case KVEncoding::kInt8TwoTranspositions: + return "Int8TwoTranspositions"; + default: + return "Unknown"; + } +} + } // namespace gcpp diff --git a/gemma/configs.h b/gemma/configs.h index 77317cc9..1ee66339 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -92,6 +92,13 @@ enum class KVEncoding { kInt8TwoTranspositions = 6, }; +// Returns a string representation of the KVEncoding. +// This representation is should not change and can be used for +// serialization or logging. +// Note that no reverse function exists to convert a string back to a +// KVEncoding. +std::string KVEncodingToString(KVEncoding encoding); + enum class AttentionImpl { kFlash = 0, // Flash Attention (default) kFlashTransposedQs,