diff --git a/BUILD.bazel b/BUILD.bazel index 0582fae7..3ec04e7f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -542,6 +542,31 @@ cc_test( ], ) +cc_library( + name = "kv_transcoding", + srcs = ["gemma/kv_transcoding.cc"], + hdrs = ["gemma/kv_transcoding.h"], + deps = [ + ":activations", + ":basics", + ":configs", + ":kv_cache", + "//compression:types", + "@highway//:hwy", + ], +) + +cc_test( + name = "kv_transcoding_test", + srcs = ["gemma/kv_transcoding_test.cc"], + deps = [ + ":configs", + ":kv_transcoding", + "//testing/base/public:gunit_main", + "@highway//:hwy", + ], +) + cc_library( name = "activations", hdrs = ["gemma/activations.h"], diff --git a/gemma/configs.h b/gemma/configs.h index 0c5dbe8a..d988adcd 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -83,6 +83,17 @@ static inline bool EnumValid(LayerAttentionType type) { return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit; } +// Values stated explicitly to allow for semantic reordering +enum class KVEncoding { + kUnspecified = 0, + kF32 = 1, + kBF16 = 2, + kF32TwoTranspositions = 3, + kBF16TwoTranspositions = 4, + kInt8 = 5, + kInt8TwoTranspositions = 6, +}; + enum class AttentionImpl { kOld, // Previous Attention implementation kFlash, // Flash Attention (default) diff --git a/gemma/kv_transcoding.cc b/gemma/kv_transcoding.cc new file mode 100644 index 00000000..b81ca817 --- /dev/null +++ b/gemma/kv_transcoding.cc @@ -0,0 +1,314 @@ +#include "gemma/kv_transcoding.h" + +#include +#include +#include +#include + +#include "compression/types.h" +#include "gemma/activations.h" +#include "gemma/configs.h" +#include "gemma/kv_cache.h" +#include "util/basics.h" +#include "hwy/base.h" +#include "hwy/highway.h" + +namespace gcpp { + +std::optional GetTileSizeBytes(gcpp::KVEncoding encoding, + size_t qkv_dim) { + constexpr size_t kTileSize = gcpp::KVCache::kTileSize; + switch (encoding) { + case gcpp::KVEncoding::kInt8: + case gcpp::KVEncoding::kInt8TwoTranspositions: + return qkv_dim * kTileSize * 2 * sizeof(int8_t) + + kTileSize * 2 * sizeof(gcpp::KV_microscale_t); + case gcpp::KVEncoding::kBF16: + case gcpp::KVEncoding::kBF16TwoTranspositions: + return qkv_dim * kTileSize * 2 * sizeof(gcpp::BF16); + case gcpp::KVEncoding::kF32: + case gcpp::KVEncoding::kF32TwoTranspositions: + return qkv_dim * kTileSize * 2 * sizeof(float); + default: + return std::nullopt; + } +} + +namespace { +constexpr size_t kTileSize = gcpp::KVCache::kTileSize; + +inline size_t KOffset(bool transposed, size_t qkv_dim, size_t dim, + size_t token) { + HWY_DASSERT(dim < qkv_dim && token < kTileSize); + return transposed ? ((dim / 2) * kTileSize * 2 + token * 2 + (dim % 2)) + : (dim * kTileSize + token); +} + +inline size_t VOffset(bool transposed, size_t qkv_dim, size_t dim, + size_t token) { + HWY_DASSERT(dim < qkv_dim && token < kTileSize); + return transposed ? ((token / 2) * qkv_dim * 2 + dim * 2 + (token % 2)) + : (token * qkv_dim + dim); +} + +int8_t Quantize(float v, float inv_scale) { + float scaled = v * inv_scale; + if (scaled > 127.0f) return 127; + if (scaled < -127.0f) return -127; + return hwy::ConvertScalarTo(scaled); +} + +template +inline void DecodeTileWithFn(size_t qkv_dim, DecodedTile* out, + const DecodeKFn& decode_k, + const DecodeVFn& decode_v) { + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + out->k_elem(token, dim) = decode_k(dim, token); + } + } + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + out->v_elem(token, dim) = decode_v(dim, token); + } + } +} + +template +inline void EncodeTileWithFn(size_t qkv_dim, const DecodedTile& decoded, + const EncodeKFn& encode_k, + const EncodeVFn& encode_v) { + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + encode_k(dim, token, decoded.k_elem(token, dim)); + } + } + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + encode_v(dim, token, decoded.v_elem(token, dim)); + } + } +} + +void EncodeTileF32(bool transposed, size_t qkv_dim, const DecodedTile& decoded, + hwy::Span out_encoded_tile_data) { + float* data = HWY_RCAST_ALIGNED(float*, out_encoded_tile_data.data()); + const size_t v_start = qkv_dim * kTileSize; + EncodeTileWithFn( + qkv_dim, decoded, + [&](size_t dim, size_t token, float val) + HWY_ATTR { data[KOffset(transposed, qkv_dim, dim, token)] = val; }, + [&](size_t dim, size_t token, float val) HWY_ATTR { + data[v_start + VOffset(transposed, qkv_dim, dim, token)] = val; + }); +} + +void EncodeTileBF16(bool transposed, size_t qkv_dim, const DecodedTile& decoded, + hwy::Span out_encoded_tile_data) { + gcpp::BF16* data = + HWY_RCAST_ALIGNED(gcpp::BF16*, out_encoded_tile_data.data()); + const size_t v_start = qkv_dim * kTileSize; + EncodeTileWithFn( + qkv_dim, decoded, + [&](size_t dim, size_t token, float val) HWY_ATTR { + data[KOffset(transposed, qkv_dim, dim, token)] = + hwy::ConvertScalarTo(val); + }, + [&](size_t dim, size_t token, float val) HWY_ATTR { + data[v_start + VOffset(transposed, qkv_dim, dim, token)] = + hwy::ConvertScalarTo(val); + }); +} + +void EncodeTileInt8(bool transposed, size_t qkv_dim, const DecodedTile& decoded, + hwy::Span out_encoded_tile_data) { + int8_t* k_data = HWY_RCAST_ALIGNED(int8_t*, out_encoded_tile_data.data()); + int8_t* v_data = k_data + qkv_dim * kTileSize; + gcpp::KV_microscale_t* scales = + HWY_RCAST_ALIGNED(gcpp::KV_microscale_t*, v_data + kTileSize * qkv_dim); + gcpp::KV_microscale_t* k_scales = scales; + gcpp::KV_microscale_t* v_scales = scales + kTileSize; + + AlignedFloatVector k_max_abs(kTileSize, 0.0f); + AlignedFloatVector v_max_abs(kTileSize, 0.0f); + + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + k_max_abs[token] = + std::max(k_max_abs[token], std::abs(decoded.k_elem(token, dim))); + } + } + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + v_max_abs[token] = + std::max(v_max_abs[token], std::abs(decoded.v_elem(token, dim))); + } + } + + AlignedFloatVector inv_scales_k(kTileSize); + AlignedFloatVector inv_scales_v(kTileSize); + for (size_t token = 0; token < kTileSize; ++token) { + float scale_k = k_max_abs[token] == 0.0f ? 1.0f : k_max_abs[token] / 127.0f; + k_scales[token] = hwy::ConvertScalarTo(scale_k); + inv_scales_k[token] = 1.0f / scale_k; + + float scale_v = v_max_abs[token] == 0.0f ? 1.0f : v_max_abs[token] / 127.0f; + v_scales[token] = hwy::ConvertScalarTo(scale_v); + inv_scales_v[token] = 1.0f / scale_v; + } + + EncodeTileWithFn( + qkv_dim, decoded, + [&](size_t dim, size_t token, float val) HWY_ATTR { + k_data[KOffset(transposed, qkv_dim, dim, token)] = + Quantize(val, inv_scales_k[token]); + }, + [&](size_t dim, size_t token, float val) HWY_ATTR { + v_data[VOffset(transposed, qkv_dim, dim, token)] = + Quantize(val, inv_scales_v[token]); + }); +} + +void DecodeTileF32(bool transposed, size_t qkv_dim, + hwy::Span encoded_tile_data, DecodedTile* out) { + const float* data = HWY_RCAST_ALIGNED(const float*, encoded_tile_data.data()); + const size_t v_start = qkv_dim * kTileSize; + DecodeTileWithFn( + qkv_dim, out, + [&](size_t dim, size_t token) + HWY_ATTR { return data[KOffset(transposed, qkv_dim, dim, token)]; }, + [&](size_t dim, size_t token) HWY_ATTR { + return data[v_start + VOffset(transposed, qkv_dim, dim, token)]; + }); +} + +void DecodeTileBF16(bool transposed, size_t qkv_dim, + hwy::Span encoded_tile_data, DecodedTile* out) { + const gcpp::BF16* data = + HWY_RCAST_ALIGNED(const gcpp::BF16*, encoded_tile_data.data()); + const size_t v_start = qkv_dim * kTileSize; + DecodeTileWithFn( + qkv_dim, out, + [&](size_t dim, size_t token) HWY_ATTR { + return hwy::ConvertScalarTo( + data[KOffset(transposed, qkv_dim, dim, token)]); + }, + [&](size_t dim, size_t token) HWY_ATTR { + return hwy::ConvertScalarTo( + data[v_start + VOffset(transposed, qkv_dim, dim, token)]); + }); +} + +void DecodeTileInt8(bool transposed, size_t qkv_dim, + hwy::Span encoded_tile_data, DecodedTile* out) { + const int8_t* k_data = + HWY_RCAST_ALIGNED(const int8_t*, encoded_tile_data.data()); + const int8_t* v_data = k_data + qkv_dim * kTileSize; + const gcpp::KV_microscale_t* scales = HWY_RCAST_ALIGNED( + const gcpp::KV_microscale_t*, v_data + kTileSize * qkv_dim); + const gcpp::KV_microscale_t* k_scales = scales; + const gcpp::KV_microscale_t* v_scales = scales + kTileSize; + + DecodeTileWithFn( + qkv_dim, out, + [&](size_t dim, size_t token) HWY_ATTR { + float scale = hwy::ConvertScalarTo(k_scales[token]); + return k_data[KOffset(transposed, qkv_dim, dim, token)] * scale; + }, + [&](size_t dim, size_t token) HWY_ATTR { + float scale = hwy::ConvertScalarTo(v_scales[token]); + return v_data[VOffset(transposed, qkv_dim, dim, token)] * scale; + }); +} + +} // namespace + +bool IsTransposed(KVEncoding encoding) { + switch (encoding) { + case KVEncoding::kF32TwoTranspositions: + case KVEncoding::kBF16TwoTranspositions: + case KVEncoding::kInt8TwoTranspositions: + return true; + default: + return false; + } +} + +hwy::AlignedUniquePtr AllocateEncodedTile(KVEncoding encoding, + size_t qkv_dim) { + std::optional size = GetTileSizeBytes(encoding, qkv_dim); + if (!size.has_value()) return hwy::AlignedUniquePtr(); + return hwy::MakeUniqueAlignedArray(*size); +} + +bool DecodeTile(KVEncoding encoding, hwy::Span encoded_tile_data, + size_t qkv_dim, DecodedTile* out) { + std::optional required_size_or = GetTileSizeBytes(encoding, qkv_dim); + if (!required_size_or.has_value()) return false; + size_t required_size = *required_size_or; + if (encoded_tile_data.size() < required_size) { + return false; + } + bool transposed = IsTransposed(encoding); + switch (encoding) { + case gcpp::KVEncoding::kF32: + case gcpp::KVEncoding::kF32TwoTranspositions: { + DecodeTileF32(transposed, qkv_dim, encoded_tile_data, out); + return true; + } + case gcpp::KVEncoding::kBF16: + case gcpp::KVEncoding::kBF16TwoTranspositions: { + DecodeTileBF16(transposed, qkv_dim, encoded_tile_data, out); + return true; + } + case gcpp::KVEncoding::kInt8: + case gcpp::KVEncoding::kInt8TwoTranspositions: { + DecodeTileInt8(transposed, qkv_dim, encoded_tile_data, out); + return true; + } + default: + return false; + } +} + +bool EncodeTile(gcpp::KVEncoding encoding, const DecodedTile& decoded, + size_t qkv_dim, hwy::Span out_encoded_tile_data) { + std::optional required_size_or = GetTileSizeBytes(encoding, qkv_dim); + if (!required_size_or.has_value()) return false; + size_t required_size = *required_size_or; + if (out_encoded_tile_data.size() < required_size) { + return false; + } + bool transposed = IsTransposed(encoding); + switch (encoding) { + case gcpp::KVEncoding::kF32: + case gcpp::KVEncoding::kF32TwoTranspositions: { + EncodeTileF32(transposed, qkv_dim, decoded, out_encoded_tile_data); + return true; + } + case gcpp::KVEncoding::kBF16: + case gcpp::KVEncoding::kBF16TwoTranspositions: { + EncodeTileBF16(transposed, qkv_dim, decoded, out_encoded_tile_data); + return true; + } + case gcpp::KVEncoding::kInt8: + case gcpp::KVEncoding::kInt8TwoTranspositions: { + EncodeTileInt8(transposed, qkv_dim, decoded, out_encoded_tile_data); + return true; + } + default: + return false; + } +} + +bool TranscodeTile(gcpp::KVEncoding src_encoding, + hwy::Span src_data, + gcpp::KVEncoding dst_encoding, hwy::Span dst_data, + size_t qkv_dim) { + DecodedTile decoded(qkv_dim, kTileSize); + if (!DecodeTile(src_encoding, src_data, qkv_dim, &decoded)) return false; + + return EncodeTile(dst_encoding, decoded, qkv_dim, dst_data); +} + +} // namespace gcpp diff --git a/gemma/kv_transcoding.h b/gemma/kv_transcoding.h new file mode 100644 index 00000000..67610d47 --- /dev/null +++ b/gemma/kv_transcoding.h @@ -0,0 +1,70 @@ +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_KV_TRANSCODING_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_TRANSCODING_H_ + +#include +#include +#include + +#include "gemma/configs.h" +#include "hwy/aligned_allocator.h" + +namespace gcpp { + +// Returns the size in bytes of a single KV cache tile for a given encoding. +// Returns std::nullopt if the encoding is unsupported. +std::optional GetTileSizeBytes(gcpp::KVEncoding encoding, + size_t qkv_dim); + +// Canonical representation of a single tile of K and V data decoded to float32. +// Layout: K is [tile_size, qkv_dim] contiguous, V is [tile_size, qkv_dim] +// contiguous. +struct DecodedTile { + std::vector> k; + std::vector> v; + size_t qkv_dim = 0; + size_t tile_size = 0; + + DecodedTile() = default; + DecodedTile(size_t qkv_dim, size_t tile_size) + : k(qkv_dim * tile_size), + v(tile_size * qkv_dim), + qkv_dim(qkv_dim), + tile_size(tile_size) {} + + float& k_elem(size_t token, size_t dim) { return k[token * qkv_dim + dim]; } + const float& k_elem(size_t token, size_t dim) const { + return k[token * qkv_dim + dim]; + } + + float& v_elem(size_t token, size_t dim) { return v[token * qkv_dim + dim]; } + const float& v_elem(size_t token, size_t dim) const { + return v[token * qkv_dim + dim]; + } +}; + +// Allocates an aligned buffer for storing +// an encoded tile of the given encoding. +hwy::AlignedUniquePtr AllocateEncodedTile(gcpp::KVEncoding encoding, + size_t qkv_dim); + +// Decodes a single tile's K and V data from its encoded byte buffer into +// float32 using the specified encoding. +bool DecodeTile(gcpp::KVEncoding encoding, + hwy::Span encoded_tile_data, size_t qkv_dim, + DecodedTile* out); + +// Encodes a single tile's K and V data from standard float32 into the target +// encoding. Returns false if the encoding is unsupported. +bool EncodeTile(gcpp::KVEncoding encoding, const DecodedTile& decoded, + size_t qkv_dim, hwy::Span out_encoded_tile_data); + +// Convenience utility to convert a tile directly from one encoding to another. +// Return false if either encoding is unsupported or passed data is too small. +bool TranscodeTile(gcpp::KVEncoding src_encoding, + hwy::Span src_data, + gcpp::KVEncoding dst_encoding, hwy::Span dst_data, + size_t qkv_dim); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_TRANSCODING_H_ diff --git a/gemma/kv_transcoding_test.cc b/gemma/kv_transcoding_test.cc new file mode 100644 index 00000000..5f8bb556 --- /dev/null +++ b/gemma/kv_transcoding_test.cc @@ -0,0 +1,360 @@ +#include "gemma/kv_transcoding.h" + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gemma/configs.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" // For hwy::Span + +namespace gcpp { +namespace { + +using ::testing::FloatNear; +using ::testing::Pointwise; +using ::testing::TestWithParam; +using ::testing::Values; + +struct EncodingTestCase { + gcpp::KVEncoding encoding; + float tolerance; +}; + +class KVEncodingTest : public TestWithParam {}; + +TEST_P(KVEncodingTest, EncodeDecodeRoundTrip) { + const auto& param = GetParam(); + constexpr size_t kTileSize = 32; + constexpr size_t qkv_dim = 256; + + DecodedTile original(qkv_dim, kTileSize); + // Fill with dummy data within + // a reasonable float range to avoid saturation for INT8 + const float pattern[] = {0.5f, 1.0f, 1.5f}; + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + size_t i = dim * kTileSize + token; + original.k_elem(token, dim) = pattern[i % 3]; + original.v_elem(token, dim) = pattern[i % 3]; + } + } + + std::optional tile_size_bytes = + GetTileSizeBytes(param.encoding, qkv_dim); + ASSERT_TRUE(tile_size_bytes.has_value()); + + std::vector encoded(*tile_size_bytes, 0); + EXPECT_TRUE(EncodeTile(param.encoding, original, qkv_dim, + hwy::Span(encoded.data(), encoded.size()))); + + DecodedTile decoded(qkv_dim, kTileSize); + EXPECT_TRUE(DecodeTile(param.encoding, + hwy::Span(encoded.data(), encoded.size()), + qkv_dim, &decoded)); + + EXPECT_THAT(decoded.k, Pointwise(FloatNear(param.tolerance), original.k)); + EXPECT_THAT(decoded.v, Pointwise(FloatNear(param.tolerance), original.v)); +} + +TEST_P(KVEncodingTest, SizeChecks) { + const auto& param = GetParam(); + constexpr size_t kTileSize = 32; + constexpr size_t qkv_dim = 256; + + DecodedTile decoded(qkv_dim, kTileSize); + std::optional required_size_or = + GetTileSizeBytes(param.encoding, qkv_dim); + ASSERT_TRUE(required_size_or.has_value()); + size_t required_size = *required_size_or; + + if (required_size > 0) { + std::vector too_small_encoded(required_size - 1, 0); + EXPECT_FALSE(EncodeTile( + param.encoding, decoded, qkv_dim, + hwy::Span(too_small_encoded.data(), too_small_encoded.size()))); + EXPECT_FALSE(DecodeTile(param.encoding, + hwy::Span(too_small_encoded.data(), + too_small_encoded.size()), + qkv_dim, &decoded)); + } +} + +INSTANTIATE_TEST_SUITE_P( + AllEncodings, KVEncodingTest, + Values(EncodingTestCase{gcpp::KVEncoding::kF32, 1e-6f}, + EncodingTestCase{gcpp::KVEncoding::kF32TwoTranspositions, 1e-6f}, + EncodingTestCase{gcpp::KVEncoding::kBF16, 0.05f}, + EncodingTestCase{gcpp::KVEncoding::kBF16TwoTranspositions, 0.05f}, + EncodingTestCase{gcpp::KVEncoding::kInt8, 0.1f}, + EncodingTestCase{gcpp::KVEncoding::kInt8TwoTranspositions, 0.1f})); + +TEST(KVEncodingTest, ConvertTileFloat32ToBfloat16) { + constexpr size_t kTileSize = 32; + constexpr size_t qkv_dim = 256; + gcpp::KVEncoding src_encoding = gcpp::KVEncoding::kF32; + gcpp::KVEncoding dst_encoding = gcpp::KVEncoding::kBF16; + + DecodedTile original(qkv_dim, kTileSize); + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + size_t i = dim * kTileSize + token; + original.k_elem(token, dim) = std::sin(i) * 5.0f; + original.v_elem(token, dim) = std::cos(i) * 5.0f; + } + } + + size_t src_size = GetTileSizeBytes(src_encoding, qkv_dim).value(); + size_t dst_size = GetTileSizeBytes(dst_encoding, qkv_dim).value(); + + std::vector src_data(src_size); + std::vector dst_data(dst_size); + + EXPECT_TRUE(EncodeTile(src_encoding, original, qkv_dim, + hwy::Span(src_data.data(), src_data.size()))); + + EXPECT_TRUE(TranscodeTile( + src_encoding, hwy::Span(src_data.data(), src_data.size()), + dst_encoding, hwy::Span(dst_data.data(), dst_data.size()), + qkv_dim)); + + DecodedTile decoded(qkv_dim, kTileSize); + EXPECT_TRUE(DecodeTile( + dst_encoding, hwy::Span(dst_data.data(), dst_data.size()), + qkv_dim, &decoded)); + + EXPECT_THAT(decoded.k, Pointwise(FloatNear(0.05f), original.k)); +} + +TEST(KVEncodingTest, PairwiseConversion) { + constexpr size_t kTileSize = 32; + constexpr size_t qkv_dim = 256; + + std::vector encodings = { + gcpp::KVEncoding::kF32, gcpp::KVEncoding::kF32TwoTranspositions, + gcpp::KVEncoding::kBF16, gcpp::KVEncoding::kBF16TwoTranspositions, + gcpp::KVEncoding::kInt8, gcpp::KVEncoding::kInt8TwoTranspositions}; + + for (auto src : encodings) { + for (auto dst : encodings) { + if (src == dst) continue; + + DecodedTile original(qkv_dim, kTileSize); + const float pattern[] = {0.5f, 1.0f, 1.5f}; + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + size_t i = dim * kTileSize + token; + original.k_elem(token, dim) = pattern[i % 3]; + original.v_elem(token, dim) = pattern[i % 3]; + } + } + + size_t src_size = GetTileSizeBytes(src, qkv_dim).value(); + size_t dst_size = GetTileSizeBytes(dst, qkv_dim).value(); + + std::vector src_data(src_size); + std::vector dst_data(dst_size); + + ASSERT_TRUE(EncodeTile(src, original, qkv_dim, + hwy::Span(src_data.data(), src_data.size()))) + << "src=" << static_cast(src); + + ASSERT_TRUE(TranscodeTile( + src, hwy::Span(src_data.data(), src_data.size()), dst, + hwy::Span(dst_data.data(), dst_data.size()), qkv_dim)) + << "src=" << static_cast(src) + << " dst=" << static_cast(dst); + + DecodedTile decoded(qkv_dim, kTileSize); + ASSERT_TRUE(DecodeTile( + dst, hwy::Span(dst_data.data(), dst_data.size()), qkv_dim, + &decoded)) + << "dst=" << static_cast(dst); + + float tolerance = 0.1f; // Max tolerance for Int8 + EXPECT_THAT(decoded.k, Pointwise(FloatNear(tolerance), original.k)) + << "src=" << static_cast(src) + << " dst=" << static_cast(dst); + EXPECT_THAT(decoded.v, Pointwise(FloatNear(tolerance), original.v)) + << "src=" << static_cast(src) + << " dst=" << static_cast(dst); + } + } +} + +TEST(KVEncodingTest, LayoutValidationF32) { + constexpr size_t kTileSize = 32; + constexpr size_t qkv_dim = 4; + gcpp::KVEncoding encoding = gcpp::KVEncoding::kF32; + + DecodedTile original(qkv_dim, kTileSize); + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + original.k_elem(token, dim) = dim * kTileSize + token + 1; + } + } + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + original.v_elem(token, dim) = + token * qkv_dim + dim + 1 + qkv_dim * kTileSize; + } + } + + size_t size = GetTileSizeBytes(encoding, qkv_dim).value(); + std::vector encoded(size); + + ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim, + hwy::Span(encoded.data(), encoded.size()))); + + const float* data = reinterpret_cast(encoded.data()); + + // K should be row-major [qkv_dim, tile_size] + EXPECT_EQ(data[0], 1.0f); // d=0, t=0 + EXPECT_EQ(data[1], 2.0f); // d=0, t=1 + EXPECT_EQ(data[32], 33.0f); // d=1, t=0 + + // V should be row-major [tile_size, qkv_dim] + size_t v_start = qkv_dim * kTileSize; + EXPECT_EQ(data[v_start], 129.0f); // t=0, d=0 + EXPECT_EQ(data[v_start + 1], 130.0f); // t=0, d=1 + EXPECT_EQ(data[v_start + 4], 133.0f); // t=1, d=0 +} + +TEST(KVEncodingTest, LayoutValidationF32TwoTranspositions) { + constexpr size_t kTileSize = 32; + constexpr size_t qkv_dim = 4; + gcpp::KVEncoding encoding = gcpp::KVEncoding::kF32TwoTranspositions; + + DecodedTile original(qkv_dim, kTileSize); + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + original.k_elem(token, dim) = dim * kTileSize + token + 1; + } + } + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + original.v_elem(token, dim) = + token * qkv_dim + dim + 1 + qkv_dim * kTileSize; + } + } + + size_t size = GetTileSizeBytes(encoding, qkv_dim).value(); + std::vector encoded(size); + + ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim, + hwy::Span(encoded.data(), encoded.size()))); + + const float* data = reinterpret_cast(encoded.data()); + + // K transposed: [qkv_dim/2, tile_size, 2] + EXPECT_EQ(data[0], 1.0f); // d=0, t=0 + EXPECT_EQ(data[1], 33.0f); // d=1, t=0 + EXPECT_EQ(data[2], 2.0f); // d=0, t=1 + EXPECT_EQ(data[3], 34.0f); // d=1, t=1 + EXPECT_EQ(data[64], 65.0f); // d=2, t=0 + EXPECT_EQ(data[65], 97.0f); // d=3, t=0 + + // V transposed: [tile_size/2, qkv_dim, 2] + size_t v_start = qkv_dim * kTileSize; + EXPECT_EQ(data[v_start], 129.0f); // t=0, d=0 + EXPECT_EQ(data[v_start + 1], 133.0f); // t=1, d=0 + EXPECT_EQ(data[v_start + 2], 130.0f); // t=0, d=1 + EXPECT_EQ(data[v_start + 3], 134.0f); // t=1, d=1 +} + +TEST(KVEncodingTest, LayoutValidationInt8) { + constexpr size_t kTileSize = 32; + constexpr size_t qkv_dim = 4; + gcpp::KVEncoding encoding = gcpp::KVEncoding::kInt8; + + DecodedTile original(qkv_dim, kTileSize); + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + original.k_elem(token, dim) = dim * kTileSize + token + 1; + } + } + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + original.v_elem(token, dim) = + token * qkv_dim + dim + 1 + qkv_dim * kTileSize; + } + } + + size_t size = GetTileSizeBytes(encoding, qkv_dim).value(); + std::vector encoded(size); + + ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim, + hwy::Span(encoded.data(), encoded.size()))); + + const int8_t* data = reinterpret_cast(encoded.data()); + + // K should be row-major [qkv_dim, tile_size] + // K[3,0] = 97. Max for t=0 is 97. Scale = 97/127. + // Quantized K[3,0] = 127. + // K[3,0] is at offset 3 * 32 + 0 = 96. + EXPECT_EQ(data[96], 127); + + // V should be row-major [tile_size, qkv_dim] + size_t v_start = qkv_dim * kTileSize; + // V[0,3] = 132. Max for t=0 is 132. Scale = 132/127. + // Quantized V[0,3] = 127. + // V[0,3] is at offset v_start + 0 * 4 + 3 = v_start + 3. + EXPECT_EQ(data[v_start + 3], 127); +} + +TEST(KVEncodingTest, LayoutValidationInt8TwoTranspositions) { + constexpr size_t kTileSize = 32; + constexpr size_t qkv_dim = 4; + gcpp::KVEncoding encoding = gcpp::KVEncoding::kInt8TwoTranspositions; + + DecodedTile original(qkv_dim, kTileSize); + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + original.k_elem(token, dim) = dim * kTileSize + token + 1; + } + } + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + original.v_elem(token, dim) = + token * qkv_dim + dim + 1 + qkv_dim * kTileSize; + } + } + + size_t size = GetTileSizeBytes(encoding, qkv_dim).value(); + std::vector encoded(size); + + ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim, + hwy::Span(encoded.data(), encoded.size()))); + + const int8_t* data = reinterpret_cast(encoded.data()); + + // K transposed: [qkv_dim/2, tile_size, 2] + // K[0,0] = 1. Max for t=0 is 97. Scale = 97/127. + // Quantized K[0,0] = 1. + // K[1,0] = 33. Quantized K[1,0] = 33 / (97/127) = 43.14 -> 43. + // K[1,0] is at offset 1. + EXPECT_EQ(data[0], 1); + EXPECT_EQ(data[1], 43); + + // V transposed: [tile_size/2, qkv_dim, 2] + size_t v_start = qkv_dim * kTileSize; + // V[0,0] = 129. Max for t=0 is 132. Scale = 132/127. + // Quantized V[0,0] = round(129 * 127 / 132) = 124. + // V[1,0] = 133. Max for t=1 is 136. Scale = 136/127. + // Quantized V[1,0] = round(133 * 127 / 136) = 124. + // In transposed layout, V[0,0] is at v_start. V[1,0] is at v_start + 1. + EXPECT_EQ(data[v_start], 124); + EXPECT_EQ(data[v_start + 1], 124); + + // V[1,3] = 136. Max for t=1 is 136. Quantized = 127. + // Offset in transposed V: t/2*8 + d*2 + t%2. + // For t=1, d=3: 0*8 + 3*2 + 1 = 7. + EXPECT_EQ(data[v_start + 7], 127); +} + +} // namespace +} // namespace gcpp