Skip to content

Common Model Layers

discrete_diffusion.models.common

Shared model primitives for UNI-D² backbones.

This module centralizes lightweight, backend-agnostic helpers used across multiple backbones (DiT, BlockDiT, Encoder-Decoder). Attention backend selection (flash-attn vs SDPA vs flex) remains in each backbone.

DDiTBlockCausal

Bases: Module

Source code in src/discrete_diffusion/models/common.py
class DDiTBlockCausal(nn.Module):
  def __init__(self, dim, n_heads, mlp_ratio=4, dropout=0.1, attn_backend='auto'):
    super().__init__()
    self.n_heads = n_heads
    self.attn_backend = attn_backend

    self.norm1 = LayerNorm(dim)
    self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
    self.attn_out = nn.Linear(dim, dim, bias=False)
    self.dropout1 = nn.Dropout(dropout)

    self.norm2 = LayerNorm(dim)
    self.mlp = nn.Sequential(
      nn.Linear(dim, mlp_ratio * dim, bias=True),
      nn.GELU(approximate='tanh'),
      nn.Linear(mlp_ratio * dim, dim, bias=True))
    self.dropout2 = nn.Dropout(dropout)
    self.dropout = dropout

  def _get_bias_dropout_scale(self):
    if self.training:
      return bias_dropout_add_scale_fused_train
    else:
      return bias_dropout_add_scale_fused_inference

  def _apply_causal_attention(self, qkv, rotary_cos_sin):
    """Apply causal attention with fallback logic."""
    cos, sin = rotary_cos_sin
    cos = cos.to(qkv.dtype)
    sin = sin.to(qkv.dtype)

    # Try flash-attn first
    if self.attn_backend == 'flash_attn' or (self.attn_backend == 'auto' and supports_flash_attention()):
      with torch.amp.autocast('cuda', enabled=False):
        qkv_rotary = apply_rotary_pos_emb(qkv, cos, sin)
      return flash_varlen_attention_qkvpacked(qkv_rotary, causal=True)
    else:
      # Fallback to SDPA
      qkv_rotary = apply_rotary_pos_emb_torchscript(qkv, cos, sin)
      q, k, v = qkv_rotary.chunk(3, dim=2)
      return sdpa_attention(q, k, v, causal=True, dropout_p=0.0)

  def forward(self, x, rotary_cos_sin, **kwargs):
    del kwargs
    bias_dropout_scale_fn = self._get_bias_dropout_scale()
    x_skip = x
    x = self.norm1(x)

    qkv = self.attn_qkv(x)
    qkv = rearrange(
      qkv,
      'b s (three h d) -> b s three h d',
      three=3,
      h=self.n_heads)
    x = self._apply_causal_attention(qkv, rotary_cos_sin)

    scale = torch.ones(1, device=x.device, dtype=x.dtype)
    x = bias_dropout_scale_fn(
      self.attn_out(x), None, scale, x_skip, self.dropout)

    x = bias_dropout_scale_fn(
      self.mlp(self.norm2(x)), None, scale, x, self.dropout)
    return x

LabelEmbedder

Bases: Module

Embeds class labels into vector representations.

Source code in src/discrete_diffusion/models/common.py
class LabelEmbedder(nn.Module):
  """Embeds class labels into vector representations."""

  def __init__(self, num_classes, cond_size):
    super().__init__()
    self.embedding_table = nn.Embedding(num_classes + 1, cond_size)
    self.num_classes = num_classes

  def forward(self, labels):
    return self.embedding_table(labels)

TimestepEmbedder

Bases: Module

Embeds scalar timesteps into vector representations.

Source code in src/discrete_diffusion/models/common.py
class TimestepEmbedder(nn.Module):
  """Embeds scalar timesteps into vector representations."""

  def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
    super().__init__()
    self.mlp = nn.Sequential(
      nn.Linear(frequency_embedding_size, hidden_size, bias=True),
      nn.SiLU(),
      nn.Linear(hidden_size, hidden_size, bias=True))
    self.frequency_embedding_size = frequency_embedding_size

  @staticmethod
  def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000):
    # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
    half = dim // 2
    freqs = torch.exp(
      - math.log(max_period)
      * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
      / half)
    args = t[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
      embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

  def forward(self, t: torch.Tensor) -> torch.Tensor:
    t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
    t_emb = self.mlp(t_freq)
    return t_emb

apply_rotary_pos_emb(qkv, cos, sin)

In-place rotary application for qkv-packed tensors using flash-attn helper.

Source code in src/discrete_diffusion/models/common.py
def apply_rotary_pos_emb(qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
  """In-place rotary application for qkv-packed tensors using flash-attn helper."""
  if flash_attn is None:
    raise RuntimeError("flash_attn is required for rotary qkv application")
  cos = cos[0, :, 0, 0, :cos.shape[-1] // 2]
  sin = sin[0, :, 0, 0, :sin.shape[-1] // 2]
  return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)

apply_rotary_pos_emb_single(vec, cos, sin)

Apply rotary to a single tensor (q or k) with shape (B, S, H, D).

Source code in src/discrete_diffusion/models/common.py
def apply_rotary_pos_emb_single(vec: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
  """Apply rotary to a single tensor (q or k) with shape (B, S, H, D)."""
  if flash_attn is None:
    raise RuntimeError("flash_attn is required for rotary single-vector application")
  with torch.amp.autocast('cuda', enabled=False):
    cos = cos.to(vec.dtype)
    sin = sin.to(vec.dtype)
    if vec.shape[1] < cos.shape[1]:
      cos = cos[:, :vec.shape[1]]
      sin = sin[:, :vec.shape[1]]
    if cos.shape[0] == 1:
      cos_in = cos[0, :, 0, 0, :cos.shape[-1] // 2]
      sin_in = sin[0, :, 0, 0, :sin.shape[-1] // 2]
    else:
      cos_in = cos[:, :, 0, 0, :cos.shape[-1] // 2]
      sin_in = sin[:, :, 0, 0, :sin.shape[-1] // 2]
    vec = flash_attn.layers.rotary.apply_rotary_emb_torch(vec, cos_in, sin_in)
  return vec

apply_rotary_pos_emb_torchscript(qkv, cos, sin)

TorchScript-friendly rotary application for SDPA/backends without flash-attn.

Expects qkv shaped (B, S, 3, H, D). Returns transformed qkv.

Source code in src/discrete_diffusion/models/common.py
def apply_rotary_pos_emb_torchscript(qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
  """TorchScript-friendly rotary application for SDPA/backends without flash-attn.

  Expects qkv shaped (B, S, 3, H, D). Returns transformed qkv.
  """
  return (qkv * cos) + (rotate_half(qkv) * sin)

flash_varlen_attention_qkvpacked(qkv, causal=False, dropout_p=0.0)

FlashAttention varlen over packed qkv.

Parameters:

Name Type Description Default
qkv Tensor

Tensor shaped (B, S, 3, H, D)

required
causal bool

Whether to apply causal masking.

False
dropout_p float

Dropout probability (kept for parity; 0.0 for inference/eval).

0.0

Returns:

Type Description
Tensor

Tensor shaped (B, S, H*D)

Source code in src/discrete_diffusion/models/common.py
def flash_varlen_attention_qkvpacked(
    qkv: torch.Tensor,
    causal: bool = False,
    dropout_p: float = 0.0,
) -> torch.Tensor:
  """FlashAttention varlen over packed qkv.

  Args:
    qkv: Tensor shaped (B, S, 3, H, D)
    causal: Whether to apply causal masking.
    dropout_p: Dropout probability (kept for parity; 0.0 for inference/eval).

  Returns:
    Tensor shaped (B, S, H*D)
  """
  if flash_attn is None:
    raise RuntimeError("flash_attn is not available for flash_varlen_attention_qkvpacked")
  bsz, seqlen = qkv.shape[0], qkv.shape[1]
  qkv_flat = rearrange(qkv, 'b s ... -> (b s) ...')
  cu_seqlens = torch.arange(
    0, (bsz + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device)
  x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
    qkv_flat, cu_seqlens, seqlen, dropout_p, causal=causal)
  x = rearrange(x, '(b s) h d -> b s (h d)', b=bsz)
  return x

residual_linear(x, W, x_skip, residual_scale)

Compute x_skip + residual_scale * (W @ x) efficiently via addmm.

Shapes
  • x: (..., dim_in)
  • W: (dim_out, dim_in)
  • returns: (..., dim_out)
Source code in src/discrete_diffusion/models/common.py
def residual_linear(x: torch.Tensor, W: torch.Tensor, x_skip: torch.Tensor, residual_scale: float) -> torch.Tensor:
  """Compute x_skip + residual_scale * (W @ x) efficiently via addmm.

  Shapes:
    - x: (..., dim_in)
    - W: (dim_out, dim_in)
    - returns: (..., dim_out)
  """
  dim_out, dim_in = W.shape[0], W.shape[1]
  return torch.addmm(
    x_skip.view(-1, dim_out),
    x.view(-1, dim_in),
    W.T,
    alpha=residual_scale).view(*x.shape[:-1], dim_out)

sdpa_attention(q, k, v, attn_mask=None, causal=False, dropout_p=0.0, scale=None)

Scaled dot-product attention over packed heads using torch SDPA.

Parameters:

Name Type Description Default
q, k, v

Tensors shaped (B, S, H, D)

required
attn_mask Optional[Tensor]

Optional mask shaped (B, S, S) or broadcastable.

None
causal bool

Whether to use causal masking.

False
dropout_p float

Dropout probability (training only; kept for API parity).

0.0
scale Optional[float]

Optional scale override (1/sqrt(D)).

None

Returns:

Type Description
Tensor

Tensor shaped (B, S, H*D) (flattened heads for downstream linear).

Source code in src/discrete_diffusion/models/common.py
def sdpa_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    attn_mask: typing.Optional[torch.Tensor] = None,
    causal: bool = False,
    dropout_p: float = 0.0,
    scale: typing.Optional[float] = None,
) -> torch.Tensor:
  """Scaled dot-product attention over packed heads using torch SDPA.

  Args:
    q, k, v: Tensors shaped (B, S, H, D)
    attn_mask: Optional mask shaped (B, S, S) or broadcastable.
    causal: Whether to use causal masking.
    dropout_p: Dropout probability (training only; kept for API parity).
    scale: Optional scale override (1/sqrt(D)).

  Returns:
    Tensor shaped (B, S, H*D) (flattened heads for downstream linear).
  """
  # torch SDPA expects (B, H, S, D)
  q = q.transpose(1, 2)
  k = k.transpose(1, 2)
  v = v.transpose(1, 2)
  x = F.scaled_dot_product_attention(
    q, k, v,
    attn_mask=attn_mask[:, None] if attn_mask is not None else None,
    dropout_p=dropout_p,
    is_causal=causal,
    scale=scale)
  x = x.transpose(1, 2)  # (B, S, H, D)
  return rearrange(x, 'b s h d -> b s (h d)')

sdpa_attention_masked(q, k, v, attn_mask, causal=False)

Convenience wrapper for masked SDPA returning (B,S,H*D).

Source code in src/discrete_diffusion/models/common.py
def sdpa_attention_masked(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    attn_mask: torch.Tensor,
    causal: bool = False) -> torch.Tensor:
  """Convenience wrapper for masked SDPA returning (B,S,H*D)."""
  return sdpa_attention(q, k, v, attn_mask=attn_mask, causal=causal, dropout_p=0.0)

sdpa_attention_unmasked(q, k, v)

Convenience wrapper for unmasked SDPA returning (B,S,H*D).

Source code in src/discrete_diffusion/models/common.py
def sdpa_attention_unmasked(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
  """Convenience wrapper for unmasked SDPA returning (B,S,H*D)."""
  return sdpa_attention(q, k, v, attn_mask=None, causal=False, dropout_p=0.0)

split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin)

Apply rotary to q,k slices of packed qkv.

Expects qkv shaped (B, S, 3, H, D). Returns (q, k, v) with shapes (B, S, H, D), (B, S, H, D), (B, S, H, D).

Source code in src/discrete_diffusion/models/common.py
def split_and_apply_rotary_pos_emb(qkv: torch.Tensor, rotary_cos_sin: typing.Tuple[torch.Tensor, torch.Tensor]):
  """Apply rotary to q,k slices of packed qkv.

  Expects qkv shaped (B, S, 3, H, D). Returns (q, k, v) with shapes
  (B, S, H, D), (B, S, H, D), (B, S, H, D).
  """
  if flash_attn is None:
    raise RuntimeError("flash_attn is required for rotary split helpers")
  with torch.amp.autocast('cuda', enabled=False):
    cos, sin = rotary_cos_sin
    cos = cos.to(qkv.dtype)
    sin = sin.to(qkv.dtype)
    # Align cached length/batch with qkv if needed
    if qkv.shape[1] < cos.shape[1]:
      cos = cos[:, :qkv.shape[1]]
      sin = sin[:, :qkv.shape[1]]
    if cos.shape[0] == 1:
      cos_in = cos[0, :, 0, 0, :cos.shape[-1] // 2]
      sin_in = sin[0, :, 0, 0, :sin.shape[-1] // 2]
    else:
      cos_in = cos[:, :, 0, 0, :cos.shape[-1] // 2]
      sin_in = sin[:, :, 0, 0, :sin.shape[-1] // 2]
    q, k, v = qkv.chunk(3, dim=2)
    q = flash_attn.layers.rotary.apply_rotary_emb_torch(q.squeeze(dim=2), cos_in, sin_in)
    k = flash_attn.layers.rotary.apply_rotary_emb_torch(k.squeeze(dim=2), cos_in, sin_in)
    v = v.squeeze(dim=2)
  return q, k, v

supports_flash_attention()

Check if flash-attn is available and functional.

Source code in src/discrete_diffusion/models/common.py
def supports_flash_attention() -> bool:
  """Check if flash-attn is available and functional."""
  return FLASH_ATTN_AVAILABLE

supports_flex_attention()

Check if torch flex attention is available (PyTorch 2.4+).

Source code in src/discrete_diffusion/models/common.py
def supports_flex_attention() -> bool:
  """Check if torch flex attention is available (PyTorch 2.4+)."""
  return hasattr(torch.nn.functional, 'flex_attention')