Skip to main content
Middleware provides a powerful way to intercept and modify the behavior of agents, chat clients, and function calls. Use middleware for logging, authentication, caching, error handling, and more.

Types of Middleware

The Agent Framework supports three types of middleware:

Agent Middleware

Intercepts agent run() calls:
from agent_framework import AgentMiddleware, AgentContext

class LoggingAgentMiddleware(AgentMiddleware):
    async def process(
        self,
        context: AgentContext,
        call_next
    ) -> None:
        print(f"Agent: {context.agent.name}")
        print(f"Input: {context.messages}")
        
        await call_next()
        
        print(f"Output: {context.result}")

Chat Middleware

Intercepts chat client get_response() calls:
from agent_framework import ChatMiddleware, ChatContext

class CachingChatMiddleware(ChatMiddleware):
    def __init__(self):
        self.cache = {}
    
    async def process(
        self,
        context: ChatContext,
        call_next
    ) -> None:
        # Create cache key from messages
        cache_key = hash(tuple(m.text for m in context.messages))
        
        if cache_key in self.cache:
            # Return cached response
            context.result = self.cache[cache_key]
        else:
            # Call the model
            await call_next()
            # Cache the result
            self.cache[cache_key] = context.result

Function Middleware

Intercepts tool/function executions:
from agent_framework import FunctionMiddleware, FunctionInvocationContext
import time

class TimingFunctionMiddleware(FunctionMiddleware):
    async def process(
        self,
        context: FunctionInvocationContext,
        call_next
    ) -> None:
        function_name = context.function.name
        
        start = time.time()
        await call_next()
        duration = time.time() - start
        
        print(f"Function {function_name} took {duration:.2f}s")

Creating Middleware

Class-Based Middleware

Inherit from the appropriate base class:
from agent_framework import AgentMiddleware, AgentContext

class SecurityAgentMiddleware(AgentMiddleware):
    async def process(
        self,
        context: AgentContext,
        call_next
    ) -> None:
        # Check for security violations
        last_message = context.messages[-1] if context.messages else None
        if last_message and last_message.text:
            query = last_message.text.lower()
            if "password" in query or "secret" in query:
                # Block the request
                context.result = AgentResponse(
                    messages=[Message(
                        "assistant",
                        ["I cannot process requests involving sensitive information."]
                    )]
                )
                return  # Don't call call_next()
        
        # Security check passed
        await call_next()

Function-Based Middleware

Use type hints to indicate middleware type:
from agent_framework import AgentContext, FunctionInvocationContext

async def logging_agent_middleware(
    context: AgentContext,
    call_next
) -> None:
    """Agent middleware using type hint detection."""
    print(f"Before: {context.messages}")
    await call_next()
    print(f"After: {context.result}")

async def logging_function_middleware(
    context: FunctionInvocationContext,
    call_next
) -> None:
    """Function middleware using type hint detection."""
    print(f"Calling: {context.function.name}")
    await call_next()
    print(f"Completed: {context.function.name}")

Decorator-Based Middleware

Use decorators to explicitly mark middleware type:
from agent_framework import (
    agent_middleware,
    chat_middleware,
    function_middleware
)

@agent_middleware
async def simple_agent_middleware(context, call_next):  # type: ignore
    """No type hints needed with decorator."""
    print("Agent middleware executed")
    await call_next()

@chat_middleware
async def simple_chat_middleware(context, call_next):  # type: ignore
    """Chat middleware marked by decorator."""
    print("Chat middleware executed")
    await call_next()

@function_middleware
async def simple_function_middleware(context, call_next):  # type: ignore
    """Function middleware marked by decorator."""
    print(f"Function {context.function.name} called")  # type: ignore
    await call_next()

Using Middleware

Agent-Level Middleware

Apply middleware to all agent runs:
from agent_framework.openai import OpenAIResponsesClient

agent = OpenAIResponsesClient().as_agent(
    name="Assistant",
    instructions="You are helpful.",
    middleware=[
        LoggingAgentMiddleware(),
        SecurityAgentMiddleware(),
        TimingFunctionMiddleware()
    ]
)

response = await agent.run("Hello")  # Middleware applies

Run-Level Middleware

Apply middleware to a specific run:
agent = client.as_agent(
    name="Assistant",
    instructions="You are helpful."
)

# No middleware
response1 = await agent.run("Query 1")

# With run-level middleware
response2 = await agent.run(
    "Query 2",
    middleware=[SpecialMiddleware()]
)

Middleware Order

Middleware executes in the order specified:
agent = client.as_agent(
    name="Assistant",
    middleware=[
        middleware1,  # Executes first (outermost)
        middleware2,
        middleware3   # Executes last (innermost)
    ]
)

# Execution order:
# middleware1.process -> before
#   middleware2.process -> before
#     middleware3.process -> before
#       [AGENT RUNS]
#     middleware3.process -> after
#   middleware2.process -> after
# middleware1.process -> after

Middleware Context

AgentContext

class AgentContext:
    agent: SupportsAgentRun  # The agent being invoked
    messages: list[Message]  # Input messages
    session: AgentSession | None  # Session (if provided)
    tools: list[ToolTypes] | None  # Available tools
    options: ChatOptions | None  # Model options
    stream: bool  # Whether streaming is enabled
    result: AgentResponse | None  # Set by middleware or agent

ChatContext

class ChatContext:
    client: SupportsChatGetResponse  # The chat client
    messages: list[Message]  # Input messages
    tools: list[ToolTypes] | None  # Available tools
    options: ChatOptions | None  # Model options
    stream: bool  # Whether streaming is enabled
    result: ChatResponse | None  # Set by middleware or client

FunctionInvocationContext

class FunctionInvocationContext:
    function: FunctionTool  # The function being called
    arguments: dict[str, Any]  # Function arguments
    session: AgentSession | None  # Session (if available)
    result: Any  # Set by middleware or function
    error: Exception | None  # Set if function raises error

Advanced Patterns

Modifying Inputs

from agent_framework import AgentMiddleware, AgentContext, Message

class InputTransformMiddleware(AgentMiddleware):
    async def process(
        self,
        context: AgentContext,
        call_next
    ) -> None:
        # Add system message to every request
        system_msg = Message(
            "system",
            ["Always respond in a professional tone."]
        )
        context.messages = [system_msg] + context.messages
        
        await call_next()

Modifying Outputs

class OutputFilterMiddleware(AgentMiddleware):
    async def process(
        self,
        context: AgentContext,
        call_next
    ) -> None:
        await call_next()
        
        # Filter sensitive data from response
        if context.result:
            filtered_text = self.filter_sensitive_data(context.result.text)
            context.result.messages[0].content[0].text = filtered_text
    
    def filter_sensitive_data(self, text: str) -> str:
        # Remove email addresses
        import re
        return re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]', text)

Short-Circuiting Execution

Return a result without calling the agent:
class CacheMiddleware(AgentMiddleware):
    def __init__(self):
        self.cache = {}
    
    async def process(
        self,
        context: AgentContext,
        call_next
    ) -> None:
        # Create cache key
        cache_key = hash(tuple(m.text for m in context.messages))
        
        if cache_key in self.cache:
            # Return cached result, don't call agent
            context.result = self.cache[cache_key]
            return  # Skip call_next()
        
        # Not cached, call agent
        await call_next()
        
        # Cache the result
        if context.result:
            self.cache[cache_key] = context.result

Middleware Termination

Explicitly terminate execution:
from agent_framework import MiddlewareTermination, AgentResponse, Message

class TerminationMiddleware(AgentMiddleware):
    async def process(
        self,
        context: AgentContext,
        call_next
    ) -> None:
        if should_terminate(context):
            # Raise termination exception with result
            raise MiddlewareTermination(
                "Request terminated by policy",
                result=AgentResponse(
                    messages=[Message("assistant", ["Request denied."])]
                )
            )
        
        await call_next()

Exception Handling

class ErrorHandlingMiddleware(AgentMiddleware):
    async def process(
        self,
        context: AgentContext,
        call_next
    ) -> None:
        try:
            await call_next()
        except Exception as e:
            # Log the error
            logger.error(f"Agent failed: {e}", exc_info=True)
            
            # Return user-friendly error message
            context.result = AgentResponse(
                messages=[Message(
                    "assistant",
                    ["I encountered an error. Please try again."]
                )]
            )

Shared State Between Middleware

class SharedStateMiddleware(AgentMiddleware):
    def __init__(self, shared_state: dict):
        self.shared_state = shared_state
    
    async def process(
        self,
        context: AgentContext,
        call_next
    ) -> None:
        # Store data for other middleware
        self.shared_state["request_time"] = time.time()
        
        await call_next()

# Create shared state
shared = {}

agent = client.as_agent(
    name="Assistant",
    middleware=[
        SharedStateMiddleware(shared),
        AnotherMiddleware(shared)  # Can access same state
    ]
)

Session Behavior Middleware

from agent_framework import AgentSession

class SessionBehaviorMiddleware(AgentMiddleware):
    async def process(
        self,
        context: AgentContext,
        call_next
    ) -> None:
        if context.session:
            # Track interaction count
            count = context.session.state.get("interaction_count", 0)
            context.session.state["interaction_count"] = count + 1
            
            # Modify behavior based on session history
            if count > 10:
                # Add context for long conversations
                context.messages.insert(0, Message(
                    "system",
                    ["This is a long conversation. Be concise."]
                ))
        
        await call_next()

Runtime Context Delegation

class DelegationMiddleware(AgentMiddleware):
    def __init__(self, specialist_agent):
        self.specialist = specialist_agent
    
    async def process(
        self,
        context: AgentContext,
        call_next
    ) -> None:
        # Check if query should be delegated
        last_message = context.messages[-1]
        if last_message.text and "technical" in last_message.text.lower():
            # Delegate to specialist agent
            result = await self.specialist.run(
                last_message.text,
                session=context.session
            )
            context.result = result
            return  # Don't call original agent
        
        await call_next()

Combining Middleware Types

from agent_framework import tool

@tool(approval_mode="always_require")
def get_data() -> str:
    return "Data"

agent = client.as_agent(
    name="Assistant",
    instructions="You are helpful.",
    tools=[get_data],
    middleware=[
        LoggingAgentMiddleware(),      # Logs agent runs
        CachingChatMiddleware(),        # Caches chat responses
        TimingFunctionMiddleware()      # Times function calls
    ]
)

# All three middleware types work together:
# - Agent middleware runs for agent.run()
# - Chat middleware runs for underlying chat client
# - Function middleware runs when tools are called
response = await agent.run("Get me some data")

Best Practices

Each middleware should have a single, well-defined responsibility. This makes them easier to test and compose.
Prefer using call_next() and only modify context.result when you need to override behavior.
When context.stream is True, be aware that the result will be an async iterable:
async def process(self, context: AgentContext, call_next) -> None:
    await call_next()
    
    if context.stream:
        # context.result is AsyncIterable[AgentResponseUpdate]
        pass
    else:
        # context.result is AgentResponse
        pass
  • Use Agent Middleware for high-level agent behavior (auth, routing)
  • Use Chat Middleware for model interaction (caching, retries)
  • Use Function Middleware for tool execution (logging, validation)
The order matters. Place authentication/authorization middleware before others:
middleware=[
    AuthMiddleware(),      # First: check auth
    CachingMiddleware(),   # Second: check cache
    LoggingMiddleware()    # Last: log everything
]

API Reference

AgentMiddleware Base Class

class AgentMiddleware:
    async def process(
        self,
        context: AgentContext,
        call_next: Callable[[], Awaitable[None]]
    ) -> None:
        """Process agent invocation.
        
        Args:
            context: Agent context with input/output data
            call_next: Call to continue middleware pipeline
        """
        await call_next()

ChatMiddleware Base Class

class ChatMiddleware:
    async def process(
        self,
        context: ChatContext,
        call_next: Callable[[], Awaitable[None]]
    ) -> None:
        """Process chat client invocation.
        
        Args:
            context: Chat context with input/output data
            call_next: Call to continue middleware pipeline
        """
        await call_next()

FunctionMiddleware Base Class

class FunctionMiddleware:
    async def process(
        self,
        context: FunctionInvocationContext,
        call_next: Callable[[], Awaitable[None]]
    ) -> None:
        """Process function/tool invocation.
        
        Args:
            context: Function context with input/output data
            call_next: Call to continue middleware pipeline
        """
        await call_next()

Examples

Authentication Middleware

from agent_framework import AgentMiddleware, AgentContext, Message

class AuthMiddleware(AgentMiddleware):
    def __init__(self, api_key: str):
        self.api_key = api_key
    
    async def process(
        self,
        context: AgentContext,
        call_next
    ) -> None:
        # Extract user token from session
        if not context.session:
            context.result = AgentResponse(
                messages=[Message("assistant", ["Authentication required."])]
            )
            return
        
        user_token = context.session.state.get("auth_token")
        if not self.verify_token(user_token):
            context.result = AgentResponse(
                messages=[Message("assistant", ["Invalid authentication."])]
            )
            return
        
        # Authenticated, proceed
        await call_next()
    
    def verify_token(self, token: str | None) -> bool:
        # Token verification logic
        return token == self.api_key

# Use the middleware
agent = client.as_agent(
    name="SecureAgent",
    middleware=[AuthMiddleware(api_key="secret123")]
)

session = AgentSession()
session.state["auth_token"] = "secret123"

response = await agent.run("Hello", session=session)

Rate Limiting Middleware

import time
from collections import defaultdict

class RateLimitMiddleware(AgentMiddleware):
    def __init__(self, max_requests: int, window_seconds: int):
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self.requests: dict[str, list[float]] = defaultdict(list)
    
    async def process(
        self,
        context: AgentContext,
        call_next
    ) -> None:
        # Get user ID from session
        user_id = context.session.state.get("user_id", "anonymous") if context.session else "anonymous"
        
        # Clean old requests
        now = time.time()
        cutoff = now - self.window_seconds
        self.requests[user_id] = [
            t for t in self.requests[user_id] if t > cutoff
        ]
        
        # Check rate limit
        if len(self.requests[user_id]) >= self.max_requests:
            context.result = AgentResponse(
                messages=[Message(
                    "assistant",
                    [f"Rate limit exceeded. Try again in {self.window_seconds} seconds."]
                )]
            )
            return
        
        # Record request
        self.requests[user_id].append(now)
        
        # Proceed
        await call_next()

# Use the middleware
agent = client.as_agent(
    name="LimitedAgent",
    middleware=[RateLimitMiddleware(max_requests=10, window_seconds=60)]
)
  • Agents - Using middleware with agents
  • Tools - Function middleware for tools
  • Workflows - Middleware in workflows

Build docs developers (and LLMs) love