Common Model Layers
discrete_diffusion.models.common
Shared model primitives for UNI-D² backbones.
This module centralizes lightweight, backend-agnostic helpers used across multiple backbones (DiT, BlockDiT, Encoder-Decoder). Attention backend selection (flash-attn vs SDPA vs flex) remains in each backbone.
DDiTBlockCausal
Bases: Module
Source code in src/discrete_diffusion/models/common.py
LabelEmbedder
Bases: Module
Embeds class labels into vector representations.
Source code in src/discrete_diffusion/models/common.py
TimestepEmbedder
Bases: Module
Embeds scalar timesteps into vector representations.
Source code in src/discrete_diffusion/models/common.py
apply_rotary_pos_emb(qkv, cos, sin)
In-place rotary application for qkv-packed tensors using flash-attn helper.
Source code in src/discrete_diffusion/models/common.py
apply_rotary_pos_emb_single(vec, cos, sin)
Apply rotary to a single tensor (q or k) with shape (B, S, H, D).
Source code in src/discrete_diffusion/models/common.py
apply_rotary_pos_emb_torchscript(qkv, cos, sin)
TorchScript-friendly rotary application for SDPA/backends without flash-attn.
Expects qkv shaped (B, S, 3, H, D). Returns transformed qkv.
Source code in src/discrete_diffusion/models/common.py
flash_varlen_attention_qkvpacked(qkv, causal=False, dropout_p=0.0)
FlashAttention varlen over packed qkv.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qkv
|
Tensor
|
Tensor shaped (B, S, 3, H, D) |
required |
causal
|
bool
|
Whether to apply causal masking. |
False
|
dropout_p
|
float
|
Dropout probability (kept for parity; 0.0 for inference/eval). |
0.0
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Tensor shaped (B, S, H*D) |
Source code in src/discrete_diffusion/models/common.py
residual_linear(x, W, x_skip, residual_scale)
Compute x_skip + residual_scale * (W @ x) efficiently via addmm.
Shapes
- x: (..., dim_in)
- W: (dim_out, dim_in)
- returns: (..., dim_out)
Source code in src/discrete_diffusion/models/common.py
sdpa_attention(q, k, v, attn_mask=None, causal=False, dropout_p=0.0, scale=None)
Scaled dot-product attention over packed heads using torch SDPA.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q, k, v
|
Tensors shaped (B, S, H, D) |
required | |
attn_mask
|
Optional[Tensor]
|
Optional mask shaped (B, S, S) or broadcastable. |
None
|
causal
|
bool
|
Whether to use causal masking. |
False
|
dropout_p
|
float
|
Dropout probability (training only; kept for API parity). |
0.0
|
scale
|
Optional[float]
|
Optional scale override (1/sqrt(D)). |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Tensor shaped (B, S, H*D) (flattened heads for downstream linear). |
Source code in src/discrete_diffusion/models/common.py
sdpa_attention_masked(q, k, v, attn_mask, causal=False)
Convenience wrapper for masked SDPA returning (B,S,H*D).
Source code in src/discrete_diffusion/models/common.py
sdpa_attention_unmasked(q, k, v)
Convenience wrapper for unmasked SDPA returning (B,S,H*D).
Source code in src/discrete_diffusion/models/common.py
split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin)
Apply rotary to q,k slices of packed qkv.
Expects qkv shaped (B, S, 3, H, D). Returns (q, k, v) with shapes (B, S, H, D), (B, S, H, D), (B, S, H, D).