Skip to main content

Objectives

By the end of this lab you will be able to:
  • Build a FastMCP server with proper architecture and organization
  • Implement database integration with connection pooling and error handling
  • Create MCP tools for database schema introspection and query execution
  • Configure Row Level Security context management
  • Add health monitoring and observability features
  • Test your MCP server locally and with VS Code

Prerequisites

  • Completed Lab 4: Database Design
  • PostgreSQL running with the retail schema loaded
  • Virtual environment activated with dependencies installed

Project structure

mcp_server/
├── __init__.py
├── config.py                          # Configuration management
├── health_check.py                    # Health monitoring endpoints
├── sales_analysis.py                  # Main MCP server implementation
├── sales_analysis_postgres.py         # Database integration layer
└── sales_analysis_text_embeddings.py  # AI/semantic search integration

Step 1: Configuration management

# mcp_server/config.py
import os
import logging
from typing import Optional, Dict, Any
from dataclasses import dataclass
from dotenv import load_dotenv

load_dotenv()
logger = logging.getLogger(__name__)

@dataclass
class DatabaseConfig:
    host: str
    port: int
    database: str
    user: str
    password: str
    min_connections: int = 2
    max_connections: int = 10
    command_timeout: int = 30

    @classmethod
    def from_env(cls) -> 'DatabaseConfig':
        return cls(
            host=os.getenv('POSTGRES_HOST', 'localhost'),
            port=int(os.getenv('POSTGRES_PORT', '5432')),
            database=os.getenv('POSTGRES_DB', 'zava'),
            user=os.getenv('POSTGRES_USER', 'postgres'),
            password=os.getenv('POSTGRES_PASSWORD', ''),
            min_connections=int(os.getenv('POSTGRES_MIN_CONNECTIONS', '2')),
            max_connections=int(os.getenv('POSTGRES_MAX_CONNECTIONS', '10')),
            command_timeout=int(os.getenv('POSTGRES_COMMAND_TIMEOUT', '30'))
        )

    def to_asyncpg_params(self) -> Dict[str, Any]:
        return {
            'host': self.host,
            'port': self.port,
            'database': self.database,
            'user': self.user,
            'password': self.password,
            'command_timeout': self.command_timeout,
            'server_settings': {
                'application_name': 'zava-mcp-server',
                'jit': 'off',
                'work_mem': '4MB',
                'statement_timeout': f'{self.command_timeout}s'
            }
        }

@dataclass
class AzureConfig:
    project_endpoint: str
    openai_endpoint: str
    embedding_model_deployment: str
    client_id: str
    client_secret: str
    tenant_id: str

    @classmethod
    def from_env(cls) -> 'AzureConfig':
        return cls(
            project_endpoint=os.getenv('PROJECT_ENDPOINT', ''),
            openai_endpoint=os.getenv('AZURE_OPENAI_ENDPOINT', ''),
            embedding_model_deployment=os.getenv('EMBEDDING_MODEL_DEPLOYMENT_NAME', 'text-embedding-3-small'),
            client_id=os.getenv('AZURE_CLIENT_ID', ''),
            client_secret=os.getenv('AZURE_CLIENT_SECRET', ''),
            tenant_id=os.getenv('AZURE_TENANT_ID', '')
        )

    def is_configured(self) -> bool:
        return all([self.project_endpoint, self.openai_endpoint,
                    self.client_id, self.client_secret, self.tenant_id])

@dataclass
class ServerConfig:
    host: str = '0.0.0.0'
    port: int = 8000
    log_level: str = 'INFO'
    enable_cors: bool = True
    enable_health_check: bool = True
    applicationinsights_connection_string: Optional[str] = None

    @classmethod
    def from_env(cls) -> 'ServerConfig':
        return cls(
            host=os.getenv('MCP_SERVER_HOST', '0.0.0.0'),
            port=int(os.getenv('MCP_SERVER_PORT', '8000')),
            log_level=os.getenv('LOG_LEVEL', 'INFO').upper(),
            enable_cors=os.getenv('ENABLE_CORS', 'true').lower() == 'true',
            enable_health_check=os.getenv('ENABLE_HEALTH_CHECK', 'true').lower() == 'true',
            applicationinsights_connection_string=os.getenv('APPLICATIONINSIGHTS_CONNECTION_STRING')
        )

class MCPServerConfig:
    def __init__(self):
        self.database = DatabaseConfig.from_env()
        self.azure = AzureConfig.from_env()
        self.server = ServerConfig.from_env()
        self._validate_config()

    def _validate_config(self):
        if not self.database.password:
            logger.warning("Database password is empty.")
        if not self.azure.is_configured():
            logger.warning("Azure configuration is incomplete. AI features may not work.")
        logger.info(f"Database: {self.database.host}:{self.database.port}")
        logger.info(f"Server: {self.server.host}:{self.server.port}")

config = MCPServerConfig()

Step 2: Database integration layer

# mcp_server/sales_analysis_postgres.py
import asyncio
import asyncpg
import logging
from typing import Dict, Any, List, Optional
from contextlib import asynccontextmanager
from datetime import datetime
import json

from .config import config

logger = logging.getLogger(__name__)

class PostgreSQLSchemaProvider:
    """Provides PostgreSQL database access and schema information."""

    def __init__(self):
        self.connection_pool: Optional[asyncpg.Pool] = None
        self.postgres_config = config.database.to_asyncpg_params()

    async def create_pool(self) -> None:
        if self.connection_pool is None:
            self.connection_pool = await asyncpg.create_pool(
                **self.postgres_config,
                min_size=config.database.min_connections,
                max_size=config.database.max_connections,
                max_inactive_connection_lifetime=300
            )
            logger.info("Database connection pool created")

    async def close_pool(self) -> None:
        if self.connection_pool:
            await self.connection_pool.close()
            self.connection_pool = None

    @asynccontextmanager
    async def get_connection(self):
        if not self.connection_pool:
            await self.create_pool()
        async with self.connection_pool.acquire() as connection:
            yield connection

    async def set_rls_context(self, connection: asyncpg.Connection, rls_user_id: str) -> None:
        await connection.execute(
            "SELECT set_config('app.current_rls_user_id', $1, false)",
            rls_user_id
        )

    async def execute_query(self, sql_query: str, rls_user_id: str, max_rows: int = 20) -> str:
        """Execute a SQL query with Row Level Security context."""
        async with self.get_connection() as conn:
            await self.set_rls_context(conn, rls_user_id)

            rows = await asyncio.wait_for(
                conn.fetch(sql_query),
                timeout=config.database.command_timeout
            )

            if not rows:
                return "Query executed successfully. No rows returned."

            limited_rows = rows[:max_rows]
            return self._format_query_results(limited_rows, len(rows), max_rows)

    def _format_query_results(self, rows, total_rows: int, max_rows: int) -> str:
        if not rows:
            return "No results found."

        columns = list(rows[0].keys())
        result_lines = [f"Results ({len(rows)} of {total_rows} rows):", "=" * 50]
        result_lines.append(" | ".join(columns))
        result_lines.append("-" * len(" | ".join(columns)))

        for row in rows:
            formatted_values = []
            for col in columns:
                value = row[col]
                if value is None:
                    formatted_values.append("NULL")
                elif isinstance(value, datetime):
                    formatted_values.append(value.strftime("%Y-%m-%d %H:%M:%S"))
                elif isinstance(value, (dict, list)):
                    formatted_values.append(json.dumps(value))
                else:
                    formatted_values.append(str(value))
            result_lines.append(" | ".join(formatted_values))

        if total_rows > max_rows:
            result_lines.append(f"\n... and {total_rows - max_rows} more rows (truncated)")

        return "\n".join(result_lines)

    async def health_check(self) -> Dict[str, Any]:
        try:
            async with self.get_connection() as conn:
                result = await conn.fetchval("SELECT 1")
                return {
                    "status": "healthy",
                    "database_responsive": result == 1
                }
        except Exception as e:
            return {"status": "unhealthy", "error": str(e)}

db_provider = PostgreSQLSchemaProvider()

Step 3: Main MCP server

# mcp_server/sales_analysis.py
import logging
from typing import List, Annotated
from contextlib import asynccontextmanager

from fastmcp import FastMCP, Context
from pydantic import Field

from .config import config
from .sales_analysis_postgres import db_provider
from .health_check import setup_health_endpoints

logging.basicConfig(
    level=getattr(logging, config.server.log_level),
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

mcp = FastMCP("Zava Retail Sales Analysis")

VALID_TABLES = [
    "retail.stores",
    "retail.customers",
    "retail.categories",
    "retail.product_types",
    "retail.products",
    "retail.orders",
    "retail.order_items",
    "retail.inventory"
]

def get_rls_user_id(ctx: Context) -> str:
    """Extract Row Level Security User ID from request context."""
    if hasattr(ctx, 'headers') and ctx.headers:
        rls_user_id = ctx.headers.get("x-rls-user-id")
        if rls_user_id:
            return rls_user_id
    default_id = "00000000-0000-0000-0000-000000000000"
    logger.warning(f"No RLS User ID found, using default: {default_id}")
    return default_id

@mcp.tool()
async def get_multiple_table_schemas(
    ctx: Context,
    table_names: Annotated[List[str], Field(
        description="List of table names to retrieve schemas for. Valid tables: " + ", ".join(VALID_TABLES)
    )]
) -> str:
    """
    Retrieve database schemas for multiple tables in a single request.
    Provides column names, types, constraints, foreign keys, and indexes.
    """
    rls_user_id = get_rls_user_id(ctx)

    invalid_tables = [t for t in table_names if t not in VALID_TABLES]
    if invalid_tables:
        return f"Error: Invalid table names: {', '.join(invalid_tables)}. Valid: {', '.join(VALID_TABLES)}"

    try:
        return await db_provider.get_multiple_table_schemas(table_names, rls_user_id)
    except Exception as e:
        logger.error(f"Error retrieving schemas: {e}")
        return f"Error retrieving table schemas: {e!s}"

@mcp.tool()
async def execute_sales_query(
    ctx: Context,
    postgresql_query: Annotated[str, Field(
        description="A well-formed PostgreSQL query. Always get table schemas first before writing queries."
    )]
) -> str:
    """
    Execute PostgreSQL queries against the retail sales database with Row Level Security.
    Results are automatically filtered to the caller's authorized store context.
    Returns up to 20 rows for readability.
    """
    rls_user_id = get_rls_user_id(ctx)

    try:
        logger.info(f"Executing query for user: {rls_user_id}")
        return await db_provider.execute_query(postgresql_query, rls_user_id)
    except Exception as e:
        logger.error(f"Query error: {e}")
        return f"Error executing database query: {e!s}"

@mcp.tool()
async def get_current_utc_date(ctx: Context) -> str:
    """Get the current UTC date and time in ISO format for use in time-sensitive queries."""
    try:
        return await db_provider.get_current_utc_date()
    except Exception as e:
        return f"Error getting current UTC date: {e!s}"

Step 4: Health check endpoints

# mcp_server/health_check.py
import logging
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse

logger = logging.getLogger(__name__)

def setup_health_endpoints(app: FastAPI, db_provider) -> None:
    """Add health check endpoints to the FastAPI application."""

    @app.get("/health")
    async def health_check() -> JSONResponse:
        return JSONResponse(status_code=200, content={
            "status": "healthy",
            "service": "zava-retail-mcp-server",
            "timestamp": await db_provider.get_current_utc_date()
        })

    @app.get("/health/detailed")
    async def detailed_health_check() -> JSONResponse:
        db_health = await db_provider.health_check()
        overall_healthy = db_health["status"] == "healthy"

        return JSONResponse(
            status_code=200 if overall_healthy else 503,
            content={
                "service": "zava-retail-mcp-server",
                "status": "healthy" if overall_healthy else "unhealthy",
                "components": {"database": db_health}
            }
        )

    @app.get("/health/ready")
    async def readiness_check() -> JSONResponse:
        db_health = await db_provider.health_check()
        if db_health["status"] != "healthy":
            raise HTTPException(status_code=503, detail="Database not ready")
        return JSONResponse(status_code=200, content={"status": "ready"})

    @app.get("/health/live")
    async def liveness_check() -> JSONResponse:
        return JSONResponse(status_code=200, content={"status": "alive"})

Step 5: Application lifecycle and startup

# Continued in mcp_server/sales_analysis.py

@asynccontextmanager
async def lifespan(app):
    """Manage application startup and shutdown."""
    logger.info("Starting Zava Retail MCP Server...")
    try:
        await db_provider.create_pool()
        health_status = await db_provider.health_check()
        if health_status["status"] != "healthy":
            raise Exception("Database not healthy at startup")
        logger.info("MCP Server startup complete")
        yield
    except Exception as e:
        logger.error(f"Startup failed: {e}")
        raise
    finally:
        logger.info("Shutting down MCP Server...")
        await db_provider.close_pool()

def create_app():
    app = mcp.sse_app()
    app.router.lifespan_context = lifespan

    if config.server.enable_health_check:
        setup_health_endpoints(app, db_provider)

    if config.server.enable_cors:
        from fastapi.middleware.cors import CORSMiddleware
        app.add_middleware(
            CORSMiddleware,
            allow_origins=["*"],
            allow_credentials=True,
            allow_methods=["*"],
            allow_headers=["*"],
        )
    return app

app = create_app()

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("sales_analysis:app",
                host=config.server.host,
                port=config.server.port,
                reload=True)

Testing the server

1

Start the MCP server

source mcp-env/bin/activate
cd mcp_server
python sales_analysis.py
2

Test health endpoints

curl http://localhost:8000/health
curl http://localhost:8000/health/detailed
3

Test MCP protocol endpoints

# List available tools
curl -X POST http://localhost:8000/mcp \
  -H "Content-Type: application/json" \
  -H "x-rls-user-id: 00000000-0000-0000-0000-000000000000" \
  -d '{"method": "tools/list", "params": {}}'

# Get table schema
curl -X POST http://localhost:8000/mcp \
  -H "Content-Type: application/json" \
  -H "x-rls-user-id: 00000000-0000-0000-0000-000000000000" \
  -d '{
    "method": "tools/call",
    "params": {
      "name": "get_multiple_table_schemas",
      "arguments": {"table_names": ["retail.stores", "retail.products"]}
    }
  }'
4

Test in VS Code AI Chat

Open VS Code, go to AI Chat (Ctrl+Shift+P → “AI Chat”), type #zava, select your server, and ask:
  • “What tables are available in the database?”
  • “Show me the top 5 stores by number of orders.”

Unit tests

# tests/test_mcp_server.py
import pytest
import asyncio
from mcp_server.sales_analysis_postgres import PostgreSQLSchemaProvider

@pytest.mark.asyncio
async def test_database_connection():
    db = PostgreSQLSchemaProvider()
    try:
        await db.create_pool()
        health = await db.health_check()
        assert health["status"] == "healthy"
    finally:
        await db.close_pool()

@pytest.mark.asyncio
async def test_query_execution():
    db = PostgreSQLSchemaProvider()
    try:
        await db.create_pool()
        result = await db.execute_query(
            "SELECT COUNT(*) as store_count FROM retail.stores",
            "00000000-0000-0000-0000-000000000000"
        )
        assert "store_count" in result
    finally:
        await db.close_pool()

Key takeaways

  • FastMCP provides declarative tool registration with automatic type validation
  • RLS context is set per-request from the x-rls-user-id header
  • Connection pooling with lifecycle management ensures clean startup and shutdown
  • Health endpoints (/health, /health/ready, /health/live) support both human and Kubernetes probes
  • Error handling returns user-friendly messages without exposing sensitive internal details

Next: Lab 6 — Tool Development

Expand the tool collection with advanced query patterns, SQL validation, business intelligence tools, and schema introspection.

Build docs developers (and LLMs) love