Skip to main content

Overview

Z-Image-Turbo is a 6B parameter Single-Stream DiT (S3-DiT) optimized for fast image generation with 9-step turbo inference.

Key features

  • Z-Image-Turbo transformer: 6B parameter model with Noise Refiner, Context Refiner, and 30 Joint blocks
  • Qwen3-4B text encoder: Extracts layer 34 (2560-dim embeddings)
  • 3-axis RoPE: Position encoding with dimensions [32, 48, 48] and theta=256
  • 9-step turbo inference: Distilled for fast generation
  • 4-bit quantization: Memory-efficient inference (~3GB vs ~12GB)

Architecture differences from FLUX

  • Uses Noise Refiner + Context Refiner + Joint blocks (vs Double + Single)
  • 3-axis RoPE [32, 48, 48] with theta=256 (vs 4-axis)
  • Per-block AdaLN with tanh gates
  • Qwen3-4B layer 34 extraction (vs concat layers 8, 17, 26)

Core types

ZImageTransformer

The main transformer model for Z-Image-Turbo.
pub struct ZImageTransformer {
    pub config: ZImageConfig,
    // ... internal layers
}
config
ZImageConfig
Model configuration parameters

Methods

new
fn(config: ZImageConfig) -> Result<Self, Exception>
Create a new Z-Image transformer.
config
ZImageConfig
Model configuration. Use ZImageConfig::default() for standard 6B model.
forward
fn(&mut self, x: &Array, t: &Array, cap_feats: &Array, x_pos: &Array, cap_pos: &Array, x_mask: Option<&Array>, cap_mask: Option<&Array>) -> Result<Array, Exception>
Run forward pass through the transformer.
x
&Array
Image latents [batch, img_seq, in_channels * patch^2] where in_channels=16
t
&Array
Timesteps [batch] scaled by t_scale=1000.0
cap_feats
&Array
Caption features from Qwen3 layer 34 [batch, cap_seq, 2560]
x_pos
&Array
Image position coordinates [batch, img_seq, 3] for (h, w, t)
cap_pos
&Array
Caption position coordinates [batch, cap_seq, 3]
x_mask
Option<&Array>
Optional image mask for padding
cap_mask
Option<&Array>
Optional caption mask for padding
returns
Array
Predicted velocity [batch, img_seq, in_channels * patch^2]
forward_with_rope
fn(&mut self, x: &Array, t: &Array, cap_feats: &Array, x_pos: &Array, cap_pos: &Array, cos: &Array, sin: &Array, x_mask: Option<&Array>, cap_mask: Option<&Array>) -> Result<Array, Exception>
Forward pass with pre-computed RoPE frequencies (faster for denoising loops).
cos
&Array
Pre-computed cosine frequencies from compute_rope()
sin
&Array
Pre-computed sine frequencies from compute_rope()
compute_rope
fn(&self, x_pos: &Array, cap_pos: &Array) -> Result<(Array, Array), Exception>
Pre-compute 3-axis RoPE frequencies for caching.
x_pos
&Array
Image position coordinates [batch, img_seq, 3]
cap_pos
&Array
Caption position coordinates [batch, cap_seq, 3]
returns
(Array, Array)
Tuple of (cos, sin) frequencies for efficient reuse

ZImageConfig

Configuration for Z-Image-Turbo.
pub struct ZImageConfig {
    pub dim: i32,              // 3840
    pub n_heads: i32,          // 30
    pub n_kv_heads: i32,       // 30
    pub n_layers: i32,         // 30 joint blocks
    pub n_refiner_layers: i32, // 2 refiner blocks each
    pub in_channels: i32,      // 16
    pub cap_feat_dim: i32,     // 2560 (Qwen3 layer 34)
    pub axes_dims: [i32; 3],   // [32, 48, 48]
    pub rope_theta: f32,       // 256.0
    pub t_scale: f32,          // 1000.0
    pub norm_eps: f32,         // 1e-5
    pub patch_size: i32,       // 2
}
dim
i32
default:"3840"
Model dimension (hidden size)
n_heads
i32
default:"30"
Number of attention heads
n_layers
i32
default:"30"
Number of joint transformer blocks
n_refiner_layers
i32
default:"2"
Number of noise refiner and context refiner blocks (2 each)
cap_feat_dim
i32
default:"2560"
Caption feature dimension from Qwen3 layer 34
axes_dims
[i32; 3]
default:"[32, 48, 48]"
3-axis RoPE dimensions for (h, w, t) axes
rope_theta
f32
default:"256.0"
RoPE base frequency (different from FLUX’s 2000.0)
t_scale
f32
default:"1000.0"
Timestep scaling factor

ZImageTransformerBlock

Single transformer block with optional AdaLN modulation.
pub struct ZImageTransformerBlock {
    pub dim: i32,
    pub has_modulation: bool,
    // ... internal layers
}
has_modulation
bool
Whether block uses AdaLN modulation (true for noise refiner and joint blocks, false for context refiner)
new
fn(config: &ZImageConfig, has_modulation: bool) -> Result<Self, Exception>
Create a new transformer block.
config
&ZImageConfig
Model configuration
has_modulation
bool
Enable AdaLN modulation with tanh gates

3-axis RoPE utilities

create_coordinate_grid
fn(size: (i32, i32, i32), start: (i32, i32, i32)) -> Result<Array, Exception>
Create 3D coordinate grid for position encoding.
size
(i32, i32, i32)
Grid dimensions (d0, d1, d2)
start
(i32, i32, i32)
Starting coordinates (s0, s1, s2)
returns
Array
Coordinate array [d0*d1*d2, 3]
compute_rope_3axis
fn(positions: &Array, axes_dims: &[i32; 3], theta: f32) -> Result<(Array, Array), Exception>
Compute 3-axis RoPE frequencies.
positions
&Array
Position coordinates [batch, seq, 3] for (h, w, t)
axes_dims
&[i32; 3]
Dimensions per axis, e.g., [32, 48, 48]
theta
f32
Base frequency (256.0 for Z-Image)
returns
(Array, Array)
Tuple of (cos, sin) each shape [batch, seq, 1, half_total_dim]
apply_rope_3axis
fn(x: &Array, cos: &Array, sin: &Array) -> Result<Array, Exception>
Apply rotary embedding using even/odd split.
x
&Array
Input tensor [batch, seq, heads, head_dim]
cos
&Array
Cosine frequencies [batch, seq, 1, head_dim/2]
sin
&Array
Sine frequencies [batch, seq, 1, head_dim/2]
returns
Array
Rotated tensor same shape as input

Quantization

ZImageTransformerQuantized

4-bit quantized Z-Image transformer for reduced memory.
load_quantized_zimage_transformer
fn(weights: HashMap<String, Array>, config: ZImageConfig) -> Result<ZImageTransformerQuantized, Exception>
Load 4-bit quantized model from pre-quantized weights.
weights
HashMap<String, Array>
Pre-quantized model weights
config
ZImageConfig
Model configuration
returns
ZImageTransformerQuantized
4-bit quantized model (~3GB memory vs ~12GB full precision)

QuantizedQwen3TextEncoder

4-bit quantized Qwen3 text encoder.
load_quantized_qwen3_encoder
fn(weights: HashMap<String, Array>, config: Qwen3Config) -> Result<QuantizedQwen3TextEncoder, Exception>
Load 4-bit quantized Qwen3 encoder.
weights
HashMap<String, Array>
Pre-quantized weights
config
Qwen3Config
Qwen3 configuration

Weight utilities

sanitize_mlx_weights
fn(weights: HashMap<String, Array>) -> HashMap<String, Array>
Sanitize Z-Image weights from MLX format to Rust format.
weights
HashMap<String, Array>
Raw weights from MLX format
returns
HashMap<String, Array>
Sanitized weights ready to load
sanitize_zimage_weights
fn(weights: HashMap<String, Array>) -> HashMap<String, Array>
Sanitize Z-Image weights from PyTorch diffusers format.

Re-exported from flux-klein-mlx

Z-Image shares several components with FLUX.2-klein:
  • Qwen3Config, Qwen3TextEncoder, sanitize_qwen3_weights
  • Decoder, AutoEncoderConfig
  • load_safetensors, sanitize_vae_weights
  • FluxSampler, FluxSamplerConfig

Performance comparison

ModeMemorySpeed
Dequantized (f32)~12GB~1.87s/step
Quantized (4-bit)~3GB~2.08s/step

Example usage

use zimage_mlx::{ZImageTransformer, ZImageConfig, create_coordinate_grid};
use flux_klein_mlx::{Qwen3TextEncoder, Decoder, FluxSampler};

// Load models
let text_encoder = Qwen3TextEncoder::new(Qwen3Config::default())?;
let transformer = ZImageTransformer::new(ZImageConfig::default())?;
let vae = Decoder::new(AutoEncoderConfig::default())?;

// Or use quantized for lower memory
let transformer = load_quantized_zimage_transformer(weights, config)?;

// Create position grids
let x_pos = create_coordinate_grid((32, 32, 1), (0, 0, 0))?;
let cap_pos = create_coordinate_grid((77, 1, 1), (0, 0, 0))?;

// Pre-compute RoPE
let (rope_cos, rope_sin) = transformer.compute_rope(&x_pos, &cap_pos)?;

// Generate image
for t in sampler.timesteps(None)? {
    let t_arr = Array::from_slice(&[t], &[1]);
    let v_pred = transformer.forward_with_rope(
        &latents, &t_arr, &cap_feats, 
        &x_pos, &cap_pos, &rope_cos, &rope_sin,
        None, None
    )?;
    latents = sampler.step(&latents, &v_pred, t, t_next)?;
}

Build docs developers (and LLMs) love