FlexMDM Transformer
discrete_diffusion.models.flexmdm_transformer
FlexMDM Transformer Model for Any-Order Mask Insertion Flow.
This module implements the transformer architecture for FlexMDM, including adaptive layer normalization, rotary embeddings, and dual prediction heads for both token logits and expected gap lengths.
AnyOrderMaskInsertionFlow
Bases: Module
FlexMDM Any-Order Mask Insertion Flow model.
This model predicts both token logits and expected gap lengths for the joint insertion-masking process.
Source code in src/discrete_diffusion/models/flexmdm_transformer.py
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 | |
forward(indices, t)
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
indices
|
Tensor
|
Token indices [B, L] |
required |
t
|
Tensor
|
Timestep [B] |
required |
Returns:
| Type | Description |
|---|---|
ModelPrediction
|
ModelPrediction with token_logits and expected_gaps or length_posterior |
Source code in src/discrete_diffusion/models/flexmdm_transformer.py
DDiTBlock
Bases: Module
Diffusion Transformer block with adaptive layer norm.
Source code in src/discrete_diffusion/models/flexmdm_transformer.py
DDitFinalLayer
Bases: Module
Final output layer with adaptive layer norm.
Source code in src/discrete_diffusion/models/flexmdm_transformer.py
EmbeddingLayer
Bases: Module
Token embedding layer.
Source code in src/discrete_diffusion/models/flexmdm_transformer.py
LayerNorm
Bases: Module
Layer normalization with learnable scale.
Source code in src/discrete_diffusion/models/flexmdm_transformer.py
Rotary
Bases: Module
Rotary positional embeddings.
Source code in src/discrete_diffusion/models/flexmdm_transformer.py
ScalarLengthHead
Bases: Module
Predicts expected gap lengths as scalars.
Source code in src/discrete_diffusion/models/flexmdm_transformer.py
TimestepEmbedder
Bases: Module
Embeds scalar timesteps into vector representations.
Source code in src/discrete_diffusion/models/flexmdm_transformer.py
timestep_embedding(t, dim, max_period=10000)
staticmethod
Create sinusoidal timestep embeddings.
Source code in src/discrete_diffusion/models/flexmdm_transformer.py
apply_rotary_pos_emb(qkv, cos, sin)
Apply rotary positional embeddings (uses flash_attn if available).
Source code in src/discrete_diffusion/models/flexmdm_transformer.py
bias_dropout_add_scale_fused_inference(x, bias, scale, residual, prob)
Fused bias-dropout-add-scale for inference.
Source code in src/discrete_diffusion/models/flexmdm_transformer.py
bias_dropout_add_scale_fused_train(x, bias, scale, residual, prob)
Fused bias-dropout-add-scale for training.
Source code in src/discrete_diffusion/models/flexmdm_transformer.py
get_mask_mod(seq_len)
Create mask function for variable-length sequences.