import logging
from typing import Any
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.driver.my_backend.operations import (
MyBackendEntityNodeOperations,
MyBackendEpisodeNodeOperations,
MyBackendCommunityNodeOperations,
MyBackendSagaNodeOperations,
MyBackendEntityEdgeOperations,
MyBackendEpisodicEdgeOperations,
MyBackendCommunityEdgeOperations,
MyBackendHasEpisodeEdgeOperations,
MyBackendNextEpisodeEdgeOperations,
MyBackendSearchOperations,
MyBackendGraphMaintenanceOperations,
)
logger = logging.getLogger(__name__)
class MyBackendDriver(GraphDriver):
provider = GraphProvider.MY_BACKEND
aoss_client: None = None # Set if using external search (like Neptune)
def __init__(
self,
host: str = 'localhost',
port: int = 7687,
username: str | None = None,
password: str | None = None,
database: str = 'default_db',
):
super().__init__()
self._database = database
# Initialize your database client
self.client = MyBackendClient(
host=host,
port=port,
username=username,
password=password
)
# Instantiate all 11 operation classes
self._entity_node_ops = MyBackendEntityNodeOperations()
self._episode_node_ops = MyBackendEpisodeNodeOperations()
self._community_node_ops = MyBackendCommunityNodeOperations()
self._saga_node_ops = MyBackendSagaNodeOperations()
self._entity_edge_ops = MyBackendEntityEdgeOperations()
self._episodic_edge_ops = MyBackendEpisodicEdgeOperations()
self._community_edge_ops = MyBackendCommunityEdgeOperations()
self._has_episode_edge_ops = MyBackendHasEpisodeEdgeOperations()
self._next_episode_edge_ops = MyBackendNextEpisodeEdgeOperations()
self._search_ops = MyBackendSearchOperations()
self._graph_ops = MyBackendGraphMaintenanceOperations()
# Expose operations via properties
@property
def entity_node_ops(self):
return self._entity_node_ops
@property
def episode_node_ops(self):
return self._episode_node_ops
# ... implement all 11 properties
@property
def search_ops(self):
return self._search_ops
@property
def graph_ops(self):
return self._graph_ops
# Implement abstract methods
def execute_query(self, cypher_query: str, **kwargs: Any):
"""Execute a query against the database."""
return self.client.execute(cypher_query, **kwargs)
def session(self, database: str | None = None):
"""Create a database session."""
db = database or self._database
return MyBackendDriverSession(self.client, db)
async def build_indices_and_constraints(self, delete_existing: bool = False):
"""Build indices and constraints."""
if delete_existing:
await self.delete_all_indexes()
# Get index queries from shared builders
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
fulltext_indices = get_fulltext_indices(self.provider)
range_indices = get_range_indices(self.provider)
async with self.session() as session:
for query in fulltext_indices + range_indices:
await session.run(query)
async def delete_all_indexes(self):
"""Delete all indices."""
# Implementation depends on your database
async with self.session() as session:
await session.run("DROP ALL INDEXES")
def close(self):
"""Close the database connection."""
self.client.close()