mlx-rs-core
Shared inference infrastructure used across all model-specific MLX Rust crates (qwen3-mlx, glm4-mlx, gpt-sovits-mlx, etc.).Overview
mlx-rs-core provides common components for efficient inference:- KV cache - Fast autoregressive decoding with key-value caching
- Token generation - Generic generation infrastructure with sampling
- Attention utilities - RoPE, attention masks, scaled dot-product attention
- Custom kernels - Fused Metal kernels for performance-critical operations
- Audio processing - Mel spectrogram and audio utilities
- Error handling - Common error types and conversions
Module exports
Cache
KV cache implementations for efficient autoregressive generation.KeyValueCache trait
Trait for key-value caches used in attention mechanisms
Returns current cache offset (number of tokens cached)
Returns maximum cache size for sliding window attention, if any
ConcatKeyValueCache
Simple concatenation-based KV cacheCaches by concatenating new keys/values with existing ones. Simple but can be slow for long sequences.
KVCache
Optimized step-based KV cache with pre-allocationPre-allocates buffers in steps of 256 tokens and uses in-place updates, avoiding expensive concatenation. Matches Python mlx-lm implementation.
Create cache with default step size of 256 tokens
Create cache with custom step size
Number of tokens to pre-allocate per growth step
Generate
Generic token generation infrastructure with builder pattern.Generate struct
Iterator-based token generator with configurable sampling
Builder
Create new generation builder
Set tokenizer for decoding tokens to text
Set model for generation (must implement
Module trait)Set prompt tokens as Array
Set sampling temperature (default: 0.0 for greedy)
temp = 0.0- Greedy decoding (argmax)temp > 0.0- Sampling with temperature
Set maximum number of tokens to generate (default: 256)
Set custom sampler implementation
Build the generator (consumes builder)
Response
Generated text response
Decoded text from generated tokens
Generated token IDs
Sampler
Token sampling strategies for generation.Sampler trait
Trait for implementing custom sampling strategies
DefaultSampler
Default sampling implementation
- Temperature 0.0: Greedy decoding (argmax)
- Temperature > 0.0: Categorical sampling with temperature scaling
Utilities (utils)
Attention utilities, RoPE initialization, and helper functions.RoPE initialization
Initialize rotary position embedding
Dimension of each attention head
Base frequency for rotations (typically 10000.0)
Use traditional RoPE formulation
Optional scaling configuration for extended context
"type": “default”, “linear”, etc."factor": Scaling factor
Maximum sequence length
Configured RoPE module
Attention masks
Create causal attention mask for autoregressive generation
Hidden states tensor (shape: [batch, seq_len, …])
KV cache array (one per layer)
Force returning explicit mask array instead of hardware causal
Nonefor single token (no mask needed)Some(AttentionMask::Causal)for hardware-optimized causalSome(AttentionMask::Array)for explicit mask array
Attention mask variants
Mask type for scaled dot-product attention
Scaled dot-product attention
Compute scaled dot-product attention
Query tensor [batch, n_heads, seq_q, head_dim]
Key tensor [batch, n_kv_heads, seq_k, head_dim]
Value tensor [batch, n_kv_heads, seq_v, head_dim]
Optional KV cache
Attention scale factor (typically 1.0 / sqrt(head_dim))
Optional attention mask
Attention output [batch, n_heads, seq_q, head_dim]
Metal kernels
Custom fused Metal kernels for performance-critical operations.fused_swiglu
Fused SwiGLU activation using custom Metal kernelComputes:
silu(gate) * x = (gate / (1 + exp(-gate))) * xPerformance: 10-12x faster than separate silu() + multiply() calls. Critical for MoE models with many SwiGLU operations.Input tensor (any shape)
Gate tensor (same shape as x)
SwiGLU output (same shape as inputs)
fused_modulate
Fused LayerNorm + Modulation using custom Metal kernelComputes:
(1 + scale) * LayerNorm(x) + shiftwhere LayerNorm has no learnable parameters (elementwise_affine=False).Performance: Fuses 7+ operations into single kernel. Critical for DiT (Diffusion Transformer) models:- 4 modulate calls per block
- 60 blocks
- 40 forward passes per generation
- = 9,600 modulate calls per image
Input tensor [batch, seq, dim] or [seq, dim]
Shift tensor (flattened to [dim])
Scale tensor (flattened to [dim])
Modulated output (same shape as x)
Model input/output traits
Traits for generic generation infrastructure.ModelInput trait
Trait for model input types that can be constructed from builderModels implement this to receive tokens, cache, and state during generation.
ModelOutput trait
Trait for model output types that provide logitsImplemented by model outputs to extract next-token logits.
Tokenizer loading
Error handling
Common error type for mlx-rs-core operations
Result type alias with mlx-rs-core Error
Helper macros
Helper macro for early returns in iterator contexts
Audio processing
Theaudio module provides utilities for audio processing:
- Mel spectrogram computation
- Audio preprocessing for speech models
- Feature extraction utilities
Speculative decoding
Thespeculative module provides support for speculative decoding to accelerate generation:
- Draft model integration
- Verification and acceptance logic
- Multi-token generation strategies
Convert utilities
Theconvert module (requires convert feature) provides model conversion utilities:
Example usage
Feature flags
convert- Enable model conversion utilities