Skip to main content

Middleware API Reference

Middleware intercepts agent invocations, chat requests, and function calls to add custom processing.

AgentMiddleware

Abstract base class for agent middleware that intercepts agent invocations.
from agent_framework import AgentMiddleware, AgentContext

Abstract Method

process()

Process an agent invocation.
class LoggingMiddleware(AgentMiddleware):
    async def process(self, context: AgentContext, call_next):
        print(f"Before: {context.agent.name}")
        await call_next()
        print(f"After: {context.result}")
context
AgentContext
required
Agent invocation context containing agent, messages, and metadata.Use context.stream to determine if this is a streaming call. Set context.result to override execution, or observe the actual execution result after calling call_next().
call_next
Callable[[], Awaitable[None]]
required
Function to call the next middleware or final agent execution. Does not return anything - all data flows through the context.

Example: Retry Middleware

from agent_framework import AgentMiddleware, AgentContext, Agent

class RetryMiddleware(AgentMiddleware):
    def __init__(self, max_retries: int = 3):
        self.max_retries = max_retries
    
    async def process(self, context: AgentContext, call_next):
        for attempt in range(self.max_retries):
            await call_next()
            if context.result and not getattr(context.result, 'is_error', False):
                break
            print(f"Retry {attempt + 1}/{self.max_retries}")

# Use with an agent
agent = Agent(client=client, name="assistant", middleware=[RetryMiddleware()])

AgentContext

Context object for agent middleware invocations.
from agent_framework import AgentContext

Attributes

agent
SupportsAgentRun
The agent being invoked.
messages
list[Message]
The messages being sent to the agent.
session
AgentSession | None
The agent session for this invocation, if any.
options
Mapping[str, Any] | None
The options for the agent invocation as a dict.
stream
bool
Whether this is a streaming invocation.
metadata
dict[str, Any]
Metadata dictionary for sharing data between agent middleware.
result
AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None
Agent execution result. Can be observed after calling call_next() or set to override the execution result.
  • For non-streaming: should be AgentResponse
  • For streaming: should be ResponseStream[AgentResponseUpdate, AgentResponse]
kwargs
dict[str, Any]
Additional keyword arguments passed to the agent run method.

Example

from agent_framework import AgentMiddleware, AgentContext
import time

class TimingMiddleware(AgentMiddleware):
    async def process(self, context: AgentContext, call_next):
        # Access context properties
        print(f"Agent: {context.agent.name}")
        print(f"Messages: {len(context.messages)}")
        print(f"Session: {context.session}")
        print(f"Streaming: {context.stream}")
        
        # Store metadata
        context.metadata["start_time"] = time.time()
        
        # Continue execution
        await call_next()
        
        # Access result after execution
        duration = time.time() - context.metadata["start_time"]
        print(f"Duration: {duration}s")
        print(f"Result: {context.result}")

ChatMiddleware

Abstract base class for chat middleware that intercepts chat client requests.
from agent_framework import ChatMiddleware, ChatContext

Abstract Method

process()

Process a chat client request.
class SystemPromptMiddleware(ChatMiddleware):
    def __init__(self, system_prompt: str):
        self.system_prompt = system_prompt
    
    async def process(self, context: ChatContext, call_next):
        from agent_framework import Message
        context.messages.insert(0, Message(role="system", text=self.system_prompt))
        await call_next()
context
ChatContext
required
Chat invocation context containing chat client, messages, options, and metadata.Use context.stream to determine if this is a streaming call. Set context.result to override execution.
call_next
Callable[[], Awaitable[None]]
required
Function to call the next middleware or final chat execution.

ChatContext

Context object for chat middleware invocations.
from agent_framework import ChatContext

Attributes

client
SupportsChatGetResponse
The chat client being invoked.
messages
Sequence[Message]
The messages being sent to the chat client.
options
Mapping[str, Any] | None
The options for the chat request as a dict.
stream
bool
Whether this is a streaming invocation.
metadata
dict[str, Any]
Metadata dictionary for sharing data between chat middleware.
result
ChatResponse | ResponseStream[ChatResponseUpdate, ChatResponse] | None
Chat execution result. Can be observed after calling call_next() or set to override the execution result.
  • For non-streaming: should be ChatResponse
  • For streaming: should be ResponseStream[ChatResponseUpdate, ChatResponse]
kwargs
dict[str, Any]
Additional keyword arguments passed to the chat client.

Example

from agent_framework import ChatMiddleware, ChatContext, Agent

class TokenCounterMiddleware(ChatMiddleware):
    async def process(self, context: ChatContext, call_next):
        print(f"Chat client: {context.client.__class__.__name__}")
        print(f"Messages: {len(context.messages)}")
        print(f"Model: {context.options.get('model_id')}")
        
        # Store metadata
        context.metadata["input_tokens"] = self.count_tokens(context.messages)
        
        # Continue execution
        await call_next()
        
        # Access result and count output tokens
        if context.result:
            context.metadata["output_tokens"] = self.count_tokens(context.result)
    
    def count_tokens(self, data):
        # Token counting implementation
        return 0

# Use with an agent
agent = Agent(
    client=client,
    name="assistant",
    middleware=[TokenCounterMiddleware()]
)

FunctionMiddleware

Abstract base class for function middleware that intercepts function invocations.
from agent_framework import FunctionMiddleware, FunctionInvocationContext

Abstract Method

process()

Process a function invocation.
class CachingMiddleware(FunctionMiddleware):
    def __init__(self):
        self.cache = {}
    
    async def process(self, context: FunctionInvocationContext, call_next):
        cache_key = f"{context.function.name}:{context.arguments}"
        
        # Check cache
        if cache_key in self.cache:
            context.result = self.cache[cache_key]
            raise MiddlewareTermination()
        
        # Execute function
        await call_next()
        
        # Cache result
        if context.result:
            self.cache[cache_key] = context.result
context
FunctionInvocationContext
required
Function invocation context containing function, arguments, and metadata.Set context.result to override execution, or observe the actual execution result after calling call_next().
call_next
Callable[[], Awaitable[None]]
required
Function to call the next middleware or final function execution.

FunctionInvocationContext

Context object for function middleware invocations.
from agent_framework import FunctionInvocationContext

Attributes

function
FunctionTool
The function being invoked.
arguments
BaseModel | Mapping[str, Any]
The validated arguments for the function.
metadata
dict[str, Any]
Metadata dictionary for sharing data between function middleware.
result
Any
Function execution result. Can be observed after calling call_next() or set to override the execution result.
kwargs
dict[str, Any]
Additional keyword arguments passed to the chat method that invoked this function.

Example

from agent_framework import FunctionMiddleware, FunctionInvocationContext, Agent

class ValidationMiddleware(FunctionMiddleware):
    async def process(self, context: FunctionInvocationContext, call_next):
        print(f"Function: {context.function.name}")
        print(f"Arguments: {context.arguments}")
        
        # Validate arguments
        if not self.validate(context.arguments):
            raise MiddlewareTermination("Validation failed")
        
        # Continue execution
        await call_next()
    
    def validate(self, arguments):
        # Validation logic
        return True

# Use with an agent
agent = Agent(
    client=client,
    name="assistant",
    middleware=[ValidationMiddleware()]
)

Middleware Decorators

Function-based middleware using decorators.

@agent_middleware

Mark a function as agent middleware.
from agent_framework import agent_middleware, AgentContext, Agent

@agent_middleware
async def logging_middleware(context: AgentContext, call_next):
    print(f"Before: {context.agent.name}")
    await call_next()
    print(f"After: {context.result}")

# Use with an agent
agent = Agent(client=client, name="assistant", middleware=[logging_middleware])

@function_middleware

Mark a function as function middleware.
from agent_framework import function_middleware, FunctionInvocationContext, Agent

@function_middleware
async def logging_middleware(context: FunctionInvocationContext, call_next):
    print(f"Calling: {context.function.name}")
    await call_next()
    print(f"Result: {context.result}")

# Use with an agent
agent = Agent(client=client, name="assistant", middleware=[logging_middleware])

@chat_middleware

Mark a function as chat middleware.
from agent_framework import chat_middleware, ChatContext, Agent

@chat_middleware
async def logging_middleware(context: ChatContext, call_next):
    print(f"Messages: {len(context.messages)}")
    await call_next()
    print(f"Response: {context.result}")

# Use with an agent
agent = Agent(client=client, name="assistant", middleware=[logging_middleware])

MiddlewareTermination

Control-flow exception to terminate middleware execution early.
from agent_framework import MiddlewareTermination

Constructor

raise MiddlewareTermination(
    "Custom termination message",
    result=custom_result
)
message
str
default:"Middleware terminated execution."
Error message.
result
Any
default:"None"
Optional result to return when terminating.

Example

from agent_framework import FunctionMiddleware, MiddlewareTermination

class AuthMiddleware(FunctionMiddleware):
    async def process(self, context, call_next):
        if not self.is_authorized(context):
            context.result = "Unauthorized"
            raise MiddlewareTermination(
                "Authorization failed",
                result="Unauthorized"
            )
        await call_next()

categorize_middleware()

Categorize middleware from multiple sources into agent, function, and chat types.
from agent_framework import categorize_middleware

categorized = categorize_middleware(
    [agent_mw, func_mw, chat_mw],
    more_middleware
)
# Returns: {"agent": [...], "function": [...], "chat": [...]}
*middleware_sources
MiddlewareTypes | Sequence[MiddlewareTypes] | None
required
Variable number of middleware sources to categorize.
return
MiddlewareDict
Dictionary with keys “agent”, “function”, “chat” containing lists of categorized middleware.

Build docs developers (and LLMs) love