Skip to main content

Objectives

By the end of this lab you will be able to:
  • Design advanced MCP tools with complex parameter validation
  • Implement secure database query tools with SQL injection protection
  • Create schema introspection capabilities for dynamic queries
  • Build custom analytics tools for business intelligence
  • Apply comprehensive error handling and graceful degradation
  • Optimize tool performance for production workloads

Prerequisites

Core tool architecture

Base tool classes

# mcp_server/tools/base.py
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from dataclasses import dataclass
from enum import Enum
import time
import logging

logger = logging.getLogger(__name__)

class ToolCategory(Enum):
    DATABASE_QUERY = "database_query"
    SCHEMA_INTROSPECTION = "schema_introspection"
    ANALYTICS = "analytics"
    UTILITY = "utility"

@dataclass
class ToolResult:
    """Standardized tool result structure."""
    success: bool
    data: Any = None
    error: Optional[str] = None
    metadata: Optional[Dict[str, Any]] = None
    execution_time_ms: Optional[float] = None
    row_count: Optional[int] = None

class BaseTool(ABC):
    """Abstract base class for all MCP tools."""

    def __init__(self, name: str, description: str, category: ToolCategory):
        self.name = name
        self.description = description
        self.category = category
        self.call_count = 0
        self.total_execution_time = 0.0

    @abstractmethod
    async def execute(self, **kwargs) -> ToolResult:
        pass

    @abstractmethod
    def get_input_schema(self) -> Dict[str, Any]:
        pass

    async def call(self, **kwargs) -> ToolResult:
        """Execute with metrics, timing, and error handling."""
        start_time = time.time()
        self.call_count += 1

        try:
            self._validate_input(kwargs)
            logger.info(f"Executing tool: {self.name}", extra={
                'tool_name': self.name,
                'tool_category': self.category.value
            })

            result = await self.execute(**kwargs)
            execution_time = (time.time() - start_time) * 1000
            result.execution_time_ms = execution_time
            self.total_execution_time += execution_time
            return result

        except Exception as e:
            execution_time = (time.time() - start_time) * 1000
            logger.error(f"Tool execution failed: {self.name}: {e}", exc_info=True)
            return ToolResult(
                success=False,
                error=f"Tool execution failed: {str(e)}",
                execution_time_ms=execution_time
            )

    def _validate_input(self, kwargs: Dict[str, Any]):
        schema = self.get_input_schema()
        missing = [p for p in schema.get('required', []) if p not in kwargs]
        if missing:
            raise ValueError(f"Missing required parameters: {missing}")

    def get_statistics(self) -> Dict[str, Any]:
        return {
            'name': self.name,
            'category': self.category.value,
            'call_count': self.call_count,
            'average_execution_time_ms': (
                self.total_execution_time / self.call_count if self.call_count > 0 else 0
            )
        }

SQL query validation

Always validate user-provided SQL before execution. The QueryValidator class blocks dangerous keywords, common injection patterns, and access to unauthorized schemas.
# mcp_server/tools/query_validator.py
import re
import sqlparse
from enum import Enum
from typing import Dict, Any

class QueryRisk(Enum):
    LOW = "low"
    MEDIUM = "medium"
    HIGH = "high"
    CRITICAL = "critical"

class QueryValidator:
    """Validate and analyze SQL queries for security risks."""

    DANGEROUS_KEYWORDS = {
        'DROP', 'DELETE', 'TRUNCATE', 'ALTER', 'CREATE', 'INSERT',
        'UPDATE', 'GRANT', 'REVOKE', 'EXEC', 'EXECUTE'
    }

    ALLOWED_SCHEMAS = {'retail', 'information_schema', 'pg_catalog'}

    ALLOWED_TABLES = {
        'customers', 'products', 'sales_transactions',
        'sales_transaction_items', 'product_categories',
        'product_embeddings', 'stores'
    }

    def __init__(self):
        self.injection_patterns = [
            r"(\b(UNION|union)\s+(ALL\s+)?(SELECT|select))",
            r"(\b(DROP|drop)\s+(TABLE|table|DATABASE|database))",
            r"(\b(DELETE|delete)\s+(FROM|from))",
            r"(\b(INSERT|insert)\s+(INTO|into))",
            r"(\b(UPDATE|update)\s+\w+\s+(SET|set))",
            r"(\b(EXEC|exec|EXECUTE|execute)\s*\()",
            r"(;\s*(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER))",
        ]
        self.compiled_patterns = [
            re.compile(p, re.IGNORECASE) for p in self.injection_patterns
        ]

    def validate_query(self, query: str) -> Dict[str, Any]:
        result = {
            'is_safe': True,
            'risk_level': QueryRisk.LOW,
            'issues': [],
            'warnings': []
        }

        parsed = sqlparse.parse(query)
        if not parsed:
            result['is_safe'] = False
            result['issues'].append("Unable to parse query")
            result['risk_level'] = QueryRisk.HIGH
            return result

        for statement in parsed:
            stmt_type = statement.get_type()
            if stmt_type and stmt_type.upper() not in ['SELECT', 'WITH']:
                result['issues'].append(f"Disallowed statement type: {stmt_type}")
                result['is_safe'] = False

            for token in statement.flatten():
                if token.ttype is sqlparse.tokens.Keyword:
                    if token.value.upper() in self.DANGEROUS_KEYWORDS:
                        result['issues'].append(f"Dangerous keyword: {token.value.upper()}")
                        result['is_safe'] = False

        for pattern in self.compiled_patterns:
            if pattern.findall(query):
                result['issues'].append("Potential injection pattern detected")
                result['is_safe'] = False

        if not result['is_safe']:
            result['risk_level'] = QueryRisk.CRITICAL if 'injection' in str(result['issues']).lower() else QueryRisk.HIGH

        return result

query_validator = QueryValidator()

Sales analysis tool

# mcp_server/tools/sales_analysis.py
from datetime import datetime, timedelta
from .base import DatabaseTool, ToolResult

class SalesAnalysisTool(DatabaseTool):
    """Advanced sales analysis and reporting tool."""

    def __init__(self, db_provider):
        super().__init__("execute_sales_query",
                         "Execute sales analysis queries with template support",
                         db_provider)

        self.query_templates = {
            'daily_sales': """
                SELECT
                    DATE(transaction_date) as sales_date,
                    COUNT(*) as transaction_count,
                    SUM(total_amount) as total_revenue,
                    AVG(total_amount) as avg_transaction_value,
                    COUNT(DISTINCT customer_id) as unique_customers
                FROM retail.sales_transactions
                WHERE transaction_date >= $1 AND transaction_date <= $2
                  AND transaction_type = 'sale'
                GROUP BY DATE(transaction_date)
                ORDER BY sales_date DESC
            """,

            'top_products': """
                SELECT
                    p.product_name,
                    p.brand,
                    SUM(sti.quantity) as total_quantity_sold,
                    SUM(sti.total_price) as total_revenue,
                    COUNT(DISTINCT st.transaction_id) as transaction_count,
                    AVG(sti.unit_price) as avg_price
                FROM retail.sales_transaction_items sti
                JOIN retail.sales_transactions st ON sti.transaction_id = st.transaction_id
                JOIN retail.products p ON sti.product_id = p.product_id
                WHERE st.transaction_date >= $1 AND st.transaction_date <= $2
                  AND st.transaction_type = 'sale'
                GROUP BY p.product_id, p.product_name, p.brand
                ORDER BY total_revenue DESC
                LIMIT $3
            """,

            'sales_trends': """
                WITH daily_sales AS (
                    SELECT
                        DATE(transaction_date) as sales_date,
                        SUM(total_amount) as daily_revenue,
                        COUNT(*) as daily_transactions
                    FROM retail.sales_transactions
                    WHERE transaction_date >= $1 AND transaction_date <= $2
                      AND transaction_type = 'sale'
                    GROUP BY DATE(transaction_date)
                )
                SELECT
                    sales_date,
                    daily_revenue,
                    daily_transactions,
                    AVG(daily_revenue) OVER (
                        ORDER BY sales_date ROWS BETWEEN 6 PRECEDING AND CURRENT ROW
                    ) as rolling_7day_avg,
                    CASE
                        WHEN LAG(daily_revenue, 1) OVER (ORDER BY sales_date) IS NOT NULL THEN
                            ROUND(((daily_revenue - LAG(daily_revenue, 1) OVER (ORDER BY sales_date))
                                   / LAG(daily_revenue, 1) OVER (ORDER BY sales_date) * 100)::numeric, 2)
                        ELSE NULL
                    END as day_over_day_growth_pct
                FROM daily_sales
                ORDER BY sales_date DESC
            """
        }

    async def execute(self, **kwargs) -> ToolResult:
        query_type = kwargs.get('query_type', 'custom')
        store_id = kwargs.get('store_id')

        if not store_id:
            return ToolResult(success=False, error="store_id is required")

        if query_type in self.query_templates:
            return await self._execute_template_query(query_type, kwargs)
        elif query_type == 'custom':
            return await self._execute_custom_query(kwargs)
        else:
            return ToolResult(success=False, error=f"Unknown query type: {query_type}")

    async def _execute_custom_query(self, kwargs) -> ToolResult:
        custom_query = kwargs.get('query')
        store_id = kwargs['store_id']

        if not custom_query:
            return ToolResult(success=False, error="query is required for custom query_type")

        validation = query_validator.validate_query(custom_query)
        if not validation['is_safe']:
            return ToolResult(
                success=False,
                error=f"Query validation failed: {', '.join(validation['issues'])}",
                metadata={'risk_level': validation['risk_level'].value}
            )

        result = await self.execute_query(custom_query, None, store_id)
        return result

    def get_input_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "query_type": {
                    "type": "string",
                    "enum": list(self.query_templates.keys()) + ["custom"],
                    "default": "daily_sales"
                },
                "store_id": {"type": "string", "pattern": "^[a-zA-Z0-9_-]+$"},
                "start_date": {"type": "string", "format": "date"},
                "end_date": {"type": "string", "format": "date"},
                "limit": {"type": "integer", "minimum": 1, "maximum": 1000, "default": 20},
                "query": {"type": "string"}
            },
            "required": ["store_id"],
            "additionalProperties": False
        }

Schema introspection tool

# mcp_server/tools/schema_introspection.py
class SchemaIntrospectionTool(DatabaseTool):
    """Tool for exploring database schema and metadata."""

    def __init__(self, db_provider):
        super().__init__("get_table_schema",
                         "Get detailed schema information for database tables",
                         db_provider)
        self.category = ToolCategory.SCHEMA_INTROSPECTION

    async def execute(self, **kwargs) -> ToolResult:
        table_name = kwargs.get('table_name')
        include_indexes = kwargs.get('include_indexes', True)

        try:
            async with self.get_connection() as conn:
                columns_query = """
                    SELECT
                        column_name, data_type, is_nullable,
                        column_default, character_maximum_length,
                        numeric_precision, numeric_scale, ordinal_position
                    FROM information_schema.columns
                    WHERE table_schema = 'retail' AND table_name = $1
                    ORDER BY ordinal_position
                """
                columns = await conn.fetch(columns_query, table_name)

                schema_info = {
                    'table_name': table_name,
                    'columns': [dict(col) for col in columns],
                    'indexes': []
                }

                if include_indexes:
                    indexes_query = """
                        SELECT indexname, indexdef
                        FROM pg_indexes
                        WHERE schemaname = 'retail' AND tablename = $1
                    """
                    indexes = await conn.fetch(indexes_query, table_name)
                    schema_info['indexes'] = [dict(idx) for idx in indexes]

            return ToolResult(
                success=True,
                data=schema_info,
                metadata={'table_name': table_name}
            )

        except Exception as e:
            return ToolResult(success=False, error=f"Schema introspection failed: {str(e)}")

    def get_input_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "table_name": {
                    "type": "string",
                    "pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$"
                },
                "include_indexes": {"type": "boolean", "default": True}
            },
            "additionalProperties": False
        }

Business intelligence tool

# mcp_server/tools/business_intelligence.py
class BusinessIntelligenceTool(DatabaseTool):
    """Advanced analytics tool for business intelligence queries."""

    def __init__(self, db_provider):
        super().__init__("generate_business_insights",
                         "Generate comprehensive business intelligence reports and insights",
                         db_provider)
        self.category = ToolCategory.ANALYTICS

    async def execute(self, **kwargs) -> ToolResult:
        analysis_type = kwargs.get('analysis_type', 'summary')
        store_id = kwargs.get('store_id')
        days = kwargs.get('days', 30)

        if not store_id:
            return ToolResult(success=False, error="store_id is required")

        if analysis_type == 'summary':
            return await self._generate_business_summary(store_id, days)
        else:
            return ToolResult(success=False, error=f"Unknown analysis type: {analysis_type}")

    async def _generate_business_summary(self, store_id: str, days: int) -> ToolResult:
        summary_query = f"""
            WITH date_range AS (
                SELECT CURRENT_DATE - INTERVAL '{days} days' as start_date,
                       CURRENT_DATE as end_date
            ),
            sales_summary AS (
                SELECT
                    COUNT(*) as total_transactions,
                    COUNT(DISTINCT customer_id) as unique_customers,
                    SUM(total_amount) as total_revenue,
                    AVG(total_amount) as avg_transaction_value
                FROM retail.sales_transactions st, date_range dr
                WHERE st.transaction_date >= dr.start_date
                  AND st.transaction_date <= dr.end_date
                  AND st.transaction_type = 'sale'
            )
            SELECT * FROM sales_summary
        """

        result = await self.execute_query(summary_query, None, store_id)

        if result.success:
            result.metadata = {
                'analysis_type': 'business_summary',
                'period_days': days,
                'store_id': store_id
            }

        return result

    def get_input_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "analysis_type": {
                    "type": "string",
                    "enum": ["summary", "customer_segmentation", "product_performance", "seasonal_trends"],
                    "default": "summary"
                },
                "store_id": {"type": "string", "pattern": "^[a-zA-Z0-9_-]+$"},
                "days": {"type": "integer", "minimum": 1, "maximum": 365, "default": 30}
            },
            "required": ["store_id"],
            "additionalProperties": False
        }

Utility tool

# mcp_server/tools/utility.py
class UtilityTool(DatabaseTool):
    """Utility tool for common operations."""

    def __init__(self, db_provider):
        super().__init__("get_current_utc_date",
                         "Get current UTC date and time for reference",
                         db_provider)
        self.category = ToolCategory.UTILITY

    async def execute(self, **kwargs) -> ToolResult:
        format_type = kwargs.get('format', 'iso')

        try:
            async with self.get_connection() as conn:
                query_map = {
                    'iso':       "SELECT CURRENT_TIMESTAMP AT TIME ZONE 'UTC' as current_utc_datetime",
                    'epoch':     "SELECT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP AT TIME ZONE 'UTC') as current_utc_epoch",
                    'date_only': "SELECT CURRENT_DATE as current_date"
                }
                if format_type not in query_map:
                    return ToolResult(success=False, error=f"Unknown format type: {format_type}")

                result = await conn.fetchrow(query_map[format_type])
                return ToolResult(
                    success=True,
                    data=dict(result),
                    metadata={'format_type': format_type, 'timezone': 'UTC'}
                )
        except Exception as e:
            return ToolResult(success=False, error=f"Utility operation failed: {str(e)}")

    def get_input_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "format": {
                    "type": "string",
                    "enum": ["iso", "epoch", "date_only"],
                    "default": "iso"
                }
            },
            "additionalProperties": False
        }

Key takeaways

  • Base tool classes enforce consistent structure, error handling, and metrics collection
  • Query validator blocks dangerous SQL patterns before they reach the database
  • Template queries provide safe, pre-validated analytics without custom SQL
  • Schema introspection enables AI to discover table structure dynamically before writing queries
  • Business intelligence tools aggregate business metrics into actionable summaries
  • Tool input schemas drive automatic parameter validation and AI discovery

Next: Lab 7 — Semantic Search

Integrate Azure OpenAI embeddings and pgvector to enable natural language product search.

Build docs developers (and LLMs) love