Types of Middleware
The Agent Framework supports three types of middleware:Agent Middleware
Intercepts agentrun() calls:
from agent_framework import AgentMiddleware, AgentContext
class LoggingAgentMiddleware(AgentMiddleware):
async def process(
self,
context: AgentContext,
call_next
) -> None:
print(f"Agent: {context.agent.name}")
print(f"Input: {context.messages}")
await call_next()
print(f"Output: {context.result}")
Chat Middleware
Intercepts chat clientget_response() calls:
from agent_framework import ChatMiddleware, ChatContext
class CachingChatMiddleware(ChatMiddleware):
def __init__(self):
self.cache = {}
async def process(
self,
context: ChatContext,
call_next
) -> None:
# Create cache key from messages
cache_key = hash(tuple(m.text for m in context.messages))
if cache_key in self.cache:
# Return cached response
context.result = self.cache[cache_key]
else:
# Call the model
await call_next()
# Cache the result
self.cache[cache_key] = context.result
Function Middleware
Intercepts tool/function executions:from agent_framework import FunctionMiddleware, FunctionInvocationContext
import time
class TimingFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
call_next
) -> None:
function_name = context.function.name
start = time.time()
await call_next()
duration = time.time() - start
print(f"Function {function_name} took {duration:.2f}s")
Creating Middleware
Class-Based Middleware
Inherit from the appropriate base class:from agent_framework import AgentMiddleware, AgentContext
class SecurityAgentMiddleware(AgentMiddleware):
async def process(
self,
context: AgentContext,
call_next
) -> None:
# Check for security violations
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text.lower()
if "password" in query or "secret" in query:
# Block the request
context.result = AgentResponse(
messages=[Message(
"assistant",
["I cannot process requests involving sensitive information."]
)]
)
return # Don't call call_next()
# Security check passed
await call_next()
Function-Based Middleware
Use type hints to indicate middleware type:from agent_framework import AgentContext, FunctionInvocationContext
async def logging_agent_middleware(
context: AgentContext,
call_next
) -> None:
"""Agent middleware using type hint detection."""
print(f"Before: {context.messages}")
await call_next()
print(f"After: {context.result}")
async def logging_function_middleware(
context: FunctionInvocationContext,
call_next
) -> None:
"""Function middleware using type hint detection."""
print(f"Calling: {context.function.name}")
await call_next()
print(f"Completed: {context.function.name}")
Decorator-Based Middleware
Use decorators to explicitly mark middleware type:from agent_framework import (
agent_middleware,
chat_middleware,
function_middleware
)
@agent_middleware
async def simple_agent_middleware(context, call_next): # type: ignore
"""No type hints needed with decorator."""
print("Agent middleware executed")
await call_next()
@chat_middleware
async def simple_chat_middleware(context, call_next): # type: ignore
"""Chat middleware marked by decorator."""
print("Chat middleware executed")
await call_next()
@function_middleware
async def simple_function_middleware(context, call_next): # type: ignore
"""Function middleware marked by decorator."""
print(f"Function {context.function.name} called") # type: ignore
await call_next()
Using Middleware
Agent-Level Middleware
Apply middleware to all agent runs:from agent_framework.openai import OpenAIResponsesClient
agent = OpenAIResponsesClient().as_agent(
name="Assistant",
instructions="You are helpful.",
middleware=[
LoggingAgentMiddleware(),
SecurityAgentMiddleware(),
TimingFunctionMiddleware()
]
)
response = await agent.run("Hello") # Middleware applies
Run-Level Middleware
Apply middleware to a specific run:agent = client.as_agent(
name="Assistant",
instructions="You are helpful."
)
# No middleware
response1 = await agent.run("Query 1")
# With run-level middleware
response2 = await agent.run(
"Query 2",
middleware=[SpecialMiddleware()]
)
Middleware Order
Middleware executes in the order specified:agent = client.as_agent(
name="Assistant",
middleware=[
middleware1, # Executes first (outermost)
middleware2,
middleware3 # Executes last (innermost)
]
)
# Execution order:
# middleware1.process -> before
# middleware2.process -> before
# middleware3.process -> before
# [AGENT RUNS]
# middleware3.process -> after
# middleware2.process -> after
# middleware1.process -> after
Middleware Context
AgentContext
class AgentContext:
agent: SupportsAgentRun # The agent being invoked
messages: list[Message] # Input messages
session: AgentSession | None # Session (if provided)
tools: list[ToolTypes] | None # Available tools
options: ChatOptions | None # Model options
stream: bool # Whether streaming is enabled
result: AgentResponse | None # Set by middleware or agent
ChatContext
class ChatContext:
client: SupportsChatGetResponse # The chat client
messages: list[Message] # Input messages
tools: list[ToolTypes] | None # Available tools
options: ChatOptions | None # Model options
stream: bool # Whether streaming is enabled
result: ChatResponse | None # Set by middleware or client
FunctionInvocationContext
class FunctionInvocationContext:
function: FunctionTool # The function being called
arguments: dict[str, Any] # Function arguments
session: AgentSession | None # Session (if available)
result: Any # Set by middleware or function
error: Exception | None # Set if function raises error
Advanced Patterns
Modifying Inputs
from agent_framework import AgentMiddleware, AgentContext, Message
class InputTransformMiddleware(AgentMiddleware):
async def process(
self,
context: AgentContext,
call_next
) -> None:
# Add system message to every request
system_msg = Message(
"system",
["Always respond in a professional tone."]
)
context.messages = [system_msg] + context.messages
await call_next()
Modifying Outputs
class OutputFilterMiddleware(AgentMiddleware):
async def process(
self,
context: AgentContext,
call_next
) -> None:
await call_next()
# Filter sensitive data from response
if context.result:
filtered_text = self.filter_sensitive_data(context.result.text)
context.result.messages[0].content[0].text = filtered_text
def filter_sensitive_data(self, text: str) -> str:
# Remove email addresses
import re
return re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]', text)
Short-Circuiting Execution
Return a result without calling the agent:class CacheMiddleware(AgentMiddleware):
def __init__(self):
self.cache = {}
async def process(
self,
context: AgentContext,
call_next
) -> None:
# Create cache key
cache_key = hash(tuple(m.text for m in context.messages))
if cache_key in self.cache:
# Return cached result, don't call agent
context.result = self.cache[cache_key]
return # Skip call_next()
# Not cached, call agent
await call_next()
# Cache the result
if context.result:
self.cache[cache_key] = context.result
Middleware Termination
Explicitly terminate execution:from agent_framework import MiddlewareTermination, AgentResponse, Message
class TerminationMiddleware(AgentMiddleware):
async def process(
self,
context: AgentContext,
call_next
) -> None:
if should_terminate(context):
# Raise termination exception with result
raise MiddlewareTermination(
"Request terminated by policy",
result=AgentResponse(
messages=[Message("assistant", ["Request denied."])]
)
)
await call_next()
Exception Handling
class ErrorHandlingMiddleware(AgentMiddleware):
async def process(
self,
context: AgentContext,
call_next
) -> None:
try:
await call_next()
except Exception as e:
# Log the error
logger.error(f"Agent failed: {e}", exc_info=True)
# Return user-friendly error message
context.result = AgentResponse(
messages=[Message(
"assistant",
["I encountered an error. Please try again."]
)]
)
Shared State Between Middleware
class SharedStateMiddleware(AgentMiddleware):
def __init__(self, shared_state: dict):
self.shared_state = shared_state
async def process(
self,
context: AgentContext,
call_next
) -> None:
# Store data for other middleware
self.shared_state["request_time"] = time.time()
await call_next()
# Create shared state
shared = {}
agent = client.as_agent(
name="Assistant",
middleware=[
SharedStateMiddleware(shared),
AnotherMiddleware(shared) # Can access same state
]
)
Session Behavior Middleware
from agent_framework import AgentSession
class SessionBehaviorMiddleware(AgentMiddleware):
async def process(
self,
context: AgentContext,
call_next
) -> None:
if context.session:
# Track interaction count
count = context.session.state.get("interaction_count", 0)
context.session.state["interaction_count"] = count + 1
# Modify behavior based on session history
if count > 10:
# Add context for long conversations
context.messages.insert(0, Message(
"system",
["This is a long conversation. Be concise."]
))
await call_next()
Runtime Context Delegation
class DelegationMiddleware(AgentMiddleware):
def __init__(self, specialist_agent):
self.specialist = specialist_agent
async def process(
self,
context: AgentContext,
call_next
) -> None:
# Check if query should be delegated
last_message = context.messages[-1]
if last_message.text and "technical" in last_message.text.lower():
# Delegate to specialist agent
result = await self.specialist.run(
last_message.text,
session=context.session
)
context.result = result
return # Don't call original agent
await call_next()
Combining Middleware Types
from agent_framework import tool
@tool(approval_mode="always_require")
def get_data() -> str:
return "Data"
agent = client.as_agent(
name="Assistant",
instructions="You are helpful.",
tools=[get_data],
middleware=[
LoggingAgentMiddleware(), # Logs agent runs
CachingChatMiddleware(), # Caches chat responses
TimingFunctionMiddleware() # Times function calls
]
)
# All three middleware types work together:
# - Agent middleware runs for agent.run()
# - Chat middleware runs for underlying chat client
# - Function middleware runs when tools are called
response = await agent.run("Get me some data")
Best Practices
Keep middleware focused
Keep middleware focused
Each middleware should have a single, well-defined responsibility. This makes them easier to test and compose.
Don't modify context objects directly unless necessary
Don't modify context objects directly unless necessary
Prefer using
call_next() and only modify context.result when you need to override behavior.Handle streaming responses carefully
Handle streaming responses carefully
When
context.stream is True, be aware that the result will be an async iterable:async def process(self, context: AgentContext, call_next) -> None:
await call_next()
if context.stream:
# context.result is AsyncIterable[AgentResponseUpdate]
pass
else:
# context.result is AgentResponse
pass
Use appropriate middleware type
Use appropriate middleware type
- Use Agent Middleware for high-level agent behavior (auth, routing)
- Use Chat Middleware for model interaction (caching, retries)
- Use Function Middleware for tool execution (logging, validation)
Be mindful of middleware order
Be mindful of middleware order
The order matters. Place authentication/authorization middleware before others:
middleware=[
AuthMiddleware(), # First: check auth
CachingMiddleware(), # Second: check cache
LoggingMiddleware() # Last: log everything
]
API Reference
AgentMiddleware Base Class
class AgentMiddleware:
async def process(
self,
context: AgentContext,
call_next: Callable[[], Awaitable[None]]
) -> None:
"""Process agent invocation.
Args:
context: Agent context with input/output data
call_next: Call to continue middleware pipeline
"""
await call_next()
ChatMiddleware Base Class
class ChatMiddleware:
async def process(
self,
context: ChatContext,
call_next: Callable[[], Awaitable[None]]
) -> None:
"""Process chat client invocation.
Args:
context: Chat context with input/output data
call_next: Call to continue middleware pipeline
"""
await call_next()
FunctionMiddleware Base Class
class FunctionMiddleware:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[], Awaitable[None]]
) -> None:
"""Process function/tool invocation.
Args:
context: Function context with input/output data
call_next: Call to continue middleware pipeline
"""
await call_next()
Examples
Authentication Middleware
from agent_framework import AgentMiddleware, AgentContext, Message
class AuthMiddleware(AgentMiddleware):
def __init__(self, api_key: str):
self.api_key = api_key
async def process(
self,
context: AgentContext,
call_next
) -> None:
# Extract user token from session
if not context.session:
context.result = AgentResponse(
messages=[Message("assistant", ["Authentication required."])]
)
return
user_token = context.session.state.get("auth_token")
if not self.verify_token(user_token):
context.result = AgentResponse(
messages=[Message("assistant", ["Invalid authentication."])]
)
return
# Authenticated, proceed
await call_next()
def verify_token(self, token: str | None) -> bool:
# Token verification logic
return token == self.api_key
# Use the middleware
agent = client.as_agent(
name="SecureAgent",
middleware=[AuthMiddleware(api_key="secret123")]
)
session = AgentSession()
session.state["auth_token"] = "secret123"
response = await agent.run("Hello", session=session)
Rate Limiting Middleware
import time
from collections import defaultdict
class RateLimitMiddleware(AgentMiddleware):
def __init__(self, max_requests: int, window_seconds: int):
self.max_requests = max_requests
self.window_seconds = window_seconds
self.requests: dict[str, list[float]] = defaultdict(list)
async def process(
self,
context: AgentContext,
call_next
) -> None:
# Get user ID from session
user_id = context.session.state.get("user_id", "anonymous") if context.session else "anonymous"
# Clean old requests
now = time.time()
cutoff = now - self.window_seconds
self.requests[user_id] = [
t for t in self.requests[user_id] if t > cutoff
]
# Check rate limit
if len(self.requests[user_id]) >= self.max_requests:
context.result = AgentResponse(
messages=[Message(
"assistant",
[f"Rate limit exceeded. Try again in {self.window_seconds} seconds."]
)]
)
return
# Record request
self.requests[user_id].append(now)
# Proceed
await call_next()
# Use the middleware
agent = client.as_agent(
name="LimitedAgent",
middleware=[RateLimitMiddleware(max_requests=10, window_seconds=60)]
)