Skip to content

Diffusion Transformer (DiT)

discrete_diffusion.models.dit

DIT

Bases: Module, PyTorchModelHubMixin

Diffusion Transformer (DiT) backbone model.

A Transformer architecture optimized for diffusion, supporting both causal (GPT-style) and bidirectional (BERT-style) attention, with adaptive layer normalization for time conditioning.

Source code in src/discrete_diffusion/models/dit.py
class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
  """Diffusion Transformer (DiT) backbone model.

  A Transformer architecture optimized for diffusion, supporting both
  causal (GPT-style) and bidirectional (BERT-style) attention, with
  adaptive layer normalization for time conditioning.
  """
  def __init__(self, config, vocab_size: int):
    """Initialize the DiT model.

    Args:
        config: Hydra configuration object containing model hyperparameters.
        vocab_size: Size of the vocabulary.
    """
    super().__init__()
    if type(config) == dict:
      config = omegaconf.OmegaConf.create(config)
    self.causal = config.algo.causal_attention
    self.adaLN = not self.causal
    self.config = config
    self.vocab_size = vocab_size
    dim = config.model.hidden_size
    cond_dim = config.model.cond_dim
    self.vocab_embed = EmbeddingLayer(dim, vocab_size)
    if not self.causal:
      self.sigma_map = TimestepEmbedder(cond_dim)
    self.rotary_emb = Rotary(dim // config.model.n_heads)

    blocks = []
    for _ in range(config.model.n_blocks):
      if self.causal:
        block = DDiTBlockCausal(
          dim=dim,
          n_heads=config.model.n_heads,
          dropout=config.model.dropout)
      else:
        block = DDiTBlock(
          dim=dim,
          n_heads=config.model.n_heads,
          cond_dim=cond_dim,
          adaLN=self.adaLN,
          dropout=config.model.dropout)
      blocks.append(block)
    self.blocks = nn.ModuleList(blocks)

    self.output_layer = DDiTFinalLayer(
      hidden_size=dim,
      out_channels=vocab_size,
      cond_dim=cond_dim,
      adaLN=self.adaLN)
    self.scale_by_sigma = config.model.scale_by_sigma
    # Tie output projection to input embeddings if requested
    if getattr(config.model, 'tie_word_embeddings', False):
      self.output_layer.linear.weight = self.vocab_embed.embedding

  def _get_bias_dropout_scale(self):
    if self.training:
      return bias_dropout_add_scale_fused_train
    else:
      return  bias_dropout_add_scale_fused_inference

  def forward(self, x, sigma):
    """Forward pass of the DiT.

    Args:
        x: Input token indices [batch, seq_len].
        sigma: Noise level/time embedding [batch] or [batch, seq_len].

    Returns:
        Tensor: Logits [batch, seq_len, vocab_size].
    """
    x = self.vocab_embed(x)
    if self.causal:
      t_cond = None
    else:
      t_cond = F.silu(self.sigma_map(sigma))

    rotary_cos_sin = self.rotary_emb(x)

    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
      for i in range(len(self.blocks)):
        x = self.blocks[i](x, rotary_cos_sin, c=t_cond)
      x = self.output_layer(x, c=t_cond)

    return x

__init__(config, vocab_size)

Initialize the DiT model.

Parameters:

Name Type Description Default
config

Hydra configuration object containing model hyperparameters.

required
vocab_size int

Size of the vocabulary.

required
Source code in src/discrete_diffusion/models/dit.py
def __init__(self, config, vocab_size: int):
  """Initialize the DiT model.

  Args:
      config: Hydra configuration object containing model hyperparameters.
      vocab_size: Size of the vocabulary.
  """
  super().__init__()
  if type(config) == dict:
    config = omegaconf.OmegaConf.create(config)
  self.causal = config.algo.causal_attention
  self.adaLN = not self.causal
  self.config = config
  self.vocab_size = vocab_size
  dim = config.model.hidden_size
  cond_dim = config.model.cond_dim
  self.vocab_embed = EmbeddingLayer(dim, vocab_size)
  if not self.causal:
    self.sigma_map = TimestepEmbedder(cond_dim)
  self.rotary_emb = Rotary(dim // config.model.n_heads)

  blocks = []
  for _ in range(config.model.n_blocks):
    if self.causal:
      block = DDiTBlockCausal(
        dim=dim,
        n_heads=config.model.n_heads,
        dropout=config.model.dropout)
    else:
      block = DDiTBlock(
        dim=dim,
        n_heads=config.model.n_heads,
        cond_dim=cond_dim,
        adaLN=self.adaLN,
        dropout=config.model.dropout)
    blocks.append(block)
  self.blocks = nn.ModuleList(blocks)

  self.output_layer = DDiTFinalLayer(
    hidden_size=dim,
    out_channels=vocab_size,
    cond_dim=cond_dim,
    adaLN=self.adaLN)
  self.scale_by_sigma = config.model.scale_by_sigma
  # Tie output projection to input embeddings if requested
  if getattr(config.model, 'tie_word_embeddings', False):
    self.output_layer.linear.weight = self.vocab_embed.embedding

forward(x, sigma)

Forward pass of the DiT.

Parameters:

Name Type Description Default
x

Input token indices [batch, seq_len].

required
sigma

Noise level/time embedding [batch] or [batch, seq_len].

required

Returns:

Name Type Description
Tensor

Logits [batch, seq_len, vocab_size].

Source code in src/discrete_diffusion/models/dit.py
def forward(self, x, sigma):
  """Forward pass of the DiT.

  Args:
      x: Input token indices [batch, seq_len].
      sigma: Noise level/time embedding [batch] or [batch, seq_len].

  Returns:
      Tensor: Logits [batch, seq_len, vocab_size].
  """
  x = self.vocab_embed(x)
  if self.causal:
    t_cond = None
  else:
    t_cond = F.silu(self.sigma_map(sigma))

  rotary_cos_sin = self.rotary_emb(x)

  with torch.amp.autocast('cuda', dtype=torch.bfloat16):
    for i in range(len(self.blocks)):
      x = self.blocks[i](x, rotary_cos_sin, c=t_cond)
    x = self.output_layer(x, c=t_cond)

  return x