Custom Types
SQLAlchemy allows you to create custom types that add application-specific behavior to existing types or define entirely new type mappings.TypeDecorator
TypeDecorator wraps an existing type and adds custom processing for bind parameters and result values.
Basic TypeDecorator
from sqlalchemy import TypeDecorator, String
import json
class JSONEncodedDict(TypeDecorator):
"""Stores dict as JSON string."""
impl = String # Base type
cache_ok = True # Enable statement caching
def process_bind_param(self, value, dialect):
"""Convert Python dict to JSON string for database."""
if value is not None:
return json.dumps(value)
return value
def process_result_value(self, value, dialect):
"""Convert JSON string from database to Python dict."""
if value is not None:
return json.loads(value)
return value
# Use custom type
users = Table(
'users',
metadata,
Column('id', Integer, primary_key=True),
Column('settings', JSONEncodedDict(200))
)
Key Methods:
impl: The underlying SQLAlchemy type to useprocess_bind_param(): Process Python value before sending to databaseprocess_result_value(): Process database value before returning to Pythoncache_ok: Set to True if type is safe for statement caching
TypeDecorator Examples
- Encrypted String
- Lowercase Email
- Enum to String
- Timezone-Aware DateTime
from sqlalchemy import TypeDecorator, String
from cryptography.fernet import Fernet
class EncryptedString(TypeDecorator):
"""Automatically encrypt/decrypt string values."""
impl = String
cache_ok = True
def __init__(self, key, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cipher = Fernet(key)
def process_bind_param(self, value, dialect):
if value is not None:
# Encrypt before storing
return self.cipher.encrypt(value.encode()).decode()
return value
def process_result_value(self, value, dialect):
if value is not None:
# Decrypt after retrieving
return self.cipher.decrypt(value.encode()).decode()
return value
# Usage
encryption_key = Fernet.generate_key()
users = Table(
'users',
metadata,
Column('id', Integer, primary_key=True),
Column('ssn', EncryptedString(encryption_key, 255))
)
from sqlalchemy import TypeDecorator, String
class LowerCaseString(TypeDecorator):
"""Automatically lowercase strings."""
impl = String
cache_ok = True
def process_bind_param(self, value, dialect):
if value is not None:
return value.lower()
return value
def process_result_value(self, value, dialect):
# Already lowercase from database
return value
users = Table(
'users',
metadata,
Column('email', LowerCaseString(100), unique=True)
)
# Automatically lowercased
conn.execute(
users.insert().values(email='[email protected]')
)
# Stored as: '[email protected]'
from sqlalchemy import TypeDecorator, String
from enum import Enum
class EnumType(TypeDecorator):
"""Store enum as string value."""
impl = String
cache_ok = True
def __init__(self, enum_class, *args, **kwargs):
self.enum_class = enum_class
# Get max length of enum values
max_len = max(len(e.value) for e in enum_class)
super().__init__(max_len, *args, **kwargs)
def process_bind_param(self, value, dialect):
if value is not None:
if isinstance(value, self.enum_class):
return value.value
return value
return value
def process_result_value(self, value, dialect):
if value is not None:
return self.enum_class(value)
return value
class Status(Enum):
PENDING = 'pending'
ACTIVE = 'active'
INACTIVE = 'inactive'
orders = Table(
'orders',
metadata,
Column('id', Integer, primary_key=True),
Column('status', EnumType(Status))
)
# Use enum directly
conn.execute(
orders.insert().values(status=Status.ACTIVE)
)
# Result is enum
result = conn.execute(select(orders.c.status))
status = result.scalar()
print(type(status)) # <class 'Status'>
from sqlalchemy import TypeDecorator, DateTime
from datetime import datetime, timezone
import pytz
class TZDateTime(TypeDecorator):
"""Store timezone-aware datetime in UTC."""
impl = DateTime
cache_ok = True
def process_bind_param(self, value, dialect):
if value is not None:
if value.tzinfo is None:
# Assume UTC if no timezone
value = value.replace(tzinfo=timezone.utc)
# Convert to UTC for storage
return value.astimezone(timezone.utc)
return value
def process_result_value(self, value, dialect):
if value is not None:
# Add UTC timezone
return value.replace(tzinfo=timezone.utc)
return value
events = Table(
'events',
metadata,
Column('id', Integer, primary_key=True),
Column('created_at', TZDateTime, nullable=False)
)
# Store with timezone
eastern = pytz.timezone('US/Eastern')
now = datetime.now(eastern)
conn.execute(
events.insert().values(created_at=now)
)
# Stored as UTC in database
Advanced TypeDecorator Features
Comparator Operations
from sqlalchemy import TypeDecorator, String
from sqlalchemy.types import TypeEngine
class CaseInsensitiveString(TypeDecorator):
impl = String
cache_ok = True
class comparator_factory(String.Comparator):
"""Custom comparator for case-insensitive operations."""
def __eq__(self, other):
# Use SQL LOWER() for comparison
return func.lower(self.expr) == func.lower(other)
def in_(self, other):
# Case-insensitive IN
return func.lower(self.expr).in_(
[func.lower(x) for x in other]
)
users = Table(
'users',
metadata,
Column('username', CaseInsensitiveString(50))
)
# Case-insensitive comparison
stmt = select(users).where(users.c.username == 'ALICE')
# Generates: WHERE lower(username) = lower('ALICE')
Coercion
from sqlalchemy import TypeDecorator, Integer
class IntBoolean(TypeDecorator):
"""Store boolean as integer (0/1)."""
impl = Integer
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
return 1 if value else 0
def process_result_value(self, value, dialect):
if value is None:
return value
return bool(value)
def coerce_compared_value(self, op, value):
"""Coerce compared values to same type."""
if isinstance(value, bool):
return self
return self.impl
users = Table(
'users',
metadata,
Column('is_active', IntBoolean)
)
# Boolean comparison works
stmt = select(users).where(users.c.is_active == True)
UserDefinedType
For types that don’t wrap existing types, useUserDefinedType.
Basic UserDefinedType
from sqlalchemy import UserDefinedType
class Point(UserDefinedType):
"""Custom geometric point type."""
cache_ok = True
def get_col_spec(self, **kw):
"""Return SQL type name."""
return "POINT"
def bind_processor(self, dialect):
"""Return function to process bind values."""
def process(value):
if value is None:
return None
# Convert (x, y) tuple to string
return f"POINT({value[0]} {value[1]})"
return process
def result_processor(self, dialect, coltype):
"""Return function to process result values."""
def process(value):
if value is None:
return None
# Parse "POINT(x y)" to (x, y) tuple
coords = value.replace('POINT(', '').replace(')', '').split()
return (float(coords[0]), float(coords[1]))
return process
locations = Table(
'locations',
metadata,
Column('id', Integer, primary_key=True),
Column('coordinates', Point)
)
# Use custom type
conn.execute(
locations.insert().values(coordinates=(40.7128, -74.0060)) # NYC
)
result = conn.execute(select(locations.c.coordinates))
point = result.scalar()
print(point) # (40.7128, -74.0060)
PostgreSQL Custom Types
from sqlalchemy import UserDefinedType, event
from sqlalchemy import DDL
class CITEXT(UserDefinedType):
"""PostgreSQL case-insensitive text type."""
cache_ok = True
def get_col_spec(self, **kw):
return "CITEXT"
# Create extension before tables
@event.listens_for(metadata, 'before_create')
def create_citext_extension(target, connection, **kw):
connection.execute(DDL('CREATE EXTENSION IF NOT EXISTS citext'))
users = Table(
'users',
metadata,
Column('email', CITEXT, unique=True)
)
# Case-insensitive storage and comparison
conn.execute(users.insert().values(email='[email protected]'))
stmt = select(users).where(users.c.email == '[email protected]')
# Matches due to CITEXT
Variant Types
Variant allows different type implementations per dialect.
from sqlalchemy import String, Text
from sqlalchemy.dialects import postgresql, mysql, sqlite
# Base type
base_json = String(1000)
# Use native JSON where available
json_type = base_json.with_variant(
postgresql.JSON(), 'postgresql'
).with_variant(
mysql.JSON(), 'mysql'
).with_variant(
Text(), 'sqlite'
)
data = Table(
'data',
metadata,
Column('id', Integer, primary_key=True),
Column('payload', json_type)
)
# PostgreSQL: JSON type
# MySQL: JSON type
# SQLite: TEXT type
# Others: VARCHAR(1000)
Dialect-Specific Variants
from sqlalchemy import Numeric, Float
from sqlalchemy.dialects import postgresql, mysql
# Use NUMERIC on PostgreSQL, FLOAT elsewhere
price_type = Float().with_variant(
postgresql.NUMERIC(10, 2), 'postgresql'
).with_variant(
mysql.DECIMAL(10, 2), 'mysql'
)
products = Table(
'products',
metadata,
Column('price', price_type)
)
Type Utilities
Type Coercion
from sqlalchemy import type_coerce
# Treat column as different type (no CAST in SQL)
stmt = select(
type_coerce(users.c.metadata, JSON)['key'].astext
)
# vs CAST (generates SQL CAST)
from sqlalchemy import cast
stmt = select(
cast(users.c.age, String)
)
Type Inspection
from sqlalchemy import inspect
# Get column type
col_type = users.c.email.type
print(type(col_type)) # <class 'sqlalchemy.types.String'>
print(col_type.length) # 100
print(col_type.python_type) # <class 'str'>
# Check if type is specific class
from sqlalchemy import String, Integer
if isinstance(users.c.email.type, String):
print(f"Email is String with length {users.c.email.type.length}")
Common Custom Type Patterns
URL Type
from sqlalchemy import TypeDecorator, String
from urllib.parse import urlparse, urlunparse
class URLType(TypeDecorator):
"""Validate and normalize URLs."""
impl = String
cache_ok = True
def process_bind_param(self, value, dialect):
if value is not None:
# Parse and normalize URL
parsed = urlparse(value)
# Ensure scheme
if not parsed.scheme:
value = f"https://{value}"
parsed = urlparse(value)
return urlunparse(parsed)
return value
websites = Table(
'websites',
metadata,
Column('url', URLType(500))
)
conn.execute(
websites.insert().values(url='example.com')
)
# Stored as: 'https://example.com'
Money Type
from sqlalchemy import TypeDecorator, Numeric
from decimal import Decimal
class Money(TypeDecorator):
"""Store money with automatic rounding."""
impl = Numeric(10, 2)
cache_ok = True
def process_bind_param(self, value, dialect):
if value is not None:
# Round to 2 decimal places
return Decimal(value).quantize(Decimal('0.01'))
return value
def process_result_value(self, value, dialect):
if value is not None:
return Decimal(value)
return value
transactions = Table(
'transactions',
metadata,
Column('amount', Money, nullable=False)
)
Phone Number Type
from sqlalchemy import TypeDecorator, String
import re
class PhoneNumber(TypeDecorator):
"""Validate and format phone numbers."""
impl = String(20)
cache_ok = True
def process_bind_param(self, value, dialect):
if value is not None:
# Remove all non-digits
digits = re.sub(r'\D', '', value)
if len(digits) == 10:
# Format as (XXX) XXX-XXXX
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
elif len(digits) == 11 and digits[0] == '1':
# Format with country code
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
# Return as-is if invalid
return value
return value
contacts = Table(
'contacts',
metadata,
Column('phone', PhoneNumber)
)
conn.execute(
contacts.insert().values(phone='5551234567')
)
# Stored as: '(555) 123-4567'
Compressed Text
from sqlalchemy import TypeDecorator, LargeBinary
import zlib
class CompressedText(TypeDecorator):
"""Compress text before storage."""
impl = LargeBinary
cache_ok = True
def process_bind_param(self, value, dialect):
if value is not None:
# Compress string to bytes
return zlib.compress(value.encode('utf-8'))
return value
def process_result_value(self, value, dialect):
if value is not None:
# Decompress bytes to string
return zlib.decompress(value).decode('utf-8')
return value
logs = Table(
'logs',
metadata,
Column('id', Integer, primary_key=True),
Column('message', CompressedText)
)
Testing Custom Types
import pytest
from sqlalchemy import create_engine, MetaData, Table, Column, Integer
from sqlalchemy import select
def test_json_encoded_dict():
engine = create_engine('sqlite:///:memory:')
metadata = MetaData()
test_table = Table(
'test',
metadata,
Column('id', Integer, primary_key=True),
Column('data', JSONEncodedDict(200))
)
metadata.create_all(engine)
# Test insert
test_data = {'key': 'value', 'number': 42}
with engine.begin() as conn:
conn.execute(
test_table.insert().values(data=test_data)
)
# Test retrieval
result = conn.execute(select(test_table.c.data))
retrieved = result.scalar()
assert retrieved == test_data
assert isinstance(retrieved, dict)
Performance Considerations
cache_ok
Always set
cache_ok = True on custom types if they’re safe for caching. This enables statement caching.Process Functions
Keep
process_bind_param() and process_result_value() lightweight. Heavy processing impacts performance.Validation
Validate in application code when possible. Type processors run for every value.
Dialect Checks
Use
dialect.name checks sparingly. Consider Variant for dialect-specific behavior.Next Steps
Engines
Configure database engines and connections
Standard Types
Review built-in types