Skip to content

Flex Schedule

discrete_diffusion.noise_schedules.flex

FlexMDM schedule primitives and factory.

FlexSchedule

Bases: ABC

Minimal interface matching FlexMDM's schedule objects.

Note

Flex noise schedules go from 0 to 1 (increasing noise/masking), while standard schedules go from 1 to 0 (decreasing signal). These should be unified in the future.

Source code in src/discrete_diffusion/noise_schedules/flex.py
class FlexSchedule(abc.ABC):
  """Minimal interface matching FlexMDM's schedule objects.

  Note:
      Flex noise schedules go from 0 to 1 (increasing noise/masking), while
      standard schedules go from 1 to 0 (decreasing signal). These should be
      unified in the future.
  """

  @abc.abstractmethod
  def at(self, t: Tensor) -> Tensor:
    raise NotImplementedError

  @abc.abstractmethod
  def derivative_at(self, t: Tensor) -> Tensor:
    raise NotImplementedError

  @abc.abstractmethod
  def inv(self, alpha: Tensor) -> Tensor:
    raise NotImplementedError

  def rate_scale_factor(self, t: Tensor) -> Tensor:
    denom = (1 - self.at(t)).clamp_min(1e-6)
    return self.derivative_at(t) / denom

  def sample(self, shape: tuple[int, ...], device: torch.device) -> Tensor:
    uniform = torch.rand(shape, device=device)
    return self.inv(uniform)

  def sample_truncated(
    self, threshold: Tensor, shape: tuple[int, ...], device: torch.device
  ) -> Tensor:
    uniform = torch.rand(shape, device=device)
    threshold_alpha = self.at(threshold)
    return self.inv(uniform * (1 - threshold_alpha) + threshold_alpha)

build_flex_schedule(config)

Instantiate a Flex-style schedule from a Hydra config snippet.

Source code in src/discrete_diffusion/noise_schedules/flex.py
def build_flex_schedule(config: Mapping[str, Any] | None) -> FlexSchedule:
  """Instantiate a Flex-style schedule from a Hydra config snippet."""
  cfg = _to_dict(config)
  schedule_type = cfg.get("type", "linear").lower()

  if schedule_type == "linear":
    return LinearSchedule()
  if schedule_type == "cosine":
    return CosineSchedule()
  if schedule_type == "sin":
    return SinSchedule()
  if schedule_type == "polynomial":
    if "exp" not in cfg:
      raise ValueError("Polynomial schedule requires 'exp'.")
    return PolynomialSchedule(exp=float(cfg["exp"]))
  if schedule_type == "geometric":
    missing = [k for k in ("min", "max") if k not in cfg]
    if missing:
      raise ValueError(f"Geometric schedule missing keys: {missing}")
    return GeometricSchedule(min_val=float(cfg["min"]), max_val=float(cfg["max"]))

  raise ValueError(f"Unsupported Flex schedule type: {schedule_type}")