Overview
The minicpm-sala-mlx crate provides inference for MiniCPM-SALA, a compact language model with hybrid attention combining sparse (Lightning Attention) and dense layers. SALA includes built-in thinking capabilities and speculative decoding for faster inference.
Key features
- Hybrid attention - Alternating Lightning (sparse) and dense attention layers
- Built-in thinking -
<think>...</think> blocks for reasoning
- Speculative decoding - Draft model acceleration
- Custom Metal kernels - Optimized Lightning Attention implementation
- Compact size - High performance in small models (2B-4B parameters)
Installation
Add to your Cargo.toml:
[dependencies]
minicpm-sala-mlx = "0.1"
Core functions
load_model
Loads a MiniCPM-SALA model from a directory.
pub fn load_model(model_dir: impl AsRef<Path>) -> Result<Model, Error>
Path to the model directory containing:
config.json - Model configuration
model.safetensors.index.json - Weight file index (or single model.safetensors)
- Model weight files
Returns a loaded Model ready for inference
load_tokenizer
Loads the tokenizer from the model directory.
pub fn load_tokenizer(model_dir: impl AsRef<Path>) -> Result<Tokenizer, Error>
Path to the model directory containing tokenizer.json
Returns a HuggingFace Tokenizer instance
get_model_args
Parses model configuration from config.json.
pub fn get_model_args(model_dir: impl AsRef<Path>) -> Result<ModelArgs, Error>
Path to directory containing config.json
Returns parsed ModelArgs with model hyperparameters
Utility functions
Formats a single-turn chat prompt in ChatML format.
pub fn format_chat_prompt(system: &str, user: &str) -> String
System message defining assistant behavior
Returns formatted ChatML prompt ready for tokenization
Formats a multi-turn chat prompt in ChatML format.
pub fn format_chat_prompt_multi(system: &str, turns: &[(&str, &str)]) -> String
List of (role, content) pairs where role is “user” or “assistant”
Returns formatted multi-turn ChatML prompt
strip_thinking
Removes <think>...</think> block from generated text.
pub fn strip_thinking(text: &str) -> &str
Generated text that may contain thinking block
Returns text after </think> tag, or original text if no thinking block
is_stop_token
Checks if a token is a stop token (EOS or <|im_end|>).
pub fn is_stop_token(token_id: u32) -> bool
Returns true if token is EOS (2) or <|im_end|> (73440)
Types
Model
The main model struct for MiniCPM-SALA inference.
pub struct Model {
pub args: ModelArgs,
pub model: MiniCPMSALAModel,
pub lm_head: nn::Linear,
}
ModelArgs
Model configuration from config.json.
pub struct ModelArgs {
pub model_type: String,
pub vocab_size: i32,
pub hidden_size: i32,
pub intermediate_size: i32,
pub num_hidden_layers: i32,
pub num_attention_heads: i32,
pub num_key_value_heads: i32,
pub max_position_embeddings: i32,
pub rms_norm_eps: f32,
pub rope_theta: f32,
pub scale_emb: f32,
pub scale_depth: f32,
pub dim_model_base: i32,
pub sparse_layer_interval: i32,
pub quantization: Option<QuantizationConfig>,
}
Embedding scaling factor (muP scaling)
Depth scaling for residual connections
Interval between Lightning (sparse) attention layers
HybridAttention
Enum for sparse or dense attention layers.
pub enum HybridAttention {
Sparse(SparseAttention), // Lightning Attention
Lightning(LightningAttention), // Dense attention variant
}
ThinkFilter
Incremental filter for streaming output with think block suppression.
pub struct ThinkFilter {
think_done: bool,
prev_text_len: usize,
}
Constructor
pub fn new(no_think: bool) -> Self
If true, suppresses <think>...</think> content in output
next
pub fn next<'a>(&mut self, full_text: &'a str) -> &'a str
Returns new text to emit (empty string if still in think block)
Example usage
Basic generation
use minicpm_sala_mlx::{
load_model, load_tokenizer, format_chat_prompt, is_stop_token,
};
use mlx_rs::ops::indexing::NewAxis;
let model_dir = "models/MiniCPM3-4B";
let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?;
// Format chat prompt
let prompt_text = format_chat_prompt(
"You are a helpful assistant.",
"What is the Fibonacci sequence?",
);
let encoding = tokenizer.encode(&prompt_text, true)?;
let prompt = mlx_rs::Array::from(encoding.get_ids()).index(NewAxis);
// MiniCPM-SALA uses LayerCache instead of KVCache
let mut caches = create_layer_caches(model.args.num_hidden_layers as usize);
// Generate (simplified - actual generation uses model-specific API)
for i in 0..200 {
let logits = model.forward(&prompt, None, &mut caches)?;
let token = sample(&logits, 0.7)?;
let id = token.item::<u32>();
if is_stop_token(id) {
break;
}
let text = tokenizer.decode(&[id], true)?;
print!("{}", text);
}
With think filtering
use minicpm_sala_mlx::{
load_model, load_tokenizer, format_chat_prompt,
ThinkFilter, is_stop_token,
};
let model_dir = "models/MiniCPM3-4B";
let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?;
let prompt_text = format_chat_prompt(
"You are a helpful assistant.",
"Solve this math problem: What is 15% of 240?",
);
let encoding = tokenizer.encode(&prompt_text, true)?;
let mut generated_ids = encoding.get_ids().to_vec();
// Create think filter (suppresses <think>...</think>)
let mut filter = ThinkFilter::new(true);
let mut caches = create_layer_caches(model.args.num_hidden_layers as usize);
for _ in 0..300 {
let prompt = mlx_rs::Array::from(&generated_ids).index(NewAxis);
let logits = model.forward(&prompt, None, &mut caches)?;
let token = sample(&logits, 0.6)?;
let id = token.item::<u32>();
if is_stop_token(id) {
break;
}
generated_ids.push(id);
// Decode and filter thinking
let full_text = tokenizer.decode(&generated_ids, true)?;
let new_text = filter.next(&full_text);
if !new_text.is_empty() {
print!("{}", new_text);
}
}
Multi-turn conversation
use minicpm_sala_mlx::{load_model, load_tokenizer, format_chat_prompt_multi};
let model_dir = "models/MiniCPM3-4B";
let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?;
let system = "You are a helpful coding assistant.";
let turns = vec![
("user", "Write a Python function to reverse a string."),
("assistant", "Here's a function:\n\n```python\ndef reverse_string(s):\n return s[::-1]\n```"),
("user", "Can you add error handling?"),
];
let prompt_text = format_chat_prompt_multi(system, &turns);
let encoding = tokenizer.encode(&prompt_text, true)?;
// Continue generation...
Architecture details
Hybrid attention layers
MiniCPM-SALA alternates between Lightning (sparse) and dense attention:
fn is_sparse_layer(layer_idx: usize, interval: i32) -> bool {
layer_idx % (interval as usize) == 0
}
// Layer 0, 4, 8, ... use Lightning Attention (sparse)
// Layer 1, 2, 3, 5, 6, 7, ... use dense attention
Thinking blocks
MiniCPM-SALA can generate intermediate reasoning:
<think>
The user is asking about 15% of 240.
15% = 0.15
0.15 × 240 = 36
</think>
15% of 240 is 36.
Use ThinkFilter to hide thinking in streaming output or strip_thinking for post-processing.
Speculative decoding
Use SpeculativeDecoder for faster inference:
use minicpm_sala_mlx::{load_model, SpeculativeDecoder};
let draft_model = load_model("models/MiniCPM3-1B")?;
let target_model = load_model("models/MiniCPM3-4B")?;
let mut decoder = SpeculativeDecoder::new(
draft_model,
target_model,
4, // Generate 4 draft tokens at a time
);
// Use decoder for ~2x speedup
- Lightning Attention reduces memory and computation in sparse layers
- Hybrid architecture balances quality and efficiency
- Speculative decoding can provide 1.5-2.5x speedup
- Compact models (2B-4B) run efficiently on consumer hardware
Constants
pub const EOS_TOKEN_ID: u32 = 2;
pub const IM_END_TOKEN_ID: u32 = 73440;
See also