from dataclasses import dataclass, field
from string import Template
from typing import Dict, List, Union, Any
@dataclass
class PromptTemplateManager:
role_mapping: Dict[str, str] = field(
default_factory=lambda: {
"system": "system",
"user": "user",
"assistant": "assistant"
}
)
templates: Dict[str, Union[Template, List[Dict[str, Any]]]] = field(
init=False,
default_factory=dict,
)
def __post_init__(self):
"""Load all templates from the templates directory."""
self.templates_dir = os.path.join(
os.path.dirname(__file__), "templates"
)
self._load_templates()
def _load_templates(self):
"""Load all .py files from templates/ directory."""
for filename in os.listdir(self.templates_dir):
if filename.endswith(".py") and filename != "__init__.py":
script_name = os.path.splitext(filename)[0]
module = importlib.import_module(f"remem.prompts.templates.{script_name}")
if not hasattr(module, "prompt_template"):
raise AttributeError(
f"Module '{script_name}' does not define 'prompt_template'"
)
prompt_template = module.prompt_template
# Convert to Template if string
if isinstance(prompt_template, str):
self.templates[script_name] = Template(prompt_template)
# Handle chat history format
elif isinstance(prompt_template, list):
for item in prompt_template:
item["role"] = self.role_mapping.get(item["role"], item["role"])
item["content"] = Template(item["content"]) \
if isinstance(item["content"], str) else item["content"]
self.templates[script_name] = prompt_template
def render(self, name: str, **kwargs) -> Union[str, List[Dict[str, Any]]]:
"""Render a template with variables."""
template = self.get_template(name)
if isinstance(template, Template):
return template.substitute(**kwargs)
elif isinstance(template, list):
return [
{"role": item["role"], "content": item["content"].substitute(**kwargs)}
for item in template
]
def get_template(self, name: str):
"""Get a template by name."""
if name not in self.templates:
raise KeyError(f"Template '{name}' not found.")
return self.templates[name]