Skip to main content

Overview

RAG strategies control how REMem indexes documents and performs retrieval for question answering. Each extraction method can have a corresponding strategy that defines its indexing and QA logic.

Strategy Factory Pattern

REMem uses a factory pattern (rag_strategies/factory.py:9-59) to manage strategies:
from typing import Dict, Type
from .base_strategy import RAGStrategy

class RAGStrategyFactory:
    """Factory class for creating RAG strategies based on extraction method."""
    
    _strategies: Dict[str, Type[RAGStrategy]] = {
        "openie": PassageTripleStrategy,
        "episodic_gist": EpisodicGistStrategy,
        "temporal": TemporalStrategy,
    }
    
    @classmethod
    def create_strategy(cls, extract_method: str, remem_instance) -> RAGStrategy:
        """Create a RAG strategy based on the extraction method."""
        if extract_method not in cls._strategies:
            raise ValueError(
                f"Unsupported extraction method: {extract_method}. "
                f"Supported methods: {list(cls._strategies.keys())}"
            )
        
        strategy_class = cls._strategies[extract_method]
        return strategy_class(remem_instance)
    
    @classmethod
    def register_strategy(cls, extract_method: str, strategy_class: Type[RAGStrategy]):
        """Register a new strategy for an extraction method."""
        cls._strategies[extract_method] = strategy_class
    
    @classmethod
    def get_supported_methods(cls) -> list[str]:
        """Get list of supported extraction methods."""
        return list(cls._strategies.keys())

Base Strategy Interface

All strategies inherit from RAGStrategy (rag_strategies/base_strategy.py:7-146):
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union
from remem.utils.misc_utils import QuerySolution

class RAGStrategy(ABC):
    """Abstract base class for RAG strategies."""
    
    def __init__(self, remem_instance):
        self.remem = remem_instance
    
    @abstractmethod
    def index(self, docs: List[str]) -> None:
        """Index documents using the specific strategy."""
        pass
    
    @abstractmethod
    def rag_for_qa(
        self,
        queries: Union[List[str], List[QuerySolution]],
        num_to_retrieve: int = 10,
        gold_answers: Optional[List[List[str]]] = None,
        gold_docs: Optional[List[List[str]]] = None,
        metrics: Tuple[str, ...] = ("qa_em", "qa_f1", "retrieval_recall"),
        question_metadata: Optional[List[Dict]] = None,
        **kwargs,
    ) -> Tuple[List[QuerySolution], List[str], List[Dict], Dict, Dict]:
        """Perform RAG-based QA using the specific strategy."""
        pass
    
    def retrieve_each_query(self, query: str, return_chunk: Optional[str] = None):
        """Retrieve documents for a single query. Can be overridden."""
        return self.remem.retrieve_each_query(query, return_chunk)
    
    def get_graph_info(self) -> Dict:
        """Get statistics about the graph structure."""
        # Implementation in base class
        pass

Built-in Strategies

DefaultRAGStrategy (OpenIE)

From rag_strategies/default_strategy.py:8-60:
class DefaultRAGStrategy(RAGStrategy):
    """Default RAG strategy for standard OpenIE-based extraction."""
    
    def index(self, docs: List[str]) -> None:
        """Index documents using standard OpenIE approach."""
        self.remem.index_original(docs)
    
    def rag_for_qa(
        self,
        queries: Union[List[str], List[QuerySolution]],
        num_to_retrieve: int = 5,
        gold_answers: Optional[List[List[str]]] = None,
        gold_docs: Optional[List[List[str]]] = None,
        metrics: Tuple[str, ...] = ("qa_em", "qa_f1", "retrieval_recall"),
        question_metadata: Optional[List[Dict]] = None,
        to_save: bool = True,
        **kwargs,
    ) -> Tuple[List[QuerySolution], List[str], List[Dict], Dict, Dict]:
        """Perform QA using standard RAG approach."""
        # Retrieve documents
        if not isinstance(queries[0], QuerySolution):
            query_solutions = self.remem.retrieve(queries=queries)
        else:
            query_solutions = queries
        
        # Set metadata
        if question_metadata is not None:
            for idx, q in enumerate(query_solutions):
                q.question_metadata = question_metadata[idx]
        
        # Evaluate retrieval
        qa_evaluators, retrieval_evaluators = self.remem.get_evaluators(
            gold_answers, gold_docs, metrics
        )
        overall_retrieval_metrics = self.remem.evaluate_retrieval(
            gold_docs, query_solutions, retrieval_evaluators
        )
        
        # Perform QA
        query_solutions, all_response_message, all_metadata = self.remem.qa(query_solutions)
        
        # Evaluate QA
        overall_qa_metrics = self.remem.evaluate_qa(
            gold_answers, qa_evaluators, query_solutions, question_metadata
        )
        
        # Save results
        if to_save:
            self.remem.save_rag_results(
                gold_answers, gold_docs, query_solutions, 
                overall_qa_metrics, overall_retrieval_metrics
            )
        
        return query_solutions, all_response_message, all_metadata, \
               overall_retrieval_metrics, overall_qa_metrics

EpisodicGistStrategy

Advanced strategy with gist-based retrieval (rag_strategies/episodic_gist_strategy.py:21-1013):
class EpisodicGistStrategy(RAGStrategy):
    """Strategy for episodic gist-based extraction and retrieval."""
    
    def __init__(self, remem_instance):
        super().__init__(remem_instance)
        self.concatenate_gists_per_chunk = remem_instance.global_config.concatenate_gists_per_chunk
        self.split_verbatim_per_chunk = remem_instance.global_config.split_verbatim_per_chunk
    
    def index(self, docs: List) -> None:
        """Index with episodic gist extraction."""
        # Add chunk embeddings
        self.remem.add_chunk_and_embeddings(docs)
        chunk_dict = self.remem.chunk_embedding_store.hash_id_to_row
        
        # Load or perform extraction
        all_openie_info, chunk_keys_to_process = self.remem.load_existing_openie(
            chunk_dict.keys()
        )
        
        if len(chunk_keys_to_process) > 0:
            ie_results = self.remem.openie.batch_openie(new_openie_rows)
            self.merge_gist_extraction_results(all_openie_info, chunk_keys_to_process, ie_results)
        
        # Build episodic embedding stores
        element_to_encode = defaultdict(list)
        for chunk in episode_results_dict.values():
            # Process verbatim, gists, facts, entities
            # ...
        
        # Construct graph with gist->fact and verbatim->gist edges
        self._augment_episodic_graph()
        self.remem.save_igraph()
    
    def rag_for_qa(self, queries, **kwargs):
        """Perform QA with parallel processing and per-sample evaluation."""
        # Support for:
        # - Parallel query processing
        # - Per-sample saving/loading
        # - Gist-based retrieval
        # - Agent-based reasoning
        # ...

Creating a Custom Strategy

1. Define Your Strategy Class

# src/remem/rag_strategies/my_custom_strategy.py
from typing import Dict, List, Optional, Tuple, Union
from .base_strategy import RAGStrategy
from remem.utils.misc_utils import QuerySolution

class MyCustomStrategy(RAGStrategy):
    """Custom RAG strategy with specialized retrieval logic."""
    
    def __init__(self, remem_instance):
        super().__init__(remem_instance)
        # Initialize strategy-specific parameters
        self.custom_param = remem_instance.global_config.custom_param
    
    def index(self, docs: List[str]) -> None:
        """Custom indexing logic."""
        # 1. Add chunks and embeddings
        self.remem.add_chunk_and_embeddings(docs)
        
        # 2. Run extraction
        chunk_dict = self.remem.chunk_embedding_store.hash_id_to_row
        ie_results = self.remem.openie.batch_openie(chunk_dict)
        
        # 3. Build custom data structures
        self._build_custom_index(ie_results)
        
        # 4. Construct graph
        self._augment_custom_graph()
        self.remem.save_igraph()
    
    def rag_for_qa(
        self,
        queries: Union[List[str], List[QuerySolution]],
        num_to_retrieve: int = 10,
        gold_answers: Optional[List[List[str]]] = None,
        gold_docs: Optional[List[List[str]]] = None,
        metrics: Tuple[str, ...] = ("qa_em", "qa_f1"),
        **kwargs,
    ) -> Tuple[List[QuerySolution], List[str], List[Dict], Dict, Dict]:
        """Custom QA logic."""
        query_solutions = []
        
        for query in queries:
            # 1. Custom retrieval
            docs, scores = self._custom_retrieve(query, num_to_retrieve)
            
            # 2. Create QuerySolution
            query_solution = QuerySolution(
                question=query,
                docs=docs,
                doc_scores=scores,
            )
            query_solutions.append(query_solution)
        
        # 3. Generate answers
        query_solutions, responses, metadata = self.remem.qa(query_solutions)
        
        # 4. Evaluate
        qa_evaluators, retrieval_evaluators = self.remem.get_evaluators(
            gold_answers, gold_docs, metrics
        )
        overall_qa_metrics = self.remem.evaluate_qa(
            gold_answers, qa_evaluators, query_solutions, None
        )
        
        return query_solutions, responses, metadata, {}, overall_qa_metrics
    
    def _custom_retrieve(self, query: str, k: int):
        """Implement custom retrieval logic."""
        # Example: Combine semantic search with custom ranking
        # 1. Get initial candidates
        candidates = self.remem.chunk_embedding_store.search(query, k=k*2)
        
        # 2. Apply custom reranking
        reranked = self._custom_rerank(query, candidates)
        
        # 3. Return top-k
        return reranked[:k]
    
    def _build_custom_index(self, ie_results):
        """Build strategy-specific indices."""
        # Example: Build entity co-occurrence matrix
        pass
    
    def _augment_custom_graph(self):
        """Construct custom graph structure."""
        # Example: Add weighted edges based on custom similarity
        pass

2. Register Your Strategy

# In your application code or remem/__init__.py
from remem.rag_strategies.factory import RAGStrategyFactory
from remem.rag_strategies.my_custom_strategy import MyCustomStrategy

# Register the strategy
RAGStrategyFactory.register_strategy("my_custom", MyCustomStrategy)

3. Use Your Strategy

from remem.remem import ReMem
from remem.utils.config_utils import BaseConfig

config = BaseConfig(
    dataset="test",
    extract_method="my_custom",  # Must match registered name
    llm_name="gpt-4o-mini",
    custom_param="value",  # Strategy-specific params
)

rag = ReMem(global_config=config)
docs = ["Document 1", "Document 2"]
rag.index(docs)

queries = ["What is in the documents?"]
solutions, responses, meta, ret_metrics, qa_metrics = rag.rag_for_qa(
    queries=queries,
    gold_answers=[["Answer"]],
)

Advanced: Multi-Step Retrieval

Example from EpisodicGistStrategy:
def _rag_each_query(self, remem, query, return_chunk="gists", **kwargs):
    """Multi-step retrieval with gist-based exploration."""
    # Step 1: Initial gist retrieval
    gist_results = remem.episodic_embedding_stores["gists"].search(
        query, k=20
    )
    
    # Step 2: Expand via graph
    expanded_facts = []
    for gist_id in gist_results:
        # Find connected facts in graph
        neighbors = remem.graph.neighbors(gist_id)
        fact_neighbors = [n for n in neighbors if n.startswith("facts-")]
        expanded_facts.extend(fact_neighbors)
    
    # Step 3: Rerank and return
    if return_chunk == "verbatim":
        # Map back to verbatim chunks
        verbatim_ids = self._map_to_verbatim(gist_results)
        return verbatim_ids
    else:
        return gist_results

Strategy-Specific Configuration

Add custom config fields:
from dataclasses import dataclass
from remem.utils.config_utils import BaseConfig

@dataclass
class MyCustomConfig(BaseConfig):
    # Strategy-specific fields
    custom_rerank_weight: float = 0.5
    custom_expansion_hops: int = 2
    custom_threshold: float = 0.7
Use in strategy:
class MyCustomStrategy(RAGStrategy):
    def __init__(self, remem_instance):
        super().__init__(remem_instance)
        self.rerank_weight = remem_instance.global_config.custom_rerank_weight
        self.expansion_hops = remem_instance.global_config.custom_expansion_hops

Parallel Processing

From EpisodicGistStrategy.rag_for_qa():
def rag_for_qa(self, queries, parallel=True, max_workers=8, **kwargs):
    """Process queries in parallel."""
    if parallel:
        from concurrent.futures import ThreadPoolExecutor, as_completed
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_idx = {
                executor.submit(self._process_single_query, args): idx 
                for idx, args in enumerate(query_args)
            }
            
            for future in tqdm(as_completed(future_to_idx), total=len(queries)):
                idx = future_to_idx[future]
                query_solutions[idx] = future.result()
    else:
        # Sequential processing
        for idx, query in tqdm(enumerate(queries)):
            query_solutions[idx] = self._process_single_query(query)

Helper: Get Graph Info

The base class provides get_graph_info() (base_strategy.py:68-146):
def get_graph_info(self) -> Dict:
    """Get statistics about the graph."""
    graph_info = {}
    
    # Count phrase nodes
    phrase_nodes = self.remem.phrase_embedding_store.get_all_ids()
    graph_info["num_phrase_nodes"] = len(set(phrase_nodes))
    
    # Count passage nodes
    passage_nodes = self.remem.chunk_embedding_store.get_all_ids()
    graph_info["num_passage_nodes"] = len(set(passage_nodes))
    
    # Count edges
    graph_info["num_extracted_edges"] = len(
        self.remem.triple_embedding_store.get_all_ids()
    )
    graph_info["num_total_edges"] = len(self.remem.node_to_node_count)
    
    return graph_info

Next Steps

Build docs developers (and LLMs) love