Skip to main content

Overview

Secure MCP Gateway features a comprehensive plugin architecture that allows you to extend and customize authentication, guardrails, and telemetry functionality. The plugin system follows SOLID principles and provides clean interfaces for building custom providers.

Plugin Types

Authentication

Custom authentication providers for API keys, OAuth, JWT, and more

Guardrails

Input/output validation providers for content safety

Telemetry

Logging, tracing, and metrics providers

Architecture Principles

The plugin system follows these design patterns:

Open/Closed Principle

The system is open for extension (new plugins) but closed for modification (core system unchanged):
# Core interfaces never change
class AuthProvider(ABC):
    @abstractmethod
    async def authenticate(self, credentials: AuthCredentials) -> AuthResult:
        pass

# New providers extend the interface
class MyCustomAuthProvider(AuthProvider):
    async def authenticate(self, credentials: AuthCredentials) -> AuthResult:
        # Custom implementation
        pass

Interface Segregation

Providers implement only the interfaces they need:
# Separate concerns for different capabilities
class InputGuardrail(Protocol):
    async def validate(self, request: GuardrailRequest) -> GuardrailResponse:
        ...

class OutputGuardrail(Protocol):
    async def validate(self, response: str, request: GuardrailRequest) -> GuardrailResponse:
        ...

class PIIHandler(Protocol):
    async def detect_pii(self, content: str) -> List[GuardrailViolation]:
        ...
    async def redact_pii(self, content: str) -> tuple[str, Dict]:
        ...

Dependency Inversion

High-level modules depend on abstractions, not concrete implementations:
# High-level code depends on abstract registry
class GatewayServer:
    def __init__(self):
        self.auth_manager = get_auth_config_manager()  # Returns abstract interface
        self.guardrail_manager = get_guardrail_config_manager()
        self.telemetry_manager = get_telemetry_config_manager()

Authentication Plugins

AuthProvider Interface

src/secure_mcp_gateway/plugins/auth/base.py
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional, List, Dict, Any

class AuthStatus(Enum):
    SUCCESS = "success"
    FAILURE = "failure"
    EXPIRED = "expired"
    INVALID_CREDENTIALS = "invalid_credentials"
    INSUFFICIENT_PERMISSIONS = "insufficient_permissions"

class AuthMethod(Enum):
    API_KEY = "api_key"
    OAUTH = "oauth"
    JWT = "jwt"
    BASIC_AUTH = "basic_auth"
    BEARER_TOKEN = "bearer_token"
    CUSTOM = "custom"

@dataclass
class AuthCredentials:
    # Primary credentials
    api_key: Optional[str] = None
    gateway_key: Optional[str] = None
    project_id: Optional[str] = None
    user_id: Optional[str] = None
    
    # OAuth/JWT
    access_token: Optional[str] = None
    refresh_token: Optional[str] = None
    
    # Basic auth
    username: Optional[str] = None
    password: Optional[str] = None
    
    # Additional metadata
    headers: Dict[str, str] = field(default_factory=dict)
    context: Dict[str, Any] = field(default_factory=dict)

@dataclass
class AuthResult:
    status: AuthStatus
    authenticated: bool
    message: str
    
    # User/Session information
    user_id: Optional[str] = None
    project_id: Optional[str] = None
    session_id: Optional[str] = None
    
    # Configuration
    gateway_config: Optional[Dict[str, Any]] = None
    mcp_config: Optional[List[Dict[str, Any]]] = None
    
    # Permissions
    permissions: List[str] = field(default_factory=list)
    roles: List[str] = field(default_factory=list)
    
    metadata: Dict[str, Any] = field(default_factory=dict)
    error: Optional[str] = None

class AuthProvider(ABC):
    @abstractmethod
    def get_name(self) -> str:
        """Provider name (e.g., 'oauth2', 'jwt')"""
        pass
    
    @abstractmethod
    def get_version(self) -> str:
        """Provider version"""
        pass
    
    @abstractmethod
    def get_supported_methods(self) -> List[AuthMethod]:
        """Supported authentication methods"""
        pass
    
    @abstractmethod
    async def authenticate(self, credentials: AuthCredentials) -> AuthResult:
        """Authenticate user with provided credentials"""
        pass
    
    @abstractmethod
    async def validate_session(self, session_id: str) -> bool:
        """Validate if a session is still valid"""
        pass
    
    @abstractmethod
    async def refresh_authentication(
        self, session_id: str, credentials: AuthCredentials
    ) -> AuthResult:
        """Refresh authentication for existing session"""
        pass

Example: OAuth2 Provider

src/secure_mcp_gateway/plugins/auth/example_providers.py
class OAuth2Provider(AuthProvider):
    def __init__(
        self,
        client_id: str,
        client_secret: str,
        authorization_url: str,
        token_url: str,
        user_info_url: str = None,
        scopes: List[str] = None,
        timeout: int = 30
    ):
        self.client_id = client_id
        self.client_secret = client_secret
        self.authorization_url = authorization_url
        self.token_url = token_url
        self.user_info_url = user_info_url
        self.scopes = scopes or []
        self.timeout = timeout
    
    def get_name(self) -> str:
        return "oauth2"
    
    def get_version(self) -> str:
        return "1.0.0"
    
    def get_supported_methods(self) -> List[AuthMethod]:
        return [AuthMethod.OAUTH, AuthMethod.BEARER_TOKEN]
    
    async def authenticate(self, credentials: AuthCredentials) -> AuthResult:
        access_token = credentials.access_token
        
        if not access_token:
            return AuthResult(
                status=AuthStatus.INVALID_CREDENTIALS,
                authenticated=False,
                message="Access token required",
                error="Missing access_token"
            )
        
        # Validate token by fetching user info
        if self.user_info_url:
            user_info = await self._get_user_info(access_token)
            
            if not user_info:
                return AuthResult(
                    status=AuthStatus.FAILURE,
                    authenticated=False,
                    message="Invalid access token"
                )
            
            return AuthResult(
                status=AuthStatus.SUCCESS,
                authenticated=True,
                message="OAuth authentication successful",
                user_id=user_info.get("sub"),
                metadata={"user_info": user_info}
            )
        
        return AuthResult(
            status=AuthStatus.SUCCESS,
            authenticated=True,
            message="Bearer token accepted"
        )
    
    async def _get_user_info(self, access_token: str) -> Optional[Dict]:
        headers = {"Authorization": f"Bearer {access_token}"}
        
        try:
            response = requests.get(
                self.user_info_url,
                headers=headers,
                timeout=self.timeout
            )
            
            if response.status_code == 200:
                return response.json()
        except Exception as e:
            logger.error(f"Failed to get user info: {e}")
        
        return None

Guardrail Plugins

GuardrailProvider Interface

src/secure_mcp_gateway/plugins/guardrails/base.py
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Any, List, Optional

class GuardrailAction(Enum):
    ALLOW = "allow"
    BLOCK = "block"
    WARN = "warn"
    MODIFY = "modify"

class ViolationType(Enum):
    # Input violations
    PII = "pii"
    INJECTION_ATTACK = "injection_attack"
    TOXIC_CONTENT = "toxicity"
    NSFW_CONTENT = "nsfw"
    KEYWORD_VIOLATION = "keyword_detector"
    POLICY_VIOLATION = "policy_violation"
    
    # Output violations
    RELEVANCY_FAILURE = "relevancy"
    ADHERENCE_FAILURE = "adherence"
    HALLUCINATION = "hallucination"

@dataclass
class GuardrailViolation:
    violation_type: ViolationType
    severity: float  # 0.0 to 1.0
    message: str
    action: GuardrailAction
    metadata: Dict[str, Any]
    redacted_content: Optional[str] = None

@dataclass
class GuardrailRequest:
    content: str
    tool_name: Optional[str] = None
    tool_args: Optional[Dict[str, Any]] = None
    server_name: Optional[str] = None
    context: Optional[Dict[str, Any]] = None

@dataclass
class GuardrailResponse:
    is_safe: bool
    action: GuardrailAction
    violations: List[GuardrailViolation]
    modified_content: Optional[str] = None
    metadata: Dict[str, Any] = None
    processing_time_ms: Optional[float] = None

class GuardrailProvider(ABC):
    @abstractmethod
    def get_name(self) -> str:
        """Provider name (e.g., 'enkrypt', 'openai')"""
        pass
    
    @abstractmethod
    def get_version(self) -> str:
        """Provider version"""
        pass
    
    @abstractmethod
    def create_input_guardrail(
        self, config: Dict[str, Any]
    ) -> Optional[InputGuardrail]:
        """Create input guardrail instance"""
        pass
    
    @abstractmethod
    def create_output_guardrail(
        self, config: Dict[str, Any]
    ) -> Optional[OutputGuardrail]:
        """Create output guardrail instance"""
        pass
    
    def create_pii_handler(
        self, config: Dict[str, Any]
    ) -> Optional[PIIHandler]:
        """Create PII handler (optional)"""
        return None

Example: Custom Keyword Guardrail

src/secure_mcp_gateway/plugins/guardrails/example_providers.py
class CustomKeywordGuardrail:
    def __init__(self, config: Dict[str, Any]):
        self.blocked_keywords = config.get("blocked_keywords", [])
        self.case_sensitive = config.get("case_sensitive", False)
    
    async def validate(self, request: GuardrailRequest) -> GuardrailResponse:
        content = request.content
        if not self.case_sensitive:
            content = content.lower()
            blocked_keywords = [kw.lower() for kw in self.blocked_keywords]
        else:
            blocked_keywords = self.blocked_keywords
        
        violations = []
        for keyword in blocked_keywords:
            if keyword in content:
                violations.append(
                    GuardrailViolation(
                        violation_type=ViolationType.KEYWORD_VIOLATION,
                        severity=0.8,
                        message=f"Blocked keyword detected: {keyword}",
                        action=GuardrailAction.BLOCK,
                        metadata={"keyword": keyword}
                    )
                )
        
        return GuardrailResponse(
            is_safe=len(violations) == 0,
            action=GuardrailAction.ALLOW if not violations else GuardrailAction.BLOCK,
            violations=violations
        )
    
    def get_supported_detectors(self) -> List[ViolationType]:
        return [ViolationType.KEYWORD_VIOLATION]

class CustomKeywordProvider(GuardrailProvider):
    def __init__(self, blocked_keywords: List[str]):
        self.blocked_keywords = blocked_keywords
    
    def get_name(self) -> str:
        return "custom_keyword"
    
    def get_version(self) -> str:
        return "1.0.0"
    
    def create_input_guardrail(self, config: Dict[str, Any]):
        config["blocked_keywords"] = self.blocked_keywords
        return CustomKeywordGuardrail(config)
    
    def create_output_guardrail(self, config: Dict[str, Any]):
        config["blocked_keywords"] = self.blocked_keywords
        return CustomKeywordGuardrail(config)

Telemetry Plugins

TelemetryProvider Interface

src/secure_mcp_gateway/plugins/telemetry/base.py
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Dict

class TelemetryLevel(Enum):
    DEBUG = "debug"
    INFO = "info"
    WARNING = "warning"
    ERROR = "error"
    CRITICAL = "critical"

@dataclass
class TelemetryResult:
    success: bool
    provider_name: str
    message: str = ""
    data: Dict[str, Any] = field(default_factory=dict)
    error: str | None = None
    timestamp: datetime = field(default_factory=datetime.now)

class TelemetryProvider(ABC):
    @property
    @abstractmethod
    def name(self) -> str:
        """Get provider name"""
        pass
    
    @property
    @abstractmethod
    def version(self) -> str:
        """Get provider version"""
        pass
    
    @abstractmethod
    def initialize(self, config: Dict[str, Any]) -> TelemetryResult:
        """Initialize the telemetry provider"""
        pass
    
    @abstractmethod
    def create_logger(self, name: str) -> Any:
        """Create a logger instance"""
        pass
    
    @abstractmethod
    def create_tracer(self, name: str) -> Any:
        """Create a tracer instance"""
        pass
    
    def create_meter(self, name: str) -> Any:
        """Create a meter instance (optional)"""
        return None
    
    def shutdown(self) -> TelemetryResult:
        """Shutdown the provider (optional)"""
        return TelemetryResult(
            success=True,
            provider_name=self.name,
            message="Shutdown successful"
        )

Plugin Configuration

Configure Plugins in Gateway Config

enkrypt_mcp_config.json
{
  "plugins": {
    "auth": {
      "provider": "local_apikey",
      "config": {}
    },
    "guardrails": {
      "provider": "enkrypt",
      "config": {
        "api_key": "YOUR_ENKRYPT_API_KEY",
        "base_url": "https://api.enkryptai.com"
      }
    },
    "telemetry": {
      "provider": "opentelemetry",
      "config": {
        "url": "http://localhost:4317",
        "insecure": true
      }
    }
  }
}

Creating Custom Plugins

1

Create Provider Class

Implement the appropriate abstract base class (AuthProvider, GuardrailProvider, or TelemetryProvider)
2

Implement Required Methods

All @abstractmethod methods must be implemented
3

Add to example_providers.py

Place your provider in the appropriate example_providers.py file
4

Register Provider

Update the plugin loader to include your provider
5

Configure in Config File

Set your provider name in the plugins section

Example: Custom Logging Provider

custom_telemetry_provider.py
from secure_mcp_gateway.plugins.telemetry.base import (
    TelemetryProvider,
    TelemetryResult
)
import logging

class CustomLoggerProvider(TelemetryProvider):
    @property
    def name(self) -> str:
        return "custom_logger"
    
    @property
    def version(self) -> str:
        return "1.0.0"
    
    def initialize(self, config: Dict[str, Any]) -> TelemetryResult:
        self.log_file = config.get("log_file", "gateway.log")
        self.log_level = config.get("log_level", "INFO")
        
        logging.basicConfig(
            filename=self.log_file,
            level=getattr(logging, self.log_level),
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        
        return TelemetryResult(
            success=True,
            provider_name=self.name,
            message="Custom logger initialized"
        )
    
    def create_logger(self, name: str):
        return logging.getLogger(name)
    
    def create_tracer(self, name: str):
        # No tracing support in this simple example
        return None

Plugin Registry System

Each plugin type has a singleton registry:
from secure_mcp_gateway.plugins.auth import get_auth_config_manager
from secure_mcp_gateway.plugins.guardrails import get_guardrail_config_manager
from secure_mcp_gateway.plugins.telemetry import get_telemetry_config_manager

# Get plugin managers (singletons)
auth_manager = get_auth_config_manager()
guardrail_manager = get_guardrail_config_manager()
telemetry_manager = get_telemetry_config_manager()

# Use providers
provider = auth_manager.get_provider()
result = await provider.authenticate(credentials)

Best Practices

Version Compatibility: Always specify provider version in get_version() for compatibility tracking.
Error Handling: Providers should handle errors gracefully and return appropriate status codes in result objects.
Testing: Test custom providers with the included test MCP servers in bad_mcps/ directory.

Build docs developers (and LLMs) love