Skip to main content

What is lazy evaluation?

Lazy evaluation is a strategy where operations are not executed immediately when called. Instead, MLX builds a computation graph and only evaluates it when the results are explicitly needed.
use mlx_rs::{array, ops};

// These operations don't execute yet - they build a graph
let a = array!([1.0, 2.0, 3.0]);
let b = array!([4.0, 5.0, 6.0]);
let c = ops::add(&a, &b)?;  // Graph: c = add(a, b)
let d = ops::mul(&c, 2.0)?;  // Graph: d = mul(add(a, b), 2.0)

// Evaluation happens here
d.eval()?;  // Now MLX executes the optimized graph
println!("{:?}", d);  // [10.0, 14.0, 18.0]
This is similar to how TensorFlow 1.x worked with graph mode, but MLX constructs graphs dynamically like PyTorch, combining the benefits of both approaches.

Why lazy evaluation?

1. Kernel fusion

Multiple operations can be combined into a single GPU kernel, reducing overhead: Without fusion (eager execution):
// Each operation launches a separate kernel
let x = ops::add(&a, &b)?;       // Kernel 1: x = a + b
x.eval()?;                        // GPU wait
let y = ops::mul(&x, &c)?;       // Kernel 2: y = x * c  
y.eval()?;                        // GPU wait
let z = ops::relu(&y)?;          // Kernel 3: z = relu(y)
z.eval()?;                        // GPU wait

// Total: 3 kernel launches + 3 GPU waits
With fusion (lazy evaluation):
// Build graph without execution
let x = ops::add(&a, &b)?;
let y = ops::mul(&x, &c)?;
let z = ops::relu(&y)?;

// Single kernel: z = relu((a + b) * c)
z.eval()?;  

// Total: 1 kernel launch + 1 GPU wait
Kernel fusion reduces:
  • Kernel launch overhead: ~10-50 microseconds per launch
  • Memory bandwidth: Intermediate results stay in GPU registers/cache
  • Global memory writes: x and y never written to memory

2. Memory optimization

Intermediate arrays can be skipped entirely:
let (a, b) = heavy_computation()?;  // Both results computed

// Only use a, ignore b
let c = ops::add(&a, 1.0)?;
c.eval()?;  // b is never materialized

// Memory saved: size of b
Real-world example from transformer models:
// Attention computation
let scores = ops::matmul(&q, &k)?;  // [batch, heads, seq, seq]
let scaled = ops::div(&scores, scale)?;
let masked = ops::add(&scaled, &mask)?;
let probs = ops::softmax(&masked, -1)?;
let out = ops::matmul(&probs, &v)?;

out.eval()?;  
// Only 'out' materialized, intermediate tensors (scores, scaled, masked, probs) 
// may be fused or computed on-the-fly

3. Computation graph optimization

MLX can optimize the graph before execution: Constant folding:
let x = array!([1.0, 2.0, 3.0]);
let scale = Array::from_f32(2.0);
let offset = Array::from_f32(3.0);

let combined = ops::add(&scale, &offset)?;  // Can be precomputed
let y = ops::mul(&x, &combined)?;           // Uses precomputed constant
Common subexpression elimination:
let a_squared = ops::mul(&a, &a)?;
let expr1 = ops::add(&a_squared, &b)?;
let expr2 = ops::sub(&a_squared, &c)?;  // Reuses a_squared computation
Dead code elimination:
fn process(x: &Array, use_branch_a: bool) -> Array {
    let branch_a = expensive_computation_a(x)?;  // Only built if needed
    let branch_b = expensive_computation_b(x)?;  // Only built if needed
    
    if use_branch_a {
        branch_a  // branch_b never evaluated
    } else {
        branch_b  // branch_a never evaluated  
    }
}
See mlx-rs/src/lib.rs:60 for detailed lazy evaluation documentation.

When evaluation happens

Explicit evaluation

The eval() method forces evaluation:
let a = array!([1, 2, 3]);
let b = ops::add(&a, 1)?;  // Not evaluated yet

b.eval()?;  // Explicitly evaluate

Implicit evaluation

Certain operations automatically trigger evaluation: Accessing array data:
let x = ops::add(&a, &b)?;

// These all trigger evaluation:
let slice: &[f32] = x.as_slice();  // Access raw data
let value: f32 = x.item();          // Get scalar value  
println!("{:?}", x);                // Display array (calls eval internally)
Saving to disk:
let weights = model.parameters();

// Saves trigger evaluation of all parameters
Array::save_safetensors("model.safetensors", &weights)?;
Control flow based on array values:
let condition = ops::greater(&x, &threshold)?;

// Using as boolean triggers evaluation
if condition.item::<bool>() {  // eval() called here
    println!("Threshold exceeded");
}
Using arrays for control flow can be inefficient if done frequently. The graph is evaluated at each branch, preventing larger graph optimizations.

Batch evaluation

Evaluate multiple arrays at once:
use mlx_rs::transforms::eval;

let a = ops::add(&x, &y)?;
let b = ops::mul(&x, &z)?;
let c = ops::sub(&y, &z)?;

// Evaluate all three in a single call
eval(&[&a, &b, &c])?;

// Now all three are materialized
This is more efficient than:
a.eval()?;  // GPU launch 1
b.eval()?;  // GPU launch 2  
c.eval()?;  // GPU launch 3

Practical usage patterns

Training loops

Evaluate at the end of each iteration:
use mlx_rs::transforms::{grad, eval};

for (step, batch) in dataset.enumerate() {
    // Build computation graph (no evaluation)
    let loss = model.forward(&batch)?;
    let grads = grad(loss_fn, &[0])(&[&batch])?;
    optimizer.update(&mut model, grads)?;
    
    // Evaluate once per step
    eval(&[&loss])?;  // Also evaluates updated parameters
    
    if step % 100 == 0 {
        println!("Step {}: loss = {:.4}", step, loss.item::<f32>());
    }
}
Evaluate at the natural iteration boundary (batch/epoch) rather than after every operation.

Inference with KV caching

use mlx_rs_core::{ConcatKeyValueCache, Generate};

let mut cache: Vec<Option<ConcatKeyValueCache>> = vec![None; num_layers];
let mut tokens = prompt_tokens.clone();

for step in 0..max_tokens {
    // Build graph for forward pass
    let logits = model.forward(&tokens, &mut cache)?;
    let next_token = sample(&logits, temperature)?;
    
    // Evaluate once per token generation
    next_token.eval()?;  // Also evaluates cache updates
    
    tokens = vec![next_token.item::<u32>()];
}
The KV cache grows without reevaluating past tokens - lazy evaluation makes this efficient.

Audio processing pipeline

use mlx_rs_core::audio::{load_wav, resample};
use funasr_mlx::transcribe;

// CPU: Load audio (I/O, no graph)
let (samples, rate) = load_wav("audio.wav")?;
let samples_16k = resample(&samples, rate, 16000);

// GPU: Build encoder graph
let features = extract_features(&samples_16k)?;
let hidden = encoder.forward(&features)?;
let logits = decoder.forward(&hidden)?;

// Evaluate once at the end
logits.eval()?;
let transcript = decode_tokens(&logits)?;

Conditional computation

fn process(x: &Array, mode: ProcessingMode) -> Result<Array> {
    match mode {
        ProcessingMode::Fast => {
            // Simple graph
            let y = ops::relu(&x)?;
            y.eval()?;
            Ok(y)
        }
        ProcessingMode::Accurate => {
            // Complex graph  
            let y = expensive_operation(&x)?;
            y.eval()?;
            Ok(y)
        }
    }
}
Only the selected branch is evaluated.

When to evaluate

Good evaluation points

End of iteration: After processing a batch/sample
for batch in dataset {
    let output = model.forward(&batch)?;
    output.eval()?;  // Once per batch
}
Before inspection: When you need to see intermediate results
let hidden = encoder.forward(&x)?;
hidden.eval()?;  // Needed to inspect values
println!("Hidden stats: mean={:.3}", ops::mean(&hidden, None)?.item::<f32>());
Before control flow: When array values determine branching
let confidence = ops::max(&probs, &[-1], None)?;
confidence.eval()?;

if confidence.item::<f32>() > 0.95 {
    // High confidence path
}
Memory management: Clear memory after large operations
let result = huge_computation(&x)?;
result.eval()?;
// result now materialized, intermediate graphs freed

Bad evaluation points

After every operation: Defeats the purpose of lazy evaluation
// Bad: Too many evaluations
let a = ops::add(&x, &y)?;
a.eval()?;
let b = ops::mul(&a, &z)?;
b.eval()?;
let c = ops::relu(&b)?;
c.eval()?;

// Good: Single evaluation
let c = ops::relu(&ops::mul(&ops::add(&x, &y)?, &z)?)?;
c.eval()?;
Inside tight loops: Accumulates overhead
// Bad
for i in 0..1000 {
    let x = ops::add(&a, Array::from_int(i))?;
    x.eval()?;  // 1000 kernel launches
}

// Good: Vectorize
let indices = Array::from_slice(&(0..1000).collect::<Vec<_>>(), &[1000])?;
let results = ops::add(&a, &indices)?;
results.eval()?;  // 1 kernel launch
For intermediate debugging: Use logging instead
// Bad
let a = ops::add(&x, &y)?;
a.eval()?;  // Just to debug
println!("DEBUG: a = {:?}", a);
let b = ops::mul(&a, &z)?;

// Good: Defer debugging
let a = ops::add(&x, &y)?;
let b = ops::mul(&a, &z)?;
b.eval()?;

if log::log_enabled!(log::Level::Debug) {
    println!("DEBUG: a = {:?}", a);  // Already evaluated
}

Performance considerations

Graph size vs. evaluation frequency

There’s a trade-off between graph size and evaluation overhead: Very frequent evaluation (many small graphs):
  • ❌ High kernel launch overhead
  • ✅ Low graph construction overhead
  • ✅ Low memory usage
Very infrequent evaluation (few large graphs):
  • ✅ Amortized kernel launch overhead
  • ❌ High graph construction overhead
  • ❌ Higher memory usage
Sweet spot: 10s to 1000s of operations per eval()
// Good: ~100 ops per eval
for batch in dataset {  // 1000 batches
    let out = model.forward(&batch)?;  // ~100 ops
    out.eval()?;
}

// Bad: 1 op per eval
for batch in dataset {
    for layer in model.layers() {
        let out = layer.forward(&batch)?;  // 1 op
        out.eval()?;  // Too frequent
    }
}

// Bad: 100,000 ops per eval
let mut acc = initial_value();
for i in 0..100_000 {
    acc = ops::add(&acc, &step)?;  // Graph grows huge
}
acc.eval()?;  // Graph construction overhead too high

Memory vs. recomputation

Lazy evaluation trades memory for computation:
// Approach 1: Evaluate intermediate
let intermediate = expensive_op(&x)?;
intermediate.eval()?;  // Stored in memory

let a = use_intermediate(&intermediate)?;
let b = use_intermediate(&intermediate)?;  // Reuses stored value

// Approach 2: Keep lazy
let intermediate = expensive_op(&x)?;  // Not evaluated

let a = use_intermediate(&intermediate)?;
let b = use_intermediate(&intermediate)?;

// If a and b are evaluated separately, intermediate computed twice
a.eval()?;  // Computes expensive_op
b.eval()?;  // Computes expensive_op again

// Better: Batch evaluation
eval(&[&a, &b])?;  // Computes expensive_op once
Rule: If a value is used multiple times, evaluate it:
if reused_multiple_times {
    value.eval()?;  // Store in memory
} else {
    // Keep lazy, will be fused into consumers
}

Parallel streams

Lazy evaluation enables parallelism across streams:
use mlx_rs::{StreamOrDevice, transforms::eval};

// Build graphs on different streams
let cpu_stream = StreamOrDevice::cpu();
let gpu_stream = StreamOrDevice::gpu();

let cpu_result = cpu_intensive_op(&x, cpu_stream)?;
let gpu_result = gpu_intensive_op(&y, gpu_stream)?;

// Evaluate in parallel
eval(&[&cpu_result, &gpu_result])?;
// CPU and GPU execute concurrently

Common pitfalls

Pitfall 1: Assuming immediate execution

let x = ops::div(&a, &b)?;
println!("Result: {}", x);  // This evaluates x!

let y = ops::add(&x, 1)?;  // x already evaluated
y.eval()?;  // Only evaluates y, can't fuse with x
Fix: Avoid printing/accessing until ready:
let x = ops::div(&a, &b)?;
let y = ops::add(&x, 1)?;
y.eval()?;  // Fuses div and add

println!("Result: {}", y);  // Access after eval

Pitfall 2: Multiple unnecessary evals

let params = model.parameters();
for (name, param) in params {
    param.eval()?;  // N evaluations
}
Fix: Batch evaluate:
let params: Vec<&Array> = model.parameters().map(|(_, p)| p).collect();
eval(&params)?;  // Single evaluation

Pitfall 3: Control flow evaluation inside hot loops

// Bad: Evaluates every iteration
for token in generated_tokens {
    let logits = model.forward(&token)?;
    let top_logit = ops::max(&logits, &[-1], None)?;
    
    if top_logit.item::<f32>() > threshold {  // Eval!
        break;
    }
}
Fix: Move check outside or use vectorized stopping:
for token in generated_tokens {
    let logits = model.forward(&token)?;
    logits.eval()?;  // Explicit eval
    
    let top_logit = ops::max(&logits, &[-1], None)?.item::<f32>();
    if top_logit > threshold {
        break;
    }
}

Comparison with other frameworks

FrameworkExecution ModelGraph Construction
MLXLazy (default)Dynamic
PyTorchEager (default)Dynamic (compile with torch.compile)
TensorFlow 1.xLazyStatic (must define graph first)
TensorFlow 2.xEager (default)Dynamic (trace with @tf.function)
JAXLazy (via jit)Dynamic (trace with @jax.jit)
MLX’s approach:
  • Always builds graphs (like TensorFlow 1.x)
  • Dynamic shapes and control flow (like PyTorch)
  • Automatic kernel fusion (like JAX)
  • No separate “compile” step needed

Additional resources

MLX framework

How lazy evaluation fits into MLX

Unified memory

Memory optimization benefits

Performance guide

Tips for optimizing lazy evaluation

Transforms

Function transformations (grad, compile)

Build docs developers (and LLMs) love