Skip to content

Base Schedule

discrete_diffusion.noise_schedules.base

Base interfaces for noise schedules used in discrete diffusion.

Defines a stable NoiseSchedule protocol with continuous-time semantics and an adapter ScheduleAdapter that preserves the legacy call signature schedule(t) -> (alpha_prime_t, alpha_t) used throughout existing trainers.

NoiseSchedule

Bases: Module

Abstract base class for continuous-time noise schedules.

Implementations should return attenuation factors alpha(t) in (0, 1] and their derivative alpha'(t) with respect to t. Some schedules may also provide a cumulative/"total" noise measure (e.g., required by SEDD).

Source code in src/discrete_diffusion/noise_schedules/base.py
class NoiseSchedule(torch.nn.Module):
  """Abstract base class for continuous-time noise schedules.

  Implementations should return attenuation factors `alpha(t)` in (0, 1] and
  their derivative `alpha'(t)` with respect to `t`. Some schedules may also
  provide a cumulative/"total" noise measure (e.g., required by SEDD).
  """

  def __init__(self) -> None:
    super().__init__()

  def alpha_t(self, t: torch.Tensor) -> torch.Tensor:
    """Return attenuation `alpha(t)` for timesteps `t` in [0, 1].

    Args:
      t: Tensor of shape `(B,)` or broadcastable to `(B, 1)` with dtype float.

    Returns:
      Tensor matching the shape of `t` (broadcastable) with values in (0, 1].
    """
    raise NotImplementedError

  def alpha_prime_t(self, t: torch.Tensor) -> torch.Tensor:
    """Return derivative `d/dt alpha(t)` for timesteps `t`.

    Args:
      t: Tensor of shape `(B,)` or broadcastable to `(B, 1)` with dtype float.

    Returns:
      Tensor broadcastable to the shape of `t`.
    """
    raise NotImplementedError

  def total_noise(self, t: torch.Tensor) -> torch.Tensor:
    """Optional cumulative noise measure for schedules that define it.

    Implementations that do not support a total noise measure may raise
    `NotImplementedError`. This is used by SEDD-style forward processes.
    """
    raise NotImplementedError

alpha_prime_t(t)

Return derivative d/dt alpha(t) for timesteps t.

Parameters:

Name Type Description Default
t Tensor

Tensor of shape (B,) or broadcastable to (B, 1) with dtype float.

required

Returns:

Type Description
Tensor

Tensor broadcastable to the shape of t.

Source code in src/discrete_diffusion/noise_schedules/base.py
def alpha_prime_t(self, t: torch.Tensor) -> torch.Tensor:
  """Return derivative `d/dt alpha(t)` for timesteps `t`.

  Args:
    t: Tensor of shape `(B,)` or broadcastable to `(B, 1)` with dtype float.

  Returns:
    Tensor broadcastable to the shape of `t`.
  """
  raise NotImplementedError

alpha_t(t)

Return attenuation alpha(t) for timesteps t in [0, 1].

Parameters:

Name Type Description Default
t Tensor

Tensor of shape (B,) or broadcastable to (B, 1) with dtype float.

required

Returns:

Type Description
Tensor

Tensor matching the shape of t (broadcastable) with values in (0, 1].

Source code in src/discrete_diffusion/noise_schedules/base.py
def alpha_t(self, t: torch.Tensor) -> torch.Tensor:
  """Return attenuation `alpha(t)` for timesteps `t` in [0, 1].

  Args:
    t: Tensor of shape `(B,)` or broadcastable to `(B, 1)` with dtype float.

  Returns:
    Tensor matching the shape of `t` (broadcastable) with values in (0, 1].
  """
  raise NotImplementedError

total_noise(t)

Optional cumulative noise measure for schedules that define it.

Implementations that do not support a total noise measure may raise NotImplementedError. This is used by SEDD-style forward processes.

Source code in src/discrete_diffusion/noise_schedules/base.py
def total_noise(self, t: torch.Tensor) -> torch.Tensor:
  """Optional cumulative noise measure for schedules that define it.

  Implementations that do not support a total noise measure may raise
  `NotImplementedError`. This is used by SEDD-style forward processes.
  """
  raise NotImplementedError