Poster:

https://docs.google.com/presentation/d/1y0-kRhTVF8nXgVid5uOujw1LFdqwdyfN/edit?usp=sharing&ouid=113677000722797735161&rtpof=true&sd=true

Introduction

Mechanistic interpretation of how neural networks perform algorithmic tasks is crucial for understanding how these models implement logic and combinatorial reasoning. Automata provide one model of computation of particular interest. Given a sequence of $T$ inputs $\sigma_t$and initial state $q_0$, an automaton applies a deterministic transition function $\delta(q_{t-1}, \sigma_t)$ to calculate a state sequence $q_t$. We consider finite automata with discrete sets of possible inputs $\sigma$ and possible states $q$. The transition function that underlies an automaton (without mapping states to output) is referred to as a finite semiautomaton. Examples of semiautomata include a parity counter, 1D or 2D Gridworld, cyclic groups $C_n$, symmetric permutation groups $S_n$, and alternating permutation groups $A_n$.

Intuitively, an automaton can be simulated using a recurrent neural network (RNN). The recurrent composition of the RNN cell is a natural match for the sequential application of an automaton’s transition function. The number of computational steps required in this case is therefore $\mathcal{O}(T)$, that is, linearly proportional to the sequence length.

Surprisingly, however, Liu et al. (2022) have shown that Transformers can automata can simulate automata with depth (number of layers) sublinear with the sequence length using parallel shortcut solutions [1]. They demonstrated that Transformers can simulate all semiautomata with $\mathcal{O}(\log T)$ depth and polynomial-width attention and MLP layers. In addition, for certain semiautomata, even shorter shortcuts are possible. In particular, Liu et al. (2022) showed that the Gridworld automaton (described below) can be simulated with a Transformer having only 2 layers. As part of their proof, they provided a specific algorithm that Transformers may use to implement this simulation, described below.

In this project, we investigated the mechanisms by which Transformers simulate automata, in order to determine the extent to which the algorithms implemented by these models resemble the theoretical constructions proposed by Liu et al. (2022) and mechanistically understand how neural networks implement algorithmic tasks more broadly. By examining attention patterns and applying causal scrubbing [2], we found that, at least in some cases, Transformers do appear to implement an algorithm similar to the construction proposed by Liu et al. (2022). However, the internal structure of the model varied dramatically depending on the network architecture and hyperparameters. Finally, our preliminary investigations of information flow via network subspaces and causal scrubbing of positional embeddings suggest promising directions for continuing this work.

Shortcut Solution for 1D Gridworld

The Gridworld$_n$ automaton is defined by $n=S+1$ states arranged in a line, with the output at each sequence position being the current state {0…S}. At each time step, the automaton receives an input token indicating whether to move left or right. It respectively decrements or increments its current state by 1, unless that would cause it to move past a boundary, in which case it stays where it is.

Diagram of 1D Gridworld$_4$, from [1].

Diagram of 1D Gridworld$_4$, from [1].

Formally, the transition function is defined as

$$ \begin{align*} \delta(q, L) = \max(q-1,0)\\ \delta(q, R) = \min(q+1,S) \end{align*} $$

The depth-2 shortcut solution for Transformers for Gridworld$_n$ derived by Liu et al. (2022) consists of two distinct steps. The first layer computes the prefix sum, or what the state would be if there were no boundaries. This is done by placing an amount of attention on an extra start token inversely proportional to the sequence position, and no attention anywhere else. Since the model is causal and softmax normalizes the total attention pattern to 1, the effect is to uniformly sum the values of every preceding token. The second layer operates as a boundary detector which determines the location of the most recent boundary. This is done by finding the minimum or maximum position within the shortest suffix containing $S+1$ distinct prefix sums, which must be the location of the most recent boundary by the intermediate value theorem. The current state can then be determined by computing the prefix sum difference since the most recent boundary.

Illustration of the Gridworld shortcut solution, from [1].

Illustration of the Gridworld shortcut solution, from [1].

Computing the boundary $t_{\mathrm{final}}$ from the most recent n distinct prefix sums, from [1].

Computing the boundary $t_{\mathrm{final}}$ from the most recent n distinct prefix sums, from [1].

This “even shorter shortcut” result is formally stated as (Liu et al. (2022), Theorem 3)

For each positive integer $T$, Transformers can simulate the $(S+1)$-state gridworld semiautomaton with 2 attention layers, where the MLP has either (i) depth $\mathcal{O}(\log S)$, width $\mathcal{O}(T+S)$, or (ii) depth $\mathcal{O}(1)$, width $\mathcal{O}(T) + 2^{\mathcal{O}(S)}$. The weight norms are bounded by $\mathrm{poly}(T)$.

Training Transformers on Automata

In order to obtain a model to interpret, we first trained Transformer models to simulate 1D Gridworld. We found that, consistent with reports in previous work [1], training Transformers to simulate automata can be highly unstable. In particular, the loss curves exhibit steep jumps after which the model never recovers, even with a relatively low learning rate (see figure below). We hypothesize that this behavior may be caused by the discrete, logical nature of algorithmic tasks, such that a partially trained model that makes a single incorrect prediction at a position early in the sequence will make confident but incorrect predictions at all subsequent positions, leading to an enormous loss that distorts training.

Example accuracy curve demonstrating unstable training on Gridworld$_4$.

Example accuracy curve demonstrating unstable training on Gridworld$_4$.

To stabilize training, we introduced gradient clipping in order to limit the damage caused by occasional extremely bad predictions. Exponential moving average smoothing of the weights further stabilized training, and using a learning rate schedule decreasing the learning rate throughout training allowed the network to achieve maximum performance.