tuned_lens.plotting.trajectory_plotting¶
Contains utility classes for creating heatmap visualizations.
Functions
# empty test needed in case the module has no example usage.
# otherwise, testsetup throws an error
pass
- tuned_lens.plotting.trajectory_plotting.trunc_string_left(string, new_len)¶
Truncate a string to the left.
- Return type:
str
Classes
- class tuned_lens.plotting.trajectory_plotting.TrajectoryLabels(label_strings, hover_over_entries=None)¶
Contains sets of labels for each layer and position in the residual stream.
-
hover_over_entries:
Optional[ndarray[Any,dtype[str_]]] = None¶ (n_layers x sequence_length x rows x cols) table of strings to display when hovering over a cell. For example, the top k prediction from the lens.
-
label_strings:
ndarray[Any,dtype[str_]]¶ (n_layers x sequence_length) label for each layer and position in the stream.
- stride(stride)¶
Return a new TrajectoryLabels with the given stride.
- Parameters:
stride – The number of layers between each layer we keep.
- Return type:
- Returns:
A new TrajectoryLabels with the given stride.
- template_and_customdata(col_width_limit=10)¶
Construct a template for use with Plotly’s hovertemplate.
- Return type:
Tuple[str,ndarray[Any,dtype[str_]]]
-
hover_over_entries:
- class tuned_lens.plotting.trajectory_plotting.TrajectoryStatistic(name, stats, sequence_labels=None, trajectory_labels=None, units=None, max=None, min=None, includes_output=True, _layer_labels=None)¶
This class represents a trajectory statistic that can be visualized.
For example, the entropy of the lens predictions at each layer.
- clip(min, max)¶
Return a new TrajectoryStatistic with the given min and max.
- Parameters:
min – The minimum value to clip to.
max – The maximum value to clip to.
- Return type:
- Returns:
A new TrajectoryStatistic with the given min and max.
- figure(title='', colorscale='rdbu_r', token_width=80)¶
Produce a heatmap plot of the statistic.
- Parameters:
title – The title of the plot.
colorscale – The colorscale to use for the heatmap.
token_width – The width of each token in the plot.
- Return type:
Figure- Returns:
The plotly heatmap figure.
- heatmap(colorscale='rdbu_r', log_scale=False, **kwargs)¶
Returns a Plotly Heatmap object for this statistic.
- Parameters:
colorscale – The colorscale to use for the heatmap.
log_scale – Whether to use a log scale for the colorbar.
**kwargs – Additional keyword arguments to pass to the Heatmap constructor.
- Return type:
Heatmap- Returns:
A plotly Heatmap where the x-axis is the sequence dimension, the y-axis is the layer dimension, and the color of each cell is the value of the statistic.
-
includes_output:
bool= True¶ Whether the statistic includes the final output layer.
-
max:
Optional[float] = None¶ The maximum value of the statistic.
-
min:
Optional[float] = None¶ The minimum value of the statistic.
-
name:
str¶ The name of the statistic. For example, “entropy”.
-
sequence_labels:
Optional[ndarray[Any,dtype[str_]]] = None¶ (sequence_length) labels for the sequence dimension e.g. input tokens.
-
stats:
ndarray[Any,dtype[float32]]¶ (n_layers x sequence_length) value of the statistic across layer and position.
- stride(stride)¶
Return a new TrajectoryStatistic with the given stride.
- Parameters:
stride – The number of layers between each layer we keep.
- Return type:
- Returns:
A new TrajectoryStatistic with the given stride.
-
trajectory_labels:
Optional[TrajectoryLabels] = None¶ Labels for each layer and position in the stream. For example, the top 1 prediction from the lens at each layer.
-
units:
Optional[str] = None¶ The units of the statistic.