Objectives
By the end of this lab you will be able to:- Design advanced MCP tools with complex parameter validation
- Implement secure database query tools with SQL injection protection
- Create schema introspection capabilities for dynamic queries
- Build custom analytics tools for business intelligence
- Apply comprehensive error handling and graceful degradation
- Optimize tool performance for production workloads
Prerequisites
- Completed Lab 5: MCP Server Implementation
- Running MCP server connected to the retail database
Core tool architecture
Base tool classes
# mcp_server/tools/base.py
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from dataclasses import dataclass
from enum import Enum
import time
import logging
logger = logging.getLogger(__name__)
class ToolCategory(Enum):
DATABASE_QUERY = "database_query"
SCHEMA_INTROSPECTION = "schema_introspection"
ANALYTICS = "analytics"
UTILITY = "utility"
@dataclass
class ToolResult:
"""Standardized tool result structure."""
success: bool
data: Any = None
error: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
execution_time_ms: Optional[float] = None
row_count: Optional[int] = None
class BaseTool(ABC):
"""Abstract base class for all MCP tools."""
def __init__(self, name: str, description: str, category: ToolCategory):
self.name = name
self.description = description
self.category = category
self.call_count = 0
self.total_execution_time = 0.0
@abstractmethod
async def execute(self, **kwargs) -> ToolResult:
pass
@abstractmethod
def get_input_schema(self) -> Dict[str, Any]:
pass
async def call(self, **kwargs) -> ToolResult:
"""Execute with metrics, timing, and error handling."""
start_time = time.time()
self.call_count += 1
try:
self._validate_input(kwargs)
logger.info(f"Executing tool: {self.name}", extra={
'tool_name': self.name,
'tool_category': self.category.value
})
result = await self.execute(**kwargs)
execution_time = (time.time() - start_time) * 1000
result.execution_time_ms = execution_time
self.total_execution_time += execution_time
return result
except Exception as e:
execution_time = (time.time() - start_time) * 1000
logger.error(f"Tool execution failed: {self.name}: {e}", exc_info=True)
return ToolResult(
success=False,
error=f"Tool execution failed: {str(e)}",
execution_time_ms=execution_time
)
def _validate_input(self, kwargs: Dict[str, Any]):
schema = self.get_input_schema()
missing = [p for p in schema.get('required', []) if p not in kwargs]
if missing:
raise ValueError(f"Missing required parameters: {missing}")
def get_statistics(self) -> Dict[str, Any]:
return {
'name': self.name,
'category': self.category.value,
'call_count': self.call_count,
'average_execution_time_ms': (
self.total_execution_time / self.call_count if self.call_count > 0 else 0
)
}
SQL query validation
Always validate user-provided SQL before execution. The
QueryValidator class blocks dangerous keywords, common injection patterns, and access to unauthorized schemas.# mcp_server/tools/query_validator.py
import re
import sqlparse
from enum import Enum
from typing import Dict, Any
class QueryRisk(Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class QueryValidator:
"""Validate and analyze SQL queries for security risks."""
DANGEROUS_KEYWORDS = {
'DROP', 'DELETE', 'TRUNCATE', 'ALTER', 'CREATE', 'INSERT',
'UPDATE', 'GRANT', 'REVOKE', 'EXEC', 'EXECUTE'
}
ALLOWED_SCHEMAS = {'retail', 'information_schema', 'pg_catalog'}
ALLOWED_TABLES = {
'customers', 'products', 'sales_transactions',
'sales_transaction_items', 'product_categories',
'product_embeddings', 'stores'
}
def __init__(self):
self.injection_patterns = [
r"(\b(UNION|union)\s+(ALL\s+)?(SELECT|select))",
r"(\b(DROP|drop)\s+(TABLE|table|DATABASE|database))",
r"(\b(DELETE|delete)\s+(FROM|from))",
r"(\b(INSERT|insert)\s+(INTO|into))",
r"(\b(UPDATE|update)\s+\w+\s+(SET|set))",
r"(\b(EXEC|exec|EXECUTE|execute)\s*\()",
r"(;\s*(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER))",
]
self.compiled_patterns = [
re.compile(p, re.IGNORECASE) for p in self.injection_patterns
]
def validate_query(self, query: str) -> Dict[str, Any]:
result = {
'is_safe': True,
'risk_level': QueryRisk.LOW,
'issues': [],
'warnings': []
}
parsed = sqlparse.parse(query)
if not parsed:
result['is_safe'] = False
result['issues'].append("Unable to parse query")
result['risk_level'] = QueryRisk.HIGH
return result
for statement in parsed:
stmt_type = statement.get_type()
if stmt_type and stmt_type.upper() not in ['SELECT', 'WITH']:
result['issues'].append(f"Disallowed statement type: {stmt_type}")
result['is_safe'] = False
for token in statement.flatten():
if token.ttype is sqlparse.tokens.Keyword:
if token.value.upper() in self.DANGEROUS_KEYWORDS:
result['issues'].append(f"Dangerous keyword: {token.value.upper()}")
result['is_safe'] = False
for pattern in self.compiled_patterns:
if pattern.findall(query):
result['issues'].append("Potential injection pattern detected")
result['is_safe'] = False
if not result['is_safe']:
result['risk_level'] = QueryRisk.CRITICAL if 'injection' in str(result['issues']).lower() else QueryRisk.HIGH
return result
query_validator = QueryValidator()
Sales analysis tool
# mcp_server/tools/sales_analysis.py
from datetime import datetime, timedelta
from .base import DatabaseTool, ToolResult
class SalesAnalysisTool(DatabaseTool):
"""Advanced sales analysis and reporting tool."""
def __init__(self, db_provider):
super().__init__("execute_sales_query",
"Execute sales analysis queries with template support",
db_provider)
self.query_templates = {
'daily_sales': """
SELECT
DATE(transaction_date) as sales_date,
COUNT(*) as transaction_count,
SUM(total_amount) as total_revenue,
AVG(total_amount) as avg_transaction_value,
COUNT(DISTINCT customer_id) as unique_customers
FROM retail.sales_transactions
WHERE transaction_date >= $1 AND transaction_date <= $2
AND transaction_type = 'sale'
GROUP BY DATE(transaction_date)
ORDER BY sales_date DESC
""",
'top_products': """
SELECT
p.product_name,
p.brand,
SUM(sti.quantity) as total_quantity_sold,
SUM(sti.total_price) as total_revenue,
COUNT(DISTINCT st.transaction_id) as transaction_count,
AVG(sti.unit_price) as avg_price
FROM retail.sales_transaction_items sti
JOIN retail.sales_transactions st ON sti.transaction_id = st.transaction_id
JOIN retail.products p ON sti.product_id = p.product_id
WHERE st.transaction_date >= $1 AND st.transaction_date <= $2
AND st.transaction_type = 'sale'
GROUP BY p.product_id, p.product_name, p.brand
ORDER BY total_revenue DESC
LIMIT $3
""",
'sales_trends': """
WITH daily_sales AS (
SELECT
DATE(transaction_date) as sales_date,
SUM(total_amount) as daily_revenue,
COUNT(*) as daily_transactions
FROM retail.sales_transactions
WHERE transaction_date >= $1 AND transaction_date <= $2
AND transaction_type = 'sale'
GROUP BY DATE(transaction_date)
)
SELECT
sales_date,
daily_revenue,
daily_transactions,
AVG(daily_revenue) OVER (
ORDER BY sales_date ROWS BETWEEN 6 PRECEDING AND CURRENT ROW
) as rolling_7day_avg,
CASE
WHEN LAG(daily_revenue, 1) OVER (ORDER BY sales_date) IS NOT NULL THEN
ROUND(((daily_revenue - LAG(daily_revenue, 1) OVER (ORDER BY sales_date))
/ LAG(daily_revenue, 1) OVER (ORDER BY sales_date) * 100)::numeric, 2)
ELSE NULL
END as day_over_day_growth_pct
FROM daily_sales
ORDER BY sales_date DESC
"""
}
async def execute(self, **kwargs) -> ToolResult:
query_type = kwargs.get('query_type', 'custom')
store_id = kwargs.get('store_id')
if not store_id:
return ToolResult(success=False, error="store_id is required")
if query_type in self.query_templates:
return await self._execute_template_query(query_type, kwargs)
elif query_type == 'custom':
return await self._execute_custom_query(kwargs)
else:
return ToolResult(success=False, error=f"Unknown query type: {query_type}")
async def _execute_custom_query(self, kwargs) -> ToolResult:
custom_query = kwargs.get('query')
store_id = kwargs['store_id']
if not custom_query:
return ToolResult(success=False, error="query is required for custom query_type")
validation = query_validator.validate_query(custom_query)
if not validation['is_safe']:
return ToolResult(
success=False,
error=f"Query validation failed: {', '.join(validation['issues'])}",
metadata={'risk_level': validation['risk_level'].value}
)
result = await self.execute_query(custom_query, None, store_id)
return result
def get_input_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"query_type": {
"type": "string",
"enum": list(self.query_templates.keys()) + ["custom"],
"default": "daily_sales"
},
"store_id": {"type": "string", "pattern": "^[a-zA-Z0-9_-]+$"},
"start_date": {"type": "string", "format": "date"},
"end_date": {"type": "string", "format": "date"},
"limit": {"type": "integer", "minimum": 1, "maximum": 1000, "default": 20},
"query": {"type": "string"}
},
"required": ["store_id"],
"additionalProperties": False
}
Schema introspection tool
# mcp_server/tools/schema_introspection.py
class SchemaIntrospectionTool(DatabaseTool):
"""Tool for exploring database schema and metadata."""
def __init__(self, db_provider):
super().__init__("get_table_schema",
"Get detailed schema information for database tables",
db_provider)
self.category = ToolCategory.SCHEMA_INTROSPECTION
async def execute(self, **kwargs) -> ToolResult:
table_name = kwargs.get('table_name')
include_indexes = kwargs.get('include_indexes', True)
try:
async with self.get_connection() as conn:
columns_query = """
SELECT
column_name, data_type, is_nullable,
column_default, character_maximum_length,
numeric_precision, numeric_scale, ordinal_position
FROM information_schema.columns
WHERE table_schema = 'retail' AND table_name = $1
ORDER BY ordinal_position
"""
columns = await conn.fetch(columns_query, table_name)
schema_info = {
'table_name': table_name,
'columns': [dict(col) for col in columns],
'indexes': []
}
if include_indexes:
indexes_query = """
SELECT indexname, indexdef
FROM pg_indexes
WHERE schemaname = 'retail' AND tablename = $1
"""
indexes = await conn.fetch(indexes_query, table_name)
schema_info['indexes'] = [dict(idx) for idx in indexes]
return ToolResult(
success=True,
data=schema_info,
metadata={'table_name': table_name}
)
except Exception as e:
return ToolResult(success=False, error=f"Schema introspection failed: {str(e)}")
def get_input_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"table_name": {
"type": "string",
"pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$"
},
"include_indexes": {"type": "boolean", "default": True}
},
"additionalProperties": False
}
Business intelligence tool
# mcp_server/tools/business_intelligence.py
class BusinessIntelligenceTool(DatabaseTool):
"""Advanced analytics tool for business intelligence queries."""
def __init__(self, db_provider):
super().__init__("generate_business_insights",
"Generate comprehensive business intelligence reports and insights",
db_provider)
self.category = ToolCategory.ANALYTICS
async def execute(self, **kwargs) -> ToolResult:
analysis_type = kwargs.get('analysis_type', 'summary')
store_id = kwargs.get('store_id')
days = kwargs.get('days', 30)
if not store_id:
return ToolResult(success=False, error="store_id is required")
if analysis_type == 'summary':
return await self._generate_business_summary(store_id, days)
else:
return ToolResult(success=False, error=f"Unknown analysis type: {analysis_type}")
async def _generate_business_summary(self, store_id: str, days: int) -> ToolResult:
summary_query = f"""
WITH date_range AS (
SELECT CURRENT_DATE - INTERVAL '{days} days' as start_date,
CURRENT_DATE as end_date
),
sales_summary AS (
SELECT
COUNT(*) as total_transactions,
COUNT(DISTINCT customer_id) as unique_customers,
SUM(total_amount) as total_revenue,
AVG(total_amount) as avg_transaction_value
FROM retail.sales_transactions st, date_range dr
WHERE st.transaction_date >= dr.start_date
AND st.transaction_date <= dr.end_date
AND st.transaction_type = 'sale'
)
SELECT * FROM sales_summary
"""
result = await self.execute_query(summary_query, None, store_id)
if result.success:
result.metadata = {
'analysis_type': 'business_summary',
'period_days': days,
'store_id': store_id
}
return result
def get_input_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"analysis_type": {
"type": "string",
"enum": ["summary", "customer_segmentation", "product_performance", "seasonal_trends"],
"default": "summary"
},
"store_id": {"type": "string", "pattern": "^[a-zA-Z0-9_-]+$"},
"days": {"type": "integer", "minimum": 1, "maximum": 365, "default": 30}
},
"required": ["store_id"],
"additionalProperties": False
}
Utility tool
# mcp_server/tools/utility.py
class UtilityTool(DatabaseTool):
"""Utility tool for common operations."""
def __init__(self, db_provider):
super().__init__("get_current_utc_date",
"Get current UTC date and time for reference",
db_provider)
self.category = ToolCategory.UTILITY
async def execute(self, **kwargs) -> ToolResult:
format_type = kwargs.get('format', 'iso')
try:
async with self.get_connection() as conn:
query_map = {
'iso': "SELECT CURRENT_TIMESTAMP AT TIME ZONE 'UTC' as current_utc_datetime",
'epoch': "SELECT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP AT TIME ZONE 'UTC') as current_utc_epoch",
'date_only': "SELECT CURRENT_DATE as current_date"
}
if format_type not in query_map:
return ToolResult(success=False, error=f"Unknown format type: {format_type}")
result = await conn.fetchrow(query_map[format_type])
return ToolResult(
success=True,
data=dict(result),
metadata={'format_type': format_type, 'timezone': 'UTC'}
)
except Exception as e:
return ToolResult(success=False, error=f"Utility operation failed: {str(e)}")
def get_input_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"format": {
"type": "string",
"enum": ["iso", "epoch", "date_only"],
"default": "iso"
}
},
"additionalProperties": False
}
Key takeaways
- Base tool classes enforce consistent structure, error handling, and metrics collection
- Query validator blocks dangerous SQL patterns before they reach the database
- Template queries provide safe, pre-validated analytics without custom SQL
- Schema introspection enables AI to discover table structure dynamically before writing queries
- Business intelligence tools aggregate business metrics into actionable summaries
- Tool input schemas drive automatic parameter validation and AI discovery
Next: Lab 7 — Semantic Search
Integrate Azure OpenAI embeddings and pgvector to enable natural language product search.