Skip to main content

Overview

Auth providers validate credentials and manage user sessions. The gateway supports multiple authentication methods including API keys, OAuth 2.0, JWT, and custom schemes.

AuthProvider Interface

All authentication providers must implement the AuthProvider abstract base class from src/secure_mcp_gateway/plugins/auth/base.py.

Required Methods

from abc import ABC, abstractmethod
from typing import List

class AuthProvider(ABC):
    @abstractmethod
    def get_name(self) -> str:
        """Get the unique name of this provider."""
        pass

    @abstractmethod
    def get_version(self) -> str:
        """Get the version of this provider."""
        pass

    @abstractmethod
    def get_supported_methods(self) -> List[AuthMethod]:
        """Get the authentication methods supported by this provider."""
        pass

    @abstractmethod
    async def authenticate(self, credentials: AuthCredentials) -> AuthResult:
        """Authenticate user with provided credentials."""
        pass

    @abstractmethod
    async def validate_session(self, session_id: str) -> bool:
        """Validate if a session is still valid."""
        pass

    @abstractmethod
    async def refresh_authentication(
        self, session_id: str, credentials: AuthCredentials
    ) -> AuthResult:
        """Refresh authentication for an existing session."""
        pass

Optional Methods

def validate_config(self, config: Dict[str, Any]) -> bool:
    """Validate provider configuration."""
    return True

def get_required_config_keys(self) -> List[str]:
    """Get list of required configuration keys."""
    return []

Data Models

AuthCredentials

Container for various credential types:
@dataclass
class AuthCredentials:
    # Primary credentials
    api_key: Optional[str] = None
    gateway_key: Optional[str] = None
    project_id: Optional[str] = None
    user_id: Optional[str] = None

    # OAuth/JWT
    access_token: Optional[str] = None
    refresh_token: Optional[str] = None
    id_token: Optional[str] = None

    # Basic auth
    username: Optional[str] = None
    password: Optional[str] = None

    # Additional metadata
    headers: Dict[str, str] = field(default_factory=dict)
    context: Dict[str, Any] = field(default_factory=dict)
Sensitive fields are automatically masked in string representations using __repr__().

AuthResult

Authentication result with user information:
@dataclass
class AuthResult:
    status: AuthStatus
    authenticated: bool
    message: str

    # User/Session information
    user_id: Optional[str] = None
    project_id: Optional[str] = None
    session_id: Optional[str] = None

    # Configuration
    gateway_config: Optional[Dict[str, Any]] = None
    mcp_config: Optional[List[Dict[str, Any]]] = None

    # Permissions
    permissions: List[str] = field(default_factory=list)
    roles: List[str] = field(default_factory=list)

    # Metadata
    metadata: Dict[str, Any] = field(default_factory=dict)
    error: Optional[str] = None

    @property
    def is_success(self) -> bool:
        """Check if authentication was successful."""
        return self.status == AuthStatus.SUCCESS and self.authenticated

AuthStatus Enum

class AuthStatus(Enum):
    SUCCESS = "success"
    FAILURE = "failure"
    EXPIRED = "expired"
    INVALID_CREDENTIALS = "invalid_credentials"
    INSUFFICIENT_PERMISSIONS = "insufficient_permissions"
    RATE_LIMITED = "rate_limited"
    ERROR = "error"

AuthMethod Enum

class AuthMethod(Enum):
    API_KEY = "api_key"
    OAUTH = "oauth"
    JWT = "jwt"
    BASIC_AUTH = "basic_auth"
    BEARER_TOKEN = "bearer_token"
    CUSTOM = "custom"

Example: API Key Provider

Here’s a complete implementation from src/secure_mcp_gateway/plugins/auth/example_providers.py:376:
from typing import Any, Dict, List, Optional
from secure_mcp_gateway.plugins.auth.base import (
    AuthProvider,
    AuthCredentials,
    AuthResult,
    AuthStatus,
    AuthMethod
)

class APIKeyProvider(AuthProvider):
    """Simple API key authentication provider."""

    def __init__(
        self,
        valid_keys: Dict[str, Dict[str, Any]] = None,
        key_validator: callable = None,
    ):
        self.valid_keys = valid_keys or {}
        self.key_validator = key_validator

    def get_name(self) -> str:
        return "apikey"

    def get_version(self) -> str:
        return "1.0.0"

    def get_supported_methods(self) -> List[AuthMethod]:
        return [AuthMethod.API_KEY]

    async def authenticate(self, credentials: AuthCredentials) -> AuthResult:
        try:
            api_key = credentials.api_key or credentials.gateway_key

            if not api_key:
                return AuthResult(
                    status=AuthStatus.INVALID_CREDENTIALS,
                    authenticated=False,
                    message="API key required",
                    error="Missing api_key",
                )

            # Use custom validator if provided
            if self.key_validator:
                is_valid = await self.key_validator(api_key)
                if not is_valid:
                    return AuthResult(
                        status=AuthStatus.INVALID_CREDENTIALS,
                        authenticated=False,
                        message="Invalid API key",
                        error="Key validation failed",
                    )

                return AuthResult(
                    status=AuthStatus.SUCCESS,
                    authenticated=True,
                    message="API key authentication successful",
                    metadata={"provider": "apikey"},
                )

            # Check against valid keys
            if api_key not in self.valid_keys:
                return AuthResult(
                    status=AuthStatus.INVALID_CREDENTIALS,
                    authenticated=False,
                    message="Invalid API key",
                    error="Key not found",
                )

            user_info = self.valid_keys[api_key]

            return AuthResult(
                status=AuthStatus.SUCCESS,
                authenticated=True,
                message="API key authentication successful",
                user_id=user_info.get("user_id"),
                project_id=user_info.get("project_id"),
                metadata={"provider": "apikey", "user_info": user_info},
            )

        except Exception as e:
            return AuthResult(
                status=AuthStatus.ERROR,
                authenticated=False,
                message="API key authentication failed",
                error=str(e),
            )

    async def validate_session(self, session_id: str) -> bool:
        return True

    async def refresh_authentication(
        self, session_id: str, credentials: AuthCredentials
    ) -> AuthResult:
        return await self.authenticate(credentials)

Example: OAuth2 Provider

From src/secure_mcp_gateway/plugins/auth/example_providers.py:31:
import requests

class OAuth2Provider(AuthProvider):
    """OAuth 2.0 authentication provider."""

    def __init__(
        self,
        client_id: str,
        client_secret: str,
        authorization_url: str,
        token_url: str,
        user_info_url: str = None,
        scopes: List[str] = None,
        timeout: int = 30,
    ):
        self.client_id = client_id
        self.client_secret = client_secret
        self.authorization_url = authorization_url
        self.token_url = token_url
        self.user_info_url = user_info_url
        self.scopes = scopes or []
        self.timeout = timeout

    def get_name(self) -> str:
        return "oauth2"

    def get_version(self) -> str:
        return "1.0.0"

    def get_supported_methods(self) -> List[AuthMethod]:
        return [AuthMethod.OAUTH, AuthMethod.BEARER_TOKEN]

    async def authenticate(self, credentials: AuthCredentials) -> AuthResult:
        try:
            access_token = credentials.access_token

            if not access_token:
                return AuthResult(
                    status=AuthStatus.INVALID_CREDENTIALS,
                    authenticated=False,
                    message="Access token required",
                    error="Missing access_token",
                )

            # Validate token by fetching user info
            if self.user_info_url:
                user_info = await self._get_user_info(access_token)
                if not user_info:
                    return AuthResult(
                        status=AuthStatus.INVALID_CREDENTIALS,
                        authenticated=False,
                        message="Invalid or expired token",
                        error="Token validation failed",
                    )

                return AuthResult(
                    status=AuthStatus.SUCCESS,
                    authenticated=True,
                    message="OAuth authentication successful",
                    user_id=user_info.get("sub") or user_info.get("id"),
                    metadata={"user_info": user_info, "provider": "oauth2"},
                )

            return AuthResult(
                status=AuthStatus.SUCCESS,
                authenticated=True,
                message="OAuth authentication successful",
                metadata={"provider": "oauth2"},
            )

        except Exception as e:
            return AuthResult(
                status=AuthStatus.ERROR,
                authenticated=False,
                message="OAuth authentication failed",
                error=str(e),
            )

    async def _get_user_info(self, access_token: str) -> Optional[Dict[str, Any]]:
        try:
            response = requests.get(
                self.user_info_url,
                headers={"Authorization": f"Bearer {access_token}"},
                timeout=self.timeout,
            )

            if response.status_code == 200:
                return response.json()

            return None
        except Exception:
            return None

    async def validate_session(self, session_id: str) -> bool:
        return True

    async def refresh_authentication(
        self, session_id: str, credentials: AuthCredentials
    ) -> AuthResult:
        try:
            refresh_token = credentials.refresh_token

            if not refresh_token:
                return AuthResult(
                    status=AuthStatus.INVALID_CREDENTIALS,
                    authenticated=False,
                    message="Refresh token required",
                    error="Missing refresh_token",
                )

            # Exchange refresh token for new access token
            response = requests.post(
                self.token_url,
                data={
                    "grant_type": "refresh_token",
                    "refresh_token": refresh_token,
                    "client_id": self.client_id,
                    "client_secret": self.client_secret,
                },
                timeout=self.timeout,
            )

            if response.status_code != 200:
                return AuthResult(
                    status=AuthStatus.EXPIRED,
                    authenticated=False,
                    message="Token refresh failed",
                    error="Invalid refresh token",
                )

            token_data = response.json()
            new_access_token = token_data.get("access_token")

            new_credentials = AuthCredentials(access_token=new_access_token)
            return await self.authenticate(new_credentials)

        except Exception as e:
            return AuthResult(
                status=AuthStatus.ERROR,
                authenticated=False,
                message="Token refresh failed",
                error=str(e),
            )

Example: JWT Provider

From src/secure_mcp_gateway/plugins/auth/example_providers.py:230:
import jwt

class JWTProvider(AuthProvider):
    """JWT authentication provider."""

    def __init__(
        self,
        secret_key: str,
        algorithm: str = "HS256",
        verify_exp: bool = True,
        verify_signature: bool = True,
        audience: Optional[str] = None,
        issuer: Optional[str] = None,
    ):
        self.secret_key = secret_key
        self.algorithm = algorithm
        self.verify_exp = verify_exp
        self.verify_signature = verify_signature
        self.audience = audience
        self.issuer = issuer

    def get_name(self) -> str:
        return "jwt"

    def get_version(self) -> str:
        return "1.0.0"

    def get_supported_methods(self) -> List[AuthMethod]:
        return [AuthMethod.JWT, AuthMethod.BEARER_TOKEN]

    async def authenticate(self, credentials: AuthCredentials) -> AuthResult:
        try:
            token = credentials.access_token or credentials.api_key

            if not token:
                return AuthResult(
                    status=AuthStatus.INVALID_CREDENTIALS,
                    authenticated=False,
                    message="JWT token required",
                    error="Missing token",
                )

            try:
                payload = jwt.decode(
                    token,
                    self.secret_key,
                    algorithms=[self.algorithm],
                    options={
                        "verify_exp": self.verify_exp,
                        "verify_signature": self.verify_signature,
                    },
                    audience=self.audience,
                    issuer=self.issuer,
                )
            except jwt.ExpiredSignatureError:
                return AuthResult(
                    status=AuthStatus.EXPIRED,
                    authenticated=False,
                    message="Token has expired",
                    error="Token expired",
                )
            except jwt.InvalidTokenError as e:
                return AuthResult(
                    status=AuthStatus.INVALID_CREDENTIALS,
                    authenticated=False,
                    message="Invalid token",
                    error=str(e),
                )

            user_id = payload.get("sub") or payload.get("user_id")

            return AuthResult(
                status=AuthStatus.SUCCESS,
                authenticated=True,
                message="JWT authentication successful",
                user_id=user_id,
                metadata={
                    "jwt_payload": payload,
                    "provider": "jwt",
                    "exp": payload.get("exp"),
                },
            )

        except Exception as e:
            return AuthResult(
                status=AuthStatus.ERROR,
                authenticated=False,
                message="JWT authentication failed",
                error=str(e),
            )

    async def validate_session(self, session_id: str) -> bool:
        return True

    async def refresh_authentication(
        self, session_id: str, credentials: AuthCredentials
    ) -> AuthResult:
        return await self.authenticate(credentials)

Configuration

Register your auth provider in the config file:
{
  "plugins": {
    "auth": {
      "provider": "oauth2",
      "config": {
        "client_id": "your-client-id",
        "client_secret": "your-client-secret",
        "authorization_url": "https://auth.example.com/authorize",
        "token_url": "https://auth.example.com/token",
        "user_info_url": "https://auth.example.com/userinfo"
      }
    }
  }
}

Helper Utilities

Credential Extraction

The gateway automatically extracts credentials from request context:
from secure_mcp_gateway.plugins.auth.config_manager import AuthConfigManager

auth_manager = AuthConfigManager()
credentials = auth_manager.get_gateway_credentials(ctx)
# Returns AuthCredentials with gateway_key from headers

Sensitive Data Masking

from secure_mcp_gateway.plugins.auth.base import mask_sensitive_value

masked = mask_sensitive_value("my-secret-key-12345", visible_chars=4)
# Returns: "****2345"

Testing Your Provider

import pytest
from secure_mcp_gateway.plugins.auth.base import AuthCredentials

@pytest.mark.asyncio
async def test_authenticate_valid_key():
    provider = APIKeyProvider(
        valid_keys={"test-key-123": {"user_id": "user1"}}
    )
    
    credentials = AuthCredentials(api_key="test-key-123")
    result = await provider.authenticate(credentials)
    
    assert result.is_success
    assert result.user_id == "user1"

@pytest.mark.asyncio
async def test_authenticate_invalid_key():
    provider = APIKeyProvider(valid_keys={})
    
    credentials = AuthCredentials(api_key="invalid-key")
    result = await provider.authenticate(credentials)
    
    assert not result.is_success
    assert result.status == AuthStatus.INVALID_CREDENTIALS

Plugin Overview

Learn about the plugin system architecture

Configuration

Configure authentication settings

Build docs developers (and LLMs) love