Skip to main content

Flask Integration Guide

Protect your Flask applications from LLM attacks using Koreshield decorators or wrappers.

Setup

Install the package:
pip install Koreshield flask

Basic Usage

The most idiomatic way to use Koreshield in Flask is via a decorator.
from flask import Flask, request, abort
from Koreshield.client import KoreshieldClient
from functools import wraps

app = Flask(__name__)
client = KoreshieldClient()

def guard_route(f):
    @wraps(f)
    def decorated_function(*args, **kwargs):
        # Extract prompt from JSON body, form data, or query params
        data = request.get_json(silent=True) or {}
        prompt = data.get("prompt") or data.get("message")
        
        if prompt:
            # Synchronous check (requires async helper if client is async-only)
            import asyncio
            result = asyncio.run(client.guard(prompt))
            
            if not result.is_safe:
                return {
                    "error": "Blocked",
                    "reason": result.reason
                }, 403
                
        return f(*args, **kwargs)
    return decorated_function

@app.route("/generate", methods=["POST"])
@guard_route
def generate():
    return {"status": "ok"}

Advanced Patterns

Multi-Field Protection

from flask import Flask, request, jsonify
from Koreshield.client import KoreshieldClient
from functools import wraps
import asyncio

app = Flask(__name__)
client = KoreshieldClient()

def guard_fields(*field_names):
    """Decorator to protect multiple fields"""
    def decorator(f):
        @wraps(f)
        def decorated_function(*args, **kwargs):
            data = request.get_json(silent=True) or {}
            
            for field in field_names:
                value = data.get(field)
                if value:
                    result = asyncio.run(client.guard(value))
                    
                    if not result.is_safe:
                        return jsonify({
                            "error": f"Blocked: {field}",
                            "reason": result.reason,
                            "confidence": result.confidence
                        }), 403
            
            return f(*args, **kwargs)
        return decorated_function
    return decorator

@app.route("/chat", methods=["POST"])
@guard_fields("message", "context", "system_prompt")
def chat():
    data = request.get_json()
    # All fields validated, safe to process
    return jsonify({"response": "Processed successfully"})

Custom Error Responses

from flask import Flask, request, jsonify
from Koreshield.client import KoreshieldClient
import asyncio

app = Flask(__name__)
client = KoreshieldClient()

class SecurityViolation(Exception):
    def __init__(self, reason, confidence):
        self.reason = reason
        self.confidence = confidence

@app.errorhandler(SecurityViolation)
def handle_security_violation(error):
    return jsonify({
        "error": "Security violation detected",
        "reason": error.reason,
        "confidence": error.confidence,
        "support": "[email protected]"
    }), 403

@app.route("/generate", methods=["POST"])
def generate():
    data = request.get_json()
    prompt = data.get("prompt")
    
    if prompt:
        result = asyncio.run(client.guard(prompt))
        
        if not result.is_safe:
            raise SecurityViolation(
                reason=result.reason,
                confidence=result.confidence
            )
    
    # Process safe prompt
    return jsonify({"response": "Generated"})

Blueprint Integration

from flask import Blueprint, request, jsonify
from Koreshield.client import KoreshieldClient
import asyncio

ai_bp = Blueprint('ai', __name__, url_prefix='/api/ai')
client = KoreshieldClient()

def check_safety(content):
    result = asyncio.run(client.guard(content))
    if not result.is_safe:
        return jsonify({
            "error": "Blocked",
            "reason": result.reason
        }), 403
    return None

@ai_bp.route('/chat', methods=['POST'])
def chat():
    data = request.get_json()
    message = data.get('message')
    
    # Check safety
    error = check_safety(message)
    if error:
        return error
    
    # Process message
    return jsonify({"response": "Processed"})

@ai_bp.route('/complete', methods=['POST'])
def complete():
    data = request.get_json()
    prompt = data.get('prompt')
    
    # Check safety
    error = check_safety(prompt)
    if error:
        return error
    
    # Generate completion
    return jsonify({"completion": "Generated"})

# Register blueprint
app = Flask(__name__)
app.register_blueprint(ai_bp)

Application Factory Pattern

from flask import Flask
from Koreshield.client import KoreshieldClient
import os

def create_app(config=None):
    app = Flask(__name__)
    
    # Configuration
    app.config['KORESHIELD_URL'] = os.getenv('KORESHIELD_URL', 'http://localhost:8000')
    app.config['KORESHIELD_API_KEY'] = os.getenv('KORESHIELD_API_KEY')
    
    if config:
        app.config.update(config)
    
    # Initialize Koreshield client
    client = KoreshieldClient(
        base_url=app.config['KORESHIELD_URL'],
        api_key=app.config.get('KORESHIELD_API_KEY')
    )
    
    # Store client in app context
    app.koreshield = client
    
    # Register blueprints
    from .routes import ai_bp
    app.register_blueprint(ai_bp)
    
    return app

# Usage
app = create_app()

if __name__ == '__main__':
    app.run(debug=True)

Middleware Approach

from flask import Flask, request, jsonify
from werkzeug.wrappers import Response
from Koreshield.client import KoreshieldClient
import asyncio

class KoreshieldMiddleware:
    def __init__(self, app, protected_paths=None):
        self.app = app
        self.client = KoreshieldClient()
        self.protected_paths = protected_paths or []
    
    def __call__(self, environ, start_response):
        request_path = environ.get('PATH_INFO', '')
        request_method = environ.get('REQUEST_METHOD', '')
        
        if request_path in self.protected_paths and request_method == 'POST':
            # Get request body
            try:
                from flask import Request
                req = Request(environ)
                data = req.get_json(silent=True) or {}
                
                prompt = data.get('prompt') or data.get('message')
                
                if prompt:
                    result = asyncio.run(self.client.guard(prompt))
                    
                    if not result.is_safe:
                        response = Response(
                            response=jsonify({
                                "error": "Blocked",
                                "reason": result.reason
                            }).data,
                            status=403,
                            mimetype='application/json'
                        )
                        return response(environ, start_response)
            except Exception as e:
                # Log error but continue
                print(f"Middleware error: {e}")
        
        return self.app(environ, start_response)

# Apply middleware
app = Flask(__name__)
app.wsgi_app = KoreshieldMiddleware(
    app.wsgi_app,
    protected_paths=['/api/chat', '/api/generate']
)

Configuration

import os

class Config:
    KORESHIELD_URL = os.getenv('KORESHIELD_URL', 'http://localhost:8000')
    KORESHIELD_API_KEY = os.getenv('KORESHIELD_API_KEY')
    KORESHIELD_SENSITIVITY = os.getenv('KORESHIELD_SENSITIVITY', 'medium')

app = Flask(__name__)
app.config.from_object(Config)

Error Handling

When a request is blocked, you can return a 403 Forbidden status or a custom error response JSON. The result.details object contains specific information about why it was blocked (e.g., “Prompt Injection Detected”, “PII Found”).
from flask import jsonify

@app.errorhandler(403)
def forbidden(e):
    return jsonify({
        "error": "Forbidden",
        "message": "Your request was blocked for security reasons",
        "support": "[email protected]"
    }), 403

Testing

import pytest
from unittest.mock import patch, AsyncMock
import json

@pytest.fixture
def client():
    app.config['TESTING'] = True
    with app.test_client() as client:
        yield client

@patch('Koreshield.client.KoreshieldClient.guard')
def test_safe_message(mock_guard, client):
    mock_guard.return_value = AsyncMock(
        is_safe=True,
        reason=None
    )
    
    response = client.post(
        '/generate',
        data=json.dumps({"prompt": "Hello"}),
        content_type='application/json'
    )
    
    assert response.status_code == 200

@patch('Koreshield.client.KoreshieldClient.guard')
def test_blocked_message(mock_guard, client):
    mock_guard.return_value = AsyncMock(
        is_safe=False,
        reason="Prompt injection detected"
    )
    
    response = client.post(
        '/generate',
        data=json.dumps({"prompt": "Malicious"}),
        content_type='application/json'
    )
    
    assert response.status_code == 403
    data = response.get_json()
    assert "Blocked" in data["error"]

Environment Variables

# .env
KORESHIELD_URL=http://localhost:8000
KORESHIELD_API_KEY=your-api-key-here
KORESHIELD_SENSITIVITY=medium

Production Considerations

Consider caching scan results for identical inputs to reduce latency.
Implement fallback behavior when Koreshield is unavailable.
Log all blocked requests for security auditing.
Combine with Flask-Limiter for comprehensive protection.

Python SDK

Complete Python SDK documentation

Flask Docs

Official Flask documentation

API Reference

Koreshield API reference

Build docs developers (and LLMs) love