Skip to main content

Overview

The minicpm-sala-mlx crate provides inference for MiniCPM-SALA, a compact language model with hybrid attention combining sparse (Lightning Attention) and dense layers. SALA includes built-in thinking capabilities and speculative decoding for faster inference.

Key features

  • Hybrid attention - Alternating Lightning (sparse) and dense attention layers
  • Built-in thinking - <think>...</think> blocks for reasoning
  • Speculative decoding - Draft model acceleration
  • Custom Metal kernels - Optimized Lightning Attention implementation
  • Compact size - High performance in small models (2B-4B parameters)

Installation

Add to your Cargo.toml:
[dependencies]
minicpm-sala-mlx = "0.1"

Core functions

load_model

Loads a MiniCPM-SALA model from a directory.
pub fn load_model(model_dir: impl AsRef<Path>) -> Result<Model, Error>
model_dir
impl AsRef<Path>
required
Path to the model directory containing:
  • config.json - Model configuration
  • model.safetensors.index.json - Weight file index (or single model.safetensors)
  • Model weight files
Result<Model, Error>
Result
Returns a loaded Model ready for inference

load_tokenizer

Loads the tokenizer from the model directory.
pub fn load_tokenizer(model_dir: impl AsRef<Path>) -> Result<Tokenizer, Error>
model_dir
impl AsRef<Path>
required
Path to the model directory containing tokenizer.json
Result<Tokenizer, Error>
Result
Returns a HuggingFace Tokenizer instance

get_model_args

Parses model configuration from config.json.
pub fn get_model_args(model_dir: impl AsRef<Path>) -> Result<ModelArgs, Error>
model_dir
impl AsRef<Path>
required
Path to directory containing config.json
Result<ModelArgs, Error>
Result
Returns parsed ModelArgs with model hyperparameters

Utility functions

format_chat_prompt

Formats a single-turn chat prompt in ChatML format.
pub fn format_chat_prompt(system: &str, user: &str) -> String
system
&str
required
System message defining assistant behavior
user
&str
required
User message/question
String
String
Returns formatted ChatML prompt ready for tokenization

format_chat_prompt_multi

Formats a multi-turn chat prompt in ChatML format.
pub fn format_chat_prompt_multi(system: &str, turns: &[(&str, &str)]) -> String
system
&str
required
System message
turns
&[(&str, &str)]
required
List of (role, content) pairs where role is “user” or “assistant”
String
String
Returns formatted multi-turn ChatML prompt

strip_thinking

Removes <think>...</think> block from generated text.
pub fn strip_thinking(text: &str) -> &str
text
&str
required
Generated text that may contain thinking block
&str
&str
Returns text after </think> tag, or original text if no thinking block

is_stop_token

Checks if a token is a stop token (EOS or <|im_end|>).
pub fn is_stop_token(token_id: u32) -> bool
token_id
u32
required
Token ID to check
bool
bool
Returns true if token is EOS (2) or <|im_end|> (73440)

Types

Model

The main model struct for MiniCPM-SALA inference.
pub struct Model {
    pub args: ModelArgs,
    pub model: MiniCPMSALAModel,
    pub lm_head: nn::Linear,
}

ModelArgs

Model configuration from config.json.
pub struct ModelArgs {
    pub model_type: String,
    pub vocab_size: i32,
    pub hidden_size: i32,
    pub intermediate_size: i32,
    pub num_hidden_layers: i32,
    pub num_attention_heads: i32,
    pub num_key_value_heads: i32,
    pub max_position_embeddings: i32,
    pub rms_norm_eps: f32,
    pub rope_theta: f32,
    pub scale_emb: f32,
    pub scale_depth: f32,
    pub dim_model_base: i32,
    pub sparse_layer_interval: i32,
    pub quantization: Option<QuantizationConfig>,
}
scale_emb
f32
Embedding scaling factor (muP scaling)
scale_depth
f32
Depth scaling for residual connections
sparse_layer_interval
i32
Interval between Lightning (sparse) attention layers

HybridAttention

Enum for sparse or dense attention layers.
pub enum HybridAttention {
    Sparse(SparseAttention),      // Lightning Attention
    Lightning(LightningAttention), // Dense attention variant
}

ThinkFilter

Incremental filter for streaming output with think block suppression.
pub struct ThinkFilter {
    think_done: bool,
    prev_text_len: usize,
}

Constructor

pub fn new(no_think: bool) -> Self
no_think
bool
required
If true, suppresses <think>...</think> content in output

next

pub fn next<'a>(&mut self, full_text: &'a str) -> &'a str
full_text
&'a str
required
Full decoded text so far
&'a str
&str
Returns new text to emit (empty string if still in think block)

Example usage

Basic generation

use minicpm_sala_mlx::{
    load_model, load_tokenizer, format_chat_prompt, is_stop_token,
};
use mlx_rs::ops::indexing::NewAxis;

let model_dir = "models/MiniCPM3-4B";
let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?;

// Format chat prompt
let prompt_text = format_chat_prompt(
    "You are a helpful assistant.",
    "What is the Fibonacci sequence?",
);

let encoding = tokenizer.encode(&prompt_text, true)?;
let prompt = mlx_rs::Array::from(encoding.get_ids()).index(NewAxis);

// MiniCPM-SALA uses LayerCache instead of KVCache
let mut caches = create_layer_caches(model.args.num_hidden_layers as usize);

// Generate (simplified - actual generation uses model-specific API)
for i in 0..200 {
    let logits = model.forward(&prompt, None, &mut caches)?;
    let token = sample(&logits, 0.7)?;
    let id = token.item::<u32>();
    
    if is_stop_token(id) {
        break;
    }
    
    let text = tokenizer.decode(&[id], true)?;
    print!("{}", text);
}

With think filtering

use minicpm_sala_mlx::{
    load_model, load_tokenizer, format_chat_prompt, 
    ThinkFilter, is_stop_token,
};

let model_dir = "models/MiniCPM3-4B";
let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?;

let prompt_text = format_chat_prompt(
    "You are a helpful assistant.",
    "Solve this math problem: What is 15% of 240?",
);

let encoding = tokenizer.encode(&prompt_text, true)?;
let mut generated_ids = encoding.get_ids().to_vec();

// Create think filter (suppresses <think>...</think>)
let mut filter = ThinkFilter::new(true);
let mut caches = create_layer_caches(model.args.num_hidden_layers as usize);

for _ in 0..300 {
    let prompt = mlx_rs::Array::from(&generated_ids).index(NewAxis);
    let logits = model.forward(&prompt, None, &mut caches)?;
    let token = sample(&logits, 0.6)?;
    let id = token.item::<u32>();
    
    if is_stop_token(id) {
        break;
    }
    
    generated_ids.push(id);
    
    // Decode and filter thinking
    let full_text = tokenizer.decode(&generated_ids, true)?;
    let new_text = filter.next(&full_text);
    if !new_text.is_empty() {
        print!("{}", new_text);
    }
}

Multi-turn conversation

use minicpm_sala_mlx::{load_model, load_tokenizer, format_chat_prompt_multi};

let model_dir = "models/MiniCPM3-4B";
let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?;

let system = "You are a helpful coding assistant.";
let turns = vec![
    ("user", "Write a Python function to reverse a string."),
    ("assistant", "Here's a function:\n\n```python\ndef reverse_string(s):\n    return s[::-1]\n```"),
    ("user", "Can you add error handling?"),
];

let prompt_text = format_chat_prompt_multi(system, &turns);
let encoding = tokenizer.encode(&prompt_text, true)?;

// Continue generation...

Architecture details

Hybrid attention layers

MiniCPM-SALA alternates between Lightning (sparse) and dense attention:
fn is_sparse_layer(layer_idx: usize, interval: i32) -> bool {
    layer_idx % (interval as usize) == 0
}

// Layer 0, 4, 8, ... use Lightning Attention (sparse)
// Layer 1, 2, 3, 5, 6, 7, ... use dense attention

Thinking blocks

MiniCPM-SALA can generate intermediate reasoning:
<think>
The user is asking about 15% of 240.
15% = 0.15
0.15 × 240 = 36
</think>
15% of 240 is 36.
Use ThinkFilter to hide thinking in streaming output or strip_thinking for post-processing.

Speculative decoding

Use SpeculativeDecoder for faster inference:
use minicpm_sala_mlx::{load_model, SpeculativeDecoder};

let draft_model = load_model("models/MiniCPM3-1B")?;
let target_model = load_model("models/MiniCPM3-4B")?;

let mut decoder = SpeculativeDecoder::new(
    draft_model,
    target_model,
    4, // Generate 4 draft tokens at a time
);

// Use decoder for ~2x speedup

Performance notes

  • Lightning Attention reduces memory and computation in sparse layers
  • Hybrid architecture balances quality and efficiency
  • Speculative decoding can provide 1.5-2.5x speedup
  • Compact models (2B-4B) run efficiently on consumer hardware

Constants

pub const EOS_TOKEN_ID: u32 = 2;
pub const IM_END_TOKEN_ID: u32 = 73440;

See also

Build docs developers (and LLMs) love