Skip to main content

Overview

GritLMEmbeddingModel provides access to GritLM (Generalized Representation Instruction Tuned Language Model), a unified model that can both generate embeddings and perform text generation tasks.

Class Definition

from remem.embedding_model.GritLM import GritLMEmbeddingModel
Defined in: src/remem/embedding_model/GritLM.py:17

Initialization

__init__

def __init__(
    self,
    global_config: Optional[BaseConfig] = None,
    embedding_model_name: Optional[str] = None
) -> None
Parameters:
global_config
BaseConfig
default:"None"
Global configuration object containing:
  • embedding_return_as_cpu: Return embeddings on CPU
  • embedding_return_as_numpy: Convert to numpy arrays
  • embedding_return_as_normalized: Normalize embeddings
  • embedding_batch_size: Batch size for encoding
embedding_model_name
str
default:"None"
Model name/path containing “GritLM” (e.g., “GritLM/GritLM-7B”) If provided, overrides the name from global_config
Example:
from remem.utils.config_utils import BaseConfig
from remem.embedding_model.GritLM import GritLMEmbeddingModel

config = BaseConfig()
config.embedding_model_name = "GritLM/GritLM-7B"
config.embedding_batch_size = 16
config.embedding_return_as_normalized = True

model = GritLMEmbeddingModel(global_config=config)
print(f"Embedding dimension: {model.embedding_dim}")
print(f"Device: {model.device}")

Attributes

embedding_model
GritLM
The loaded GritLM model instance
embedding_dim
int
Embedding dimension (depends on model variant)
device
torch.device
Device where the model is loaded (e.g., cuda:0, cpu)
embedding_config
EmbeddingConfig
Configuration containing:
  • embedding_model_name: Model identifier
  • return_cpu: Whether to return CPU tensors
  • return_numpy: Whether to convert to numpy
  • norm: Whether to normalize embeddings
  • model_init_params: Model initialization parameters
  • encode_params: Default encoding parameters

Methods

batch_encode

def batch_encode(self, texts: List[str], **kwargs) -> np.ndarray
Encodes a batch of text strings into embeddings. Parameters:
texts
List[str] | str
required
Text strings to encode. Can be a single string or list of strings.
instruction
str
default:"''"
Optional task instruction. Will be formatted as: "<|user|>\n{instruction}\n<|embed|>\n" if provided, or "<|embed|>\n" if empty
batch_size
int
default:"16"
Number of texts to process in each batch
Returns:
embeddings
np.ndarray
2D numpy array of shape (n_texts, embedding_dim). Normalized if embedding_return_as_normalized is True.
Example:
# Simple encoding
texts = [
    "What is machine learning?",
    "Explain neural networks"
]
embs = model.batch_encode(texts)
print(embs.shape)  # (2, embedding_dim)

# With instruction for retrieval task
query_instruction = "Given a query, retrieve relevant documents"
query_embs = model.batch_encode(
    ["machine learning applications"],
    instruction=query_instruction
)
# Formatted as: "<|user|>\nGiven a query, retrieve relevant documents\n<|embed|>\nmachine learning applications"

# Document encoding (no instruction)
doc_embs = model.batch_encode([
    "Machine learning is a field of artificial intelligence"
])
# Formatted as: "<|embed|>\nMachine learning is a field of artificial intelligence"

# Compute similarity
scores = model.get_query_doc_scores(query_embs[0], doc_embs)
print(scores)

batch_generate

def batch_generate(
    self,
    chat: List[TextChatMessage],
) -> str
This method is currently not implemented. GritLM supports generation, but the implementation is pending.
Parameters:
chat
List[TextChatMessage]
required
List of chat messages for generation

GritLM Instruction Format

GritLM uses special tokens to distinguish between embedding and generation modes:

Embedding Mode

<|embed|>
{text}
For task-specific embeddings:
<|user|>
{instruction}
<|embed|>
{text}

Example Instructions

Query Encoding:
model.batch_encode(
    ["climate change"],
    instruction="Given a query, retrieve relevant passages"
)
# Result: "<|user|>\nGiven a query, retrieve relevant passages\n<|embed|>\nclimate change"
Document Encoding:
model.batch_encode([
    "Climate change refers to long-term shifts in temperatures..."
])
# Result: "<|embed|>\nClimate change refers to long-term shifts in temperatures..."
Symmetric Search:
# Use same instruction for both queries and documents
instruction = "Represent this text for semantic search"

query_embs = model.batch_encode(
    ["python tutorial"],
    instruction=instruction
)
doc_embs = model.batch_encode(
    ["Learn Python programming basics"],
    instruction=instruction
)

Configuration Details

The model initializes with the following default configuration:
{
    "embedding_model_name": "GritLM/GritLM-7B",
    "return_cpu": True,
    "return_numpy": True,
    "norm": True,  # Normalize embeddings
    "model_init_params": {
        "model_name_or_path": "GritLM/GritLM-7B",
        "torch_dtype": "auto",
        "device_map": "auto"  # Multi-GPU support
    },
    "encode_params": {
        "batch_size": 16
    },
    "generate_params": {
        # Not yet configured
    }
}

Model Variants

GritLM comes in different sizes:
GritLM/GritLM-7B
7B parameters
Base 7B model, good balance of quality and speed
GritLM/GritLM-8x7B
8x7B parameters
Mixture of experts model for higher quality
Check HuggingFace GritLM for latest models.

Unified Embedding and Generation

GritLM’s key feature is its ability to handle both tasks with a single model:
from remem.embedding_model.GritLM import GritLMEmbeddingModel

model = GritLMEmbeddingModel(
    embedding_model_name="GritLM/GritLM-7B"
)

# Task 1: Embedding for retrieval
query_emb = model.batch_encode(
    ["What is quantum computing?"],
    instruction="Represent this question for retrieval"
)

doc_embs = model.batch_encode([
    "Quantum computing uses quantum mechanics...",
    "Machine learning is a subset of AI..."
])

scores = model.get_query_doc_scores(query_emb[0], doc_embs)
print(f"Relevance scores: {scores}")

# Task 2: Generation (when implemented)
# response = model.batch_generate(chat_messages)

Performance Considerations

GritLM models are larger than specialized embedding models. Ensure adequate GPU memory.
Memory Requirements:
  • GritLM-7B: ~14GB VRAM (fp16)
  • GritLM-8x7B: ~90GB VRAM (fp16)
Optimization Tips:
  1. Device Map: Auto-placement across GPUs with device_map="auto"
  2. Batch Size: Adjust based on available memory
  3. Dtype: Use torch_dtype="auto" for automatic precision selection
  4. Instructions: Use consistent instructions for query/doc pairs

Use Cases

Retrieval-Augmented Generation (RAG):
# 1. Encode documents
doc_embs = model.batch_encode(documents)

# 2. Encode query
query_emb = model.batch_encode(
    [user_query],
    instruction="Given a query, retrieve relevant documents"
)

# 3. Retrieve top-k documents
scores = model.get_query_doc_scores(query_emb[0], doc_embs)
top_k_indices = np.argsort(scores)[-5:][::-1]

# 4. Generate response (when implemented)
# context = [documents[i] for i in top_k_indices]
# response = model.batch_generate(build_chat(user_query, context))

See Also

Build docs developers (and LLMs) love