tuned_lens.nn.unembed#

Provides a class for mapping transformer hidden states to logits (and vice versa).

Classes

class tuned_lens.nn.unembed.InversionOutput(preimage, grad_norm, kl, loss, nfev)#

Output of Unemebd.invert.

class tuned_lens.nn.unembed.Unembed(model)#

Module that maps transformer hidden states to logits (and vice versa).

forward(h)#

Convert hidden states into logits.

Return type:

Tensor

invert(logits, *, h0=None, max_iter=1000, optimizer='lbfgs', prior_weight=0.0, prior=None, step_size=1.0, tol=0.001, weight=None)#

Project logits onto the image of the unemebed operation.

When the hidden state dimension is smaller than the vocabulary size, the unembed operation cannot perfectly represent arbitrary logits, since its image is restricted to a subspace; this phenomenon is known as the softmax bottleneck (cf. https://arxiv.org/abs/1711.03953). Because of this, the inverse can only be approximate in general. Here, we use gradient-based optimization to find a hidden state that minimizes the KL divergence from the target distribution p to unembeded logits q(h): h* = argmin_h KL(p || q(h)).

Parameters:
  • logits – Tensor of shape […, vocab_size] containing logits to invert.

  • h0 – Initial guess for the hidden state. If None, the least-squares solution of the linear equation xU = logits is used, where U is the unembedding matrix.

  • max_iter – Maximum number of iterations for the optimizer to take.

  • optimizer – Optimization algorithm to use. Currently, only “lbfgs” and “sgd” are supported.

  • prior_weight – The weight of the prior distribution is given in the loss.

  • prior – Prior distribution over hidden states used to regularize the inversion.

  • step_size – The step size for the optimizer.

  • tol – Tolerance for the inversion objective.

  • weight – Optional tensor of shape […, vocab_size] containing weights for each vocabulary item. If None, all classes are weighted equally.

Return type:

InversionOutput

unembedding_hash()#

Hash the unmbedding matrix to identify the model.

Return type:

str