[R] Hybrid attention for small code models: 50x faster inference, but data scaling still dominates
TLDR: Forked pytorch and triton internals . Changed attention so its linear first layer , middle quadratic layer, last linear layer
Inference got much faster with a low perplexity hit in tests .
I trained a 25.6M parameter Rust-focused language model from scratch using a byte-level GPT-style decoder.
The main result is that increasing dataset size mattered more than any architectural change.
Expanding the corpus from about 31MB of core Rust sources to roughly 173MB by adding a few hundred crates produced a much larger improvement than anything else. Training converged faster and reached a lower validation loss, while architectural changes had a smaller effect.
Final validation loss is 0.82 with perplexity 2.15. The best checkpoint appears around step 18.5k, with mild overfitting afterward.
Each layer replaces standard attention with a hybrid mechanism that combines local windowed attention and a GRU-like recurrent state, mixed through a learned gate. The local path captures short-range syntax, while the recurrent path carries compressed long-range information.
This hybrid attention did not clearly improve generation quality compared to a standard setup. However, it had a large impact on inference efficiency.
With a KV cache that keeps a small recent window in VRAM and compresses older tokens, inference improved from 5.6 tokens per second to 286 tokens per second on a 4060 Ti. This is about a 50x speedup without an obvious drop in output quality.
The model produces plausible Rust syntax and structure, but semantic consistency is still weak and repetition is common.
Next steps are to run ablations comparing hybrid, local-only, and recurrent-only variants, evaluate earlier checkpoints for generation quality, add code-specific evaluation such as parsing or compilation, and test longer context and BPE tokenization.
I would be interested in feedback on evaluation methods beyond perplexity for small code models, whether hybrid local and recurrent attention has worked well in practice for code generation, and whether further gains at this scale are more likely to come from more data, longer context, or architectural changes.
[link] [comments]
Want to read more?
Check out the full article on the original site