tuned_lens.plotting.prediction_trajectory#
Plot a lens table for some given text and model.
Classes
- class tuned_lens.plotting.prediction_trajectory.PredictionTrajectory(log_probs, input_ids, targets=None, anti_targets=None, tokenizer=None)#
Contains the trajectory predictions for a sequence of tokens.
A prediction trajectory is the set of next token predictions produced by the conjunction of a lens and a model when evaluated on a specific sequence of tokens. This class include multiple methods for visualizing different aspects of the trajectory.
-
anti_targets:
Optional
[ndarray
[Any
,dtype
[int64
]]] = None# (…, seq_len)
- property batch_axes: Sequence[int]#
Returns the batch axes for the trajectory.
- property batch_shape: Sequence[int]#
Returns the batch shape of the trajectory.
- cross_entropy(**kwargs)#
The cross entropy of the predictions to the targets.
- Parameters:
**kwargs – are passed to largest_prob_labels.
- Return type:
- Returns:
A TrajectoryStatistic with the cross entropy of the predictions to the targets.
- entropy(**kwargs)#
The entropy of the predictions.
- Parameters:
**kwargs – are passed to largest_prob_labels.
- Return type:
- Returns:
A TrajectoryStatistic with the entropy of the predictions.
- forward_kl(**kwargs)#
KL divergence of the lens predictions to the model predictions.
- Parameters:
**kwargs – are passed to largest_prob_labels.
- Return type:
- Returns:
A TrajectoryStatistic with the KL divergence of the lens predictions to the final output of the model.
- classmethod from_lens_and_cache(lens, input_ids, cache, model_logits, targets=None, anti_targets=None, residual_component='resid_pre', mask_input=False)#
Construct a prediction trajectory from a set of residual stream vectors.
- Parameters:
lens – A lens to use to produce the predictions.
cache – the activation cache produced by running the model.
input_ids – (…, seq_len) Ids that where input into the model.
model_logits – (…, seq_len x d_vocab) the models final output logits.
targets – (…, seq_len) the targets the model is should predict. Used for
cross_entropy()
andlog_prob_diff()
visualization.anti_targets – (…, seq_len) the incorrect label the model should not predict. Used for
log_prob_diff()
visualization.residual_component – Name of the stream vector being visualized.
mask_input – Whether to mask the input ids when computing the log probs.
- Return type:
PredictionTrajectory
- Returns:
PredictionTrajectory constructed from the residual stream vectors.
- classmethod from_lens_and_model(lens, model, input_ids, tokenizer=None, targets=None, anti_targets=None, mask_input=False)#
Construct a prediction trajectory from a set of residual stream vectors.
- Parameters:
lens – A lens to use to produce the predictions. Note this should be compatible with the model.
model – A Hugging Face causal language model to use to produce the predictions.
tokenizer – The tokenizer to use for decoding the input ids.
input_ids – (seq_len) Ids that where input into the model.
targets – (seq_len) the targets the model is should predict. Used for
cross_entropy()
andlog_prob_diff()
visualization.anti_targets – (seq_len) the incorrect label the model should not predict. Used for
log_prob_diff()
visualization.residual_component – Name of the stream vector being visualized.
mask_input – Whether to mask the input ids when computing the log probs.
- Return type:
- Returns:
PredictionTrajectory constructed from the residual stream vectors.
-
input_ids:
ndarray
[Any
,dtype
[int64
]]# (…, seq_len)
- js_divergence(other, **kwargs)#
Compute the JS divergence between self and other prediction trajectory.
- Parameters:
other – The other prediction trajectory to compare to.
**kwargs – are passed to largest_delta_in_prob_labels.
- Return type:
- Returns:
A TrajectoryStatistic with the JS divergence between self and other.
- kl_divergence(other, **kwargs)#
Compute the KL divergence between self and other prediction trajectory.
- Parameters:
other – The other prediction trajectory to compare to.
**kwargs – are passed to largest_delta_in_prob_labels.
- Return type:
- Returns:
A TrajectoryStatistic with the KL divergence between self and other.
- log_prob_diff(delta=False)#
The difference in logits between two tokens.
- Return type:
- Returns:
The difference between the log probabilities of the two tokens.
-
log_probs:
ndarray
[Any
,dtype
[float32
]]# (…, n_layers, seq_len, vocab_size) The log probabilities of the predictions for each hidden layer + the models logits
- max_probability(**kwargs)#
Max probability of the among the predictions.
- Parameters:
**kwargs – are passed to largest_prob_labels.
- Return type:
- Returns:
A TrajectoryStatistic with the max probability of the among the predictions.
- property model_log_probs: ndarray[Any, dtype[float32]]#
Returns the log probs of the model (…, seq_len, vocab_size).
- property n_batch_axis: int#
Returns the number of batch dimensions.
- property num_layers: int#
Returns the number of layers in the stream not including the model output.
- property num_tokens: int#
Returns the number of tokens in this slice of the sequence.
- property probs: ndarray[Any, dtype[float32]]#
Returns the probabilities of the predictions.
- rank(show_ranks=False, **kwargs)#
The rank of the targets among the predictions.
That is, if the target is the most likely prediction, its rank is 1; the second most likely has rank 2, etc.
- Parameters:
show_ranks – Whether to show the the rank of the target or the top token.
**kwargs – are passed to largest_prob_labels.
- Return type:
- Returns:
A TrajectoryStatistic with the rank of the targets among the predictions.
- slice_sequence(slice)#
Create a slice of the prediction trajectory along the sequence dimension.
- Return type:
-
targets:
Optional
[ndarray
[Any
,dtype
[int64
]]] = None# (…, seq_len)
- total_variation(other, **kwargs)#
Total variation distance between self and other prediction trajectory.
- Parameters:
other – The other prediction trajectory to compare to.
**kwargs – are passed to largest_delta_in_prob_labels.
- Return type:
- Returns:
A TrajectoryStatistic with the total variational distance between self and other.
- property vocab_size: int#
Returns the size of the vocabulary.
-
anti_targets: