Overview
RAG strategies control how REMem indexes documents and performs retrieval for question answering. Each extraction method can have a corresponding strategy that defines its indexing and QA logic.Strategy Factory Pattern
REMem uses a factory pattern (rag_strategies/factory.py:9-59) to manage strategies:
from typing import Dict, Type
from .base_strategy import RAGStrategy
class RAGStrategyFactory:
"""Factory class for creating RAG strategies based on extraction method."""
_strategies: Dict[str, Type[RAGStrategy]] = {
"openie": PassageTripleStrategy,
"episodic_gist": EpisodicGistStrategy,
"temporal": TemporalStrategy,
}
@classmethod
def create_strategy(cls, extract_method: str, remem_instance) -> RAGStrategy:
"""Create a RAG strategy based on the extraction method."""
if extract_method not in cls._strategies:
raise ValueError(
f"Unsupported extraction method: {extract_method}. "
f"Supported methods: {list(cls._strategies.keys())}"
)
strategy_class = cls._strategies[extract_method]
return strategy_class(remem_instance)
@classmethod
def register_strategy(cls, extract_method: str, strategy_class: Type[RAGStrategy]):
"""Register a new strategy for an extraction method."""
cls._strategies[extract_method] = strategy_class
@classmethod
def get_supported_methods(cls) -> list[str]:
"""Get list of supported extraction methods."""
return list(cls._strategies.keys())
Base Strategy Interface
All strategies inherit fromRAGStrategy (rag_strategies/base_strategy.py:7-146):
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union
from remem.utils.misc_utils import QuerySolution
class RAGStrategy(ABC):
"""Abstract base class for RAG strategies."""
def __init__(self, remem_instance):
self.remem = remem_instance
@abstractmethod
def index(self, docs: List[str]) -> None:
"""Index documents using the specific strategy."""
pass
@abstractmethod
def rag_for_qa(
self,
queries: Union[List[str], List[QuerySolution]],
num_to_retrieve: int = 10,
gold_answers: Optional[List[List[str]]] = None,
gold_docs: Optional[List[List[str]]] = None,
metrics: Tuple[str, ...] = ("qa_em", "qa_f1", "retrieval_recall"),
question_metadata: Optional[List[Dict]] = None,
**kwargs,
) -> Tuple[List[QuerySolution], List[str], List[Dict], Dict, Dict]:
"""Perform RAG-based QA using the specific strategy."""
pass
def retrieve_each_query(self, query: str, return_chunk: Optional[str] = None):
"""Retrieve documents for a single query. Can be overridden."""
return self.remem.retrieve_each_query(query, return_chunk)
def get_graph_info(self) -> Dict:
"""Get statistics about the graph structure."""
# Implementation in base class
pass
Built-in Strategies
DefaultRAGStrategy (OpenIE)
Fromrag_strategies/default_strategy.py:8-60:
class DefaultRAGStrategy(RAGStrategy):
"""Default RAG strategy for standard OpenIE-based extraction."""
def index(self, docs: List[str]) -> None:
"""Index documents using standard OpenIE approach."""
self.remem.index_original(docs)
def rag_for_qa(
self,
queries: Union[List[str], List[QuerySolution]],
num_to_retrieve: int = 5,
gold_answers: Optional[List[List[str]]] = None,
gold_docs: Optional[List[List[str]]] = None,
metrics: Tuple[str, ...] = ("qa_em", "qa_f1", "retrieval_recall"),
question_metadata: Optional[List[Dict]] = None,
to_save: bool = True,
**kwargs,
) -> Tuple[List[QuerySolution], List[str], List[Dict], Dict, Dict]:
"""Perform QA using standard RAG approach."""
# Retrieve documents
if not isinstance(queries[0], QuerySolution):
query_solutions = self.remem.retrieve(queries=queries)
else:
query_solutions = queries
# Set metadata
if question_metadata is not None:
for idx, q in enumerate(query_solutions):
q.question_metadata = question_metadata[idx]
# Evaluate retrieval
qa_evaluators, retrieval_evaluators = self.remem.get_evaluators(
gold_answers, gold_docs, metrics
)
overall_retrieval_metrics = self.remem.evaluate_retrieval(
gold_docs, query_solutions, retrieval_evaluators
)
# Perform QA
query_solutions, all_response_message, all_metadata = self.remem.qa(query_solutions)
# Evaluate QA
overall_qa_metrics = self.remem.evaluate_qa(
gold_answers, qa_evaluators, query_solutions, question_metadata
)
# Save results
if to_save:
self.remem.save_rag_results(
gold_answers, gold_docs, query_solutions,
overall_qa_metrics, overall_retrieval_metrics
)
return query_solutions, all_response_message, all_metadata, \
overall_retrieval_metrics, overall_qa_metrics
EpisodicGistStrategy
Advanced strategy with gist-based retrieval (rag_strategies/episodic_gist_strategy.py:21-1013):
class EpisodicGistStrategy(RAGStrategy):
"""Strategy for episodic gist-based extraction and retrieval."""
def __init__(self, remem_instance):
super().__init__(remem_instance)
self.concatenate_gists_per_chunk = remem_instance.global_config.concatenate_gists_per_chunk
self.split_verbatim_per_chunk = remem_instance.global_config.split_verbatim_per_chunk
def index(self, docs: List) -> None:
"""Index with episodic gist extraction."""
# Add chunk embeddings
self.remem.add_chunk_and_embeddings(docs)
chunk_dict = self.remem.chunk_embedding_store.hash_id_to_row
# Load or perform extraction
all_openie_info, chunk_keys_to_process = self.remem.load_existing_openie(
chunk_dict.keys()
)
if len(chunk_keys_to_process) > 0:
ie_results = self.remem.openie.batch_openie(new_openie_rows)
self.merge_gist_extraction_results(all_openie_info, chunk_keys_to_process, ie_results)
# Build episodic embedding stores
element_to_encode = defaultdict(list)
for chunk in episode_results_dict.values():
# Process verbatim, gists, facts, entities
# ...
# Construct graph with gist->fact and verbatim->gist edges
self._augment_episodic_graph()
self.remem.save_igraph()
def rag_for_qa(self, queries, **kwargs):
"""Perform QA with parallel processing and per-sample evaluation."""
# Support for:
# - Parallel query processing
# - Per-sample saving/loading
# - Gist-based retrieval
# - Agent-based reasoning
# ...
Creating a Custom Strategy
1. Define Your Strategy Class
# src/remem/rag_strategies/my_custom_strategy.py
from typing import Dict, List, Optional, Tuple, Union
from .base_strategy import RAGStrategy
from remem.utils.misc_utils import QuerySolution
class MyCustomStrategy(RAGStrategy):
"""Custom RAG strategy with specialized retrieval logic."""
def __init__(self, remem_instance):
super().__init__(remem_instance)
# Initialize strategy-specific parameters
self.custom_param = remem_instance.global_config.custom_param
def index(self, docs: List[str]) -> None:
"""Custom indexing logic."""
# 1. Add chunks and embeddings
self.remem.add_chunk_and_embeddings(docs)
# 2. Run extraction
chunk_dict = self.remem.chunk_embedding_store.hash_id_to_row
ie_results = self.remem.openie.batch_openie(chunk_dict)
# 3. Build custom data structures
self._build_custom_index(ie_results)
# 4. Construct graph
self._augment_custom_graph()
self.remem.save_igraph()
def rag_for_qa(
self,
queries: Union[List[str], List[QuerySolution]],
num_to_retrieve: int = 10,
gold_answers: Optional[List[List[str]]] = None,
gold_docs: Optional[List[List[str]]] = None,
metrics: Tuple[str, ...] = ("qa_em", "qa_f1"),
**kwargs,
) -> Tuple[List[QuerySolution], List[str], List[Dict], Dict, Dict]:
"""Custom QA logic."""
query_solutions = []
for query in queries:
# 1. Custom retrieval
docs, scores = self._custom_retrieve(query, num_to_retrieve)
# 2. Create QuerySolution
query_solution = QuerySolution(
question=query,
docs=docs,
doc_scores=scores,
)
query_solutions.append(query_solution)
# 3. Generate answers
query_solutions, responses, metadata = self.remem.qa(query_solutions)
# 4. Evaluate
qa_evaluators, retrieval_evaluators = self.remem.get_evaluators(
gold_answers, gold_docs, metrics
)
overall_qa_metrics = self.remem.evaluate_qa(
gold_answers, qa_evaluators, query_solutions, None
)
return query_solutions, responses, metadata, {}, overall_qa_metrics
def _custom_retrieve(self, query: str, k: int):
"""Implement custom retrieval logic."""
# Example: Combine semantic search with custom ranking
# 1. Get initial candidates
candidates = self.remem.chunk_embedding_store.search(query, k=k*2)
# 2. Apply custom reranking
reranked = self._custom_rerank(query, candidates)
# 3. Return top-k
return reranked[:k]
def _build_custom_index(self, ie_results):
"""Build strategy-specific indices."""
# Example: Build entity co-occurrence matrix
pass
def _augment_custom_graph(self):
"""Construct custom graph structure."""
# Example: Add weighted edges based on custom similarity
pass
2. Register Your Strategy
# In your application code or remem/__init__.py
from remem.rag_strategies.factory import RAGStrategyFactory
from remem.rag_strategies.my_custom_strategy import MyCustomStrategy
# Register the strategy
RAGStrategyFactory.register_strategy("my_custom", MyCustomStrategy)
3. Use Your Strategy
from remem.remem import ReMem
from remem.utils.config_utils import BaseConfig
config = BaseConfig(
dataset="test",
extract_method="my_custom", # Must match registered name
llm_name="gpt-4o-mini",
custom_param="value", # Strategy-specific params
)
rag = ReMem(global_config=config)
docs = ["Document 1", "Document 2"]
rag.index(docs)
queries = ["What is in the documents?"]
solutions, responses, meta, ret_metrics, qa_metrics = rag.rag_for_qa(
queries=queries,
gold_answers=[["Answer"]],
)
Advanced: Multi-Step Retrieval
Example fromEpisodicGistStrategy:
def _rag_each_query(self, remem, query, return_chunk="gists", **kwargs):
"""Multi-step retrieval with gist-based exploration."""
# Step 1: Initial gist retrieval
gist_results = remem.episodic_embedding_stores["gists"].search(
query, k=20
)
# Step 2: Expand via graph
expanded_facts = []
for gist_id in gist_results:
# Find connected facts in graph
neighbors = remem.graph.neighbors(gist_id)
fact_neighbors = [n for n in neighbors if n.startswith("facts-")]
expanded_facts.extend(fact_neighbors)
# Step 3: Rerank and return
if return_chunk == "verbatim":
# Map back to verbatim chunks
verbatim_ids = self._map_to_verbatim(gist_results)
return verbatim_ids
else:
return gist_results
Strategy-Specific Configuration
Add custom config fields:from dataclasses import dataclass
from remem.utils.config_utils import BaseConfig
@dataclass
class MyCustomConfig(BaseConfig):
# Strategy-specific fields
custom_rerank_weight: float = 0.5
custom_expansion_hops: int = 2
custom_threshold: float = 0.7
class MyCustomStrategy(RAGStrategy):
def __init__(self, remem_instance):
super().__init__(remem_instance)
self.rerank_weight = remem_instance.global_config.custom_rerank_weight
self.expansion_hops = remem_instance.global_config.custom_expansion_hops
Parallel Processing
FromEpisodicGistStrategy.rag_for_qa():
def rag_for_qa(self, queries, parallel=True, max_workers=8, **kwargs):
"""Process queries in parallel."""
if parallel:
from concurrent.futures import ThreadPoolExecutor, as_completed
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_idx = {
executor.submit(self._process_single_query, args): idx
for idx, args in enumerate(query_args)
}
for future in tqdm(as_completed(future_to_idx), total=len(queries)):
idx = future_to_idx[future]
query_solutions[idx] = future.result()
else:
# Sequential processing
for idx, query in tqdm(enumerate(queries)):
query_solutions[idx] = self._process_single_query(query)
Helper: Get Graph Info
The base class providesget_graph_info() (base_strategy.py:68-146):
def get_graph_info(self) -> Dict:
"""Get statistics about the graph."""
graph_info = {}
# Count phrase nodes
phrase_nodes = self.remem.phrase_embedding_store.get_all_ids()
graph_info["num_phrase_nodes"] = len(set(phrase_nodes))
# Count passage nodes
passage_nodes = self.remem.chunk_embedding_store.get_all_ids()
graph_info["num_passage_nodes"] = len(set(passage_nodes))
# Count edges
graph_info["num_extracted_edges"] = len(
self.remem.triple_embedding_store.get_all_ids()
)
graph_info["num_total_edges"] = len(self.remem.node_to_node_count)
return graph_info
Next Steps
- Custom Extraction - Define what gets extracted
- Custom Prompts - Control LLM behavior in QA
- Custom Metrics - Evaluate strategy performance
- Architecture - Understand the full pipeline