tuned_lens.plotting.trajectory_plotting#

Contains utility classes for creating heatmap visualizations.

Functions

# empty test needed in case the module has no example usage.
# otherwise, testsetup throws an error
pass
tuned_lens.plotting.trajectory_plotting.trunc_string_left(string, new_len)#

Truncate a string to the left.

Return type:

str

Classes

class tuned_lens.plotting.trajectory_plotting.TrajectoryLabels(label_strings, hover_over_entries=None)#

Contains sets of labels for each layer and position in the residual stream.

hover_over_entries: Optional[ndarray[Any, dtype[str_]]] = None#

(n_layers x sequence_length x rows x cols) table of strings to display when hovering over a cell. For example, the top k prediction from the lens.

label_strings: ndarray[Any, dtype[str_]]#

(n_layers x sequence_length) label for each layer and position in the stream.

stride(stride)#

Return a new TrajectoryLabels with the given stride.

Parameters:

stride – The number of layers between each layer we keep.

Return type:

TrajectoryLabels

Returns:

A new TrajectoryLabels with the given stride.

template_and_customdata(col_width_limit=10)#

Construct a template for use with Plotly’s hovertemplate.

Return type:

Tuple[str, ndarray[Any, dtype[str_]]]

class tuned_lens.plotting.trajectory_plotting.TrajectoryStatistic(name, stats, sequence_labels=None, trajectory_labels=None, units=None, max=None, min=None, includes_output=True, _layer_labels=None)#

This class represents a trajectory statistic that can be visualized.

For example, the entropy of the lens predictions at each layer.

clip(min, max)#

Return a new TrajectoryStatistic with the given min and max.

Parameters:
  • min – The minimum value to clip to.

  • max – The maximum value to clip to.

Return type:

TrajectoryStatistic

Returns:

A new TrajectoryStatistic with the given min and max.

figure(title='', colorscale='rdbu_r', token_width=80)#

Produce a heatmap plot of the statistic.

Parameters:
  • title – The title of the plot.

  • colorscale – The colorscale to use for the heatmap.

  • token_width – The width of each token in the plot.

Return type:

Figure

Returns:

The plotly heatmap figure.

heatmap(colorscale='rdbu_r', log_scale=False, **kwargs)#

Returns a Plotly Heatmap object for this statistic.

Parameters:
  • colorscale – The colorscale to use for the heatmap.

  • log_scale – Whether to use a log scale for the colorbar.

  • **kwargs – Additional keyword arguments to pass to the Heatmap constructor.

Return type:

Heatmap

Returns:

A plotly Heatmap where the x-axis is the sequence dimension, the y-axis is the layer dimension, and the color of each cell is the value of the statistic.

includes_output: bool = True#

Whether the statistic includes the final output layer.

max: Optional[float] = None#

The maximum value of the statistic.

min: Optional[float] = None#

The minimum value of the statistic.

name: str#

The name of the statistic. For example, β€œentropy”.

sequence_labels: Optional[ndarray[Any, dtype[str_]]] = None#

(sequence_length) labels for the sequence dimension e.g. input tokens.

stats: ndarray[Any, dtype[float32]]#

(n_layers x sequence_length) value of the statistic across layer and position.

stride(stride)#

Return a new TrajectoryStatistic with the given stride.

Parameters:

stride – The number of layers between each layer we keep.

Return type:

TrajectoryStatistic

Returns:

A new TrajectoryStatistic with the given stride.

trajectory_labels: Optional[TrajectoryLabels] = None#

Labels for each layer and position in the stream. For example, the top 1 prediction from the lens at each layer.

units: Optional[str] = None#

The units of the statistic.