Skip to main content

Overview

Middleware provides extension points throughout the agent execution lifecycle, allowing you to intercept, modify, and control behavior without changing core agent logic. The middleware system supports both class-based and decorator-based patterns.

Middleware Architecture

Middleware operates at multiple points in the agent loop:
before_agent → before_model → [wrap_model_call] → after_model → [wrap_tool_call] → after_agent
                    ↓              Model Call           ↓           Tool Calls          ↓
                Messages         Response          Messages         Results        Final State

Quick Start with Decorators

The simplest way to add middleware is using decorators:
from langchain.agents import create_agent
from langchain.agents.middleware import before_model, after_model
from langchain.agents.middleware.types import AgentState
from langgraph.runtime import Runtime

@before_model
def log_request(state: AgentState, runtime: Runtime) -> None:
    """Log each model call."""
    print(f"Calling model with {len(state['messages'])} messages")

@after_model  
def track_usage(state: AgentState, runtime: Runtime) -> None:
    """Track token usage after each model response."""
    last_message = state['messages'][-1]
    if hasattr(last_message, 'usage_metadata'):
        print(f"Tokens used: {last_message.usage_metadata}")

agent = create_agent(
    model="openai:gpt-4",
    tools=[search_tool],
    middleware=[log_request, track_usage],
)

Middleware Hooks

before_agent

Executes once before the agent loop starts:
from langchain.agents.middleware import before_agent, AgentState
from langgraph.runtime import Runtime

@before_agent
async def initialize_session(state: AgentState, runtime: Runtime) -> dict:
    """Set up session context before agent starts."""
    user_id = runtime.config.get("configurable", {}).get("user_id")
    
    # Initialize session in database
    session_id = await create_session(user_id)
    
    # Add to state
    return {"session_id": session_id}
state
AgentState
required
Current agent state including messages and custom fields.
runtime
Runtime
required
Runtime context with configuration, stream writer, and metadata.
return
dict[str, Any] | None
State updates to merge into agent state. Return None for no changes.

before_model

Executes before each model invocation:
from langchain.agents.middleware import before_model

@before_model
def add_system_time(state: AgentState, runtime: Runtime) -> dict:
    """Inject current time into context."""
    from datetime import datetime
    
    current_time = datetime.now().isoformat()
    
    # Add system message with timestamp
    return {
        "messages": [SystemMessage(content=f"Current time: {current_time}")]
    }

after_model

Executes after each model response:
from langchain.agents.middleware import after_model

@after_model(can_jump_to=["end"])
def check_completion(state: AgentState, runtime: Runtime) -> dict | None:
    """End agent if response indicates completion."""
    last_message = state["messages"][-1]
    
    if "task complete" in last_message.content.lower():
        runtime.stream_writer({"type": "status", "message": "Task completed"})
        return {"jump_to": "end"}
    
    return None
can_jump_to
list[str]
Allowed jump destinations: ["end", "model", "tools"]. Enables conditional control flow.

after_agent

Executes once after the agent loop completes:
from langchain.agents.middleware import after_agent

@after_agent
async def cleanup_session(state: AgentState, runtime: Runtime) -> None:
    """Clean up resources after agent finishes."""
    session_id = state.get("session_id")
    if session_id:
        await close_session(session_id)

Advanced: wrap_model_call

Intercept and control the actual model execution with full control over retries, fallbacks, and response modification:
from langchain.agents.middleware import wrap_model_call
from langchain.agents.middleware.types import ModelRequest, ModelResponse

@wrap_model_call
def retry_on_error(request: ModelRequest, handler) -> ModelResponse:
    """Retry model calls up to 3 times on errors."""
    for attempt in range(3):
        try:
            return handler(request)
        except Exception as e:
            if attempt == 2:
                # Final attempt failed, return error message
                return ModelResponse(
                    result=[AIMessage(content=f"Error: {str(e)}")]
                )
            time.sleep(2 ** attempt)  # Exponential backoff

ModelRequest Object

model
BaseChatModel
The chat model instance to use for this request.
messages
list[AnyMessage]
Messages to send to the model (excluding system message).
system_message
SystemMessage | None
Optional system message to prepend to the conversation.
tools
list[BaseTool]
Tools available to the model for this request.
tool_choice
Any | None
Tool selection configuration (“auto”, “required”, specific tool).
response_format
ResponseFormat | None
Structured output schema if using structured responses.
state
AgentState
Current agent state with messages and custom fields.
runtime
Runtime
Runtime context with configuration and stream writer.

Modifying Requests

Use request.override() to create modified requests:
from langchain_core.messages import SystemMessage

@wrap_model_call
def inject_instructions(request: ModelRequest, handler) -> ModelResponse:
    """Add custom instructions to every model call."""
    custom_system = SystemMessage(
        content="You are a helpful assistant. Be concise and accurate."
    )
    
    # Create new request with modified system message
    modified_request = request.override(system_message=custom_system)
    
    return handler(modified_request)

Response Transformation

@wrap_model_call
def add_metadata(request: ModelRequest, handler) -> ModelResponse:
    """Add metadata to model responses."""
    response = handler(request)
    
    # Modify response message
    ai_message = response.result[0]
    modified_message = AIMessage(
        content=ai_message.content,
        additional_kwargs={
            **ai_message.additional_kwargs,
            "processing_time": time.time(),
        },
    )
    
    return ModelResponse(
        result=[modified_message],
        structured_response=response.structured_response,
    )

Advanced: wrap_tool_call

Intercept individual tool executions:
from langchain.agents.middleware import wrap_tool_call
from langchain.agents.middleware.types import ToolCallRequest
from langchain_core.messages import ToolMessage

@wrap_tool_call
def log_tool_calls(request: ToolCallRequest, handler) -> ToolMessage:
    """Log tool execution time and results."""
    tool_name = request.tool.name
    start_time = time.time()
    
    try:
        result = handler(request)
        duration = time.time() - start_time
        
        print(f"Tool {tool_name} completed in {duration:.2f}s")
        return result
    except Exception as e:
        duration = time.time() - start_time
        print(f"Tool {tool_name} failed after {duration:.2f}s: {e}")
        raise

ToolCallRequest Object

tool_call
dict
Tool call dictionary with id, name, and args fields.
tool
BaseTool | None
The BaseTool instance if available, None otherwise.
state
AgentState
Current agent state.
runtime
Runtime
Runtime context.

Class-Based Middleware

For complex middleware with state and multiple hooks, extend AgentMiddleware:
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import (
    AgentState,
    ModelRequest,
    ModelResponse,
    ToolCallRequest,
)
from langgraph.runtime import Runtime
import time

class PerformanceMonitor(AgentMiddleware):
    """Monitor agent performance and token usage."""
    
    def __init__(self):
        super().__init__()
        self.model_calls = 0
        self.tool_calls = 0
        self.total_tokens = 0
    
    def before_agent(self, state: AgentState, runtime: Runtime) -> dict:
        """Reset counters at start."""
        self.model_calls = 0
        self.tool_calls = 0
        self.total_tokens = 0
        self.start_time = time.time()
        return None
    
    def wrap_model_call(
        self,
        request: ModelRequest,
        handler,
    ) -> ModelResponse:
        """Count model calls and tokens."""
        self.model_calls += 1
        response = handler(request)
        
        # Track token usage
        if response.result and hasattr(response.result[0], 'usage_metadata'):
            usage = response.result[0].usage_metadata
            self.total_tokens += usage.get('total_tokens', 0)
        
        return response
    
    def wrap_tool_call(self, request: ToolCallRequest, handler) -> ToolMessage:
        """Count tool executions."""
        self.tool_calls += 1
        return handler(request)
    
    def after_agent(self, state: AgentState, runtime: Runtime) -> None:
        """Report final statistics."""
        duration = time.time() - self.start_time
        
        print(f"""\n=== Performance Report ===")
        print(f"Duration: {duration:.2f}s")
        print(f"Model calls: {self.model_calls}")
        print(f"Tool calls: {self.tool_calls}")
        print(f"Total tokens: {self.total_tokens}")

# Usage
agent = create_agent(
    model="openai:gpt-4",
    tools=[search_tool],
    middleware=[PerformanceMonitor()],
)

Built-in Middleware

LangChain provides production-ready middleware for common patterns:

ModelRetryMiddleware

Automatic retry with exponential backoff:
from langchain.agents.middleware import ModelRetryMiddleware
from anthropic import RateLimitError

retry = ModelRetryMiddleware(
    max_retries=3,
    retry_on=(RateLimitError, TimeoutError),
    backoff_factor=2.0,
    initial_delay=1.0,
    max_delay=60.0,
)

agent = create_agent(
    model="anthropic:claude-sonnet-4-5-20250929",
    middleware=[retry],
)

ModelFallbackMiddleware

Fallback to alternative models on errors:
from langchain.agents.middleware import ModelFallbackMiddleware

fallback = ModelFallbackMiddleware(
    "openai:gpt-4o-mini",  # First fallback
    "anthropic:claude-sonnet-4-5-20250929",  # Second fallback
)

agent = create_agent(
    model="openai:gpt-4o",  # Primary model
    middleware=[fallback],
)

ToolRetryMiddleware

Retry failed tool calls:
from langchain.agents.middleware import ToolRetryMiddleware

retry = ToolRetryMiddleware(
    max_retries=2,
    tools=["search_database"],  # Only retry specific tools
    backoff_factor=1.5,
)

ModelCallLimitMiddleware

Prevent infinite loops:
from langchain.agents.middleware import ModelCallLimitMiddleware

limiter = ModelCallLimitMiddleware(
    max_calls=10,
    error_message="Maximum model calls reached. Please refine your request.",
)

Streaming Custom Events

Middleware can emit custom events during execution:
@before_model
async def stream_status(state: AgentState, runtime: Runtime) -> None:
    """Stream custom status updates."""
    runtime.stream_writer({
        "type": "status",
        "message": "Thinking...",
        "timestamp": time.time(),
    })

# Consume custom events
async for mode, event in agent.astream(
    {"messages": [HumanMessage("Hello")]},
    stream_mode=["updates", "custom"],
):
    if mode == "custom":
        print(f"Status: {event['message']}")

Middleware Composition

Middleware executes in order, with first middleware as outermost layer:
agent = create_agent(
    model="openai:gpt-4",
    middleware=[
        rate_limiter,      # Executes first (outermost)
        retry_middleware,  # Then retries
        logger,            # Then logs  
        monitor,           # Finally monitors (innermost)
    ],
)
For wrap_model_call and wrap_tool_call, the first middleware wraps all others.

Best Practices

Each middleware should have a single, clear responsibility:
# Good: Single responsibility
@before_model
def add_timestamp(state, runtime):
    return {"messages": [SystemMessage(content=f"Time: {datetime.now()}")]}

# Bad: Multiple responsibilities
@before_model  
def do_everything(state, runtime):
    # Logging, auth, rate limiting, caching all in one - hard to maintain!
    pass
Implement async versions (abefore_model, awrap_model_call, etc.) for I/O operations:
@before_agent
async def load_user_context(state: AgentState, runtime: Runtime) -> dict:
    """Load user data asynchronously."""
    user_id = runtime.config.get("configurable", {}).get("user_id")
    user_data = await db.fetch_user(user_id)  # Async I/O
    return {"user_context": user_data}
Always handle exceptions gracefully:
@wrap_model_call
def safe_model_call(request: ModelRequest, handler) -> ModelResponse:
    try:
        return handler(request)
    except Exception as e:
        # Log error and return graceful response
        logger.error(f"Model call failed: {e}")
        return ModelResponse(
            result=[AIMessage(content="I encountered an error. Please try again.")]
        )

Next Steps

Custom Tools

Build custom tools with middleware hooks

Rate Limiting

Implement rate limiting with middleware

Performance

Optimize agent performance with monitoring middleware

Build docs developers (and LLMs) love