Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,31 @@ cc_test(
],
)

cc_library(
name = "kv_transcoding",
srcs = ["gemma/kv_transcoding.cc"],
hdrs = ["gemma/kv_transcoding.h"],
deps = [
":activations",
":basics",
":configs",
":kv_cache",
"//compression:types",
"@highway//:hwy",
],
)

cc_test(
name = "kv_transcoding_test",
srcs = ["gemma/kv_transcoding_test.cc"],
deps = [
":configs",
":kv_transcoding",
"//testing/base/public:gunit_main",
"@highway//:hwy",
],
)

cc_library(
name = "activations",
hdrs = ["gemma/activations.h"],
Expand Down
11 changes: 11 additions & 0 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ static inline bool EnumValid(LayerAttentionType type) {
return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit;
}

// Values stated explicitly to allow for semantic reordering
enum class KVEncoding {
kUnspecified = 0,
kF32 = 1,
kBF16 = 2,
kF32TwoTranspositions = 3,
kBF16TwoTranspositions = 4,
kInt8 = 5,
kInt8TwoTranspositions = 6,
};

enum class AttentionImpl {
kOld, // Previous Attention implementation
kFlash, // Flash Attention (default)
Expand Down
320 changes: 320 additions & 0 deletions gemma/kv_transcoding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
#include "gemma/kv_transcoding.h"

#include <algorithm>
#include <cstddef>
#include <cstdlib>
#include <optional>

#include "compression/types.h"
#include "gemma/activations.h"
#include "gemma/configs.h"
#include "gemma/kv_cache.h"
#include "util/basics.h"
#include "hwy/base.h"
#include "hwy/highway.h"

namespace gcpp {

std::optional<size_t> GetTileSizeBytes(gcpp::KVEncoding encoding,
size_t qkv_dim) {
constexpr size_t kTileSize = gcpp::KVCache::kTileSize;
switch (encoding) {
case gcpp::KVEncoding::kInt8:
case gcpp::KVEncoding::kInt8TwoTranspositions:
return qkv_dim * kTileSize * 2 * sizeof(int8_t) +
kTileSize * 2 * sizeof(gcpp::KV_microscale_t);
case gcpp::KVEncoding::kBF16:
case gcpp::KVEncoding::kBF16TwoTranspositions:
return qkv_dim * kTileSize * 2 * sizeof(gcpp::BF16);
case gcpp::KVEncoding::kF32:
case gcpp::KVEncoding::kF32TwoTranspositions:
return qkv_dim * kTileSize * 2 * sizeof(float);
default:
return std::nullopt;
}
}

namespace {
constexpr size_t kTileSize = gcpp::KVCache::kTileSize;

inline size_t KOffset(bool transposed, size_t qkv_dim, size_t dim,
size_t token) {
HWY_DASSERT(dim < qkv_dim && token < kTileSize);
return transposed ? ((dim / 2) * kTileSize * 2 + token * 2 + (dim % 2))
: (dim * kTileSize + token);
}

inline size_t VOffset(bool transposed, size_t qkv_dim, size_t dim,
size_t token) {
HWY_DASSERT(dim < qkv_dim && token < kTileSize);
return transposed ? ((token / 2) * qkv_dim * 2 + dim * 2 + (token % 2))
: (token * qkv_dim + dim);
}

int8_t Quantize(float v, float inv_scale) {
float scaled = v * inv_scale;
if (scaled > 127.0f) return 127;
if (scaled < -127.0f) return -127;
return hwy::ConvertScalarTo<int8_t>(scaled);
}

template <typename DecodeKFn, typename DecodeVFn>
inline void DecodeTileWithFn(size_t qkv_dim, DecodedTile* out,
const DecodeKFn& decode_k,
const DecodeVFn& decode_v) {
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
out->k_elem(token, dim) = decode_k(dim, token);
}
}
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
out->v_elem(token, dim) = decode_v(dim, token);
}
}
}

template <typename EncodeKFn, typename EncodeVFn>
inline void EncodeTileWithFn(size_t qkv_dim, const DecodedTile& decoded,
const EncodeKFn& encode_k,
const EncodeVFn& encode_v) {
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
encode_k(dim, token, decoded.k_elem(token, dim));
}
}
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
encode_v(dim, token, decoded.v_elem(token, dim));
}
}
}

void EncodeTileF32(bool transposed, size_t qkv_dim, const DecodedTile& decoded,
hwy::Span<char> out_encoded_tile_data) {
float* data = HWY_RCAST_ALIGNED(float*, out_encoded_tile_data.data());
const size_t v_start = qkv_dim * kTileSize;
EncodeTileWithFn(
qkv_dim, decoded,
[&](size_t dim, size_t token, float val)
HWY_ATTR { data[KOffset(transposed, qkv_dim, dim, token)] = val; },
[&](size_t dim, size_t token, float val) HWY_ATTR {
data[v_start + VOffset(transposed, qkv_dim, dim, token)] = val;
});
}

void EncodeTileBF16(bool transposed, size_t qkv_dim, const DecodedTile& decoded,
hwy::Span<char> out_encoded_tile_data) {
gcpp::BF16* data =
HWY_RCAST_ALIGNED(gcpp::BF16*, out_encoded_tile_data.data());
const size_t v_start = qkv_dim * kTileSize;
EncodeTileWithFn(
qkv_dim, decoded,
[&](size_t dim, size_t token, float val) HWY_ATTR {
data[KOffset(transposed, qkv_dim, dim, token)] =
hwy::ConvertScalarTo<hwy::bfloat16_t>(val);
},
[&](size_t dim, size_t token, float val) HWY_ATTR {
data[v_start + VOffset(transposed, qkv_dim, dim, token)] =
hwy::ConvertScalarTo<hwy::bfloat16_t>(val);
});
}

void EncodeTileInt8(bool transposed, size_t qkv_dim, const DecodedTile& decoded,
hwy::Span<char> out_encoded_tile_data) {
int8_t* k_data = HWY_RCAST_ALIGNED(int8_t*, out_encoded_tile_data.data());
int8_t* v_data = k_data + qkv_dim * kTileSize;
gcpp::KV_microscale_t* scales =
HWY_RCAST_ALIGNED(gcpp::KV_microscale_t*, v_data + kTileSize * qkv_dim);
gcpp::KV_microscale_t* k_scales = scales;
gcpp::KV_microscale_t* v_scales = scales + kTileSize;

AlignedFloatVector k_max_abs(kTileSize, 0.0f);
AlignedFloatVector v_max_abs(kTileSize, 0.0f);

for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
k_max_abs[token] =
std::max(k_max_abs[token], std::abs(decoded.k_elem(token, dim)));
}
}
for (size_t token = 0; token < kTileSize; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
v_max_abs[token] =
std::max(v_max_abs[token], std::abs(decoded.v_elem(token, dim)));
}
}

AlignedFloatVector inv_scales_k(kTileSize);
AlignedFloatVector inv_scales_v(kTileSize);
for (size_t token = 0; token < kTileSize; ++token) {
float scale_k = k_max_abs[token] == 0.0f ? 1.0f : k_max_abs[token] / 127.0f;
k_scales[token] = hwy::ConvertScalarTo<gcpp::KV_microscale_t>(scale_k);
inv_scales_k[token] = 1.0f / scale_k;

float scale_v = v_max_abs[token] == 0.0f ? 1.0f : v_max_abs[token] / 127.0f;
v_scales[token] = hwy::ConvertScalarTo<gcpp::KV_microscale_t>(scale_v);
inv_scales_v[token] = 1.0f / scale_v;
}

EncodeTileWithFn(
qkv_dim, decoded,
[&](size_t dim, size_t token, float val) HWY_ATTR {
k_data[KOffset(transposed, qkv_dim, dim, token)] =
Quantize(val, inv_scales_k[token]);
},
[&](size_t dim, size_t token, float val) HWY_ATTR {
v_data[VOffset(transposed, qkv_dim, dim, token)] =
Quantize(val, inv_scales_v[token]);
});
}

void DecodeTileF32(bool transposed, size_t qkv_dim,
hwy::Span<const char> encoded_tile_data, DecodedTile* out) {
const float* data = HWY_RCAST_ALIGNED(const float*, encoded_tile_data.data());
const size_t v_start = qkv_dim * kTileSize;
DecodeTileWithFn(
qkv_dim, out,
[&](size_t dim, size_t token)
HWY_ATTR { return data[KOffset(transposed, qkv_dim, dim, token)]; },
[&](size_t dim, size_t token) HWY_ATTR {
return data[v_start + VOffset(transposed, qkv_dim, dim, token)];
});
}

void DecodeTileBF16(bool transposed, size_t qkv_dim,
hwy::Span<const char> encoded_tile_data, DecodedTile* out) {
const gcpp::BF16* data =
HWY_RCAST_ALIGNED(const gcpp::BF16*, encoded_tile_data.data());
const size_t v_start = qkv_dim * kTileSize;
DecodeTileWithFn(
qkv_dim, out,
[&](size_t dim, size_t token) HWY_ATTR {
return hwy::ConvertScalarTo<float>(
data[KOffset(transposed, qkv_dim, dim, token)]);
},
[&](size_t dim, size_t token) HWY_ATTR {
return hwy::ConvertScalarTo<float>(
data[v_start + VOffset(transposed, qkv_dim, dim, token)]);
});
}

void DecodeTileInt8(bool transposed, size_t qkv_dim,
hwy::Span<const char> encoded_tile_data, DecodedTile* out) {
const int8_t* k_data =
HWY_RCAST_ALIGNED(const int8_t*, encoded_tile_data.data());
const int8_t* v_data = k_data + qkv_dim * kTileSize;
const gcpp::KV_microscale_t* scales = HWY_RCAST_ALIGNED(
const gcpp::KV_microscale_t*, v_data + kTileSize * qkv_dim);
const gcpp::KV_microscale_t* k_scales = scales;
const gcpp::KV_microscale_t* v_scales = scales + kTileSize;

DecodeTileWithFn(
qkv_dim, out,
[&](size_t dim, size_t token) HWY_ATTR {
float scale = hwy::ConvertScalarTo<float>(k_scales[token]);
return k_data[KOffset(transposed, qkv_dim, dim, token)] * scale;
},
[&](size_t dim, size_t token) HWY_ATTR {
float scale = hwy::ConvertScalarTo<float>(v_scales[token]);
return v_data[VOffset(transposed, qkv_dim, dim, token)] * scale;
});
}

} // namespace

bool IsTransposed(KVEncoding encoding) {
switch (encoding) {
case KVEncoding::kF32TwoTranspositions:
case KVEncoding::kBF16TwoTranspositions:
case KVEncoding::kInt8TwoTranspositions:
return true;
default:
return false;
}
}

hwy::AlignedUniquePtr<char[]> AllocateEncodedTile(KVEncoding encoding,
size_t qkv_dim) {
std::optional<size_t> size = GetTileSizeBytes(encoding, qkv_dim);
if (!size.has_value()) return hwy::AlignedUniquePtr<char[]>();
return hwy::MakeUniqueAlignedArray<char>(*size);
}

bool DecodeTile(KVEncoding encoding, hwy::Span<const char> encoded_tile_data,
size_t qkv_dim, DecodedTile* out) {
std::optional<size_t> required_size_or = GetTileSizeBytes(encoding, qkv_dim);
if (!required_size_or.has_value()) return false;
size_t required_size = *required_size_or;
if (encoded_tile_data.size() < required_size) {
return false;
}

switch (encoding) {
case gcpp::KVEncoding::kF32:
case gcpp::KVEncoding::kF32TwoTranspositions: {
bool transposed = IsTransposed(encoding);
DecodeTileF32(transposed, qkv_dim, encoded_tile_data, out);
return true;
}
case gcpp::KVEncoding::kBF16:
case gcpp::KVEncoding::kBF16TwoTranspositions: {
bool transposed = IsTransposed(encoding);
DecodeTileBF16(transposed, qkv_dim, encoded_tile_data, out);
return true;
}
case gcpp::KVEncoding::kInt8:
case gcpp::KVEncoding::kInt8TwoTranspositions: {
bool transposed = IsTransposed(encoding);
DecodeTileInt8(transposed, qkv_dim, encoded_tile_data, out);
return true;
}
default:
return false;
}
}

bool EncodeTile(gcpp::KVEncoding encoding, const DecodedTile& decoded,
size_t qkv_dim, hwy::Span<char> out_encoded_tile_data) {
std::optional<size_t> required_size_or = GetTileSizeBytes(encoding, qkv_dim);
if (!required_size_or.has_value()) return false;
size_t required_size = *required_size_or;
if (out_encoded_tile_data.size() < required_size) {
return false;
}

switch (encoding) {
case gcpp::KVEncoding::kF32:
case gcpp::KVEncoding::kF32TwoTranspositions: {
bool transposed = IsTransposed(encoding);
EncodeTileF32(transposed, qkv_dim, decoded, out_encoded_tile_data);
return true;
}
case gcpp::KVEncoding::kBF16:
case gcpp::KVEncoding::kBF16TwoTranspositions: {
bool transposed = IsTransposed(encoding);
EncodeTileBF16(transposed, qkv_dim, decoded, out_encoded_tile_data);
return true;
}
case gcpp::KVEncoding::kInt8:
case gcpp::KVEncoding::kInt8TwoTranspositions: {
bool transposed = IsTransposed(encoding);
EncodeTileInt8(transposed, qkv_dim, decoded, out_encoded_tile_data);
return true;
}
default:
return false;
}
}

bool TranscodeTile(gcpp::KVEncoding src_encoding,
hwy::Span<const char> src_data,
gcpp::KVEncoding dst_encoding, hwy::Span<char> dst_data,
size_t qkv_dim) {
DecodedTile decoded(qkv_dim, kTileSize);
if (!DecodeTile(src_encoding, src_data, qkv_dim, &decoded)) return false;

return EncodeTile(dst_encoding, decoded, qkv_dim, dst_data);
}

} // namespace gcpp
Loading
Loading