UNI-D²
Discrete diffusion models operate in discrete state spaces (tokens, labels, programs, etc.) by scheduling noise and denoising transitions with carefully designed transition kernels instead of continuous Gaussian noise.
This repository centralizes tooling, datasets, experiments, and evaluation pipelines so researchers have a reliable, extendible codebase for discrete diffusion variants in text and structured domains.
Highlights
- Unified Entry Point: Hydra + Lightning workflow for experimenting with MDLM, UDLM, BD3LM, FlexMDM, GIDD, SEDD, PartitionMDLM, and CANDI.
- Comprehensive Sampling: Helpers for absorbing, autoregressive, block, and flexible sampling strategies.
- Reproducibility: Scripts to reproduce training recipes for datasets like LM1B, OpenWebText, and Text8.
Papers Implemented
- MDLM – Sahoo et al. (NeurIPS 2024)
- UDLM – Schiff et al. (arXiv 2024)
- FlexMDM – Kim et al. (arXiv 2025)
- Block Diffusion – Arriola et al. (arXiv 2025)
- GIDD – von Rütte et al. (arXiv 2025)
- SEDD – Lou et al. (arXiv 2023)
- PartitionMDLM – Deschenaux et al. (arXiv 2025)
- CANDI – Pynadath et al. (arXiv 2025)
Installation
For systems with Flash Attention (CUDA 11.4+), install it after the editable install to boost throughput:
For optimized cross-entropy computation on CUDA devices:
Quick Start
1. Training
Run the Hydra-powered CLI. Here is a minimal example for training MDLM on OpenWebText:
PYTHONPATH=src python -u -m discrete_diffusion \
data=owt \
model=small \
algo=mdlm \
loader.batch_size=32 \
trainer.devices=8 \
hydra.run.dir=./outputs/owt/mdlm
2. Sampling
Once you have a checkpoint, use the generation script:
PYTHONPATH=src python -m discrete_diffusion.evaluations.generate_samples \
checkpoint_path=outputs/owt/mdlm/checkpoints/last.ckpt \
num_samples=16 \
num_steps=2000
Extending
Want to implement a new discrete diffusion method? Check out our Extension Guides to learn how to add custom algorithms, forward processes, noise schedules, and models.