Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 191 additions & 14 deletions bench/00_misc/fltflt_arithmetic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,72 @@ static void add_gops_per_sec_summary(nvbench::state &state, double ops_per_op =
s.set_float64("value", total_ops / seconds / 1e9);
}

// Bump a value to the next bit pattern by adding 1 to its integer
// representation. Used to vary a loop input across iterations without
// charging the benchmark for an unrelated fp add on every iteration.
//
// Only `double` and `fltflt` overloads are provided: on most GPUs an
// int64 add is significantly faster than an fp64 add, and the fltflt
// alternative would dispatch through fltflt_add (~20 fp32 ops). For
// float, fp32 add and int32 add run at the same rate, so call sites
// should keep `x = x + small`.
//
// For fltflt, only the hi component is bumped, leaving the pair
// non-canonical -- benches only care that the value differs from the
// previous iteration. The +1 would be UB at the largest-positive bit
// pattern (a NaN); call sites here never reach that.
__device__ __forceinline__ void bump_ulp(double &x) {
x = __longlong_as_double(__double_as_longlong(x) + 1LL);
}
__device__ __forceinline__ void bump_ulp(fltflt &x) {
x.hi = __int_as_float(__float_as_int(x.hi) + 1);
}

// Compute-bound kernel used solely to spin GPU clocks up to steady
// state. Self-contained (doesn't depend on any of the iterative_*
// kernels below) so it can be called from warmup_gpu_once() before
// they're defined.
__global__ void clock_warmup_kernel(float *out, int N)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float acc = static_cast<float>(idx) * 0.001f;
#pragma unroll 1
for (int i = 0; i < N; i++) {
acc = acc * 1.0000001f + 1.0e-6f;
}
if (idx == 0) out[0] = acc;
}

// Idempotent process-level GPU warmup. Only the first call in a
// process actually runs warmup launches; subsequent calls are no-ops.
// Brings GPU clocks to steady state before the first nvbench timing
// window so the *first* benchmark to execute (whichever one that may
// be) does not get charged for clock ramp-up.
//
// inner_iters is sized so each launch runs ~50 ms on Blackwell, and
// four launches give ~200 ms total -- comfortably past the GPU's clock
// ramp window. (FMA-bound work retires near peak fp32 throughput, so
// undersized warmups would otherwise be over in a few ms.)
static void warmup_gpu_once()
{
static bool warmed = false;
if (warmed) return;
warmed = true;

constexpr int block_size = 256;
constexpr int grid_size = 1024;
constexpr int inner_iters = 2'000'000;

float *tmp = nullptr;
MATX_CUDA_CHECK(cudaMalloc(&tmp, sizeof(float)));
for (int w = 0; w < 4; w++) {
clock_warmup_kernel<<<grid_size, block_size>>>(tmp, inner_iters);
MATX_CUDA_CHECK_LAST_ERROR();
}
MATX_CUDA_CHECK(cudaDeviceSynchronize());
MATX_CUDA_CHECK(cudaFree(tmp));
}

template <typename T>
__global__ void iterative_add_kernel(T* __restrict__ result, int64_t size, int32_t iterations)
{
Expand Down Expand Up @@ -277,10 +343,15 @@ __global__ void iterative_fma_kernel(T* __restrict__ result, int64_t size, int32
}

//==============================================================================
// Addition Benchmark
// Addition Throughput Benchmark
//
// Many independent accumulators (ILP_FACTOR=8) and outer-loop unrolling
// expose maximum instruction-level parallelism. Latency-hiding fully covers
// per-call dependency chains, so this measures *throughput*: ops/sec when
// the warp scheduler always has independent work in flight.
//==============================================================================
template <typename PrecisionType>
void fltflt_bench_add(nvbench::state &state, nvbench::type_list<PrecisionType>)
void fltflt_bench_add_throughput(nvbench::state &state, nvbench::type_list<PrecisionType>)
{
const index_t size = static_cast<index_t>(state.get_int64("Array Size"));
const int32_t iterations = static_cast<int32_t>(state.get_int64("Iterations"));
Expand All @@ -296,6 +367,7 @@ void fltflt_bench_add(nvbench::state &state, nvbench::type_list<PrecisionType>)
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

// Benchmark execution
Expand All @@ -306,7 +378,7 @@ void fltflt_bench_add(nvbench::state &state, nvbench::type_list<PrecisionType>)
add_gops_per_sec_summary(state);
}

NVBENCH_BENCH_TYPES(fltflt_bench_add, NVBENCH_TYPE_AXES(precision_types))
NVBENCH_BENCH_TYPES(fltflt_bench_add_throughput, NVBENCH_TYPE_AXES(precision_types))
.add_int64_power_of_two_axis("Array Size", nvbench::range(24, 24, 1))
.add_int64_axis("Iterations", {250});

Expand All @@ -328,6 +400,7 @@ void fltflt_bench_sub(nvbench::state &state, nvbench::type_list<PrecisionType>)
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand Down Expand Up @@ -359,6 +432,7 @@ void fltflt_bench_mul(nvbench::state &state, nvbench::type_list<PrecisionType>)
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand Down Expand Up @@ -390,6 +464,7 @@ void fltflt_bench_div(nvbench::state &state, nvbench::type_list<PrecisionType>)
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand Down Expand Up @@ -421,6 +496,7 @@ void fltflt_bench_sqrt(nvbench::state &state, nvbench::type_list<PrecisionType>)
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand Down Expand Up @@ -492,6 +568,7 @@ void fltflt_bench_sqrt_fast(nvbench::state &state, nvbench::type_list<PrecisionT
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand Down Expand Up @@ -574,6 +651,7 @@ void fltflt_bench_norm3d(nvbench::state &state, nvbench::type_list<PrecisionType
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand Down Expand Up @@ -606,6 +684,7 @@ void fltflt_bench_abs(nvbench::state &state, nvbench::type_list<PrecisionType>)
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand Down Expand Up @@ -637,6 +716,7 @@ void fltflt_bench_fma(nvbench::state &state, nvbench::type_list<PrecisionType>)
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand Down Expand Up @@ -704,6 +784,7 @@ void fltflt_bench_madd(nvbench::state &state, nvbench::type_list<PrecisionType>)
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
Comment thread
tbensonatl marked this conversation as resolved.
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand Down Expand Up @@ -778,6 +859,7 @@ void fltflt_bench_round(nvbench::state &state, nvbench::type_list<PrecisionType>
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand Down Expand Up @@ -829,10 +911,13 @@ __global__ void iterative_fmod_kernel(T* __restrict__ result, int64_t size, int3
asm volatile("" : "+d"(val[ilp]));
}
}
if constexpr (std::is_same_v<T, fltflt>) {
init_val = init_val + 2048.0f;
if constexpr (std::is_same_v<T, float>) {
// fp32 add is full-rate, no benefit from a bit-twiddle here.
init_val += 2048.0f;
} else {
init_val += static_cast<T>(2048.0f);
// Bit-pattern bump avoids an fp64 add (or full fltflt_add) on
// every iteration just to defeat hoisting of the fmod call.
bump_ulp(init_val);
}
}

Expand Down Expand Up @@ -860,6 +945,7 @@ void fltflt_bench_fmod(nvbench::state &state, nvbench::type_list<PrecisionType>)
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand Down Expand Up @@ -910,7 +996,14 @@ __global__ void iterative_trunc_kernel(T* __restrict__ result, int64_t size, int
asm volatile("" : "+d"(val[ilp]));
}
}
init_val = init_val + static_cast<T>(2048.0f);
if constexpr (std::is_same_v<T, float>) {
// fp32 add is full-rate, no benefit from a bit-twiddle here.
init_val += 2048.0f;
} else {
// Bit-pattern bump avoids an fp64 add (or full fltflt_add) on
// every iteration just to defeat hoisting of the trunc call.
bump_ulp(init_val);
}
}

T result_val = val[0];
Expand All @@ -937,6 +1030,7 @@ void fltflt_bench_trunc(nvbench::state &state, nvbench::type_list<PrecisionType>
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand Down Expand Up @@ -987,7 +1081,14 @@ __global__ void iterative_floor_kernel(T* __restrict__ result, int64_t size, int
asm volatile("" : "+d"(val[ilp]));
}
}
init_val = init_val + static_cast<T>(2048.0f);
if constexpr (std::is_same_v<T, float>) {
// fp32 add is full-rate, no benefit from a bit-twiddle here.
init_val += 2048.0f;
} else {
// Bit-pattern bump avoids an fp64 add (or full fltflt_add) on
// every iteration just to defeat hoisting of the floor call.
bump_ulp(init_val);
}
}

T result_val = val[0];
Expand All @@ -1014,6 +1115,7 @@ void fltflt_bench_floor(nvbench::state &state, nvbench::type_list<PrecisionType>
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand Down Expand Up @@ -1050,7 +1152,15 @@ __global__ void iterative_cast2dbl_kernel(double* __restrict__ result, int64_t s
acc[ilp] = static_cast<double>(src_val);
asm volatile("" : "+d"(acc[ilp]));
}
src_val = src_val + static_cast<T>(0.0001);
if constexpr (std::is_same_v<T, float>) {
// fp32 add is full-rate, no benefit from a bit-twiddle here.
src_val = src_val + static_cast<T>(0.0001);
} else {
// Vary src_val via bit-pattern bump -- keeps the cast2dbl cost
// un-contaminated by an unrelated fp64 add or fltflt_add per
// iteration.
bump_ulp(src_val);
}
}

double result_val = acc[0];
Expand All @@ -1077,6 +1187,7 @@ void fltflt_bench_cast2dbl(nvbench::state &state, nvbench::type_list<PrecisionTy
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand Down Expand Up @@ -1113,12 +1224,14 @@ __global__ void iterative_cast2fltflt_kernel(fltflt* __restrict__ result, int64_
acc[ilp] = static_cast<fltflt>(src_val);
asm volatile("" : "+f"(acc[ilp].hi), "+f"(acc[ilp].lo));
}
// For double, increment the bit pattern to get the next representable value
// so the loop anti-aliasing doesn't introduce a double-precision add.
if constexpr (cuda::std::is_same_v<T, double>) {
src_val = __longlong_as_double(__double_as_longlong(src_val) + 1LL);
} else {
if constexpr (std::is_same_v<T, float>) {
// fp32 add is full-rate, no benefit from a bit-twiddle here.
src_val = src_val + static_cast<T>(0.0001);
} else {
// Vary src_val via bit-pattern bump -- keeps the cast2fltflt
// cost un-contaminated by an unrelated fp64 add or fltflt_add
// per iteration.
bump_ulp(src_val);
}
}

Expand Down Expand Up @@ -1146,6 +1259,7 @@ void fltflt_bench_cast2fltflt(nvbench::state &state, nvbench::type_list<Precisio
constexpr int block_size = 256;
int grid_size = static_cast<int>((size + block_size - 1) / block_size);

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
Expand All @@ -1158,3 +1272,66 @@ void fltflt_bench_cast2fltflt(nvbench::state &state, nvbench::type_list<Precisio
NVBENCH_BENCH_TYPES(fltflt_bench_cast2fltflt, NVBENCH_TYPE_AXES(precision_types))
.add_int64_power_of_two_axis("Array Size", nvbench::range(24, 24, 1))
.add_int64_axis("Iterations", {250});

//==============================================================================
// Addition Latency Benchmark
//
// Mirrors fltflt_bench_add_throughput but with the opposite scheduling
// posture: a single in-flight accumulator per thread, no ILP, no inner-loop
// unroll, and step varies per iteration (so the compiler cannot hoist or
// reassociate the chain). Each iteration's input depends on the previous
// iteration's output, so per-call dependency chains directly drive runtime.
//
// For fltflt this exposes the depth difference between the production
// fltflt_add (Zhang & Aiken SC'25 Fig 2 FPAN, critical path ~10 fp32 ops)
// and Thall's df64_add (~13 fp32 ops). The "Blocks" axis sweeps the
// latency->throughput transition: at Blocks=1 only one warp runs on one SM,
// fully exposing the chain, while at Blocks=1024 the scheduler has many
// warps in flight and latency is partially hidden.
//==============================================================================
template <typename PrecisionType>
__global__ void chain_add_kernel(int N, PrecisionType *__restrict__ out)
{
// Construct via float so the same expression compiles for float, double,
// and fltflt (each has a constructor accepting a float).
PrecisionType acc{1.0f};
#pragma unroll 1
for (int i = 0; i < N; i++) {
// step varies per iteration to defeat loop-invariant hoisting and force
// a true data dependency on the running accumulator.
const PrecisionType step{static_cast<float>(i + 1)};
acc = acc + step; // dispatches to PrecisionType's operator+
}
out[blockIdx.x * blockDim.x + threadIdx.x] = acc;
}

template <typename PrecisionType>
void fltflt_bench_add_latency(nvbench::state &state, nvbench::type_list<PrecisionType>)
{
const int chain_len = static_cast<int>(state.get_int64("Chain Length"));
const int blocks = static_cast<int>(state.get_int64("Blocks"));
constexpr int threads = 32; // exactly one warp per block

cudaExecutor exec{0};
const size_t total_threads = static_cast<size_t>(blocks) * threads;
auto result = make_tensor<PrecisionType>({static_cast<index_t>(total_threads)});

state.add_element_count(static_cast<int64_t>(chain_len) * total_threads, "ops");

warmup_gpu_once();
exec.sync();

state.exec([&](nvbench::launch &launch) {
chain_add_kernel<PrecisionType>
<<<blocks, threads, 0, (cudaStream_t)launch.get_stream()>>>(
chain_len, result.Data());
});
}

NVBENCH_BENCH_TYPES(fltflt_bench_add_latency, NVBENCH_TYPE_AXES(precision_types))
.add_int64_axis("Chain Length", {1024, 4096, 16384})
// Blocks=1 : 1 warp on 1 SM, all other SMs idle -- latency fully exposed
// Blocks=4 : 4 warps on 4 SMs, each SM has 1 warp -- still latency-bound
// Blocks=160 : ~1 block per SM on a ~160-SM device -- partial latency hiding
// Blocks=1024: many blocks per SM -- throughput-bound, latency hides
.add_int64_axis("Blocks", {1, 4, 160, 1024});
Loading