from nemoguardrails import LLMRails
from nemoguardrails.actions.actions import ActionResult
import re
async def check_pii(context: dict) -> ActionResult:
"""Check if user input contains PII (emails, phone numbers, SSN)."""
user_message = context.get("last_user_message")
# Check for email addresses
email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
if re.search(email_pattern, user_message):
return ActionResult(
return_value=False,
context_updates={"pii_detected": "email"}
)
# Check for phone numbers
phone_pattern = r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b'
if re.search(phone_pattern, user_message):
return ActionResult(
return_value=False,
context_updates={"pii_detected": "phone"}
)
# Check for SSN
ssn_pattern = r'\b\d{3}-\d{2}-\d{4}\b'
if re.search(ssn_pattern, user_message):
return ActionResult(
return_value=False,
context_updates={"pii_detected": "ssn"}
)
return ActionResult(return_value=True)
async def mask_sensitive_data(context: dict) -> ActionResult:
"""Mask sensitive data in bot responses."""
bot_message = context.get("bot_message")
# Mask credit card numbers
cc_pattern = r'\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b'
masked_message = re.sub(cc_pattern, '****-****-****-****', bot_message)
# Mask API keys
api_pattern = r'\b[A-Za-z0-9]{32,}\b'
masked_message = re.sub(api_pattern, '[REDACTED]', masked_message)
return ActionResult(return_value=masked_message)
def init(app: LLMRails):
app.register_action(check_pii, "check_pii")
app.register_action(mask_sensitive_data, "mask_sensitive_data")