Skip to content

Autoregressive Sampling

discrete_diffusion.sampling.ar

Autoregressive sampler for AR model.

ARSampler

Bases: Sampler

Sampler for autoregressive language models.

Source code in src/discrete_diffusion/sampling/ar.py
class ARSampler(Sampler):
  """Sampler for autoregressive language models."""

  def __init__(self, config, forward_process=None):
    self.config = config

  @torch.no_grad()
  def generate(self, model, *, num_samples, num_steps, eps, inject_bos):
    """Generate samples autoregressively from left to right.

    Args:
      model: The AR model instance.
      num_samples: Number of samples to generate.
      num_steps: Unused for AR (kept for API compatibility).
      eps: Unused for AR (kept for API compatibility).
      inject_bos: Whether to inject BOS token at position 0.

    Returns:
      Generated token sequences of shape [num_samples, num_tokens].
    """
    del num_steps, eps  # Unused for AR

    # Precompute token buffer
    num_pred_tokens = model.num_tokens - 1
    x = torch.zeros(
      (num_samples, num_pred_tokens + 1),
      dtype=torch.long,
      device=model.device)
    if inject_bos:
      x[:, 0] = model.tokenizer.bos_token_id

    # Precompute Gumbel noise for sampling
    noise = (torch.distributions.Gumbel(0, 1)
             .sample((num_samples, num_pred_tokens, model.vocab_size))
             .to(model.device))
    if self.config.sampling.use_float64:
      noise = noise.to(torch.float64)

    # Generate tokens autoregressively
    for i in range(num_pred_tokens):
      output = model.backbone(x[:, :i + 1], None)
      output[:, :, model.mask_id] = model.neg_infinity
      output = output.log_softmax(-1)
      y = (output[:, -1, :] + noise[:, i, :]).argmax(-1)
      x[:, i + 1] = y

    return x

generate(model, *, num_samples, num_steps, eps, inject_bos)

Generate samples autoregressively from left to right.

Parameters:

Name Type Description Default
model

The AR model instance.

required
num_samples

Number of samples to generate.

required
num_steps

Unused for AR (kept for API compatibility).

required
eps

Unused for AR (kept for API compatibility).

required
inject_bos

Whether to inject BOS token at position 0.

required

Returns:

Type Description

Generated token sequences of shape [num_samples, num_tokens].

Source code in src/discrete_diffusion/sampling/ar.py
@torch.no_grad()
def generate(self, model, *, num_samples, num_steps, eps, inject_bos):
  """Generate samples autoregressively from left to right.

  Args:
    model: The AR model instance.
    num_samples: Number of samples to generate.
    num_steps: Unused for AR (kept for API compatibility).
    eps: Unused for AR (kept for API compatibility).
    inject_bos: Whether to inject BOS token at position 0.

  Returns:
    Generated token sequences of shape [num_samples, num_tokens].
  """
  del num_steps, eps  # Unused for AR

  # Precompute token buffer
  num_pred_tokens = model.num_tokens - 1
  x = torch.zeros(
    (num_samples, num_pred_tokens + 1),
    dtype=torch.long,
    device=model.device)
  if inject_bos:
    x[:, 0] = model.tokenizer.bos_token_id

  # Precompute Gumbel noise for sampling
  noise = (torch.distributions.Gumbel(0, 1)
           .sample((num_samples, num_pred_tokens, model.vocab_size))
           .to(model.device))
  if self.config.sampling.use_float64:
    noise = noise.to(torch.float64)

  # Generate tokens autoregressively
  for i in range(num_pred_tokens):
    output = model.backbone(x[:, :i + 1], None)
    output[:, :, model.mask_id] = model.neg_infinity
    output = output.log_softmax(-1)
    y = (output[:, -1, :] + noise[:, i, :]).argmax(-1)
    x[:, i + 1] = y

  return x