Skip to main content

Overview

The moxin-vlm-mlx crate provides Vision-Language Model inference for Moxin-7B on Apple Silicon using MLX. It implements a dual-backbone vision encoder architecture with DINOv2 and SigLIP, coupled with a Mistral-7B decoder for multimodal understanding.

Architecture

Image (224x224)
  |-> DINOv2 ViT-L/14 -> [B, 256, 1024]
  |-> SigLIP ViT-SO400M -> [B, 256, 1152]
             | concat
       [B, 256, 2176]
             | FusedMLPProjector
       [B, 256, 4096]  (256 visual tokens)
             |
  BOS + [visual tokens] + text tokens
             | Mistral-7B decoder (36 layers)
       logits -> autoregressive generation

Model loading

load_model

Loads the Moxin-7B VLM from a directory containing safetensors and config.json.
pub fn load_model(model_dir: impl AsRef<Path>) -> Result<MoxinVLM>
model_dir
impl AsRef<Path>
required
Path to the directory containing model files (safetensors and config.json)
return
Result<MoxinVLM>
Returns the loaded MoxinVLM model or an error if loading fails
The function automatically detects:
  • Sharded vs single safetensors files
  • Weight key prefix variants for vision and language models
  • Model configuration from config.json or uses defaults

MoxinVLM

The main VLM model struct combining vision encoders and language decoder.

Fields

dino
ViTEncoder
DINOv2 ViT-Large/14 vision encoder (24 layers, 1024-dim)
siglip
ViTEncoder
SigLIP ViT-SO400M/14 vision encoder (27 layers, 1152-dim)
projector
FusedMLPProjector
3-layer MLP projecting fused vision features (2176-dim) to LLM space (4096-dim)
embed_tokens
nn::Embedding
Token embedding layer for the language model
layers
Vec<LLMBlock>
Mistral-7B decoder transformer blocks (default 36 layers)
norm
nn::RmsNorm
Final RMS normalization layer
lm_head
MaybeQuantized<nn::Linear>
Language model head for logit prediction
config
MistralConfig
Model configuration

forward

Full VLM forward pass: encode image + text to logits.
pub fn forward<C: KeyValueCache + Default>(
    &mut self,
    dino_image: &Array,
    siglip_image: &Array,
    input_ids: &Array,
    cache: &mut Vec<C>,
) -> Result<Array>
dino_image
&Array
required
Image array [B, 224, 224, 3] in NHWC format, ImageNet-normalized (use normalize_dino)
siglip_image
&Array
required
Image array [B, 224, 224, 3] in NHWC format, unit-normalized (use normalize_siglip)
input_ids
&Array
required
Token IDs [B, seq_len] including BOS token
cache
&mut Vec<C>
required
Key-value cache for attention layers
return
Result<Array>
Logits array for next token prediction

decode_token

Text-only decode for a single token with KV caching.
pub fn decode_token<C: KeyValueCache + Default>(
    &mut self,
    token: &Array,
    cache: &mut Vec<C>,
) -> Result<Array>
token
&Array
required
Token array [B] or [B, 1] to decode
cache
&mut Vec<C>
required
Key-value cache for attention layers
return
Result<Array>
Logits array for next token prediction

quantize

Quantizes the LLM decoder linear layers while keeping vision encoders and projector in BF16.
pub fn quantize(self, group_size: i32, bits: i32) -> Result<Self>
group_size
i32
required
Group size for quantization (e.g., 64, 128)
bits
i32
required
Bit width for quantization (e.g., 4, 8)
return
Result<Self>
Quantized model instance
Only the Mistral-7B decoder is quantized since it dominates memory and compute. Vision encoders have dimensions that aren’t cleanly divisible by common group sizes, and they only run once during prefill.

Image preprocessing

normalize_dino

Normalizes an image array for DINOv2 using ImageNet statistics.
pub fn normalize_dino(img: &Array) -> Result<Array>
img
&Array
required
Image array [B, H, W, 3] with float values in range [0, 1]
return
Result<Array>
Normalized image array using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]

normalize_siglip

Normalizes an image array for SigLIP using unit normalization.
pub fn normalize_siglip(img: &Array) -> Result<Array>
img
&Array
required
Image array [B, H, W, 3] with float values in range [0, 1]
return
Result<Array>
Normalized image array using mean=[0.5, 0.5, 0.5] and std=[0.5, 0.5, 0.5]

Generation

Generate

Iterator-based token generator for VLM inference.
pub struct Generate<'a, C> {
    vlm: &'a mut MoxinVLM,
    cache: &'a mut Vec<C>,
    temp: f32,
    state: GenerateState,
}

new

Creates a new generator for VLM inference.
pub fn new(
    vlm: &'a mut MoxinVLM,
    cache: &'a mut Vec<C>,
    temp: f32,
    dino_image: Array,
    siglip_image: Array,
    input_ids: Array,
) -> Self
vlm
&'a mut MoxinVLM
required
Mutable reference to the VLM model
cache
&'a mut Vec<C>
required
Mutable reference to KV cache
temp
f32
required
Sampling temperature (0.0 for greedy, higher for more randomness)
dino_image
Array
required
DINOv2-normalized image array
siglip_image
Array
required
SigLIP-normalized image array
input_ids
Array
required
Input token IDs
The generator implements Iterator<Item = Result<Array>> and yields tokens one at a time. The first call performs prefill (full forward pass), subsequent calls perform cached decode.

sample

Samples the next token from logits using temperature-based sampling.
pub fn sample(logits: &Array, temp: f32) -> std::result::Result<Array, mlx_rs::error::Exception>
logits
&Array
required
Logits array from model output
temp
f32
required
Sampling temperature (0.0 for argmax/greedy sampling)
return
std::result::Result<Array, mlx_rs::error::Exception>
Sampled token ID

Vision encoder

ViTEncoder

Vision Transformer encoder for processing images into patch features.

forward

Processes an image through the vision transformer.
pub fn forward(&mut self, x: &Array) -> Result<Array, Exception>
x
&Array
required
Image array [B, H, W, 3] in NHWC format
return
Result<Array, Exception>
Patch features [B, num_patches, embed_dim] from second-to-last transformer block
Extracts features from the second-to-last transformer block, strips CLS and register tokens, returning only patch features.

quantize

Quantizes all linear layers in the vision encoder.
pub fn quantize(self, group_size: i32, bits: i32) -> std::result::Result<Self, Exception>
group_size
i32
required
Group size for quantization
bits
i32
required
Bit width for quantization
return
std::result::Result<Self, Exception>
Quantized encoder instance

ViTConfig

Configuration for Vision Transformer models.

dinov2_large

Returns configuration for DINOv2 ViT-Large/14 with 4 register tokens.
pub fn dinov2_large() -> Self
return
Self
ViTConfig with embed_dim=1024, depth=24, num_heads=16, has_cls_token=true, num_registers=4, has_layer_scale=true

siglip_so400m

Returns configuration for SigLIP ViT-SO400M/14.
pub fn siglip_so400m() -> Self
return
Self
ViTConfig with embed_dim=1152, depth=27, num_heads=16, has_cls_token=false, num_registers=0, has_layer_scale=false

load_vit_encoder

Loads a ViT encoder from a weight map.
pub fn load_vit_encoder(
    weights: &HashMap<String, Array>,
    prefix: &str,
    config: ViTConfig,
) -> Result<ViTEncoder, Error>
weights
&HashMap<String, Array>
required
Weight dictionary containing model parameters
prefix
&str
required
Key prefix for this encoder’s weights
config
ViTConfig
required
Encoder configuration
return
Result<ViTEncoder, Error>
Loaded ViT encoder instance

Projector

FusedMLPProjector

3-layer MLP with GELU activations for projecting fused vision features to LLM embedding space.
pub struct FusedMLPProjector {
    pub fc1: MaybeQuantized<nn::Linear>,
    pub fc2: MaybeQuantized<nn::Linear>,
    pub fc3: MaybeQuantized<nn::Linear>,
}
Maps fused DINOv2+SigLIP features (2176-dim) to LLM embedding space (4096-dim).

quantize

Quantizes all linear layers in the projector.
pub fn quantize(self, group_size: i32, bits: i32) -> Result<Self, Exception>
group_size
i32
required
Group size for quantization
bits
i32
required
Bit width for quantization
return
Result<Self, Exception>
Quantized projector instance

load_projector

Loads projector weights from a weight map.
pub fn load_projector(
    weights: &HashMap<String, Array>,
    prefix: &str,
) -> Result<FusedMLPProjector, Error>
weights
&HashMap<String, Array>
required
Weight dictionary containing model parameters
prefix
&str
required
Key prefix for projector weights
return
Result<FusedMLPProjector, Error>
Loaded projector instance

Configuration

VLMConfig

Vision-language model configuration.
text_config
Option<MistralConfig>
Language model configuration
image_sizes
Vec<i32>
Image dimensions (default: [224, 224])
image_token_index
i32
Special token index for image placeholders (default: 32000)

MistralConfig

Mistral-7B decoder configuration.
hidden_size
i32
Hidden dimension size (default: 4096)
num_hidden_layers
i32
Number of transformer layers (default: 36)
num_attention_heads
i32
Number of attention heads (default: 32)
num_key_value_heads
i32
Number of key-value heads for GQA (default: 8)
intermediate_size
i32
FFN intermediate dimension (default: 14336)
vocab_size
i32
Vocabulary size (default: 32064)
rms_norm_eps
f32
RMS normalization epsilon (default: 1e-5)
rope_theta
f32
RoPE base frequency (default: 10000.0)
tie_word_embeddings
bool
Whether to tie input and output embeddings (default: false)

Error handling

Error

Error types for VLM operations.
Mlx
mlx_rs::error::Exception
MLX computation errors
MlxIo
mlx_rs::error::IoError
MLX I/O errors
Io
std::io::Error
Standard I/O errors
Json
serde_json::Error
JSON parsing errors
Tokenizer
String
Tokenizer-related errors
WeightNotFound
String
Missing weight in model file
InvalidConfig
String
Invalid configuration

Re-exports

The crate re-exports useful types from mlx_rs_core:
pub use mlx_rs_core::{
    cache::{ConcatKeyValueCache, KVCache, KeyValueCache},
    load_tokenizer,
};
ConcatKeyValueCache
mlx_rs_core::cache::ConcatKeyValueCache
Concatenation-based KV cache implementation
KVCache
mlx_rs_core::cache::KVCache
Basic KV cache implementation
KeyValueCache
mlx_rs_core::cache::KeyValueCache
KV cache trait
load_tokenizer
mlx_rs_core::load_tokenizer
Tokenizer loading function

Build docs developers (and LLMs) love