From my Notion page on Adaptive entropy sparsity
(This is a very crude first draft. There are no exact details or code to go along with this since I wanted some understanding of the abstraction first. )
I will avoid timestamping this for obvious reasons. But here is the vague idea for now just to have something to send to Wowo. You start by taking the entropy and define something like
\[ \zeta _\alpha (S)=\begin{cases} \zeta {\text{min}}\equiv 0 & \text{High entropy = softmax} \\ \mathcal{D}(S) & \mathcal{S}_{0}\leq S \leq \mathcal{S}_{\text{max}} \\ \zeta _{\text{max}} & \text{Low entropy = sparsemax} \end{cases}\;, \]
where $\mathcal{D}(S)$ is a function of entropy that ranges from $\zeta _{\text{min}}=0$ to $\zeta _{\text{max}}$. This is just being bound between the usual $\mathrm{softmax}$ and $\mathrm{sparsemax}$ functions, so basically we are saying that if we have high entropy, we want lower sparsity and if we have lower entropy, we want higher sparsity. This is essentially to employ a version of magnitude-clipping or pruning but the more stronger claim I want to make is that doing this entropy-dynamic sparsity increases the $\text{effective correlation}$ of the model, by reducing the correlations between irrelevant and relevant inputs. In most base models this is not exactly an issue, but arguably in the efficiency vs performance trade-off, maximizing effective correlation is possible only if we don’t restrict ourselves to dense attention weights, meaning that $\mathrm{softmax}$ is not very efficient in this sense.
While I have not come up with a very formal definition of $\mathcal{D}(S)$ and how exactly this model gets implemented, here is a reason why this whole entropy-based sparsity does not have to be that adaptive, and here is where $\mathrm{varentropy}$ comes in, deciding whether $\zeta (S)$ has to be a continuously adaptive parameter. In calculating the metric attentions, we find $\mathrm{entropy}$ as well as $\mathrm{varentropy}$. This decides whether the sparsity factor is dynamic or fixed; low $\mathrm{varentropy}$ bounds are exactly $\mathrm{softmax}$ and $\mathrm{sparsemax}$. Higher varentropy implies that the bounds to $\mathcal{D}(S)$ changes, although still bounded by the $\mathrm{softmax}$ and $\mathrm{sparsemax}$ functions. I am still working on how exactly to see this relation and give a mathematical formulation for it.
The next thing I will mention here is the order of complexity for the Transformer model with adaptive entropy sparsity. The usual vanilla Transformer model has an order of complexity that is quadratic in the context length, $O(n^{2})$, whereas for a sparsity model, we can obtain a much better complexity order of $O(n\sqrt{n})$. I am not exactly sure if the fixed-ness of the model affects this complexity better or worse in terms of efficiency but on an all, the complexity (at least theoretically) is not very fixed in terms of sparsity and therefore depends on $\zeta (S)$:
$$ O(\mathrm{complexity})\propto \frac{1}{\zeta (S)}\;. $$
When sparsity is higher, the complexity typically decreases since it has to tend to fewer tokens in the token pool $\mathcal{T}_{i}$. This depends on the overall varentropy metric calculations, and so we arrive at the following theoretical complexity for the AES:
$$ O(\mathrm{complexity})\propto \frac{1}{\mathrm{varentropy}}\;. $$
It would be nice to actually run a model with AES and see how it changes the complexity and the validations. For now, we still are left with the question of what $\mathcal{D}(S)$ exactly is. Remember though that this function is a continuous function taking values between the $\mathrm{softmax}$ with $\zeta (S)=0$ and $\mathrm{sparsemax}$ when $\zeta (S)=\zeta _{\mathrm{max}}$. This definition bounds the entropy sparsity parameter between either extremely dense or extremely sparse attention weights. I used this definition to agree with the $\alpha$ -entmax function, where $\alpha =1$ is $\mathrm{softmax}$ and $\alpha =2$ is the $\mathrm{sparsemax}$ function. This bounding may be changed if appropriate, since the $\alpha$ in our context is a similar learnable parameter of the model that is optimized with stochastic gradient descent. Note that this is purely theoretical; in a real-time scenario it is very possible that this sparsity does not even account for any efficiency curve benefits.
One final thing I will mention is that there is the possibility of a $\text{sampling redundancy}$ in using entropy for the sparsity model. This redundancy arises when we comb through $\mathcal{T}_{i}$ enough to make the token pool reduce TO BE COMPLETED
And now, to the main objective: a mathematical description for this effective correlation. My first thoughts were to describe the partition
$$ Z=\sum \exp \left( \mathbf{x}_{i}^{\textsf{T}} \right)\;, $$
where $\mathbf{x}_{i}^{\textsf{T}}$ is $x_{i}/T$, where $T$ is the temperature. This is just an observation from the $\mathrm{softmax}$ function and isn’t really anything special. Until you identify the $\text{free energy}$ as the negative of $\log Z$:
$$ F=-\log Z\;. $$
Then we write the $n$-point correlators or moments or whatever (sorry I’m too QFT brained rn) are the expansions like
$$ \langle x_{1}\dots x_{n}\rangle =\frac{1}{Z}\frac{\partial^{n} Z(\xi )}{\partial \xi _{n}\dots \partial \xi _{n}}\;. $$
The $1$-point correlator $\langle x_{1}\rangle$ gives the free energy, and looks like
\[ F=-\log \sum _{a} \underbrace{\exp \left( qk^{\textsf{T}}{a}+\xi x^{\textsf{T}}_{a} \right)}_{\text{Partition function}}\;. \]
This is nothing new and just straight-up comes from Vasu’s paper Tree Attention: Topology-aware Decoding for Long-Context Attention on GPU clusters and the Yann LeCun et al paper on Energy-based models. I think there are some interesting correlations between these two approaches wrt entropy sampling and possibly extending it to something like the Helmholtz free energy. When computing these terms in LLMs, it might be helpful to specifically work with these quantities specifically in different settings. As an example, the $\log Z$ term might be useful in making the sampling uncertainty a more precise thing, see for instance Yann LeCun et al’s paper to see what the possible relation is. I will update this with a few more thoughts in a while.
Note of caution
I have to say though, that most of this is purely theoretical at the moment and has not been tested anywhere personally. Not to mention that all of this is VERY crude and a lot of refinements will be necessary to obtain a base working model for the entropy sparsity model I mentioned. Or maybe not, or maybe there is something key that I am missing altogether. There may and mostly will be mistakes, do not assume otherwise. But I will try to find a proper formalism to the partition function thing in the next two hours hopefully. Or maybe possible that I will update the GH \doc draft and make it public.
1018 hrs local time
Update 1: The free energy is not free
So writing $Z$ a little more formally (taking $\beta =1/T$):
$$ Z=\sum \exp \left( -\beta E \right)\;, $$
we could write the free energy in terms of $Z_{\beta }$ and $S$. And then rewrite our entire model in terms of $F$. But the point is that the entropy based sampling in itself suffices for the most part; there is no need to go through all this pain of computing $Z_{\beta }$ or do anything even remotely related to EBMs. The worth of doing this free energy stuff is that we can adopt an EBM approach to sampling and obtain better optimization routines in terms of $\nabla F$ and the uncertainties in the pool can be made dynamic from this. But I have no stringent way to sell this idea yet. Emphasis on the yet.
No comments:
Post a Comment