Skip to content

Base Process

discrete_diffusion.forward_process.base

Base interface for discrete forward processes.

Forward processes encapsulate tokenizer-specific details and apply a chosen noise schedule to produce noised latent variables z_t (or x_t).

This module only defines the abstract interface; concrete implementations will be introduced separately.

ForwardProcess

Bases: Module

Abstract base class for discrete forward noising dynamics.

Implementations should use self.tokenizer and self.schedule to compute noised states for given inputs and timesteps.

Source code in src/discrete_diffusion/forward_process/base.py
class ForwardProcess(torch.nn.Module):
  """Abstract base class for discrete forward noising dynamics.

  Implementations should use `self.tokenizer` and `self.schedule` to compute
  noised states for given inputs and timesteps.
  """

  def __init__(self, tokenizer, schedule: NoiseSchedule, name=None) -> None:
    super().__init__()
    self.tokenizer = tokenizer
    self.schedule = schedule
    self.name = name

  def forward(self, input_ids: torch.Tensor, t: torch.Tensor):  # pragma: no cover - abstract method
    """Return the noised tokens at time `t`.

    Concrete classes may return additional tensors as needed (e.g.,
    per-position `t` for blockwise sampling).
    """
    raise NotImplementedError

forward(input_ids, t)

Return the noised tokens at time t.

Concrete classes may return additional tensors as needed (e.g., per-position t for blockwise sampling).

Source code in src/discrete_diffusion/forward_process/base.py
def forward(self, input_ids: torch.Tensor, t: torch.Tensor):  # pragma: no cover - abstract method
  """Return the noised tokens at time `t`.

  Concrete classes may return additional tensors as needed (e.g.,
  per-position `t` for blockwise sampling).
  """
  raise NotImplementedError