Bases: Module, PyTorchModelHubMixin
Diffusion Transformer (DiT) backbone model.
A Transformer architecture optimized for diffusion, supporting both
causal (GPT-style) and bidirectional (BERT-style) attention, with
adaptive layer normalization for time conditioning.
Source code in src/discrete_diffusion/models/dit.py
| class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
"""Diffusion Transformer (DiT) backbone model.
A Transformer architecture optimized for diffusion, supporting both
causal (GPT-style) and bidirectional (BERT-style) attention, with
adaptive layer normalization for time conditioning.
"""
def __init__(self, config, vocab_size: int):
"""Initialize the DiT model.
Args:
config: Hydra configuration object containing model hyperparameters.
vocab_size: Size of the vocabulary.
"""
super().__init__()
if type(config) == dict:
config = omegaconf.OmegaConf.create(config)
self.causal = config.algo.causal_attention
self.adaLN = not self.causal
self.config = config
self.vocab_size = vocab_size
dim = config.model.hidden_size
cond_dim = config.model.cond_dim
self.vocab_embed = EmbeddingLayer(dim, vocab_size)
if not self.causal:
self.sigma_map = TimestepEmbedder(cond_dim)
self.rotary_emb = Rotary(dim // config.model.n_heads)
blocks = []
for _ in range(config.model.n_blocks):
if self.causal:
block = DDiTBlockCausal(
dim=dim,
n_heads=config.model.n_heads,
dropout=config.model.dropout)
else:
block = DDiTBlock(
dim=dim,
n_heads=config.model.n_heads,
cond_dim=cond_dim,
adaLN=self.adaLN,
dropout=config.model.dropout)
blocks.append(block)
self.blocks = nn.ModuleList(blocks)
self.output_layer = DDiTFinalLayer(
hidden_size=dim,
out_channels=vocab_size,
cond_dim=cond_dim,
adaLN=self.adaLN)
self.scale_by_sigma = config.model.scale_by_sigma
# Tie output projection to input embeddings if requested
if getattr(config.model, 'tie_word_embeddings', False):
self.output_layer.linear.weight = self.vocab_embed.embedding
def _get_bias_dropout_scale(self):
if self.training:
return bias_dropout_add_scale_fused_train
else:
return bias_dropout_add_scale_fused_inference
def forward(self, x, sigma):
"""Forward pass of the DiT.
Args:
x: Input token indices [batch, seq_len].
sigma: Noise level/time embedding [batch] or [batch, seq_len].
Returns:
Tensor: Logits [batch, seq_len, vocab_size].
"""
x = self.vocab_embed(x)
if self.causal:
t_cond = None
else:
t_cond = F.silu(self.sigma_map(sigma))
rotary_cos_sin = self.rotary_emb(x)
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
for i in range(len(self.blocks)):
x = self.blocks[i](x, rotary_cos_sin, c=t_cond)
x = self.output_layer(x, c=t_cond)
return x
|
__init__(config, vocab_size)
Initialize the DiT model.
Parameters:
| Name |
Type |
Description |
Default |
config
|
|
Hydra configuration object containing model hyperparameters.
|
required
|
vocab_size
|
int
|
|
required
|
Source code in src/discrete_diffusion/models/dit.py
| def __init__(self, config, vocab_size: int):
"""Initialize the DiT model.
Args:
config: Hydra configuration object containing model hyperparameters.
vocab_size: Size of the vocabulary.
"""
super().__init__()
if type(config) == dict:
config = omegaconf.OmegaConf.create(config)
self.causal = config.algo.causal_attention
self.adaLN = not self.causal
self.config = config
self.vocab_size = vocab_size
dim = config.model.hidden_size
cond_dim = config.model.cond_dim
self.vocab_embed = EmbeddingLayer(dim, vocab_size)
if not self.causal:
self.sigma_map = TimestepEmbedder(cond_dim)
self.rotary_emb = Rotary(dim // config.model.n_heads)
blocks = []
for _ in range(config.model.n_blocks):
if self.causal:
block = DDiTBlockCausal(
dim=dim,
n_heads=config.model.n_heads,
dropout=config.model.dropout)
else:
block = DDiTBlock(
dim=dim,
n_heads=config.model.n_heads,
cond_dim=cond_dim,
adaLN=self.adaLN,
dropout=config.model.dropout)
blocks.append(block)
self.blocks = nn.ModuleList(blocks)
self.output_layer = DDiTFinalLayer(
hidden_size=dim,
out_channels=vocab_size,
cond_dim=cond_dim,
adaLN=self.adaLN)
self.scale_by_sigma = config.model.scale_by_sigma
# Tie output projection to input embeddings if requested
if getattr(config.model, 'tie_word_embeddings', False):
self.output_layer.linear.weight = self.vocab_embed.embedding
|
forward(x, sigma)
Forward pass of the DiT.
Parameters:
| Name |
Type |
Description |
Default |
x
|
|
Input token indices [batch, seq_len].
|
required
|
sigma
|
|
Noise level/time embedding [batch] or [batch, seq_len].
|
required
|
Returns:
| Name | Type |
Description |
Tensor |
|
Logits [batch, seq_len, vocab_size].
|
Source code in src/discrete_diffusion/models/dit.py
| def forward(self, x, sigma):
"""Forward pass of the DiT.
Args:
x: Input token indices [batch, seq_len].
sigma: Noise level/time embedding [batch] or [batch, seq_len].
Returns:
Tensor: Logits [batch, seq_len, vocab_size].
"""
x = self.vocab_embed(x)
if self.causal:
t_cond = None
else:
t_cond = F.silu(self.sigma_map(sigma))
rotary_cos_sin = self.rotary_emb(x)
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
for i in range(len(self.blocks)):
x = self.blocks[i](x, rotary_cos_sin, c=t_cond)
x = self.output_layer(x, c=t_cond)
return x
|