I built a Mamba1 variant I call SM1 with d_state=1 that runs on Blackwell in pure PyTorch [P]
On windows mamba-ssm is not easily available and doesn't compile on sm_120. SM1 (Scalar Mamba1) replaces the entire selective scan with two native PyTorch ops:
L = torch.cumprod(dA, dim=1)
h = L * (h0.unsqueeze(1) + torch.cumsum(dBx / L.clamp(min=1e-6), dim=1))
y = h * C
This is the exact closed-form solution to the d_state=1 recurrence via variation of parameters. Not an approximation, it is identical to sequential computation of floating point precision. d_state=2 breaks it. d_state=1 is the boundary where the closed form exists.
The Mamba1 scan intermediates are (B, T, F, S). SM1 eliminates S entirely, there is 16x less scan memory than a Mamba1 with d_state=16. The inference state for a 130M param model is about 14,080 floats, 56 KB, no KV cache, O(1) per token forever.
I am currently training it on 163K MIDI files, which is 2.5B tokens roughly in my custom format. 130M params fits in under half of my 16 GB card which is an RTX 5060 Ti.
[link] [comments]
Want to read more?
Check out the full article on the original site