tuned_lens.plotting.trajectory_plotting#
Contains utility classes for creating heatmap visualizations.
Functions
# empty test needed in case the module has no example usage.
# otherwise, testsetup throws an error
pass
- tuned_lens.plotting.trajectory_plotting.trunc_string_left(string, new_len)#
Truncate a string to the left.
- Return type:
str
Classes
- class tuned_lens.plotting.trajectory_plotting.TrajectoryLabels(label_strings, hover_over_entries=None)#
Contains sets of labels for each layer and position in the residual stream.
-
hover_over_entries:
Optional
[ndarray
[Any
,dtype
[str_
]]] = None# (n_layers x sequence_length x rows x cols) table of strings to display when hovering over a cell. For example, the top k prediction from the lens.
-
label_strings:
ndarray
[Any
,dtype
[str_
]]# (n_layers x sequence_length) label for each layer and position in the stream.
- stride(stride)#
Return a new TrajectoryLabels with the given stride.
- Parameters:
stride β The number of layers between each layer we keep.
- Return type:
- Returns:
A new TrajectoryLabels with the given stride.
- template_and_customdata(col_width_limit=10)#
Construct a template for use with Plotlyβs hovertemplate.
- Return type:
Tuple
[str
,ndarray
[Any
,dtype
[str_
]]]
-
hover_over_entries:
- class tuned_lens.plotting.trajectory_plotting.TrajectoryStatistic(name, stats, sequence_labels=None, trajectory_labels=None, units=None, max=None, min=None, includes_output=True, _layer_labels=None)#
This class represents a trajectory statistic that can be visualized.
For example, the entropy of the lens predictions at each layer.
- clip(min, max)#
Return a new TrajectoryStatistic with the given min and max.
- Parameters:
min β The minimum value to clip to.
max β The maximum value to clip to.
- Return type:
- Returns:
A new TrajectoryStatistic with the given min and max.
- figure(title='', colorscale='rdbu_r', token_width=80)#
Produce a heatmap plot of the statistic.
- Parameters:
title β The title of the plot.
colorscale β The colorscale to use for the heatmap.
token_width β The width of each token in the plot.
- Return type:
Figure
- Returns:
The plotly heatmap figure.
- heatmap(colorscale='rdbu_r', log_scale=False, **kwargs)#
Returns a Plotly Heatmap object for this statistic.
- Parameters:
colorscale β The colorscale to use for the heatmap.
log_scale β Whether to use a log scale for the colorbar.
**kwargs β Additional keyword arguments to pass to the Heatmap constructor.
- Return type:
Heatmap
- Returns:
A plotly Heatmap where the x-axis is the sequence dimension, the y-axis is the layer dimension, and the color of each cell is the value of the statistic.
-
includes_output:
bool
= True# Whether the statistic includes the final output layer.
-
max:
Optional
[float
] = None# The maximum value of the statistic.
-
min:
Optional
[float
] = None# The minimum value of the statistic.
-
name:
str
# The name of the statistic. For example, βentropyβ.
-
sequence_labels:
Optional
[ndarray
[Any
,dtype
[str_
]]] = None# (sequence_length) labels for the sequence dimension e.g. input tokens.
-
stats:
ndarray
[Any
,dtype
[float32
]]# (n_layers x sequence_length) value of the statistic across layer and position.
- stride(stride)#
Return a new TrajectoryStatistic with the given stride.
- Parameters:
stride β The number of layers between each layer we keep.
- Return type:
- Returns:
A new TrajectoryStatistic with the given stride.
-
trajectory_labels:
Optional
[TrajectoryLabels
] = None# Labels for each layer and position in the stream. For example, the top 1 prediction from the lens at each layer.
-
units:
Optional
[str
] = None# The units of the statistic.