Skip to content

Forward Process Utils

discrete_diffusion.forward_process.utils

Utilities for forward-process implementations.

Includes tokenizer helpers and a numerically stable categorical sampler.

sample_categorical(categorical_probs)

Sample categories via a Gumbel-max formulation for stability.

Expects categorical_probs to be non-negative and to sum to one along the last dimension. This implementation mirrors the stable sampler used in the existing absorbing helpers for consistency.

Source code in src/discrete_diffusion/forward_process/utils.py
def sample_categorical(categorical_probs: torch.Tensor) -> torch.Tensor:
  """Sample categories via a Gumbel-max formulation for stability.

  Expects `categorical_probs` to be non-negative and to sum to one along the
  last dimension. This implementation mirrors the stable sampler used in the
  existing absorbing helpers for consistency.
  """
  gumbel_norm = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()
  return (categorical_probs / gumbel_norm).argmax(dim=-1)