Combing the Tuned Lens and the Transformer Lens#

The TransformerLens is an another open source package designed to provide a standard interface for investigating the internals of transformer models. Integrating with TransformerLens and tuned-lens allows you to observe how model edits effect the prediction trajectories, and make use of all of the visualizations provided by the tuned lens package. This is primarily useful for preliminary investigations of circuits. Note this tutorial will be very hard to follow unless you are already familiar with the TransformerLens package.

To demonstrate this we will investigate the greater-than circuit in gpt2-small.

[1]:
import torch as th
from tuned_lens.plotting import PredictionTrajectory
from tuned_lens.nn import TunedLens, Unembed, LogitLens
import transformer_lens as tl

model = tl.HookedTransformer.from_pretrained(
    "gpt2",
    device="cpu",
    fold_ln=False, # The tuned lens applies the final layer norm so we should not fold
    # this into the unembed operation.
)
assert model.tokenizer is not None


tuned_lens = TunedLens.from_unembed_and_pretrained(
    unembed=Unembed(model),
    lens_resource_id="gpt2",
)

logit_lens = LogitLens.from_model(model)

def to_targets(input_ids: th.Tensor):
    return th.cat(
        (input_ids[..., 1:], th.full(input_ids.shape[:-1] + (1,), model.tokenizer.eos_token_id)
    ), dim=-1)
Using pad_token, but it is not set yet.
Loaded pretrained model gpt2 into HookedTransformer
[2]:
import plotly.io as pio
pio.renderers.default = "sphinx_gallery" # REMOVE THIS IF YOU ARE NOT SEEING PLOTS
[3]:
model.generate(" The war lasted from 1754 to 17", max_new_tokens=2, do_sample=True)
[3]:
' The war lasted from 1754 to 1776 and'
[4]:
str_tokens = model.to_str_tokens(" The war lasted from")

dates = [[12, 21],
         [11, 23],
         [10, 24],
         [16, 89],
         [17, 54],
         [14, 47],
         [15, 36],
         [17, 32],
         [18, 21],
         [11, 57]]

input_ids_strs = [str_tokens + [" " + str(data[0]),  str(data[1]), " to", " " + str(data[0])] for data in dates]
input_ids = th.tensor([[model.to_single_token(s) for s in arr] for arr in input_ids_strs])

scrub_ids_strs = [str_tokens + [" " + str(data[0]),  "01", " to", " " + str(data[0])] for data in dates]
scrub_ids = th.tensor([[model.to_single_token(s) for s in arr] for arr in scrub_ids_strs])

targets_strs = [str_tokens[1:] + [" " + str(data[0]),  str(data[1]), " to", " " + str(data[0]), str(data[1] + 3)] for data in dates]
targets = th.tensor([[model.to_single_token(s) for s in arr] for arr in targets_strs])
anti_targets_strs = [str_tokens[1:] + [" " + str(data[0]),  str(data[1]), " to", " " + str(data[0]), str(data[1] - 3)] for data in dates]
anti_targets = th.tensor([[model.to_single_token(s) for s in arr] for arr in anti_targets_strs])
[5]:
from pprint import pprint

log_prob_range = (-2, 2) # The range of log probabilities to plot
# this makes the different plots comparable.

print("Scrubbed:")
pprint(scrub_ids_strs[:2], width=120)
print("Input:")
pprint(input_ids_strs[:2], width=120)
print("Targets:")
pprint(targets_strs[:2])
print("Anti targets:")
pprint(anti_targets_strs[:2])

with th.inference_mode():
    logits, cache = model.run_with_cache(
        input=input_ids, return_type="logits"
    )

    pred_traj_clean = PredictionTrajectory.from_lens_and_cache(
        lens=tuned_lens,
        cache=cache,
        model_logits=logits,
        input_ids=input_ids,
        targets=targets,
        anti_targets=anti_targets,
    )

    pred_traj_clean_logit = PredictionTrajectory.from_lens_and_cache(
        lens=logit_lens,
        cache=cache,
        model_logits=logits,
        input_ids=input_ids,
        targets=targets,
        anti_targets=anti_targets,
    )
Scrubbed:
[['<|endoftext|>', ' The', ' war', ' lasted', ' from', ' 12', '01', ' to', ' 12'],
 ['<|endoftext|>', ' The', ' war', ' lasted', ' from', ' 11', '01', ' to', ' 11']]
Input:
[['<|endoftext|>', ' The', ' war', ' lasted', ' from', ' 12', '21', ' to', ' 12'],
 ['<|endoftext|>', ' The', ' war', ' lasted', ' from', ' 11', '23', ' to', ' 11']]
Targets:
[[' The', ' war', ' lasted', ' from', ' 12', '21', ' to', ' 12', '24'],
 [' The', ' war', ' lasted', ' from', ' 11', '23', ' to', ' 11', '26']]
Anti targets:
[[' The', ' war', ' lasted', ' from', ' 12', '21', ' to', ' 12', '18'],
 [' The', ' war', ' lasted', ' from', ' 11', '23', ' to', ' 11', '20']]
[6]:
pred_traj_clean.slice_sequence(slice(-5, None)).log_prob_diff(delta=True).clip(*log_prob_range).figure(title="Effects of each layer on the target/anti-target ratio")
[7]:
pred_traj_clean_logit.slice_sequence(slice(-5, None)).log_prob_diff(delta=True).clip(*log_prob_range).figure(title="Same as above but with the logit lens")

The above results mostly agree with the results in the paper. We see that the majority of the contributions to the correct logits do indeed come from the layers 8 through 11. Interestingly, it seems like layer 11, in particular the MLP, is actually acting as a regularizer and reducing the confidence in the target prediction.

Lets start removing components!#

Now we will show how to ablate the components of this circuit. I recommend, opening up this in google colab and playing with it!

Here are some exercises you can try: * What are the effects of ablating MLP on layer 8? * What happens when we remove MLP 11? * How much collateral damage does this cause? * How do the different types of ablation work? * Why might we prefer a swap ablation to a zero ablation hint. * CHALLENGE: What edit would you make to the model that disrupts the greater-than circuit but that has minimal effects on the models behavior?

[8]:
model.cfg.use_attn_result = True
scrubed_logits, scrubed_cache = model.run_with_cache(
    input=scrub_ids, return_type="logits"
)
model.cfg.use_attn_result = False
[9]:
import transformer_lens.utils as utils
from functools import partial


def zero_ablation_hook(result: th.Tensor, hook: tl.hook_points.HookPoint) -> th.Tensor:
    result[:] = 0
    return result

def swap_ablation_hook(result: th.Tensor, hook: tl.hook_points.HookPoint) -> th.Tensor:
    result[:] = scrubed_cache[hook.name]
    return result


MLPS_TO_ABLATE = [9]
mlp_hooks = [(utils.get_act_name("mlp_out", layer), zero_ablation_hook) for layer in MLPS_TO_ABLATE]

ATTN_MODULES_TO_ABLATE = []
attn_hooks = [(utils.get_act_name("result", layer), zero_ablation_hook) for layer in ATTN_MODULES_TO_ABLATE]

with model.hooks(fwd_hooks=(mlp_hooks + attn_hooks)), th.inference_mode():
    model.cfg.use_attn_result = True
    logits, cache = model.run_with_cache(
        input=input_ids, return_type="logits"
    )
    model.cfg.use_attn_result = False

    pred_traj_ablated = PredictionTrajectory.from_lens_and_cache(
        lens=tuned_lens,
        cache=cache,
        model_logits=logits,
        input_ids=input_ids,
        targets=targets,
        anti_targets=anti_targets,
    )

    pred_traj_ablated_logit = PredictionTrajectory.from_lens_and_cache(
        lens=logit_lens,
        cache=cache,
        model_logits=logits,
        input_ids=input_ids,
        targets=targets,
        anti_targets=anti_targets,
    )
[10]:
pred_traj_ablated.slice_sequence(slice(-5, None)).log_prob_diff(delta=True).clip(*log_prob_range).figure(title="Effects of each layer on the target/anti-target ratio after ablation")
[11]:
pred_traj_ablated_logit.log_prob_diff(delta=True).clip(*log_prob_range).figure()

How much collateral damage does ablating the above components cause?#

An interesting question we can answerer with the tuned lens is what other capabilities have our edits effected?

Note the code bellow assumes you are using zero ablation.

[12]:
with model.hooks(fwd_hooks=(mlp_hooks + attn_hooks)), th.inference_mode():
    model.cfg.use_attn_result = True
    control_text = model.generate(" Numbers I love them!", max_new_tokens=10, do_sample=False)
    model.cfg.use_attn_result = False

input_ids_control = model.to_tokens(control_text)
targets_control = to_targets(input_ids_control)

logits, cache = model.run_with_cache(
    input=input_ids_control, return_type="logits"
)

pred_traj_control_clean = PredictionTrajectory.from_lens_and_cache(
    lens=tuned_lens,
    cache=cache,
    model_logits=logits,
    input_ids=input_ids_control,
    targets=targets_control,
)

with model.hooks(fwd_hooks=(mlp_hooks + attn_hooks)), th.inference_mode():
    model.cfg.use_attn_result = True
    logits, cache = model.run_with_cache(
        input=input_ids_control, return_type="logits"
    )
    model.cfg.use_attn_result = False

    pred_traj_control_ablated = PredictionTrajectory.from_lens_and_cache(
        lens=tuned_lens,
        cache=cache,
        model_logits=logits,
        input_ids=input_ids_control,
        targets=targets_control,
    )

pred_traj_control_ablated.kl_divergence(pred_traj_control_clean).figure()