Streaming from Flash Attention

💡

Hire: Complete Parts A, B and the first part of C. You should be able to apply the insights from Flash Attention to the streaming mean or, with guidance, come to the same insight through discussions.

Strong Hire: Complete Parts A, B, and C. You should be able to recognize the core insight from either the paper or streaming mean, and be able to extend that idea to all parts of C.

Before attempting this question, you should study the Flash Attention paper. Prepare for it as though your interviewer had told you in advance that this was going to be the subject of your interview.

This question won't ask you for details in the paper, and instead, it will test your understanding of paper's core insights in a hands-on sort of way.

A. Flash Attention Paper

We can start by asking questions about what the core insights of the paper were.

What were the core insights of the Flash Attention paper?

Most people should understand the selling point from its abstract: Fewer memory reads and writes. However, a more subtle contribution to make that work, was making a streamed version of softmax. If the candidate recognizes both of these insights, this is a big green flag. However, I don't expect most people to catch on to both, and as a result, the rest of this question dives into the second part. For more details, see How Flash Attention works.

How does Flash Attention reduce memory reads and writes?

In short, Flash Attention's main contribution is to avoid storing the full set of attention scores $QK^T$. That tensor is a full n_heads x seq_len x seq_len, which is humongous, especially for long contexts. To accomplish this, Flash Attention fuses two back-to-back matrix multiplies. Candidates should be able to give a summary along these lines. Bonus points if they understand tiling, and how Flash Attention fuses tiling for two consecutive matrix multiplies. For more details, see When to fuse multiple matrix multiplies.

💡

These questions are just for checking off a box. I would likely not correct the candidate very much if they made a mistake, because the bulk of the question is in the coming parts — if they can complete the remaining parts, they would have derived one of the paper's core insight on their own anyways.

B. Streaming Dot Product

Same we have two streams of numbers, and we'd like to compute a streaming dot product.

How long can these streams be?

Let's consider them to be infinitely long, and let's also trust that Python can handle infinitely-large integer values.

Will the streams ever contain numbers?

Let's assume that all values will be valid floats. No NaN, inf, or imaginary numbers.

Do we have a bound on the value of the final product? What if it overflows the datatype its stored in?

For simplicity, let's assume Python can hold infinitely-large numbers. We don't have a bound in advance.

Pick a representation for the stream of numbers. Hint: You cannot use a data structure to hold all values explicitly.

You can use an iterator or a generator. In my own solutions, I'll opt for a generator, but either works.

import random

def make_stream():
    while True:
        yield random.random() / random.random()  # unknown min and max

a = make_stream()
b = make_stream()
Implement a streaming dot product, and design an interface for passing the stream in.

Note that you should also be streaming the output back out. Make sure your own function also returns an iterator or generator.

def streaming_dot_product(a_stream, b_stream):
    total = 0
    for a, b in zip(a_stream, b_stream):
        total += a * b
            yield total

C. Streaming Standardized Sum

Now, we'd like to standardize and sum values from a stream.

How would you compute the streaming mean?

This is a common interview question, so you may already have memorized the answer. If you haven't, the interviewer can work with you to come up with the answer.

Algebraically, we can simply express the new mean, $\mu_{k+1}$ as a function of the old mean $\mu_k$. Let's try and do that now. Here's our expression for the mean.

$$\mu_{k+1} = \frac{1}{k+1}\sum_{i=1}^{k+1}x_i$$

First, let's handle the number of terms. We have $k+1$ terms but want $k$ terms. Isolate the $k+1$th term so our summation runs from $i=1$ to $k$.

$$\frac{1}{k+1}\sum_{i=1}^k x_i + \frac{1}{k+1}x_{k+1}$$

Now, let's handle the denominator. Multiply by $1 = \frac{k}{k}$.

$$\frac{k}{k+1}\underbrace{\left(\frac{1}{k}\sum_{i=1}^k x_i\right)}_{\mu_k} + \frac{1}{k+1}x_{k+1}$$

Finally, we can recognize and substitute $\mu_k$ in.

$$\mu_{k+1} = \frac{k}{k+1}\mu_k + \frac{1}{k+1}x_{k+1}$$

In the final form above, we have two main pieces:

  1. In that first term, we have a multiplicative correction factor. Specifically, we "replaced" the denominator in the old mean, by multiplying the old mean by $\frac{k}{k+1}$.
  2. In that second term, we have the latest member of the new mean's summation, for $x_{k+1}$.

To accomplish this in practice, always keep a copy of the count so far, $k$. That allows us to compute the correction factor and properly divide the new samlpe $x_k$.

def streaming_mean(x_stream):
    count = 0
    mean = 0
    for x in x_stream:
        new_count = count + 1
        mean = mean * (count / new_count) + x / new_count
        count = new_count
        yield mean
How would you compute the streaming variance?

This question can be broken down into two parts.

  1. Handle the streaming count. Luckily, just like in the streaming mean, we can apply a multiplicative "correction factor" by multiplying $\frac{k}{k+1}$.
  2. To handle the streaming mean, we can apply a similar idea. In the previous bullet point, we "replaced" the denominator using a multiplicative correction factor. Now, we "replace" the mean that was subtracted using an additive correction factor.

This additive correction factor takes some explaining though. First, start with the formulation for variance. Just like before, our ultimate goal is to represent $\sigma_{k+1}^2$ as a function of $\sigma_k^2$.

$$\sigma^2_{k+1} = \frac{1}{k+1}\sum_{i=1}^{k+1} (x_i - \mu_{k+1})^2$$

Again, just like before, we'll follow the same pattern of isolating terms and applying a correction. Start by isolating the $k+1$th term.

$$\frac{1}{k+1}\sum_{i=1}^{k} (x_i - \mu_{k+1})^2 + \frac{1}{k+1}(x_{k+1} - \mu_{k+1})^2$$

For readability, let's define that $k+1$th term to be $c_{k+1}$.

$$\frac{1}{k+1}\sum_{i=1}^{k} (x_i - \mu_{k+1})^2 + c_{k+1}$$

Like before, let's also apply a multiplicative correction factor, by multiply by $1 = \frac{k}{k}$.

$$\frac{k}{k+1} \frac{1}{k}\sum_{i=1}^{k} (x_i - \mu_{k+1})^2 + c_{k+1}$$

We have a big problem: The $\mu_{k+1}$ term should be $\mu_k$, so let's add $0 = \mu_k - \mu_k$, which gives us the following expression

$$\frac{k}{k+1} \frac{1}{k}\sum_{i=1}^{k} [\underbrace{(x_i - \mu_k )}_{\delta_{ik}} - \underbrace{(\mu_{k+1} - \mu_k)}_{\delta_k}]^2 + c_{k+1}$$

Expand that massive quadratic, substituting $\delta_{ik} = x_i - \mu_k$ and $\delta_k = \mu_k - \mu_{k+1}$.

$$\frac{k}{k+1} \left(\frac{1}{k}\sum_{i=1}^{k} \delta_{ik}^2 + 2\frac{1}{k}\sum_{i=1}^{k}\delta_{ik}\delta_k + \frac{1}{k}\sum_{i=1}^{k}\delta_k^2 \right)+ c_{k+1}$$

There are a few simplifications we can make:

  • The first term is just old variance $\frac{1}{k}\sum_{i=1}^{k} \delta_{ik}^2 = \frac{1}{k}\sum_{i=1}^k (x_i - \mu_k)^2 = \sigma_k^2$.
  • The third term is independent of $i$ so it reduces to $\frac{1}{k}\sum_{i=1}^k \delta_k^2 = \delta_k^2$.
  • The second term reduces to zero, because of the following.

$$2\frac{1}{k}\sum_{i=1}^{k}\delta_{ik}\delta_k = 2\frac{1}{k}\sum_{i=1}^k (x_i - \mu_k )\delta_k = 2\delta_k \frac{1}{k}\sum_{i=1}^k (x_i - \mu_k) = 2\delta_k (\mu_k - \mu_k) = 0$$

Knowing the above, we can now simplify the original expression to become the following.

$$\sigma_{k+1}^2 = \frac{k}{k+1}(\sigma_k^2 + \delta_k^2) + c_{k+1}$$

Let's plug in $\delta_k = \mu_k - \mu_{k+1}$ and $c_{k+1}$ again.

$$\sigma_{k+1}^2 = \frac{k}{k+1}[\sigma_k^2 + (\mu_{k+1} - \mu_k)^2] + \frac{1}{k+1}(x_{k+1} - \mu_{k+1})^2$$

Notice the structure of our final formulation is very similar to the mean update's, from before. Notably, we have familiar elements:

  1. In the first term, we have a multiplicative correction factor $\frac{k}{k+1}$, just like in the mean.
  2. In the first term, we also have a new additive correction factor $(\mu_{k+1} - \mu_k)^2$.
  3. In the second term, we have the latest member of the new variance's summation, for $x_{k+1}$, just like in the mean.

So, after all this, crazily enough, the only significant difference between the streaming mean and streaming variance is just the addition of this new, additive correction factor! Let's code this up.

def streaming_variance(x_stream):
    mean = 0
    count = 0
    variance = 0
    for x, new_mean in zip(x_stream, streaming_mean(x_stream)):
        new_count = count + 1

        # correct the old mean with the new one
        variance = variance + (new_mean - mean)**2.

        # correct the count in the denominator
        variance = variance * (count / new_count) + (x - new_mean)**2. / new_count

        count += 1
        mean = new_mean
        yield variance

💡

Most interviews will probably end around here. Even if the streaming mean was a short discussion, the streaming variance will have taken time to derive and explain — even if you've seen this previously and already knew the answer.

How would you now construct a function that standardizes a stream?
def streaming_standardize(x_stream):
        mean_stream = streaming_mean(x_stream)
        var_stream = streaming_variance(x_stream)
    for x, mean, variance in zip(x_stream, mean_stream, var_stream):
        yield (x - mean) / variance
How would you now compute the streaming standardized sum?
def streaming_dot_product(x_stream):
    total = 0
    old_var = 1
    old_mean = 0
    mean_stream = streaming_mean(x_stream)
        var_stream = streaming_variance(x_stream)
    for x, mean, var in zip(a_stream, mean_stream, var_stream
        total = ((total * old_var) + old_mean - mean) / var + (x_stream - mean) / var
            yield total