Overview
Themoxin-vlm-mlx crate provides Vision-Language Model inference for Moxin-7B on Apple Silicon using MLX. It implements a dual-backbone vision encoder architecture with DINOv2 and SigLIP, coupled with a Mistral-7B decoder for multimodal understanding.
Architecture
Model loading
load_model
Loads the Moxin-7B VLM from a directory containing safetensors and config.json.Path to the directory containing model files (safetensors and config.json)
Returns the loaded MoxinVLM model or an error if loading fails
- Sharded vs single safetensors files
- Weight key prefix variants for vision and language models
- Model configuration from config.json or uses defaults
MoxinVLM
The main VLM model struct combining vision encoders and language decoder.Fields
DINOv2 ViT-Large/14 vision encoder (24 layers, 1024-dim)
SigLIP ViT-SO400M/14 vision encoder (27 layers, 1152-dim)
3-layer MLP projecting fused vision features (2176-dim) to LLM space (4096-dim)
Token embedding layer for the language model
Mistral-7B decoder transformer blocks (default 36 layers)
Final RMS normalization layer
Language model head for logit prediction
Model configuration
forward
Full VLM forward pass: encode image + text to logits.Image array [B, 224, 224, 3] in NHWC format, ImageNet-normalized (use
normalize_dino)Image array [B, 224, 224, 3] in NHWC format, unit-normalized (use
normalize_siglip)Token IDs [B, seq_len] including BOS token
Key-value cache for attention layers
Logits array for next token prediction
decode_token
Text-only decode for a single token with KV caching.Token array [B] or [B, 1] to decode
Key-value cache for attention layers
Logits array for next token prediction
quantize
Quantizes the LLM decoder linear layers while keeping vision encoders and projector in BF16.Group size for quantization (e.g., 64, 128)
Bit width for quantization (e.g., 4, 8)
Quantized model instance
Image preprocessing
normalize_dino
Normalizes an image array for DINOv2 using ImageNet statistics.Image array [B, H, W, 3] with float values in range [0, 1]
Normalized image array using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]
normalize_siglip
Normalizes an image array for SigLIP using unit normalization.Image array [B, H, W, 3] with float values in range [0, 1]
Normalized image array using mean=[0.5, 0.5, 0.5] and std=[0.5, 0.5, 0.5]
Generation
Generate
Iterator-based token generator for VLM inference.new
Creates a new generator for VLM inference.Mutable reference to the VLM model
Mutable reference to KV cache
Sampling temperature (0.0 for greedy, higher for more randomness)
DINOv2-normalized image array
SigLIP-normalized image array
Input token IDs
Iterator<Item = Result<Array>> and yields tokens one at a time. The first call performs prefill (full forward pass), subsequent calls perform cached decode.
sample
Samples the next token from logits using temperature-based sampling.Logits array from model output
Sampling temperature (0.0 for argmax/greedy sampling)
Sampled token ID
Vision encoder
ViTEncoder
Vision Transformer encoder for processing images into patch features.forward
Processes an image through the vision transformer.Image array [B, H, W, 3] in NHWC format
Patch features [B, num_patches, embed_dim] from second-to-last transformer block
quantize
Quantizes all linear layers in the vision encoder.Group size for quantization
Bit width for quantization
Quantized encoder instance
ViTConfig
Configuration for Vision Transformer models.dinov2_large
Returns configuration for DINOv2 ViT-Large/14 with 4 register tokens.ViTConfig with embed_dim=1024, depth=24, num_heads=16, has_cls_token=true, num_registers=4, has_layer_scale=true
siglip_so400m
Returns configuration for SigLIP ViT-SO400M/14.ViTConfig with embed_dim=1152, depth=27, num_heads=16, has_cls_token=false, num_registers=0, has_layer_scale=false
load_vit_encoder
Loads a ViT encoder from a weight map.Weight dictionary containing model parameters
Key prefix for this encoder’s weights
Encoder configuration
Loaded ViT encoder instance
Projector
FusedMLPProjector
3-layer MLP with GELU activations for projecting fused vision features to LLM embedding space.quantize
Quantizes all linear layers in the projector.Group size for quantization
Bit width for quantization
Quantized projector instance
load_projector
Loads projector weights from a weight map.Weight dictionary containing model parameters
Key prefix for projector weights
Loaded projector instance
Configuration
VLMConfig
Vision-language model configuration.Language model configuration
Image dimensions (default: [224, 224])
Special token index for image placeholders (default: 32000)
MistralConfig
Mistral-7B decoder configuration.Hidden dimension size (default: 4096)
Number of transformer layers (default: 36)
Number of attention heads (default: 32)
Number of key-value heads for GQA (default: 8)
FFN intermediate dimension (default: 14336)
Vocabulary size (default: 32064)
RMS normalization epsilon (default: 1e-5)
RoPE base frequency (default: 10000.0)
Whether to tie input and output embeddings (default: false)
Error handling
Error
Error types for VLM operations.MLX computation errors
MLX I/O errors
Standard I/O errors
JSON parsing errors
Tokenizer-related errors
Missing weight in model file
Invalid configuration
Re-exports
The crate re-exports useful types frommlx_rs_core:
Concatenation-based KV cache implementation
Basic KV cache implementation
KV cache trait
Tokenizer loading function