from __future__ import annotations
import sqlite3
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Iterator, List, Sequence, Type
from uuid import UUID
from eventsourcing.persistence import (
AggregateRecorder,
ApplicationRecorder,
Connection,
ConnectionPool,
Cursor,
DatabaseError,
DataError,
InfrastructureFactory,
IntegrityError,
InterfaceError,
InternalError,
Notification,
NotSupportedError,
OperationalError,
PersistenceError,
ProcessRecorder,
ProgrammingError,
StoredEvent,
Tracking,
)
from eventsourcing.utils import Environment, strtobool
if TYPE_CHECKING: # pragma: nocover
from types import TracebackType
SQLITE3_DEFAULT_LOCK_TIMEOUT = 5
[docs]
class SQLiteCursor(Cursor):
[docs]
def __init__(self, sqlite_cursor: sqlite3.Cursor):
self.sqlite_cursor = sqlite_cursor
def __enter__(self) -> sqlite3.Cursor:
return self.sqlite_cursor
def __exit__(self, *args: object, **kwargs: Any) -> None:
self.sqlite_cursor.close()
[docs]
def execute(self, *args: Any, **kwargs: Any) -> None:
self.sqlite_cursor.execute(*args, **kwargs)
def executemany(self, *args: Any, **kwargs: Any) -> None:
self.sqlite_cursor.executemany(*args, **kwargs)
[docs]
def fetchall(self) -> Any:
return self.sqlite_cursor.fetchall()
[docs]
def fetchone(self) -> Any:
return self.sqlite_cursor.fetchone()
@property
def lastrowid(self) -> Any:
return self.sqlite_cursor.lastrowid
[docs]
class SQLiteConnection(Connection[SQLiteCursor]):
[docs]
def __init__(self, sqlite_conn: sqlite3.Connection, max_age: float | None):
super().__init__(max_age=max_age)
self._sqlite_conn = sqlite_conn
@contextmanager
def transaction(self, *, commit: bool) -> Iterator[SQLiteCursor]:
# Context managed cursor, and context managed transaction.
with SQLiteTransaction(self, commit=commit) as curs, curs:
yield curs
[docs]
def cursor(self) -> SQLiteCursor:
return SQLiteCursor(self._sqlite_conn.cursor())
[docs]
def rollback(self) -> None:
self._sqlite_conn.rollback()
[docs]
def commit(self) -> None:
self._sqlite_conn.commit()
def _close(self) -> None:
self._sqlite_conn.close()
super()._close()
class SQLiteTransaction:
def __init__(self, connection: SQLiteConnection, *, commit: bool = False):
self.connection = connection
self.commit = commit
def __enter__(self) -> SQLiteCursor:
# We must issue a "BEGIN" explicitly
# when running in auto-commit mode.
cursor = self.connection.cursor()
cursor.execute("BEGIN")
return cursor
def __exit__(
self,
exc_type: Type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
try:
if exc_val:
# Roll back all changes
# if an exception occurs.
self.connection.rollback()
raise exc_val
if not self.commit:
self.connection.rollback()
else:
self.connection.commit()
except sqlite3.InterfaceError as e:
raise InterfaceError(e) from e
except sqlite3.DataError as e:
raise DataError(e) from e
except sqlite3.OperationalError as e:
raise OperationalError(e) from e
except sqlite3.IntegrityError as e:
raise IntegrityError(e) from e
except sqlite3.InternalError as e:
raise InternalError(e) from e
except sqlite3.ProgrammingError as e:
raise ProgrammingError(e) from e
except sqlite3.NotSupportedError as e:
raise NotSupportedError(e) from e
except sqlite3.DatabaseError as e:
raise DatabaseError(e) from e
except sqlite3.Error as e:
raise PersistenceError(e) from e
[docs]
class SQLiteConnectionPool(ConnectionPool[SQLiteConnection]):
[docs]
def __init__(
self,
*,
db_name: str,
lock_timeout: int | None = None,
pool_size: int = 5,
max_overflow: int = 10,
pool_timeout: float = 5.0,
max_age: float | None = None,
pre_ping: bool = False,
):
self.db_name = db_name
self.lock_timeout = lock_timeout
self.is_sqlite_memory_mode = self.detect_memory_mode(db_name)
self.is_journal_mode_wal = False
self.journal_mode_was_changed_to_wal = False
super().__init__(
pool_size=pool_size,
max_overflow=max_overflow,
pool_timeout=pool_timeout,
max_age=max_age,
pre_ping=pre_ping,
mutually_exclusive_read_write=self.is_sqlite_memory_mode,
)
@staticmethod
def detect_memory_mode(db_name: str) -> bool:
return bool(db_name) and (":memory:" in db_name or "mode=memory" in db_name)
def _create_connection(self) -> SQLiteConnection:
# Make a connection to an SQLite database.
try:
c = sqlite3.connect(
database=self.db_name,
uri=True,
check_same_thread=False,
isolation_level=None, # Auto-commit mode.
cached_statements=True,
timeout=self.lock_timeout or SQLITE3_DEFAULT_LOCK_TIMEOUT,
)
except (sqlite3.Error, TypeError) as e:
raise InterfaceError(e) from e
# Use WAL (write-ahead log) mode if file-based database.
if not self.is_sqlite_memory_mode and not self.is_journal_mode_wal:
cursor = c.cursor()
cursor.execute("PRAGMA journal_mode;")
mode = cursor.fetchone()[0]
if mode.lower() == "wal":
self.is_journal_mode_wal = True
else:
cursor.execute("PRAGMA journal_mode=WAL;")
self.is_journal_mode_wal = True
self.journal_mode_was_changed_to_wal = True
# Set the row factory.
c.row_factory = sqlite3.Row
# Return the connection.
return SQLiteConnection(sqlite_conn=c, max_age=self.max_age)
class SQLiteDatastore:
def __init__(
self,
db_name: str,
*,
lock_timeout: int | None = None,
pool_size: int = 5,
max_overflow: int = 10,
pool_timeout: float = 5.0,
max_age: float | None = None,
pre_ping: bool = False,
):
self.pool = SQLiteConnectionPool(
db_name=db_name,
lock_timeout=lock_timeout,
pool_size=pool_size,
max_overflow=max_overflow,
pool_timeout=pool_timeout,
max_age=max_age,
pre_ping=pre_ping,
)
@contextmanager
def transaction(self, *, commit: bool) -> Iterator[SQLiteCursor]:
connection = self.get_connection(commit=commit)
with connection as conn, conn.transaction(commit=commit) as curs:
yield curs
@contextmanager
def get_connection(self, *, commit: bool) -> Iterator[SQLiteConnection]:
# Using reader-writer interlocking is necessary for in-memory databases,
# but also speeds up (and provides "fairness") to file-based databases.
conn = self.pool.get_connection(is_writer=commit)
try:
yield conn
finally:
self.pool.put_connection(conn)
def close(self) -> None:
self.pool.close()
def __del__(self) -> None:
self.close()
[docs]
class SQLiteAggregateRecorder(AggregateRecorder):
[docs]
def __init__(
self,
datastore: SQLiteDatastore,
events_table_name: str = "stored_events",
):
assert isinstance(datastore, SQLiteDatastore)
self.datastore = datastore
self.events_table_name = events_table_name
self.create_table_statements = self.construct_create_table_statements()
self.insert_events_statement = (
f"INSERT INTO {self.events_table_name} VALUES (?,?,?,?)"
)
self.select_events_statement = (
f"SELECT * FROM {self.events_table_name} WHERE originator_id=? "
)
def construct_create_table_statements(self) -> List[str]:
statement = (
"CREATE TABLE IF NOT EXISTS "
f"{self.events_table_name} ("
"originator_id TEXT, "
"originator_version INTEGER, "
"topic TEXT, "
"state BLOB, "
"PRIMARY KEY "
"(originator_id, originator_version)) "
"WITHOUT ROWID"
)
return [statement]
def create_table(self) -> None:
with self.datastore.transaction(commit=True) as c:
for statement in self.create_table_statements:
c.execute(statement)
[docs]
def insert_events(
self, stored_events: List[StoredEvent], **kwargs: Any
) -> Sequence[int] | None:
with self.datastore.transaction(commit=True) as c:
return self._insert_events(c, stored_events, **kwargs)
def _insert_events(
self,
c: SQLiteCursor,
stored_events: List[StoredEvent],
**_: Any,
) -> Sequence[int] | None:
params = [
(
s.originator_id.hex,
s.originator_version,
s.topic,
s.state,
)
for s in stored_events
]
c.executemany(self.insert_events_statement, params)
return None
[docs]
def select_events(
self,
originator_id: UUID,
*,
gt: int | None = None,
lte: int | None = None,
desc: bool = False,
limit: int | None = None,
) -> List[StoredEvent]:
statement = self.select_events_statement
params: List[Any] = [originator_id.hex]
if gt is not None:
statement += "AND originator_version>? "
params.append(gt)
if lte is not None:
statement += "AND originator_version<=? "
params.append(lte)
statement += "ORDER BY originator_version "
if desc is False:
statement += "ASC "
else:
statement += "DESC "
if limit is not None:
statement += "LIMIT ? "
params.append(limit)
with self.datastore.transaction(commit=False) as c:
c.execute(statement, params)
return [
StoredEvent(
originator_id=UUID(row["originator_id"]),
originator_version=row["originator_version"],
topic=row["topic"],
state=row["state"],
)
for row in c.fetchall()
]
[docs]
class SQLiteApplicationRecorder(
SQLiteAggregateRecorder,
ApplicationRecorder,
):
[docs]
def __init__(
self,
datastore: SQLiteDatastore,
events_table_name: str = "stored_events",
):
super().__init__(datastore, events_table_name)
self.select_max_notification_id_statement = (
f"SELECT MAX(rowid) FROM {self.events_table_name}"
)
def construct_create_table_statements(self) -> List[str]:
statement = (
"CREATE TABLE IF NOT EXISTS "
f"{self.events_table_name} ("
"originator_id TEXT, "
"originator_version INTEGER, "
"topic TEXT, "
"state BLOB, "
"PRIMARY KEY "
"(originator_id, originator_version))"
)
return [statement]
def _insert_events(
self,
c: SQLiteCursor,
stored_events: List[StoredEvent],
**_: Any,
) -> Sequence[int] | None:
returning = []
for stored_event in stored_events:
c.execute(
self.insert_events_statement,
(
stored_event.originator_id.hex,
stored_event.originator_version,
stored_event.topic,
stored_event.state,
),
)
returning.append(c.lastrowid)
return returning
[docs]
def select_notifications(
self,
start: int,
limit: int,
stop: int | None = None,
topics: Sequence[str] = (),
) -> List[Notification]:
"""
Returns a list of event notifications
from 'start', limited by 'limit'.
"""
params: List[int | str] = [start]
statement = f"SELECT rowid, * FROM {self.events_table_name} WHERE rowid>=? "
if stop is not None:
params.append(stop)
statement += "AND rowid<=? "
if topics:
params += list(topics)
statement += "AND topic IN (%s) " % ",".join("?" * len(topics))
params.append(limit)
statement += "ORDER BY rowid LIMIT ?"
with self.datastore.transaction(commit=False) as c:
c.execute(statement, params)
return [
Notification(
id=row["rowid"],
originator_id=UUID(row["originator_id"]),
originator_version=row["originator_version"],
topic=row["topic"],
state=row["state"],
)
for row in c.fetchall()
]
[docs]
def max_notification_id(self) -> int:
"""
Returns the maximum notification ID.
"""
with self.datastore.transaction(commit=False) as c:
return self._max_notification_id(c)
def _max_notification_id(self, c: SQLiteCursor) -> int:
c.execute(self.select_max_notification_id_statement)
return c.fetchone()[0] or 0
[docs]
class SQLiteProcessRecorder(
SQLiteApplicationRecorder,
ProcessRecorder,
):
[docs]
def __init__(
self,
datastore: SQLiteDatastore,
events_table_name: str = "stored_events",
):
super().__init__(datastore, events_table_name)
self.insert_tracking_statement = "INSERT INTO tracking VALUES (?,?)"
self.select_max_tracking_id_statement = (
"SELECT MAX(notification_id) FROM tracking WHERE application_name=?"
)
self.count_tracking_id_statement = (
"SELECT COUNT(*) FROM tracking WHERE "
"application_name=? AND notification_id=?"
)
def construct_create_table_statements(self) -> List[str]:
statements = super().construct_create_table_statements()
statements.append(
"CREATE TABLE IF NOT EXISTS tracking ("
"application_name TEXT, "
"notification_id INTEGER, "
"PRIMARY KEY "
"(application_name, notification_id)) "
"WITHOUT ROWID"
)
return statements
[docs]
def max_tracking_id(self, application_name: str) -> int:
params = [application_name]
with self.datastore.transaction(commit=False) as c:
c.execute(self.select_max_tracking_id_statement, params)
return c.fetchone()[0] or 0
[docs]
def has_tracking_id(self, application_name: str, notification_id: int) -> bool:
params = [application_name, notification_id]
with self.datastore.transaction(commit=False) as c:
c.execute(self.count_tracking_id_statement, params)
return bool(c.fetchone()[0])
def _insert_events(
self,
c: SQLiteCursor,
stored_events: List[StoredEvent],
**kwargs: Any,
) -> Sequence[int] | None:
returning = super()._insert_events(c, stored_events, **kwargs)
tracking: Tracking | None = kwargs.get("tracking", None)
if tracking is not None:
c.execute(
self.insert_tracking_statement,
(
tracking.application_name,
tracking.notification_id,
),
)
return returning
[docs]
class Factory(InfrastructureFactory):
SQLITE_DBNAME = "SQLITE_DBNAME"
SQLITE_LOCK_TIMEOUT = "SQLITE_LOCK_TIMEOUT"
CREATE_TABLE = "CREATE_TABLE"
aggregate_recorder_class = SQLiteAggregateRecorder
application_recorder_class = SQLiteApplicationRecorder
process_recorder_class = SQLiteProcessRecorder
[docs]
def __init__(self, env: Environment):
super().__init__(env)
db_name = self.env.get(self.SQLITE_DBNAME)
if not db_name:
msg = (
"SQLite database name not found "
"in environment with keys: "
f"{', '.join(self.env.create_keys(self.SQLITE_DBNAME))}"
)
raise OSError(msg)
lock_timeout_str = (
self.env.get(self.SQLITE_LOCK_TIMEOUT) or ""
).strip() or None
lock_timeout: int | None = None
if lock_timeout_str is not None:
try:
lock_timeout = int(lock_timeout_str)
except ValueError:
msg = (
"SQLite environment value for key "
f"'{self.SQLITE_LOCK_TIMEOUT}' is invalid. "
"If set, an int or empty string is expected: "
f"'{lock_timeout_str}'"
)
raise OSError(msg) from None
self.datastore = SQLiteDatastore(db_name=db_name, lock_timeout=lock_timeout)
[docs]
def aggregate_recorder(self, purpose: str = "events") -> AggregateRecorder:
events_table_name = "stored_" + purpose
recorder = self.aggregate_recorder_class(
datastore=self.datastore,
events_table_name=events_table_name,
)
if self.env_create_table():
recorder.create_table()
return recorder
[docs]
def application_recorder(self) -> ApplicationRecorder:
recorder = self.application_recorder_class(datastore=self.datastore)
if self.env_create_table():
recorder.create_table()
return recorder
[docs]
def process_recorder(self) -> ProcessRecorder:
recorder = self.process_recorder_class(datastore=self.datastore)
if self.env_create_table():
recorder.create_table()
return recorder
def env_create_table(self) -> bool:
default = "yes"
return bool(strtobool(self.env.get(self.CREATE_TABLE, default) or default))
[docs]
def close(self) -> None:
self.datastore.close()