Skip to main content

FastAPI Integration Guide

Koreshield integrates seamlessly with FastAPI applications, allowing you to protect your endpoints with minimal code changes.

Overview

There are two primary ways to integrate Koreshield with FastAPI:
  1. Dependency injection: Allows fine-grained control inside route handlers (Recommended)
  2. Middleware approach: Protects all or specific routes automatically
Using dependency injection gives you access to the Koreshield client within your route, allowing you to handle blocked requests gracefully or return custom error responses.
1

Install Dependencies

pip install Koreshield-sdk fastapi uvicorn
2

Create Koreshield Dependency

from fastapi import FastAPI, Depends, HTTPException
from Koreshield.client import KoreshieldClient
import os

app = FastAPI()

def get_Koreshield():
    # Configure with your Koreshield Proxy URL
    return KoreshieldClient(base_url=os.getenv("Koreshield_URL", "http://localhost:8000"))
3

Protect Your Routes

@app.post("/chat")
async def chat(message: str, ks: KoreshieldClient = Depends(get_Koreshield)):
    # 1. Guard the input
    result = await ks.guard(message)
    
    if not result.is_safe:
        # 2. Handle unsafe content
        raise HTTPException(
            status_code=403,
            detail={
                "error": "Content blocked",
                "reason": result.reason,
                "details": result.details
            }
        )
        
    # 3. Process safe content
    return {"response": "Processed successfully"}

Complete Example

from fastapi import FastAPI, Depends, HTTPException
from pydantic import BaseModel
from Koreshield.client import KoreshieldClient
import os

app = FastAPI(title="Secure AI API")

# Request/Response models
class ChatRequest(BaseModel):
    message: str
    context: str | None = None

class ChatResponse(BaseModel):
    response: str
    safe: bool

# Koreshield dependency
def get_koreshield():
    return KoreshieldClient(
        base_url=os.getenv("KORESHIELD_URL", "http://localhost:8000"),
        api_key=os.getenv("KORESHIELD_API_KEY")
    )

@app.post("/chat", response_model=ChatResponse)
async def chat(
    request: ChatRequest,
    ks: KoreshieldClient = Depends(get_koreshield)
):
    """Chat endpoint with Koreshield protection"""
    
    # Scan the message
    result = await ks.guard(request.message)
    
    if not result.is_safe:
        raise HTTPException(
            status_code=403,
            detail={
                "error": "Content blocked",
                "reason": result.reason,
                "confidence": result.confidence
            }
        )
    
    # Your LLM call here
    # response = await call_llm(request.message)
    
    return ChatResponse(
        response="Message processed successfully",
        safe=True
    )

@app.get("/health")
async def health():
    return {"status": "healthy"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8080)

Option 2: Middleware

A standard ASGI middleware package is coming soon for drop-in protection.

Custom Middleware Implementation

from fastapi import FastAPI, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from Koreshield.client import KoreshieldClient
import json

class KoreshieldMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, base_url: str, protected_paths: list[str]):
        super().__init__(app)
        self.client = KoreshieldClient(base_url=base_url)
        self.protected_paths = protected_paths
    
    async def dispatch(self, request: Request, call_next):
        # Only protect specific paths
        if request.url.path in self.protected_paths and request.method == "POST":
            try:
                body = await request.json()
                message = body.get("message") or body.get("prompt")
                
                if message:
                    result = await self.client.guard(message)
                    
                    if not result.is_safe:
                        return Response(
                            content=json.dumps({
                                "error": "Blocked by Koreshield",
                                "reason": result.reason
                            }),
                            status_code=403,
                            media_type="application/json"
                        )
            except Exception as e:
                # Log error but continue
                print(f"Middleware error: {e}")
        
        return await call_next(request)

# Add middleware to app
app = FastAPI()
app.add_middleware(
    KoreshieldMiddleware,
    base_url="http://localhost:8000",
    protected_paths=["/chat", "/completion"]
)

Advanced Patterns

Multi-Field Validation

from fastapi import FastAPI, Depends, HTTPException
from pydantic import BaseModel
from Koreshield.client import KoreshieldClient

class MultiFieldRequest(BaseModel):
    prompt: str
    context: str
    system_prompt: str

@app.post("/generate")
async def generate(
    request: MultiFieldRequest,
    ks: KoreshieldClient = Depends(get_koreshield)
):
    # Validate all fields
    fields_to_check = [
        ("prompt", request.prompt),
        ("context", request.context),
        ("system_prompt", request.system_prompt)
    ]
    
    for field_name, field_value in fields_to_check:
        result = await ks.guard(field_value)
        
        if not result.is_safe:
            raise HTTPException(
                status_code=403,
                detail={
                    "error": f"Threat detected in {field_name}",
                    "reason": result.reason
                }
            )
    
    return {"status": "All fields validated"}

Custom Error Responses

from fastapi import FastAPI, HTTPException, status
from fastapi.responses import JSONResponse

class SecurityViolationError(Exception):
    def __init__(self, reason: str, confidence: float):
        self.reason = reason
        self.confidence = confidence

@app.exception_handler(SecurityViolationError)
async def security_violation_handler(request, exc: SecurityViolationError):
    return JSONResponse(
        status_code=status.HTTP_403_FORBIDDEN,
        content={
            "error": "Security violation detected",
            "reason": exc.reason,
            "confidence": exc.confidence,
            "support": "[email protected]"
        }
    )

@app.post("/chat")
async def chat(message: str, ks: KoreshieldClient = Depends(get_koreshield)):
    result = await ks.guard(message)
    
    if not result.is_safe:
        raise SecurityViolationError(
            reason=result.reason,
            confidence=result.confidence
        )
    
    return {"response": "Processed"}

Background Scanning

from fastapi import BackgroundTasks

async def log_security_event(message: str, result: dict):
    """Log security events asynchronously"""
    # Log to database, monitoring service, etc.
    print(f"Security event: {result}")

@app.post("/chat")
async def chat(
    message: str,
    background_tasks: BackgroundTasks,
    ks: KoreshieldClient = Depends(get_koreshield)
):
    result = await ks.guard(message)
    
    # Log in background
    background_tasks.add_task(log_security_event, message, result.dict())
    
    if not result.is_safe:
        raise HTTPException(status_code=403, detail="Blocked")
    
    return {"response": "Processed"}

Configuration

Ensure your Koreshield Proxy is running and accessible. Configure via environment variables:
Environment VariableDescriptionDefault
Koreshield_URLURL of the Koreshield Proxyhttp://localhost:8000
Koreshield_API_KEYAPI Key (if auth enabled)None

Environment File

# .env
KORESHIELD_URL=http://localhost:8000
KORESHIELD_API_KEY=your-api-key-here

Load Configuration

from pydantic_settings import BaseSettings

class Settings(BaseSettings):
    koreshield_url: str = "http://localhost:8000"
    koreshield_api_key: str | None = None
    
    class Config:
        env_file = ".env"

settings = Settings()

def get_koreshield():
    return KoreshieldClient(
        base_url=settings.koreshield_url,
        api_key=settings.koreshield_api_key
    )

Testing

from fastapi.testclient import TestClient
import pytest
from unittest.mock import AsyncMock, patch

client = TestClient(app)

@patch('Koreshield.client.KoreshieldClient.guard')
def test_safe_message(mock_guard):
    mock_guard.return_value = AsyncMock(
        is_safe=True,
        reason=None
    )
    
    response = client.post("/chat", json={"message": "Hello"})
    assert response.status_code == 200

@patch('Koreshield.client.KoreshieldClient.guard')
def test_blocked_message(mock_guard):
    mock_guard.return_value = AsyncMock(
        is_safe=False,
        reason="Prompt injection detected"
    )
    
    response = client.post("/chat", json={"message": "Malicious input"})
    assert response.status_code == 403
    assert "blocked" in response.json()["error"].lower()

Python SDK

Complete Python SDK documentation

FastAPI Docs

Official FastAPI documentation

API Reference

Koreshield API reference

Build docs developers (and LLMs) love