A Hackable ML Compiler Stack in 5,000 Lines of Python [P]
Hey r/MachineLearning,
The modern ML (LLM) compiler stack is brutal. TVM is 500K+ lines of C++. PyTorch piles Dynamo, Inductor, and Triton on top of each other. Then there's XLA, MLIR, Halide, Mojo. There is no tutorial that covers the high-level design of an ML compiler without dropping you straight into the guts of one of these frameworks.
I built a reference compiler from scratch in ~5K lines of pure Python that emits raw CUDA. It takes a small model (TinyLlama, Qwen2.5-7B) and lowers it to a sequence of CUDA kernels through six IRs. The goal isn't to beat Triton; it is to build a hackable, easy-to-follow compiler.
Full article: A Principled ML Compiler Stack in 5,000 Lines of Python
Repo: deplodock
The pipeline consists of six IRs, each closer to the hardware than the last. Walking the following PyTorch code through every stage (real reference compiler output with names shortened for brevity and comments added):
torch.relu(torch.matmul(x + bias, w)) # x: (16, 64), bias: (64,), w: (64, 16) Torch IR. Captured FX graph, 1:1 mirror of PyTorch ops:
bias_bc = bias[j] -> (16, 64) float32 add = add(x, bias_bc) -> (16, 64) float32 matmul = matmul(add, w, has_bias=False) -> (16, 16) float32 relu = relu(matmul) -> (16, 16) float32 Tensor IR. Every op is decomposed into Elementwise / Reduction / IndexMap. Minimal unified op surface, so future frontends (ONNX, JAX) plug in without touching downstream passes:
bias_bc = bias[j] -> (16, 64) float32 w_bc = w[j, k] -> (16, 64, 16) float32 add = add(x, bias_bc) -> (16, 64) float32 add_bc = add[i, j] -> (16, 64, 16) float32 prod = multiply(add_bc, w_bc) -> (16, 64, 16) float32 red = sum(prod, axis=-2) -> (16, 1, 16) float32 matmul = red[i, na, j] -> (16, 16) float32 relu = relu(matmul) -> (16, 16) float32 The (16, 64, 16) intermediate looks ruinous, but it's never materialized; the next stage fuses it out.
Loop IR. Each kernel has a loop nest fused with adjacent kernels. Prologue, broadcasted multiply, reduction, output layout, and epilogue all collapse into a single loop nest with no intermediate buffers.
=== merged_relu -> relu === for a0 in 0..16: # free (M) for a1 in 0..16: # free (N) for a2 in 0..64: # reduce (K) in0 = load bias[a2] in1 = load x[a0, a2] in2 = load w[a2, a1] v0 = add(in1, in0) # prologue (inside reduce) v1 = multiply(v0, in2) acc0 <- add(acc0, v1) v2 = relu(acc0) # epilogue (outside reduce) merged_relu[a0, a1] = v2 Tile IR. The first GPU-aware IR. Loop axes get scheduled onto threads/blocks, Stage hoists shared inputs into shared memory, and a 2×2 register tile lets each thread accumulate four outputs at once. The K-axis is tiled into two outer iterations of 32-wide reduce. Three-stage annotations below carry the heaviest optimizations:
buffers=2@a2— double-buffer the smem allocation along thea2K-tile loop, so loads for iterationa2+1overlap compute fora2.async— emitcp.async.ca.shared.globalso the warp doesn't block on global→smem transfers; pairs withcommit_group/wait_groupfences in Kernel IR.pad=(0, 1, 0)— add 1 element of padding to the middle smem dim so warp-wide loads don't all hit the same bank.kernel k_relu_reduce Tile(axes=(a0:8=THREAD, a1:8=THREAD)): for a2 in 0..2: # K-tile # meta: double-buffered, sync (small, no async needed) bias_smem = Stage(bias, origin=((a2 * 32)), slab=(a3:32@0)) buffers=2@a2
kernel k_relu_reduce Tile(axes=(a0:8=THREAD, a1:8=THREAD)): for a2 in 0..2: # K-tile bias_smem = Stage(bias, origin=((a2 * 32)), slab=(a3:32@0)) buffers=2@a2 x_smem = Stage(x, origin=(0, (a2 * 32)), slab=(a0:8@0, a3:32@1, cell:2@0)) pad=(0, 1, 0) buffers=2@a2 async w_smem = Stage(w, origin=((a2 * 32), 0), slab=(a3:32@0, a1:8@1, cell:2@1)) buffers=2@a2 async # reduce for a3 in 0..32: in0 = load bias_smem[a2, a3] in1 = load x_smem[a2, a0, a3, 0]; in2 = load x_smem[a2, a0, a3, 1] in3 = load w_smem[a2, a3, a1, 0]; in4 = load w_smem[a2, a3, a1, 1] # prologue, reused 2× across N v0 = add(in1, in0); v1 = add(in2, in0) # 2×2 register tile acc0 <- add(acc0, multiply(v0, in3)) acc1 <- add(acc1, multiply(v0, in4)) acc2 <- add(acc2, multiply(v1, in3)) acc3 <- add(acc3, multiply(v1, in4)) # epilogue relu[a0*2, a1*2 ] = relu(acc0) relu[a0*2, a1*2 + 1] = relu(acc1) relu[a0*2 + 1, a1*2 ] = relu(acc2) relu[a0*2 + 1, a1*2 + 1] = relu(acc3) Kernel IR. Schedule materialized into hardware primitives. THREAD/BLOCK become threadIdx/blockIdx, async Stage becomes Smem + cp.async fill with commit/wait fences, sync Stage becomes a strided fill loop. Framework-agnostic: same IR could lower to Metal or HIP:
kernel k_relu_reduce Tile(axes=(a0:8=THREAD, a1:8=THREAD)): Init(acc0..acc3, op=add) for a2 in 0..2: # K-tile Smem bias_smem[2, 32] (float) StridedLoop(flat = a0*8 + a1; < 32; += 64): bias_smem[a2, flat] = load bias[a2*32 + flat] Sync # pad row to 33 to kill bank conflicts Smem x_smem[2, 8, 33, 2] (float) StridedLoop(flat = a0*8 + a1; < 512; += 64): cp.async x_smem[a2, flat/64, (flat/2)%32, flat%2] <- x[flat/64*2 + flat%2, a2*32 + (flat/2)%32] cp.async.commit_group; cp.async.wait_group(0); Sync Smem w_smem[2, 32, 8, 2] (float) StridedLoop(flat = a0*8 + a1; < 512; += 64): cp.async w_smem[a2, flat/16, (flat/2)%8, flat%2] <- w[a2*32 + flat/16, (flat/2)%8*2 + flat%2] cp.async.commit_group; cp.async.wait_group(0); Sync for a3 in 0..32: # reduce ... CUDA. One-to-one tree walk over Kernel IR, ready for nvcc. Bias-add, the K-axis reduction, the 2×2 register tile, and the relu activation all in one kernel. One HBM read each of x, bias, w, one HBM write of relu, no intermediates between ops.
extern "C" __global__ __launch_bounds__(256) void k_relu_reduce(const float* bias, const float* x, const float* w, float* relu) { long long tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < 64) { int a0 = tid / 8; int a1 = tid % 8; float acc0 = 0.0f, acc1 = 0.0f, acc2 = 0.0f, acc3 = 0.0f; #pragma unroll for (int a2 = 0; a2 < 2; a2++) { __shared__ float bias_smem[64]; for (int f = a0*8 + a1; f < 32; f += 64) bias_smem[a2*32 + f] = bias[a2*32 + f]; __syncthreads(); // padded to avoid bank conflicts __shared__ float x_smem[1056]; for (int f = a0*8 + a1; f < 512; f += 64) { unsigned int addr = __cvta_generic_to_shared( &x_smem[a2*528 + f/64*66 + f/2%32*2 + f%2] ); asm volatile( "cp.async.ca.shared.global [%0], [%1], 4;\n" :: "r"(addr), "l"(&x[(f/64*2 + f%2)*64 + (a2*32 + f/2%32)]) : "memory"); } asm volatile("cp.async.commit_group;\n" ::: "memory"); asm volatile("cp.async.wait_group 0;\n" ::: "memory"); __syncthreads(); __shared__ float w_smem[1024]; for (int f = a0*8 + a1; f < 512; f += 64) { unsigned int addr = __cvta_generic_to_shared( &w_smem[a2*512 + f/16*16 + f/2%8*2 + f%2] ); asm volatile( "cp.async.ca.shared.global [%0], [%1], 4;\n" :: "r"(addr), "l"(&w[(a2*32 + f/16)*16 + (f/2%8*2 + f%2)]) : "memory"); } asm volatile("cp.async.commit_group;\n" ::: "memory"); asm volatile("cp.async.wait_group 0;\n" ::: "memory"); __syncthreads(); #pragma unroll for (int a3 = 0; a3 < 32; a3++) { float in0 = bias_smem[a2*32 + a3]; float in1 = x_smem[a2*528 + a0*66 + a3*2 ]; float in2 = x_smem[a2*528 + a0*66 + a3*2 + 1]; float in3 = w_smem[a2*512 + a3*16 + a1*2 ]; float in4 = w_smem[a2*512 + a3*16 + a1*2 + 1]; float v0 = in1 + in0; float v1 = in2 + in0; acc0 += v0 * in3; acc1 += v0 * in4; acc2 += v1 * in3; acc3 += v1 * in4; } } relu[a0*2*16 + a1*2 ] = fmaxf(0.0f, acc0); relu[a0*2*16 + a1*2 + 1] = fmaxf(0.0f, acc1); relu[(a0*2+1)*16 + a1*2 ] = fmaxf(0.0f, acc2); relu[(a0*2+1)*16 + a1*2 + 1] = fmaxf(0.0f, acc3); } } Every stage is printable on demand. No GPU needed.
deplodock compile -c "torch.relu(torch.matmul(torch.randn(16,64) + torch.randn(64), torch.randn(64,16)))" --ir tensor|loop|tile|kernel|cuda Benchmarking against eager PyTorch and torch.compile (attention scores at Qwen-block size, where the compiler ties torch.compile):
deplodock run --bench -c "torch.nn.Softmax(dim=-1)(torch.randn(1,28,2048,2048))" End-to-end compilation of a real model:
deplodock compile Qwen/Qwen2.5-7B The linked article goes through the design in detail (RMSNorm walked through every IR, the σ-based fusion algorithm with blowup guard, validation against torch.compile on TinyLlama and Qwen2.5-7B blocks). The forthcoming second part will go through the codegen internals.
[link] [comments]
Want to read more?
Check out the full article on the original site