Skip to main content

Overview

Guardrail providers detect security threats, policy violations, and content issues in requests and responses. The gateway supports input guardrails (validate requests) and output guardrails (validate responses).

GuardrailProvider Interface

All guardrail providers must implement the GuardrailProvider abstract base class from src/secure_mcp_gateway/plugins/guardrails/base.py.

Required Methods

from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List

class GuardrailProvider(ABC):
    @abstractmethod
    def get_name(self) -> str:
        """Get the unique name/identifier for this provider."""
        pass

    @abstractmethod
    def get_version(self) -> str:
        """Get the version of this provider."""
        pass

    @abstractmethod
    def create_input_guardrail(
        self, config: Dict[str, Any]
    ) -> Optional[InputGuardrail]:
        """Create an input guardrail instance."""
        pass

    @abstractmethod
    def create_output_guardrail(
        self, config: Dict[str, Any]
    ) -> Optional[OutputGuardrail]:
        """Create an output guardrail instance."""
        pass

Optional Methods

def create_pii_handler(self, config: Dict[str, Any]) -> Optional[PIIHandler]:
    """Create a PII handler instance (optional)."""
    return None

def validate_config(self, config: Dict[str, Any]) -> bool:
    """Validate provider-specific configuration."""
    return True

def get_required_config_keys(self) -> List[str]:
    """Get list of required configuration keys."""
    return []

def get_metadata(self) -> Dict[str, Any]:
    """Get provider metadata (capabilities, limits, etc.)."""
    return {
        "name": self.get_name(),
        "version": self.get_version(),
        "supports_input": True,
        "supports_output": True,
        "supports_pii": False,
    }

Protocol Interfaces

Guardrails use protocol interfaces for different capabilities:

InputGuardrail Protocol

from typing import Protocol, runtime_checkable

@runtime_checkable
class InputGuardrail(Protocol):
    async def validate(self, request: GuardrailRequest) -> GuardrailResponse:
        """Validate input content before it's sent to the MCP server."""
        ...

    def get_supported_detectors(self) -> List[ViolationType]:
        """Get list of violation types this guardrail can detect."""
        ...

OutputGuardrail Protocol

@runtime_checkable
class OutputGuardrail(Protocol):
    async def validate(
        self, response_content: str, original_request: GuardrailRequest
    ) -> GuardrailResponse:
        """Validate output content after it's received from the MCP server."""
        ...

    def get_supported_detectors(self) -> List[ViolationType]:
        """Get list of violation types this guardrail can detect."""
        ...

PIIHandler Protocol

@runtime_checkable
class PIIHandler(Protocol):
    async def detect_pii(self, content: str) -> List[GuardrailViolation]:
        """Detect PII in content."""
        ...

    async def redact_pii(self, content: str) -> tuple[str, Dict[str, Any]]:
        """Redact PII from content and return mapping for restoration."""
        ...

    async def restore_pii(self, content: str, pii_mapping: Dict[str, Any]) -> str:
        """Restore PII using the mapping from redaction."""
        ...

Data Models

GuardrailRequest

@dataclass
class GuardrailRequest:
    content: str  # The content to evaluate
    tool_name: Optional[str] = None
    tool_args: Optional[Dict[str, Any]] = None
    server_name: Optional[str] = None
    context: Optional[Dict[str, Any]] = None

GuardrailResponse

@dataclass
class GuardrailResponse:
    is_safe: bool
    action: GuardrailAction
    violations: List[GuardrailViolation]
    modified_content: Optional[str] = None  # If content was modified
    metadata: Dict[str, Any] = None
    processing_time_ms: Optional[float] = None

GuardrailViolation

@dataclass
class GuardrailViolation:
    violation_type: ViolationType
    severity: float  # 0.0 (low) to 1.0 (high)
    message: str
    action: GuardrailAction
    metadata: Dict[str, Any]
    suggested_action: Optional[str] = None
    redacted_content: Optional[str] = None

ViolationType Enum

class ViolationType(Enum):
    # Input violations
    PII = "pii"
    INJECTION_ATTACK = "injection_attack"
    TOXIC_CONTENT = "toxicity"
    NSFW_CONTENT = "nsfw"
    KEYWORD_VIOLATION = "keyword_detector"
    POLICY_VIOLATION = "policy_violation"
    BIAS = "bias"
    SPONGE_ATTACK = "sponge_attack"
    SYSTEM_PROMPT_PROTECTION = "system_prompt_protection"
    COPYRIGHT_PROTECTION = "copyright_protection"

    # Output violations
    RELEVANCY_FAILURE = "relevancy"
    ADHERENCE_FAILURE = "adherence"
    HALLUCINATION = "hallucination"

    # Generic
    CUSTOM = "custom"

GuardrailAction Enum

class GuardrailAction(Enum):
    ALLOW = "allow"      # Continue processing
    BLOCK = "block"      # Stop processing and return error
    WARN = "warn"        # Log warning but continue
    MODIFY = "modify"    # Modify content and continue

Example: OpenAI Moderation Provider

From src/secure_mcp_gateway/plugins/guardrails/example_providers.py:96:
import httpx
from secure_mcp_gateway.plugins.guardrails.base import (
    GuardrailProvider,
    GuardrailRequest,
    GuardrailResponse,
    GuardrailViolation,
    InputGuardrail,
    OutputGuardrail,
    ViolationType,
    GuardrailAction,
)

class OpenAIInputGuardrail:
    """OpenAI Moderation API input guardrail implementation."""

    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.api_key = config.get("api_key", "")
        self.threshold = config.get("threshold", 0.7)
        self.block_categories = config.get(
            "block_categories", ["hate", "violence", "sexual"]
        )

    async def validate(self, request: GuardrailRequest) -> GuardrailResponse:
        """Validate using OpenAI Moderation API."""
        async with httpx.AsyncClient() as client:
            response = await client.post(
                "https://api.openai.com/v1/moderations",
                headers={"Authorization": f"Bearer {self.api_key}"},
                json={"input": request.content},
            )

            result = response.json()
            moderation_result = result["results"][0]

            violations = []
            is_safe = True

            # Check categories
            categories = moderation_result.get("categories", {})
            category_scores = moderation_result.get("category_scores", {})

            for category, flagged in categories.items():
                if flagged and category in self.block_categories:
                    score = category_scores.get(category, 0.0)
                    if score >= self.threshold:
                        is_safe = False
                        violations.append(
                            GuardrailViolation(
                                violation_type=self._map_category_to_violation(category),
                                severity=score,
                                message=f"Content flagged for {category}",
                                action=GuardrailAction.BLOCK,
                                metadata={"category": category, "score": score},
                            )
                        )

            return GuardrailResponse(
                is_safe=is_safe,
                action=GuardrailAction.ALLOW if is_safe else GuardrailAction.BLOCK,
                violations=violations,
                metadata={"provider": "openai-moderation"},
            )

    def get_supported_detectors(self) -> List[ViolationType]:
        return [
            ViolationType.TOXIC_CONTENT,
            ViolationType.NSFW_CONTENT,
            ViolationType.CUSTOM,
        ]

    def _map_category_to_violation(self, category: str) -> ViolationType:
        mapping = {
            "hate": ViolationType.TOXIC_CONTENT,
            "violence": ViolationType.TOXIC_CONTENT,
            "sexual": ViolationType.NSFW_CONTENT,
            "self-harm": ViolationType.TOXIC_CONTENT,
        }
        return mapping.get(category, ViolationType.CUSTOM)


class OpenAIGuardrailProvider(GuardrailProvider):
    """OpenAI Moderation API provider."""

    def __init__(self, api_key: str):
        self.api_key = api_key

    def get_name(self) -> str:
        return "openai-moderation"

    def get_version(self) -> str:
        return "1.0.0"

    def create_input_guardrail(
        self, config: Dict[str, Any]
    ) -> Optional[InputGuardrail]:
        if not config.get("enabled", False):
            return None

        config["api_key"] = self.api_key
        return OpenAIInputGuardrail(config)

    def create_output_guardrail(
        self, config: Dict[str, Any]
    ) -> Optional[OutputGuardrail]:
        if not config.get("enabled", False):
            return None

        config["api_key"] = self.api_key

        class OpenAIOutputGuardrail:
            def __init__(self, input_guardrail):
                self._input_guardrail = input_guardrail

            async def validate(
                self, response_content: str, original_request: GuardrailRequest
            ) -> GuardrailResponse:
                return await self._input_guardrail.validate(
                    GuardrailRequest(content=response_content)
                )

            def get_supported_detectors(self) -> List[ViolationType]:
                return self._input_guardrail.get_supported_detectors()

        return OpenAIOutputGuardrail(OpenAIInputGuardrail(config))

    def validate_config(self, config: Dict[str, Any]) -> bool:
        if config.get("enabled", False):
            if not self.api_key:
                return False
        return True

    def get_required_config_keys(self) -> List[str]:
        return ["enabled"]

Example: Custom Keyword Provider

From src/secure_mcp_gateway/plugins/guardrails/example_providers.py:250:
class CustomKeywordGuardrail:
    """Simple keyword-based guardrail."""

    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.blocked_keywords = config.get("blocked_keywords", [])
        self.case_sensitive = config.get("case_sensitive", False)

    async def validate(self, request: GuardrailRequest) -> GuardrailResponse:
        violations = []
        content = request.content
        
        if not self.case_sensitive:
            content = content.lower()
            blocked_keywords = [kw.lower() for kw in self.blocked_keywords]
        else:
            blocked_keywords = self.blocked_keywords

        for keyword in blocked_keywords:
            if keyword in content:
                violations.append(
                    GuardrailViolation(
                        violation_type=ViolationType.KEYWORD_VIOLATION,
                        severity=1.0,
                        message=f"Blocked keyword detected: {keyword}",
                        action=GuardrailAction.BLOCK,
                        metadata={"keyword": keyword},
                    )
                )

        is_safe = len(violations) == 0

        return GuardrailResponse(
            is_safe=is_safe,
            action=GuardrailAction.ALLOW if is_safe else GuardrailAction.BLOCK,
            violations=violations,
            metadata={"provider": "custom-keyword"},
        )

    def get_supported_detectors(self) -> List[ViolationType]:
        return [ViolationType.KEYWORD_VIOLATION]


class CustomKeywordProvider(GuardrailProvider):
    def __init__(self, blocked_keywords: List[str]):
        self.blocked_keywords = blocked_keywords

    def get_name(self) -> str:
        return "custom-keyword"

    def get_version(self) -> str:
        return "1.0.0"

    def create_input_guardrail(
        self, config: Dict[str, Any]
    ) -> Optional[InputGuardrail]:
        if not config.get("enabled", False):
            return None

        config["blocked_keywords"] = self.blocked_keywords
        return CustomKeywordGuardrail(config)

    def create_output_guardrail(
        self, config: Dict[str, Any]
    ) -> Optional[OutputGuardrail]:
        if not config.get("enabled", False):
            return None

        config["blocked_keywords"] = self.blocked_keywords

        class CustomKeywordOutputGuardrail:
            def __init__(self, input_guardrail):
                self._input_guardrail = input_guardrail

            async def validate(
                self, response_content: str, original_request: GuardrailRequest
            ) -> GuardrailResponse:
                return await self._input_guardrail.validate(
                    GuardrailRequest(content=response_content)
                )

            def get_supported_detectors(self) -> List[ViolationType]:
                return self._input_guardrail.get_supported_detectors()

        return CustomKeywordOutputGuardrail(CustomKeywordGuardrail(config))

    def get_required_config_keys(self) -> List[str]:
        return ["enabled", "blocked_keywords"]

Example: Composite Provider

Combine multiple guardrail providers with AND/OR logic (from src/secure_mcp_gateway/plugins/guardrails/example_providers.py:350):
class CompositeGuardrail:
    """Combines multiple guardrails with AND/OR logic."""

    def __init__(self, guardrails: List[InputGuardrail], logic: str = "OR"):
        self.guardrails = guardrails
        self.logic = logic.upper()

    async def validate(self, request: GuardrailRequest) -> GuardrailResponse:
        all_violations = []
        all_safe = []

        for guardrail in self.guardrails:
            result = await guardrail.validate(request)
            all_violations.extend(result.violations)
            all_safe.append(result.is_safe)

        # OR logic: all must be safe to allow
        # AND logic: any safe allows
        is_safe = all(all_safe) if self.logic == "OR" else any(all_safe)

        return GuardrailResponse(
            is_safe=is_safe,
            action=GuardrailAction.ALLOW if is_safe else GuardrailAction.BLOCK,
            violations=all_violations,
            metadata={"provider": "composite", "logic": self.logic},
        )

    def get_supported_detectors(self) -> List[ViolationType]:
        all_detectors = set()
        for guardrail in self.guardrails:
            all_detectors.update(guardrail.get_supported_detectors())
        return list(all_detectors)


class CompositeGuardrailProvider(GuardrailProvider):
    """Combines multiple providers."""

    def __init__(self, providers: List[GuardrailProvider], logic: str = "OR"):
        self.providers = providers
        self.logic = logic

    def get_name(self) -> str:
        provider_names = "_".join([p.get_name() for p in self.providers])
        return f"composite_{provider_names}"

    def get_version(self) -> str:
        return "1.0.0"

    def create_input_guardrail(
        self, config: Dict[str, Any]
    ) -> Optional[InputGuardrail]:
        if not config.get("enabled", False):
            return None

        guardrails = []
        for provider in self.providers:
            guardrail = provider.create_input_guardrail(config)
            if guardrail:
                guardrails.append(guardrail)

        if not guardrails:
            return None

        return CompositeGuardrail(guardrails, self.logic)

    def create_output_guardrail(
        self, config: Dict[str, Any]
    ) -> Optional[OutputGuardrail]:
        # Similar implementation for output guardrails
        pass

Configuration

Register your guardrail provider in the config file:
{
  "plugins": {
    "guardrails": {
      "provider": "openai-moderation",
      "config": {
        "api_key": "sk-...",
        "threshold": 0.7,
        "block_categories": ["hate", "violence", "sexual"]
      }
    }
  },
  "mcp_configs": {
    "config-id": {
      "mcp_config": [
        {
          "server_name": "github",
          "enable_tool_guardrails": true,
          "input_guardrails_policy": {
            "enabled": true,
            "block": ["toxicity", "nsfw"],
            "additional_config": {
              "threshold": 0.8
            }
          },
          "output_guardrails_policy": {
            "enabled": true,
            "block": ["toxicity", "pii"],
            "additional_config": {
              "pii_redaction": true
            }
          }
        }
      ]
    }
  }
}

Using the GuardrailFactory

from secure_mcp_gateway.plugins.guardrails.base import (
    GuardrailRegistry,
    GuardrailFactory,
)

# Register provider
registry = GuardrailRegistry()
registry.register(OpenAIGuardrailProvider(api_key="sk-..."))

# Create factory
factory = GuardrailFactory(registry)

# Create guardrails
input_guardrail = factory.create_input_guardrail(
    "openai-moderation",
    {"enabled": True, "threshold": 0.8}
)

output_guardrail = factory.create_output_guardrail(
    "openai-moderation",
    {"enabled": True}
)

Testing Your Provider

import pytest
from secure_mcp_gateway.plugins.guardrails.base import (
    GuardrailRequest,
    GuardrailAction,
    ViolationType,
)

@pytest.mark.asyncio
async def test_keyword_guardrail_blocks():
    provider = CustomKeywordProvider(blocked_keywords=["password", "secret"])
    guardrail = provider.create_input_guardrail({"enabled": True})
    
    request = GuardrailRequest(content="My password is secret123")
    response = await guardrail.validate(request)
    
    assert not response.is_safe
    assert response.action == GuardrailAction.BLOCK
    assert len(response.violations) > 0
    assert response.violations[0].violation_type == ViolationType.KEYWORD_VIOLATION

@pytest.mark.asyncio
async def test_keyword_guardrail_allows():
    provider = CustomKeywordProvider(blocked_keywords=["password"])
    guardrail = provider.create_input_guardrail({"enabled": True})
    
    request = GuardrailRequest(content="Hello world")
    response = await guardrail.validate(request)
    
    assert response.is_safe
    assert response.action == GuardrailAction.ALLOW
    assert len(response.violations) == 0

Best Practices

  • Cache results for identical content
  • Use async operations for API calls
  • Set reasonable timeouts
  • Consider rate limiting
from functools import lru_cache

@lru_cache(maxsize=1000)
def _check_keyword(content: str, keyword: str) -> bool:
    return keyword in content
Always return a safe GuardrailResponse on errors:
try:
    # Your validation logic
    pass
except Exception as e:
    return GuardrailResponse(
        is_safe=False,
        action=GuardrailAction.BLOCK,
        violations=[GuardrailViolation(
            violation_type=ViolationType.CUSTOM,
            severity=1.0,
            message=f"Validation error: {str(e)}",
            action=GuardrailAction.BLOCK,
            metadata={"error": str(e)}
        )],
        metadata={"error": str(e)}
    )
Use consistent severity scoring:
  • 0.0 - 0.3: Low severity (informational)
  • 0.3 - 0.7: Medium severity (warning)
  • 0.7 - 1.0: High severity (block)

Plugin Overview

Learn about the plugin system architecture

Guardrails Guide

Configure and use guardrails

Build docs developers (and LLMs) love