Understanding Mamba and selective state-space models as bilinear control systems

After reading the very nice Mamba paper I want to write down some thoughts on different ways of thinking about the selective SSM mechanism.

Selective SSM

A discrete-time state-space model:

\begin{align} x(t+1) &= Ax(t) + Bu(t) \\ y(t) &= Cx(t) \end{align}

where $x\in\mathbb{R}^n$, $u\in\mathbb{R}^m$, $y\in\mathbb{R}^p$. Matrices are of appropriate dimensions.

The “selective SSM” says that $A$ and $B$ are functions of the input. If those functions are linear, then the most general form of a selective SSM would have $A$ and $B$ as,

\begin{align} A(u(t)) &= \text {reshape} _{n^2 \rightarrow n \times n}(Qu(t)) \\ B(u(t)) &= \text {reshape} _{nm \rightarrow n \times m}(Ru(t)) \end{align}

where $Q\in\mathbb{R}^{n^2\times m}$ and $R\in\mathbb{R}^{nm\times m}$. The $\operatorname{reshape}$ operator is just a reshaping of the vector into a matrix and is linear.

But the Mamba paper does a much simplified version which we outline below.

Simplified selective SSM

The Mamba SSM simplifications:

Here, instead of writing $m$-many different $A$s we can just write one big $A$ of size $nm\times nm$.

So let’s reset our notational workspace and rewrite the full selective SSM:

\begin{align} x(t+1) &= A(u(t))x(t) + (B(u(t))\otimes I_m)u(t) \\ y(t) &= (I_m\otimes C)x(t) \end{align}

Or equivalently,

\begin{align} x(t+1) &= \operatorname{diag}(Qu(t)) x(t) + (Ru(t)\otimes I_m)u(t) \\ y(t) &= (I_m\otimes C)x(t) \end{align}

Where, \begin{align} A(t) &= \operatorname{diag}(Qu(t)) \\ B(t) &= Ru(t) \end{align} with $Q\in\mathbb{R}^{nm\times m}$ and $R\in\mathbb{R}^{n\times m}$. And $\otimes$ is the Kronecker product. And $p=m$, same dimensions of inputs as outputs.

The point of writing out equations $(7)-(8)$ is to explicitly write $x(t+1)$ as a function of $x(t)$ and $u(t)$ without the intermediate $A$ and $B$ since $Q$ and $R$ are the real parameters of the system.

For clarity, dimensions above are

Example

It’s easier to see what’s going on by looking at a single coordinate, say $x_1$. We have $mn$-many scalar state variables that are independent of each other.

For notational simplicity let’s assume $m=2$, just two input coordinates. And $n$ is arbitrary.

Then the dynamics of $x_1(t)$ are:

\begin{align} x_1(t+1) = p_{1,1}u_1(t)x_1(t) + p_{1,2}u_2(t)x_1(t) + q_{1,1}u_1^2(t) + q_{1,2}u_1(t)u_2(t) \end{align}

Writing it in this way makes it clear that we’re far from LTI land:

This almost looks like a bilinear control system.

Bilinear control systems

[Pedentry warning] Why do mathematicians care so much about polynomial equations in the space of all ways of writing down algebraic functions? One answer: no matter how many ways you can permute a set of $*$’s and $+$’s you can always distribute and move around terms until you get the canonical form $a_nx^n + \dots + a_1x + a_0$ so might as well just study that canonical form.

In the same way, when you’re given difference equations consisting of additions and multiplications, you want to try and write down everything to get it into some polynomial form as we did above. To my own brain, this helped de-mystify a bit what “selective” state space models are: difference equations with bilinear terms and quadratic-in-input terms.

One reason for trying to write it down this way is because in control theory there’s a whole field of study called bilinear control systems, i.e. systems of the form (resetting notational workspace again):

\begin{align} x(t+1) &= Ax(t) + Bu(t) + \sum_{i=1}^mN_ix(t)u_i(t) \ \end{align}

So our Mamba SSM is almost a bilinear SSM but not quite since it includes quadratic input terms.

Of course we can just call our inputs $u_i(t)u_j(t)$ for all $i,j$ and then the system is indeed linear in this set of inputs, thus matches exactly the form for a bilinear control system.

Continuous time, discrete time

We fibbed a bit. In the Mamba paper they:

But it’s worth noting that in general for a system $x(t+1) = Ax(t)$ it is probably better to learn $x(t+1) = \operatorname{expm}(A)x(t)$. The gradients are likely much better behaved. This is seen elswhere in deep learning, optimize in the log space. The matrix exponent moves the roomy $S$-plane into the crowded $Z$-plane. So the $\operatorname{exp}$ is probably doing good work in stabilizing learning regardless of the interpretation as a discretization step.

Takeaway

If we ignore the $\operatorname{exp}$ then we can rewrite the selective SSM as a bilinear control system consisting only of $u_i(t)x_j(t)$ terms and $u_i(t)u_j(t)$ input terms. Given the rich literature on bilinear constrol systems there may be some useful connections to explore.

References

[1] Elliott, David LeRoy. Bilinear control systems: matrices in action. Vol. 169. Dordrecht: Springer, 2009.

[2] Gu, Albert, and Tri Dao. “Mamba: Linear-time sequence modeling with selective state spaces.” arXiv preprint arXiv:2312.00752 (2023).

Back