FlashRT GEMM and FP8 Quant Epilogues

Fused GEMM and FP8 quantization epilogue kernels from FlashRT.

The main current surface is the post-GEMM FP8 quantization epilogue slice: BF16 input plus optional BF16 bias, GELU(tanh), and FP8 e4m3 quantized output, plus a per-channel BF16 scaling and FP8 quantization primitive.

The package also exposes BF16 GEMM wrappers using cuBLASLt fused bias and GELU epilogues. These wrappers are shape-sensitive and should be evaluated for the target workload before promotion.

Planned Features

  • FP8 GEMM with fused bias and activation epilogues.
  • NVFP4 GEMM with fused bias and activation epilogues.
  • Quantized output epilogues for low-latency inference pipelines.
  • Generic APIs for Transformer, VLA, and diffusion model linear blocks.

Current API

  • bf16_gemm_bias(a, b, bias, out=None)
  • bf16_gemm_bias_gelu(a, b, bias, out=None)
  • bias_gelu_quantize_fp8_static_bf16(input, bias, scale, out=None)
  • channel_scale_quantize_fp8_static_bf16(input, channel_scale, scale, out=None)
  • gelu_quantize_fp8_static_bf16(input, scale, out=None)

Performance Notes

FP8 quantize epilogue helpers are the strongest current surface across the local shape suite. BF16 GEMM epilogue wrappers are shape-sensitive and should be evaluated against torch.addmm/gelu(torch.addmm) for target shapes before promotion.

The v1 performance message for this package should center on the FP8 quantization helpers. GEMM epilogue numbers should be reported per shape, not as a broad claim.

Hardware

CUDA GPUs with BF16 tensor core support are expected for the GEMM path. FP8 output helpers additionally require PyTorch and hardware/runtime support for torch.float8_e4m3fn.

Validation

Validated HF builder targets currently include torch211-cxx11-cu128-x86_64-linux and torch211-cxx11-cu126-x86_64-linux, and torch211-cxx11-cu130-x86_64-linux. Builder ABI checks passed for all three variants. Host-side correctness smoke passed on an RTX 5090 with PyTorch 2.9.1+cu128. See VALIDATION.md for the full record and remaining gaps.

See examples/fp8_quant_epilogue_block.py for a minimal HF-style module using the FP8 quantization epilogue helpers.

Downloads last month
2
flashrt
kernel
cuda
gemm
fp8
nvfp4
Supported hardwares new
CUDA 7.58.08.68.78.99.010.011.012.012.1+PTX
NVIDIA SXM
B200
192GB
NVIDIA SXM
H200
141GB
NVIDIA SXM
H100
80GB
GPU
L40s
48GB
GPU
L40
48GB
GPU
L20
48GB
GPU
L4
24GB
DGX Spark
GB10
128GB
GPU
RTX PRO 6000 WS
96GB
GPU
RTX PRO 6000 Max-Q
96GB
GPU
RTX PRO 5000
48GB
GPU
RTX PRO 4500 WS
32GB
GPU
RTX PRO 4000
24GB
GPU
RTX PRO 4000 SFF
24GB
GPU
RTX PRO 2000
16GB
GPU
RTX 6000 Ada
48GB
GPU
RTX 5880 Ada
48GB
RTX
RTX 5000 Ada
32GB
GPU
RTX 4500 Ada
24GB
RTX
RTX 4000 Ada
20GB
RTX
RTX 4000 SFF Ada
20GB
GPU
RTX 2000 Ada
16GB
GPU
RTX A6000
48GB
GPU
RTX A5000
8GB
GPU
RTX A5000 Max-Q
16GB
GPU
RTX A5000 Mobile
16GB
GPU
RTX A4000
16GB
GPU
RTX A4000 Max-Q
8GB
GPU
RTX A4000 Mobile
8GB
GPU
RTX A3000 Mobile
6GB
GPU
RTX A2000
6GB
GPU
RTX A2000 Embedded
4GB
GPU
RTX A2000 Max-Q
4GB
GPU
RTX A2000 Mobile
4GB
GPU
A100
80GB
GPU
A40
48GB
GPU
A30
24GB
GPU
A10
24GB
GPU
A2
16GB
RTX
RTX 5090
32GB
RTX
RTX 5090 D
32GB
RTX
RTX 5090 Mobile
24GB
RTX
RTX 5080
16GB
RTX
RTX 5080 Mobile
16GB
RTX
RTX 5070
12GB
RTX
RTX 5070 Mobile
8GB
RTX
RTX 5070 Ti
16GB
RTX
RTX 5070 Ti Mobile
12GB
RTX
RTX 5060 Ti
16GB
RTX
RTX 5060
8GB
RTX
RTX 5060 Mobile
8GB
RTX
RTX 4090
24GB
RTX
RTX 4090D
24GB
RTX
RTX 4090 Mobile
16GB
RTX
RTX 4080 SUPER
16GB
RTX
RTX 4080
16GB
RTX
RTX 4080 Mobile
12GB
RTX
RTX 4070
12GB
RTX
RTX 4070 Mobile
8GB
RTX
RTX 4070 Ti
12GB
RTX
RTX 4070 Super
12GB
RTX
RTX 4070 Ti Super
16GB
RTX
RTX 4060
8GB
RTX
RTX 4060 Ti
8GB
RTX
RTX 4090 Laptop
16GB
RTX
RTX 4080 Laptop
12GB
RTX
RTX 4070 Laptop
8GB
RTX
RTX 4060 Laptop
8GB
RTX
RTX 4050 Laptop
6GB
RTX
RTX 3090
24GB
RTX
RTX 3090 Ti
24GB
RTX
RTX 3080
12GB
RTX
RTX 3080 Ti
12GB
RTX
RTX 3080 Mobile
16GB
RTX
RTX 3070
8GB
RTX
RTX 3070 Ti
8GB
RTX
RTX 3070 Ti Mobile
8GB
RTX
RTX 3060 Ti
8GB
RTX
RTX 3060
12GB
GPU
RTX 2080 Ti
11GB
GPU
RTX 2080
8GB
GPU
RTX 2070
8GB
GPU
RTX 2070 SUPER Mobile
8GB
GPU
RTX 2070 SUPER
8GB
RTX
RTX 3060 Mobile
6GB
RTX
RTX 3050 Mobile
4GB
GPU
RTX 2060
6GB
GPU
RTX 2060 12GB
12GB
GPU
RTX 2060 Mobile
6GB
GPU
RTX 2050 Mobile
4GB
GPU
RTX Titan
24GB
GPU
GTX 1660
6GB
GPU
GTX 1650 Mobile
4GB
NVIDIA T4
T4
16GB
GPU
T10
16GB
Jetson
Jetson AGX Orin 64GB
64GB
Jetson
Jetson AGX Orin 32GB
32GB
Jetson
Jetson Orin NX 16GB
16GB
Jetson
Jetson Orin NX 8GB
8GB
Jetson
Jetson Orin Nano 8GB
8GB
Jetson
Jetson Orin Nano 4GB
4GB
OS
linux
Arch
x86_64