Comparing prediction trajectories#

A prediction trajectory is the set of latent predictions produced by running a lens against each layer of a model. This process creates a sequence of distributions over the next token that in general become more accurate the high in the model they are sourced from. You can think of these distributions as the best guesses that can be made about the final token distribution from by the lenses’ affine translator for that layer.

Since we generally care about more than just one token the sequence of predictions is represented as a 3 dimensional tensor we call the prediction trajectory. This tensor has the shape (num_layers x sequence_length x vocab_size). These distributions are typically stored in log space for numerical precision reasons.

In order to start visualizing and playing with prediction trajectories we will first need to load our model and lens.

[1]:
import torch
from tuned_lens.nn.lenses import TunedLens
from transformers import AutoModelForCausalLM, AutoTokenizer

device = torch.device('cpu')
# To try a diffrent modle / lens check if the lens is avalible then modify this code
model = AutoModelForCausalLM.from_pretrained('EleutherAI/pythia-160m-deduped')
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-160m-deduped')
tuned_lens = TunedLens.from_model_and_pretrained(model, map_location=device)
tuned_lens = tuned_lens.to(device)

Now lets prepare some interesting text to examine. Here we will use a quote from Tolkien that has some nice repetition. It’s also common enough that it was probably in the training data so modifying it will hopefully let us see some conflicts between the model’s parametric knowledge and it’s in context learning.

[2]:

input_ids_ring = tokenizer.encode( "One Ring to rule them all,\n" "One Ring to find them,\n" "One Ring to bring them all\n" "and in the darkness bind them" ) input_ids_model = tokenizer.encode( "One Model to rule them all,\n" "One Model to find them,\n" "One Model to bring them all\n" "and in the darkness bind them" ) targets_ring = input_ids_ring[1:] + [tokenizer.eos_token_id] targets_model = input_ids_model[1:] + [tokenizer.eos_token_id]

Let’s validate that the tokenizations line up and this is indeed going to be a one token substitution.

[3]:
print(tokenizer.convert_ids_to_tokens(input_ids_ring))
print(tokenizer.convert_ids_to_tokens(input_ids_model))
['One', 'ĠRing', 'Ġto', 'Ġrule', 'Ġthem', 'Ġall', ',', 'Ċ', 'One', 'ĠRing', 'Ġto', 'Ġfind', 'Ġthem', ',', 'Ċ', 'One', 'ĠRing', 'Ġto', 'Ġbring', 'Ġthem', 'Ġall', 'Ċ', 'and', 'Ġin', 'Ġthe', 'Ġdarkness', 'Ġbind', 'Ġthem']
['One', 'ĠModel', 'Ġto', 'Ġrule', 'Ġthem', 'Ġall', ',', 'Ċ', 'One', 'ĠModel', 'Ġto', 'Ġfind', 'Ġthem', ',', 'Ċ', 'One', 'ĠModel', 'Ġto', 'Ġbring', 'Ġthem', 'Ġall', 'Ċ', 'and', 'Ġin', 'Ġthe', 'Ġdarkness', 'Ġbind', 'Ġthem']

Now lets generate a prediction trajectory to examine the third line in tolken’s epigrame; That line is consists of tokens [14, 21].

[4]:
from tuned_lens.plotting import PredictionTrajectory

third_line = slice(14, 21)

predictition_traj_ring = PredictionTrajectory.from_lens_and_model(
    tuned_lens,
    model,
    tokenizer=tokenizer,
    input_ids=input_ids_ring,
    targets=targets_ring,
).slice_sequence(third_line)

Now let’s visualize the prediction trajectory for this slice of the tranformers activations. Note that the entire sequence is still being fed to the model we are just visualizing a prediction trajectory for this particular slice ([14:21]) of the activations.

[5]:
import plotly.io as pio
pio.renderers.default = "sphinx_gallery" # Remove this if you are not seeing the plots
[6]:
from plotly.subplots import make_subplots

fig = make_subplots(
    rows=4,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.03,
    subplot_titles=("Entropy", "Forward KL", "Cross Entropy", "Max Probability"),
)

fig.add_trace(
    predictition_traj_ring.entropy().heatmap(
        colorbar_y=0.89, colorbar_len=0.25, textfont={'size':10}
    ),
    row=1, col=1
)

fig.add_trace(
    predictition_traj_ring.forward_kl().heatmap(
        colorbar_y=0.63, colorbar_len=0.25, textfont={'size':10}
    ),
    row=2, col=1
)

fig.add_trace(
    predictition_traj_ring.cross_entropy().heatmap(
        colorbar_y=0.37, colorbar_len=0.25, textfont={'size':10}
    ),
    row=3, col=1
)

fig.add_trace(
    predictition_traj_ring.max_probability().heatmap(
        colorbar_y=0.11, colorbar_len=0.25, textfont={'size':10}
    ),
    row=4, col=1
)

fig.update_layout(height=800, width=500, title_text="Tolkien's Tokens on visualized with the Tuned Lens")
fig.show()

Now let’s look at the prediction trajectory for our modified sequence.

[7]:
predictition_traj_model = PredictionTrajectory.from_lens_and_model(
    tuned_lens,
    model,
    tokenizer=tokenizer,
    input_ids=input_ids_model,
    targets=targets_model,
).slice_sequence(third_line)
[8]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(
    rows=4,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.03,
    subplot_titles=("Entropy", "Forward KL", "Cross Entropy", "Max Probability"),
)

fig.add_trace(
    predictition_traj_model.entropy().heatmap(
        colorbar_y=0.89, colorbar_len=0.25, textfont={'size':10}
    ),
    row=1, col=1
)

fig.add_trace(
    predictition_traj_model.forward_kl().heatmap(
        colorbar_y=0.63, colorbar_len=0.25, textfont={'size':10}
    ),
    row=2, col=1
)

fig.add_trace(
    predictition_traj_model.cross_entropy().heatmap(
        colorbar_y=0.37, colorbar_len=0.25, textfont={'size':10}
    ),
    row=3, col=1
)

fig.add_trace(
    predictition_traj_model.max_probability().heatmap(
        colorbar_y=0.11, colorbar_len=0.25, textfont={'size':10}
    ),
    row=4, col=1
)

fig.update_layout(height=800, width=500, title_text="Tolkien's Tokens on visualized with the Tuned Lens")
fig.show()

Now for the fun part, lets use the tools provided by the tuned lens to observe the changes between the original trajectory with _Ring and our modified trajectory where we introduced the token _Model.

[9]:
fig = predictition_traj_ring.total_variation(predictition_traj_model, min_prob_delta=0.05).figure("Total Variation")
fig.show()

This seems to stongly sugest that there is an induction head at layer 9 within this model.