3 min readfrom Machine Learning

Rewriting model inference with CUDA kernels: the bottleneck was not just GEMM [P]

I’ve been working on a CUDA-first inference runtime for small-batch / realtime ML workloads.

The core idea is simple: instead of treating PyTorch / TensorRT / generic graph runtimes as the main execution path, I rewrite the model inference path directly with C++/CUDA kernels.

This started from robotics / VLA workloads, but the problem is more general.

In small-batch inference, the bottleneck is often not just a single slow GEMM. A lot of latency comes from the runtime glue around the math:

  • fragmented small kernels
  • norm / residual / activation boundaries
  • quantize / dequantize overhead
  • layout transitions
  • Python / runtime scheduling
  • graph compiler fusion failures
  • precision conversion around FP8 / FP4 regions

For cloud LLM serving, batching can hide a lot of this.

For robotics, VLA, world models, and other realtime workloads, batch size is usually 1. There is nowhere to hide. Every launch, sync, and format boundary shows up directly in latency.

Some current results from my implementation:

Model / workload Hardware FlashRT latency
Pi0.5 Jetson Thor ~44 ms
Pi0 Jetson Thor ~46 ms
GROOT N1.6 Jetson Thor ~41–45 ms
Pi0.5 RTX 5090 ~17.6 ms
GROOT N1.6 RTX 5090 ~12.5–13.1 ms
Pi0-FAST RTX 5090 ~2.39 ms/token
Qwen3.6 27B RTX 5090 ~129 tok/s with NVFP4
Motus / Wan-style world model RTX 5090 ~1.3s baseline → targeting ~100ms E2E

The Motus / world-model case is especially interesting.

The baseline path is around 1.3s end-to-end. The target is ~100ms E2E, but the hard part is not simply “use a faster GEMM”. The bottlenecks are VAE, joint attention, launch fragmentation, and a large amount of glue around the actual math.

One lesson from this work: lower precision is not automatically a win.

FP8 has been consistently useful. FP4 / NVFP4 is more mixed. It can help memory footprint and some large GEMM regions, but if the FP4 region is small, discontinuous, or surrounded by conversion / scaling overhead, the end-to-end speedup can be tiny.

For example, in some VLA / world-model paths, FP4 over FP8 only gives a few percent latency improvement unless the region is large and deeply fused.

This changed how I think about inference optimization.

For large-batch cloud serving, generic runtimes and batching are often enough.

For realtime small-batch inference, the runtime overhead becomes the workload.

Curious if others have seen similar behavior with torch.compile, TensorRT, XLA, Triton, or custom CUDA kernels.

At what point do you stop trying to make a generic compiler optimize the model, and just rewrite the inference path directly?

Implementation: https://github.com/LiangSu8899/FlashRT

submitted by /u/Diligent-End-2711
[link] [comments]

Want to read more?

Check out the full article on the original site

View original article

Tagged with

#natural language processing for spreadsheets
#generative AI for data analysis
#Excel alternatives for data analysis
#large dataset processing
#financial modeling with spreadsheets
#rows.com
#cloud-based spreadsheet applications
#cloud-native spreadsheets
#CUDA
#inference
#real-time ML
#small-batch
#GEMM
#PyTorch
#TensorRT
#fragmented kernels
#quantize
#dequantize
#layout transitions
#scaling overhead