Skip to main content
MaxDiffusion supports multiple data input pipelines for training. This guide covers dataset formats, preprocessing scripts, and best practices.

Dataset types

MaxDiffusion supports four dataset types, controlled by the dataset_type flag:
PipelineLocationFormatsFeatures
hfHuggingFace Hub or Cloud Storageparquet, arrow, json, csv, txtStreaming, good for large datasets
tfHuggingFace Hub (downloads to disk)parquet, arrow, json, csv, txtIn-memory, works for small datasets
tfrecordLocal/Cloud StorageTFRecordStreaming, good for large datasets
grainLocal/Cloud StorageArrayRecordStreaming, global shuffle, deterministic

HuggingFace streaming (dataset_type=hf)

Stream data directly from HuggingFace Hub or cloud storage without downloading.

From HuggingFace Hub

dataset_type: hf
dataset_name: BleachNick/UltraEdit_500k
image_column: source_image
caption_column: source_caption
train_split: FreeForm
hf_access_token: ''  # For gated datasets

From cloud storage

dataset_type: hf
dataset_name: parquet  # or json, arrow, etc.
hf_train_files: gs://my-bucket/my-dataset/*-train-*.parquet

tf.data in-memory (dataset_type=tf)

Downloads entire dataset to memory. Best for small datasets.
dataset_type: tf
dataset_name: diffusers/pokemon-gpt4-captions
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
cache_latents_text_encoder_outputs: True
When cache_latents_text_encoder_outputs=True, the VAE and text encoder process images and captions during dataset creation, saving preprocessed latents and embeddings.

TFRecord format (dataset_type=tfrecord)

Use TFRecord files for efficient streaming of large datasets.
dataset_type: tfrecord
train_data_dir: gs://my-bucket/my-dataset/  # Directory containing .tfrec files

Grain format (dataset_type=grain)

Grain provides global shuffle and deterministic data iteration.
dataset_type: grain
grain_train_files: gs://my-bucket/my-dataset/*.arrayrecord

Wan dataset preprocessing

Wan models require special preprocessing to create TFRecord datasets with video latents and text embeddings.

Wan PusaV1 dataset example

This example uses the PusaV1 dataset.

Download the dataset

export HF_DATASET_DIR=/mnt/disks/external_disk/PusaV1_training/
export TFRECORDS_DATASET_DIR=/mnt/disks/external_disk/wan_tfr_dataset_pusa_v1
huggingface-cli download RaphaelLiu/PusaV1_training --repo-type dataset --local-dir $HF_DATASET_DIR

Create training dataset

python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  train_data_dir=$HF_DATASET_DIR \
  tfrecords_dir=$TFRECORDS_DATASET_DIR/train \
  no_records_per_shard=10 \
  enable_eval_timesteps=False
Check progress:
ls -ll $TFRECORDS_DATASET_DIR/train

Create evaluation dataset

python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  train_data_dir=$HF_DATASET_DIR \
  tfrecords_dir=$TFRECORDS_DATASET_DIR/eval \
  no_records_per_shard=10 \
  enable_eval_timesteps=True
The evaluation dataset creates 420 samples with timestep annotations for quality evaluation as described in Scaling Rectified Flow Transformers.

Remove duplicates from training set

Delete the first 420 samples from training data (they’re in eval):
printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | \
  awk -F '[-.]' '$2+0 <= 420' | \
  xargs -d '\n' rm
Verify deletion:
printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | \
  awk -F '[-.]' '$2+0 <= 420' | \
  xargs -d '\n' echo

Clean up empty files

Remove any empty eval files:
rm $TFRECORDS_DATASET_DIR/eval_timesteps/file_42-430.tfrec 2>/dev/null || true

Directory structure

Your dataset should now have:
$TFRECORDS_DATASET_DIR/
├── train/
│   ├── file_00-10.tfrec
│   ├── file_01-20.tfrec
│   └── ...
└── eval_timesteps/
    ├── file_00-10.tfrec
    ├── file_01-20.tfrec
    └── ...

General text-to-video preprocessing

For other video datasets, use the general preprocessing script:
python src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  dataset_name="your-dataset/name" \
  tfrecords_dir=$TFRECORDS_DATASET_DIR \
  no_records_per_shard=10 \
  caption_column="text" \
  image_column="image" \
  height=1280 \
  width=720 \
  seed=42
This script:
  1. Loads videos from HuggingFace datasets
  2. Encodes videos using the VAE
  3. Encodes captions using the T5 text encoder
  4. Saves latents and embeddings to TFRecord format

Configuration options

ParameterDescription
train_data_dirPath to downloaded dataset
tfrecords_dirOutput directory for TFRecord files
no_records_per_shardNumber of examples per TFRecord file
enable_eval_timestepsAdd timestep annotations for evaluation
timesteps_listTimesteps for evaluation buckets
num_eval_samplesNumber of evaluation samples (default: 420)

Upload to cloud storage

Copy preprocessed data to GCS for distributed training:
BUCKET_NAME=my-bucket
gcloud storage cp --recursive $TFRECORDS_DATASET_DIR gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}

Using preprocessed data

For training

python src/maxdiffusion/train_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  dataset_type='tfrecord' \
  train_data_dir=gs://$BUCKET_NAME/wan_tfr_dataset_pusa_v1/train/ \
  load_tfrecord_cached=True

For evaluation

python src/maxdiffusion/train_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  eval_every=100 \
  eval_data_dir=gs://$BUCKET_NAME/wan_tfr_dataset_pusa_v1/eval_timesteps/

Multihost dataloading

In multihost environments, optimal performance requires each data file to be accessed by only one host.

Best practices

  • Number of files > Number of hosts - Each host reads a subset of files
  • File assignment - Files are distributed evenly across hosts
  • Epoch handling - Hosts may finish epochs at different times

Resharding datasets

If you have fewer files than hosts, reshard your dataset:
# Increase no_records_per_shard to create fewer, larger files
no_records_per_shard=100  # Instead of 10
Or split existing files:
python -c "
import tensorflow as tf
import glob

files = glob.glob('dataset/*.tfrec')
target_shards = 64  # Number of output files

# Implementation to split files...
"

Synthetic data

For testing and benchmarking without real data:
dataset_type: 'synthetic'
synthetic_num_samples: null  # Infinite samples

# Override dimensions
synthetic_override_height: 720
synthetic_override_width: 1280
synthetic_override_num_frames: 85
synthetic_override_max_sequence_length: 512
synthetic_override_text_embed_dim: 4096
synthetic_override_num_channels_latents: 16
synthetic_override_vae_scale_factor_spatial: 8
synthetic_override_vae_scale_factor_temporal: 4

Dataset configuration reference

Common parameters

image_column: 'image'           # Column name for images/videos
caption_column: 'text'          # Column name for captions
resolution: 1024                # Image resolution (for images)
height: 1280                    # Video height (for videos)
width: 720                      # Video width (for videos)
num_frames: 81                  # Number of video frames
center_crop: False              # Center crop images
random_flip: False              # Random horizontal flip
enable_data_shuffling: True     # Shuffle data during training

Performance parameters

tokenize_captions_num_proc: 4   # Parallel workers for tokenization
transform_images_num_proc: 4    # Parallel workers for image processing
reuse_example_batch: False      # Reuse same batch (for debugging)

Build docs developers (and LLMs) love