We find ourselves in a world where we have machines that speak fluently dozens of languages, can do a wide variety of tasks like programming at a reasonable level, and we have no idea how they do it! This is a standard mechanistic interpretability (a.k.a. mech interp) pitch - a field that is trying to express neural networks’ behaviours as human-understandable algorithms, i.e. reverse engineer algorithms learned by a neural network (or a model, in short). The main motivation is that even though we know the exact form of computation being done by the model to transform the input (e.g. text prompt) to the output (e.g. text answer), we don’t know why this computation is doing what it’s doing, and this is a major concern from a standpoint of AI Safety. The model can perform the computation because it’s genuinely trained to perform the task well, or because it learned that doing the task well correlates with its other learned goals like gaining more power and resources. Without understanding the computation, we have no direct way of distinguishing between the two.
The solution proposed by mechanistic interpretability is closely analogous to reverse engineering ordinary computer programs from their compiled binaries. In both cases, we have an intrinsically non-interpretable model of computation - a sequence of binary instructions performed on a string of 0s and 1s, and the (mathematical) function of the neural network’s architecture applied with its learned parameters (weights)
But what makes us think that the same is possible for neural networks, especially the ones as large as the current Large Language Models (LLMs)? In particular, why should we even expect that neural networks solve tasks similarly to humans, and thus adopt the same “variable-centered” model of computation? While the proof-of-existence for the first question appeared relatively early (see Circuits thread by Chris Olah et al.
A feature, in general, is a fuzzy term, and you can find some good attempts to define it here
Why are sparse autoencoders called sparse? It’s actually deeply linked with the idea from the previous paragraph: if you want to use many features in a limited activation space (limited by a number of neurons), you have to exploit the fact that for any input, most of the features will not be there. So given that modern language models are trained to predict a next token in a huge variety of possible inputs, we should expect that any feature learned by the model will be sparse, i.e. it will be used by the model only for a small fraction of all possible inputs.
But wait, how is it even possible for a model to learn input-specific features if it has a low-dimensional activations space (where dimension equals the number of neurons) but a very high-dimensional input space? The answer is superposition - an idea of exploiting feature sparsity to store more features than dimensions in the activation space. It has a rich mathematical background and we invite an unfamiliar reader to learn more about it in the “Toy Models of Superposition” paper by Elhage et al.
Coming back to SAEs, they were introduced with all of these ideas in mind to solve superposition, i.e. to recover more than n features in an n-dimensional activation space of a model. How are they supposed to do it? The answer is once again in the name - autoencoders, which means that SAEs are neural networks with the “autoencoder” architecture, which is illustrated in a diagram below (borrowed from the great Adam Karvonen’s post
So the model activations are “encoded” into a high-dimensional vector of feature activations (top right, note that it always has many more elements than the model’s input), and this high-dimensional vector (a.k.a. “code”) is “decoded” back to reconstruct the input, hence the name “autoencoder”. We advise the reader to take a quick look at the “Towards monosematicity” appendix
where \(f_i(\mathbf{x}) = \text{ReLU}\left( \mathbf{W}_{enc} \mathbf{x} + \mathbf{b}_{enc} \right)_i\) are the feature activations that are computed in the left (“encoder”) part of the diagram, and \(\mathbf{d}_i\) are the rows of the decoder matrix (or columns, if you take the transpose and multiply from the other side). Note that the diagram omits bias vectors \(\mathbf{b}\) for simplicity, but conceptually they don’t change much: instead of decomposing the activation space, we’re decomposing a translation of that space by a fixed vector (because this is just easier for an SAE to learn).
If you think about it, it’s exactly what we hoped to do in an analogy with decomposing program memory into variable names! The variables are now features - vectors (directions) in the activation space. And if the autoencoder is doing a good job at reconstructing the input, we can expect that this decomposition (and hence the features) to make sense!
The last part is tricky though. Unlike variables that are deliberately used by humans to write sensible algorithms, there is no reason to expect that the features we recover with an SAE will be interpretable in a sense that a human can understand on which inputs they activate and can predict their “roles” based on that (e.g. which tokens they help to predict). But this is where the sparsity condition comes in: we don’t only want an SAE to reconstruct the input from a high-dimensional feature-activation representation, but we also want this representation to be sparse, i.e. have only a handful of non-zero feature activations at a time. We already touched on the reason for this - the hope is that we’ll be able to recover the “true” features used by the model in this way
The traditional view in mech interp has been that one cannot interpret the model’s weights if one cannot interpret the neurons that the weights are connecting. But due to the neurons polysemanticity
So what does this mean for AI Safety? We’ll cite the Anthropic team’s view on this topic (layed out in their “Interpretability Dreams”
We’d like a way to have confidence that models will never do certain behaviors such as “deliberately deceive” or “manipulate.” Today, it’s unclear how one might show this, but we believe a promising tool would be the ability to identify and enumerate over all features.
Ultimately we want to say that a model doesn’t implement some class of behaviors. Enumerating over all features makes it easy to say a feature doesn’t exist (e.g. “there is no ‘deceptive behavior’ feature”) but that isn’t quite what we want. We expect models that need to represent the world to represent unsavory behaviors. But it may be possible to build more subtle claims such as “all ‘deceptive behavior’ features do not participate in circuits X, Y and Z.”
Summarizing, the hope is to be able to prove statements of the following form:
After outlining the procedure behind SAE-interpretability, we can answer a more general question: why is it relevant to translate the matrix language of neural networks (not more understandable to us than binary code) into a human-readable algorithmic language? There are several reasons, but, among the others, once we are able to do so, we can understand what features of an input a model identifies before predicting an answer. This can allow us to identify when a model is learning to deploy features spuriously correlated with the actual labels (an intuitive example here
Nevertheless, reality is often rougher than abstraction, and mechanistic interpretability suffers from one big problem: once we crack the interpretation of a model, we are only able to decode what is going on inside a singular, particular model, and not all models with the same architecture and different weights. Luckily, to have a model that shows emergent abilities, we need a lot of compute
Is the interpretability of a model as weak as alignment to finetuning?
In this post, we try to answer these questions and extend the positive results derived from a similar study by Kissane et al.
Lastly, we want to remark how this kind of study derives its importance from the weakness of outer alignment forced by some ad-hoc finetuning. Indeed, if interpretability is more resistant to being broken than alignment, the path towards AI safety could be reached via digital neuroscience
In choosing finetuned models to work with, we tried to strike a balance between the potential relevance of these models (how many people will actually use similar models), and the availability of pre-trained SAEs from the SAELens
We then loaded the following SAEs for these models from SAELens (SAE layer numbering starts from 0):
| Model | SAE Release | SAE Layer | N Features |
|---|---|---|---|
| Gemma-2b (v1) | gemma-2b-res-jb by Joseph Bloom | Residual layer #6 | 16384 |
| Mistral-7B (v0.1) | mistral-7b-res-wg by Josh Engels | Residual layer #8 | 65536 |
Two important things to note:
Similarly to what Kissane et al.
Prior to evaluating the SAEs’ performance, we computed different similarity metrics for residual stream activations at the specific layer our SAEs are used for. The goal was to obtain some sort of a prior probability that our SAEs will transfer to the finetune model: the more similar the activations are, the higher is the (expected) probability that our SAEs will transfer. On the one hand, this analysis can be used as a first step to select a fine-tuned model from the thousands available on Hugging-Face. On the other hand, further studies can try to analyze whether the phenomenon of SAE transferability actually correlates with the difference between activations of the base and fine-tuned models (which we treat here only as an unproven heuristic).
Designing rigorous approaches to evaluate the SAEs’ performance is an open problem in mechanistic interpretability. The main complicating factor is that we’re interested not so much in the SAEs reconstructed output, but rather in the SAE feature activations and feature vectors. However, measuring whether the SAEs features are interpretable or whether the features “are truly used by the model” is not straightforward. For our work, we’ll just start with computing standard evaluation metrics proposed either in the original “Towards monosemanticity” paper, or used in the later work, e.g. this one by Joseph Bloom
We’ll compute these metrics first for the base model and its SAE to get a baseline, then for the finetuned model with the same SAE, and compare the resulting metrics against the baseline
Based on the feature density histograms, we additionally zoomed in on individual features to see how well they transfer using feature activation similarity and logit weight similarity
Before analyzing the SAE metrics on the finetuned models, we will visualize some easier computations on the residual activations (at the residual stream of the layer where we apply the corresponding SAE) to get a sense of the SAE transferability. Specifically, we are interested in the similarities between the base and finetuned model activations. We consider two metrics: the Cosine Similarity and the Euclidian Distance, for the model and datasets specified above with the Gemma-2b Python-codes and Mistral-7b MetaMath
Computing the Cosine Similarities and Euclidian Distances of the activations yields a tensor of shape [N_BATCH, N_CONTEXT] (each token position is determined by its batch number and position in the context). A simple metric to start with is to consider the global mean of the Cosine Similarities of the activations across both batch and context dimensions, giving a single scalar representing the overall similarity. This can be seen in the following table:
| Model/Finetune | Global Mean (Cosine) Similarity |
|---|---|
| Gemma-2b/Gemma-2b-Python-codes | 0.6691 |
| Mistral-7b/Mistral-7b-MetaMath | 0.9648 |
This already suggests much better transferability of the Mistral-7b SAE for its MetaMath finetune. For a more fine-grained comparison, we flatten the similarities into a N_BATCH * N_CONTEXT vector and plot the histogram across all tokens:
Gemma-2b - Cosine Similarity Histogram
Mistral-7b - Cosine Similarity Histogram
Gemma-2b - Euclidian Distance Histogram
Mistral-7b - Euclidian Distance Histogram
We can see how the Cosine Similarities for Mistral-7b are concentrated around a value close to 1, whereas the Gemma-2b similarities are more spread around the mean of 0.66 (higher variance). The Euclidian Distances histogram shows a similar distinction, with the Gemma-2b distances being spread around a mean of around 120, while the bulk of Mistral-7b distances stay at a low value.
We also visualize the per-context mean of Cosine Similarities and Euclidian Distances. We compute the mean across batches but preserve the context dimension, giving a tensor of shape [N_CONTEXT], which reflects how similarity changes over the context length.
Gemma-2b - Cosine Similarity Context Line
Mistral-7b - Cosine Similarity Context Line
Gemma-2b - Euclidian Distance Context Line
Mistral-7b - Euclidian Distance Context Line
In the above, we can see how the similarities and distances stabilise quickly after a few tokens of context, albeit around different values. Both models start with close to 1 similarity for the first token, and then stabilize after a few tokens.
These results already anticipate a considerable difference in the transferability of the SAEs for the two models, which will be explored more in-depth in the following section.
In this section, we’ll compute a set of standard SAE metrics for base and finetuned models, using the same base SAE in both scenarios (i.e., the SAE that was trained on the base model activations):
Before delving deeper into the results, we want to point out three technical details:
Similarly to Kissane et al.
SAEs fail to reconstruct activations from the opposite model that have outlier norms (e.g. BOS tokens). These account for less than 1% of the total activations, but cause cascading errors, so we need to filter these out in much of our analysis.
To solve this problem we used a similar outlier filtering technique, where an outlier is defined as an activation vector whose (L2) norm exceeds a given threshold. We tried several ways to find a “good” threshold and arrived at values similar to those used by Kissane et al:
Using these threshold values, we found that only 0.24% activations are classified as outliers in the Gemma-2b model, and 0.7% in the Mistral-7B, agreeing with the Kissane et al. result that these outliers account for less than 1% of activations. It should be noticed, however, that we only used this outlier filtering technique for our reconstruction loss & variance explained experiments to avoid numerical errors. In practice, it means that for this experiment the true sample size was a little smaller than for the other experiments, equal to \(\left( 1 - \text{outlier_fraction} \right) \times 256{,}000\) with the \(\text{outlier_fraction}\) defined above.
In the following table, we report the results for the first experiment with the Mistral model pair:
| Model\\Metric | L0 Loss | Clean CE Loss | Reconstruction CE Loss | Loss Delta | $$R^2$$ Score (Variance Explained) | Dead Features (%) |
|---|---|---|---|---|---|---|
| Mistral-7B | 83.37 | 1.78 | 1.93 | 0.15 | 0.68 | 0.76% |
| Mistral-7B MetaMath | 90.22 | 1.94 | 2.1 | 0.16 | 0.58 | 0.64% |
As you can see, the L0-Loss of the features and variance explained increase a bit, but the reconstruction loss delta is almost the same! It suggests that our Mistral SAE may still transfer after finetuning, although with a slightly worse reconstruction quality. Let’s compare these results with the Gemma-2b and its Python finetune:
| Model\\Metric | L0 Loss | Clean CE Loss | Reconstruction CE Loss | Loss Delta | $$R^2$$ Score (Variance Explained) | Dead Features (%) |
|---|---|---|---|---|---|---|
| Gemma-2b Base | 53.59 | 2.65 | 3.16 | 0.51 | 0.97 | 48.1% |
| Gemma-2b Python-codes | 84.74 | 3.29 | 7.5 | 4.21 | -10.27 | 0.1% |
Now, this is what bad SAE transferability looks like! But actually this should come as no surprise after the Kissane et al.
Here we show that the weights for Gemma v1 2B base vs chat models are unusually different, explaining this phenomenon (credit to Tom Lieberum for finding and sharing this result):
But what effect does this have on the SAE features? Well, we could expect that if an SAE is no longer able to reconstruct the input activations, it will always “hallucinate” - any features it “detects” will not make any sense. Let’s see if this expectation holds in practice for the Gemma-2b model.
We’ll start with the feature activations histogram plot. In general, this kind of histogram gives little insight since you will always have a large mode at 0 due to feature sparsity, and some kind of log-normal distribution at non-zero activations. Indeed, this is what happens in the base Gemma-2b model, when we plot its log10 feature activations histogram:
Two things to note:
With this in mind, let’s compare it with the same kind of histogram for our Gemma-2b finetune (where the features are given by the same SAE):
If that’s not a characterization for “cursed”, we don’t know what is! Instead of a nice bell curve, we now have some sort of a 3-mode monster in the non-zero activations section. To be clear - nothing like that was present when we repeated this experiment for the Mistral-7B: we obtained the well-expected bell curves with similar mean and standard deviation for both base and finetuned models. We don’t have a good explanation for this Gemma-2b anomaly, but we’ll try to give some deeper insight into what happens with the SAE features in the next section.
Let’s move on to the feature densities plot, which was produced as described in the Evaluating SAEs Performance section. Starting from Gemma-2b:
As expected from the above results, the two plots have little in common. We see that most of our dead features (in the base model) turn alive in the finetuned one! To see where exactly these dead feature densities land in the finetuned model (what are their new densities), we also made a parallel coordinate plot (below we show two versions of the same plot: with different density ranges highlighted):
So it looks like the dead features spread out quite widely in the finetuned model, contributing to more probability mass before the -3 log-density. As for the dense features (-4 to -1 log density) in the base model, their density interval gets squeezed to (-3, -1) in the finetuned model, causing a sharp mode near the -2.5 log-density value.
We’ll continue the Gemma-2b investigation in the next chapter, and conclude this section with the Mistral-7B feature density histograms:
We can see that for Mistral the feature densities distribution almost doesn’t change after the model finetuning! The only slight difference is in the number of dead features: the finetuned Mistral has around 80 dead features less than the base one. To zoom in closer, we also show the parallel coordinate plot:
So yes, a small number of features do turn alive, but also some features (even a smaller amount) turn dead in the finetuned model! Overall though, the feature densities look very similar, with the Pearson correlation of their log10 densities equal to 0.94 (versus 0.47 for the Gemma-2b case).
We want to motivate this section with a more thoughtful consideration of the question what is the best way to operationalize SAE transferability. In the previous section, we simply checked the standard SAE evaluation metrics to see how well they reconstruct the activations. But this doesn’t necessarily reflect the main goal of using SAEs - interpreting the model.
As noted in the SAE features for AI Safety section of our post, the end goal of using SAEs for interpretability is to be able to use features as the basis for circuit analysis. And if we assume that some kind of circuit analysis has been done for the base model to prove that it doesn’t implement certain undesirable behaviors, the most ambitious operationalization of SAE transferability (for AI Safety) would be the ability to apply the same kind of circuit analysis with the same SAE (or the finetuned one) to prove or disprove that the finetuned model is safe.
In our case of studying transferability “by default”, the better way to demonstrate it is to show that our SAE features “stay relevant” in the finetuned model, so that we can expect that they still potentially serve as the basis for circuit analysis. Showing this rigorously would be a really difficult task (partly because there’s no standard way to do circuit analysis in the SAE basis yet) and it’s out of scope for this blog post. What we did instead is apply an approach from Towards Monosemanticity
This can be seen as a (very) rough proxy for “the feature is doing the same job in the finetuned model”, and we call it the “feature transferability test”.
Conceptually, dead features are completely different from the ordinary features: as explained by Joseph Bloom
By dead features, we mean features that are exclusively dead (never activating across our entire 256K sample of tokens), i.e. dead only in one of the models:
We observe that only a handful of features are dead in both models, so we think our definitions give more information on what we’re analysing.
Then, our approach for the rest of this section looks as follows:
1e-10 as the epsilon value to avoid a log of zero), take the Mean Absolute Error of the resulting vectors and call it activation errorOne could say that in the Gemma-2b case, it’s obvious from the previous results that our SAE doesn’t transfer. But we could imagine a case where some (perhaps a tiny fraction) of our SAE features from the regular density interval do still transfer, so we decided to conduct this experiment anyway.
Starting with the features that are exclusively dead in the base model, their mean activation error for Gemma-2b and Gemma-2b python-codes finetune is 0.025. A histogram of these 100 activation errors is given below:
This made us think that “dead features turning alive” anomaly is not so much of an anomaly, because the dead features activate only (very) slightly in the finetuned model. The max activation value across all 100 dead features in the finetuned model was 1.1, indicating that our “dead feature direction” is only slightly off in the finetuned model, and can be easily adjusted by SAE finetuning.
As for the features that are exclusively dead in the finetune model, Gemma-2b had only two of them on our sample, with the activation error equal to 0.34 and 3.19, which is considerably higher than in the previous case.
Moving on to the regular features, we expected to see a much more drastic dissimilarity of their activations. Indeed, the mean activation similarity for our sample of Gemma-2b regular feature is 0.39. Let’s check the histogram of these similarity scores:
Interestingly, we see that a small fraction of features (~10%) have an activation similarity above 0.8! This implies that if these features were interpretable in the base model, they will most likely stay interpretable in the finetune model
As for the logit similarity of these regular features, it turns out it’s much higher than our activation similarity, with a mean value of 0.952. Looking at the logit similarity scores histogram, it’s also much more concentrated towards the end of the interval:
However, it’s easy to be misled by the mean logits similarity score. What it’s really saying is that our unembedding matrix (which is multiplied by the feature direction to get the logits similarity) hasn’t changed that much after finetuning (with a Frobenius norm ratio equal to 1.117 as we checked for our Gemma finetune). So if the feature has still the same direction, we can indeed say that the “direct feature effect” hasn’t changed in the finetuned model, but we never checked this premise! All we know is that there exist ~10% of features which have reasonably high activation similarity scores with the features from the base model. The key point is that the latter is a statement about the feature’s encoder direction (one that is used to project onto to get the feature’s activation, explained by Neel Nanda here
Here we repeat all the same experiments for Mistral-7B and its MetaMath finetune, and compare the result with the Gemma-2b case.
Let’s start with the features that are exclusively dead in the Mistral base model. Their mean activation error is 0.0003, which is almost two orders of magnitude lower than in the Gemma-2b case. The corresponding histogram looks like this:
Once again, the results suggest that even though the dead features in the base model are no longer dead in the finetuned one, they activate really weakly on average, so it should be easy to adjust them with a cheap SAE finetuning.
The activation error for the features exclusively dead in the finetuned model tells a similar story:
Here the error is even smaller, implying that even though some features stopped activating after finetuning, their corresponding activation values in the base model were really low. And the features are often uninterpretable in the lowest activation intervals anyway, so it should have a minor overall effect on SAEs transferability.
Let’s conclude this section with an analysis of our regular features. As expected from the results of the last section, the activation similarity of these features is quite high, with a mean value of 0.958. As for the activation scores histogram:
As we can see, the distribution of the scores is strongly attracted to the 0.9-1.0 correlation interval, so we can conclude that SAE feature transferability is significantly high in this case. This is also backed up by the mean logits similarity of 0.9996, and a rather straightforward logits similarity histogram:
Going back to our original question of “Do SAEs trained on a base model transfer to the finetuned one?”, the most obvious answer that comes to mind now is - it depends! We got drastically different results for our Gemma-2b-python-codes and Mistral-7B-MetaMath finetunes. However, it seems possible that one could estimate the “degree of transferability” in advance. One method is to compute various weight deviation metrics, such as the one used by Kissane et al
Another takeaway we’ve had after finishing this post is that “SAE transferability” can mean different things. One can utilize the standard SAE evaluation metric to get a high-level evaluation of the SAE quality on the finetuned model, but it doesn’t always give a deeper insight into what happens with the SAE feature once we zoom in (which may be more interesting for the real SAE applications in mech interp). Our Gemma-2b results suggest that some SAE features may still be interpretable, even when finetuning has completely rendered the SAE incapable of reconstructing the input. And although the significance of this result can be rightly questioned, we still think it is interesting to investigate further.
The main limitations we see in our work are the following:
All code is available on github