Skip to main content

Overview

FLUX.2-klein is a 4B parameter image generation model optimized for Apple Silicon. This crate provides the Rust implementation using MLX bindings.

Key features

  • FLUX.2-klein transformer: 4B parameter model with 5 double-stream and 20 single-stream blocks
  • Qwen3-4B text encoder: Shared with Z-Image-Turbo, produces 7680-dim embeddings
  • VAE decoder: AutoencoderKL for latent-to-image decoding
  • 4-bit quantization: Memory-efficient inference (~3GB vs ~8GB)
  • Rectified flow sampling: Fast 4-step generation

Core types

FluxKlein

The main transformer model for FLUX.2-klein.
pub struct FluxKlein {
    pub params: FluxKleinParams,
    // ... internal layers
}
params
FluxKleinParams
Model configuration parameters

Methods

new
fn(params: FluxKleinParams) -> Result<Self, Exception>
Create a new FLUX.2-klein transformer with the given parameters.
params
FluxKleinParams
Model configuration. Use FluxKleinParams::default() for standard 4B model.
forward
fn(&mut self, img: &Array, txt: &Array, timesteps: &Array, img_ids: &Array, txt_ids: &Array) -> Result<Array, Exception>
Run forward pass through the transformer.
img
&Array
Image latents [batch, seq, in_channels] where in_channels=128 (after patchify)
txt
&Array
Text embeddings [batch, seq, 7680] from Qwen3TextEncoder
timesteps
&Array
Denoising timesteps [batch] from 0.0 to 1.0
img_ids
&Array
Image position IDs [batch, img_seq, 3] for 3-axis RoPE
txt_ids
&Array
Text position IDs [batch, txt_seq, 3] for 3-axis RoPE
returns
Array
Predicted velocity [batch, img_seq, in_channels]
forward_with_rope
fn(&mut self, img: &Array, txt: &Array, timesteps: &Array, rope_cos: &Array, rope_sin: &Array) -> Result<Array, Exception>
Forward pass with pre-computed RoPE frequencies (faster for denoising loops).
rope_cos
&Array
Pre-computed cosine frequencies from compute_rope()
rope_sin
&Array
Pre-computed sine frequencies from compute_rope()
compute_rope
fn(txt_ids: &Array, img_ids: &Array) -> Result<(Array, Array), Exception>
Pre-compute RoPE frequencies for caching. Call once before denoising loop.
returns
(Array, Array)
Tuple of (cos, sin) frequencies for efficient reuse

FluxKleinParams

Configuration parameters for FLUX.2-klein.
pub struct FluxKleinParams {
    pub in_channels: i32,      // 128 (after patchify)
    pub hidden_size: i32,      // 3072
    pub txt_embed_dim: i32,    // 7680 (from Qwen3)
    pub num_heads: i32,        // 24
    pub mlp_ratio: f32,        // 3.0
    pub depth: i32,            // 5 double stream blocks
    pub depth_single: i32,     // 20 single stream blocks
    pub head_dim: i32,         // 128
    pub mlp_hidden: i32,       // 9216 (3072 * 3)
}
in_channels
i32
default:"128"
Input channels after patchify (32 VAE channels × 2×2 patch)
hidden_size
i32
default:"3072"
Model dimension
txt_embed_dim
i32
default:"7680"
Text embedding dimension from Qwen3-4B
depth
i32
default:"5"
Number of double-stream transformer blocks
depth_single
i32
default:"20"
Number of single-stream transformer blocks

Qwen3TextEncoder

Text encoder using Qwen3-4B model.
pub struct Qwen3TextEncoder {
    pub config: Qwen3Config,
    // ... internal layers
}
new
fn(config: Qwen3Config) -> Result<Self, Exception>
Create a new Qwen3 text encoder.
config
Qwen3Config
Model configuration. Use Qwen3Config::default() for Qwen3-4B.
forward
fn(&mut self, input_ids: &Array, attention_mask: Option<&Array>) -> Result<Array, Exception>
Encode text tokens to embeddings.
input_ids
&Array
Token IDs [batch, seq_len]
attention_mask
Option<&Array>
Optional attention mask [batch, seq_len] where 1 = real token, 0 = padding
returns
Array
Text embeddings [batch, seq_len, 7680] (concatenated layers 8, 17, 26)

Decoder (VAE)

VAE decoder for converting latents to images.
pub struct Decoder {
    pub config: AutoEncoderConfig,
    // ... internal layers
}
new
fn(config: AutoEncoderConfig) -> Result<Self, Exception>
Create a new VAE decoder.
config
AutoEncoderConfig
VAE configuration. Use AutoEncoderConfig::flux2() for FLUX.2.
forward
fn(&mut self, z: &Array) -> Result<Array, Exception>
Decode latents to images.
z
&Array
Latent array [batch, height, width, channels] in NHWC format
returns
Array
Decoded images [batch, height*8, width*8, 3] (RGB, NHWC format)

Sampling

FluxSampler

Rectified flow sampler for denoising.
pub struct FluxSampler {
    pub config: FluxSamplerConfig,
}
new
fn(config: FluxSamplerConfig) -> Result<Self, Exception>
Create a new sampler.
config
FluxSamplerConfig
Sampler configuration
schnell
fn() -> Self
Create sampler for fast 4-step generation (FLUX.2-klein default).
timesteps
fn(&self, num_steps: Option<i32>) -> Result<Vec<f32>, Exception>
Generate timestep schedule from 1.0 to 0.0.
num_steps
Option<i32>
Number of steps (defaults to config value)
returns
Vec<f32>
Timestep values from 1.0 (pure noise) to 0.0 (clean)
step
fn(&self, x_t: &Array, v_pred: &Array, t: f32, t_prev: f32) -> Result<Array, Exception>
Single denoising step using rectified flow.
x_t
&Array
Current noisy sample
v_pred
&Array
Model velocity prediction
t
f32
Current timestep
t_prev
f32
Target timestep
returns
Array
Updated sample x_{t-1} = x_t + (t_prev - t) * v_pred
sample_prior
fn(&self, shape: &[i32]) -> Result<Array, Exception>
Sample initial noise from standard Gaussian.
shape
&[i32]
Latent tensor shape [batch, seq, channels]
returns
Array
Random Gaussian noise

FluxSamplerConfig

Sampler configuration.
pub struct FluxSamplerConfig {
    pub num_steps: i32,
    pub is_schnell: bool,
    pub shift: f32,
}
num_steps
i32
default:"4"
Number of inference steps
is_schnell
bool
default:"true"
Whether to use fast linear schedule (true for FLUX.2-klein)
shift
f32
default:"1.0"
Time shift parameter for non-schnell models
klein
fn() -> Self
Create config for FLUX.2-klein (4 steps, linear schedule).

Quantization

QuantizedFluxKlein

4-bit quantized FLUX.2-klein model for reduced memory usage.
load_quantized_flux_klein
fn(weights: HashMap<String, Array>, params: FluxKleinParams) -> Result<QuantizedFluxKlein, Exception>
Load 4-bit quantized model from pre-quantized weights.
weights
HashMap<String, Array>
Pre-quantized model weights
params
FluxKleinParams
Model configuration
returns
QuantizedFluxKlein
4-bit quantized model (~3GB memory)

Weight utilities

load_safetensors
fn(path: &Path) -> Result<HashMap<String, Array>, Exception>
Load weights from safetensors file.
path
&Path
Path to .safetensors file
sanitize_flux2_klein_weights
fn(weights: HashMap<String, Array>) -> HashMap<String, Array>
Sanitize FLUX.2-klein weights from HuggingFace format to Rust format.
weights
HashMap<String, Array>
Raw weights from HuggingFace
returns
HashMap<String, Array>
Sanitized weights ready to load
sanitize_qwen3_weights
fn(weights: HashMap<String, Array>) -> HashMap<String, Array>
Sanitize Qwen3 text encoder weights.
sanitize_vae_weights
fn(weights: HashMap<String, Array>) -> HashMap<String, Array>
Sanitize VAE decoder weights from HuggingFace format.

Example usage

use flux_klein_mlx::{FluxKlein, Qwen3TextEncoder, Decoder};
use flux_klein_mlx::{FluxSampler, FluxKleinParams, Qwen3Config, AutoEncoderConfig};

// Load models
let text_encoder = Qwen3TextEncoder::new(Qwen3Config::default())?;
let transformer = FluxKlein::new(FluxKleinParams::default())?;
let vae = Decoder::new(AutoEncoderConfig::flux2())?;

// Create sampler
let sampler = FluxSampler::schnell(); // 4-step generation

// Encode text
let text_embed = text_encoder.forward(&input_ids, None)?;

// Generate image
let latents = sampler.sample_prior(&[1, 1024, 128])?;
let (rope_cos, rope_sin) = FluxKlein::compute_rope(&txt_ids, &img_ids)?;

for (t_curr, t_next) in sampler.timesteps(None)?.windows(2) {
    let t = Array::from_slice(&[t_curr], &[1]);
    let v_pred = transformer.forward_with_rope(
        &latents, &text_embed, &t, &rope_cos, &rope_sin
    )?;
    latents = sampler.step(&latents, &v_pred, t_curr, t_next)?;
}

// Decode to image
let image = vae.forward(&latents)?;

Build docs developers (and LLMs) love