from __future__ import annotations
import sqlite3
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Literal, cast
from uuid import UUID
from eventsourcing.persistence import (
AggregateRecorder,
ApplicationRecorder,
Connection,
ConnectionPool,
Cursor,
DatabaseError,
DataError,
InfrastructureFactory,
IntegrityError,
InterfaceError,
InternalError,
Notification,
NotSupportedError,
OperationalError,
PersistenceError,
ProcessRecorder,
ProgrammingError,
Recorder,
StoredEvent,
Subscription,
Tracking,
TrackingRecorder,
)
from eventsourcing.utils import Environment, EnvType, resolve_topic, strtobool
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from types import TracebackType
SQLITE3_DEFAULT_LOCK_TIMEOUT = 5
[docs]
class SQLiteCursor(Cursor):
[docs]
def __init__(self, sqlite_cursor: sqlite3.Cursor):
self.sqlite_cursor = sqlite_cursor
def __enter__(self) -> sqlite3.Cursor:
return self.sqlite_cursor
def __exit__(self, *args: object, **kwargs: Any) -> None:
self.sqlite_cursor.close()
[docs]
def execute(self, *args: Any, **kwargs: Any) -> None:
self.sqlite_cursor.execute(*args, **kwargs)
def executemany(self, *args: Any, **kwargs: Any) -> None:
self.sqlite_cursor.executemany(*args, **kwargs)
[docs]
def fetchall(self) -> Any:
return self.sqlite_cursor.fetchall()
[docs]
def fetchone(self) -> Any:
return self.sqlite_cursor.fetchone()
@property
def lastrowid(self) -> Any:
return self.sqlite_cursor.lastrowid
[docs]
class SQLiteConnection(Connection[SQLiteCursor]):
[docs]
def __init__(self, sqlite_conn: sqlite3.Connection, max_age: float | None):
super().__init__(max_age=max_age)
self._sqlite_conn = sqlite_conn
@contextmanager
def transaction(self, *, commit: bool) -> Iterator[SQLiteCursor]:
# Context managed cursor, and context managed transaction.
with SQLiteTransaction(self, commit=commit) as curs, curs:
yield curs
[docs]
def cursor(self) -> SQLiteCursor:
return SQLiteCursor(self._sqlite_conn.cursor())
[docs]
def rollback(self) -> None:
self._sqlite_conn.rollback()
[docs]
def commit(self) -> None:
self._sqlite_conn.commit()
def _close(self) -> None:
self._sqlite_conn.close()
super()._close()
class SQLiteTransaction:
def __init__(self, connection: SQLiteConnection, *, commit: bool = False):
self.connection = connection
self.commit = commit
def __enter__(self) -> SQLiteCursor:
# We must issue a "BEGIN" explicitly
# when running in auto-commit mode.
cursor = self.connection.cursor()
cursor.execute("BEGIN")
return cursor
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
try:
if exc_val:
# Roll back all changes
# if an exception occurs.
self.connection.rollback()
raise exc_val
if not self.commit:
self.connection.rollback()
else:
self.connection.commit()
except sqlite3.InterfaceError as e:
raise InterfaceError(e) from e
except sqlite3.DataError as e:
raise DataError(e) from e
except sqlite3.OperationalError as e:
raise OperationalError(e) from e
except sqlite3.IntegrityError as e:
raise IntegrityError(e) from e
except sqlite3.InternalError as e:
raise InternalError(e) from e
except sqlite3.ProgrammingError as e:
raise ProgrammingError(e) from e
except sqlite3.NotSupportedError as e:
raise NotSupportedError(e) from e
except sqlite3.DatabaseError as e:
raise DatabaseError(e) from e
except sqlite3.Error as e:
raise PersistenceError(e) from e
[docs]
class SQLiteConnectionPool(ConnectionPool[SQLiteConnection]):
[docs]
def __init__(
self,
*,
db_name: str,
lock_timeout: int | None = None,
pool_size: int = 5,
max_overflow: int = 10,
pool_timeout: float = 5.0,
max_age: float | None = None,
pre_ping: bool = False,
):
self.db_name = db_name
self.lock_timeout = lock_timeout
self.is_sqlite_memory_mode = self.detect_memory_mode(db_name)
self.is_journal_mode_wal = False
self.journal_mode_was_changed_to_wal = False
super().__init__(
pool_size=pool_size,
max_overflow=max_overflow,
pool_timeout=pool_timeout,
max_age=max_age,
pre_ping=pre_ping,
mutually_exclusive_read_write=self.is_sqlite_memory_mode,
)
@staticmethod
def detect_memory_mode(db_name: str) -> bool:
return bool(db_name) and (":memory:" in db_name or "mode=memory" in db_name)
def _create_connection(self) -> SQLiteConnection:
# Make a connection to an SQLite database.
try:
c = sqlite3.connect(
database=self.db_name,
uri=True,
check_same_thread=False,
isolation_level=None, # Auto-commit mode.
cached_statements=True,
timeout=self.lock_timeout or SQLITE3_DEFAULT_LOCK_TIMEOUT,
)
except (sqlite3.Error, TypeError) as e:
raise InterfaceError(e) from e
# Use WAL (write-ahead log) mode if file-based database.
if not self.is_sqlite_memory_mode and not self.is_journal_mode_wal:
cursor = c.cursor()
cursor.execute("PRAGMA journal_mode;")
mode = cursor.fetchone()[0]
if mode.lower() == "wal":
self.is_journal_mode_wal = True
else:
cursor.execute("PRAGMA journal_mode=WAL;")
self.is_journal_mode_wal = True
self.journal_mode_was_changed_to_wal = True
# Set the row factory.
c.row_factory = sqlite3.Row
# Return the connection.
return SQLiteConnection(sqlite_conn=c, max_age=self.max_age)
class SQLiteDatastore:
def __init__(
self,
db_name: str,
*,
lock_timeout: int | None = None,
pool_size: int = 5,
max_overflow: int = 10,
pool_timeout: float = 5.0,
max_age: float | None = None,
pre_ping: bool = False,
single_row_tracking: bool = True,
originator_id_type: Literal["uuid", "text"] = "uuid",
):
self.pool = SQLiteConnectionPool(
db_name=db_name,
lock_timeout=lock_timeout,
pool_size=pool_size,
max_overflow=max_overflow,
pool_timeout=pool_timeout,
max_age=max_age,
pre_ping=pre_ping,
)
self.single_row_tracking = single_row_tracking
self.originator_id_type = originator_id_type
@contextmanager
def transaction(self, *, commit: bool) -> Iterator[SQLiteCursor]:
connection = self.get_connection(commit=commit)
with connection as conn, conn.transaction(commit=commit) as curs:
yield curs
@contextmanager
def get_connection(self, *, commit: bool) -> Iterator[SQLiteConnection]:
# Using reader-writer interlocking is necessary for in-memory databases,
# but also speeds up (and provides "fairness") to file-based databases.
conn = self.pool.get_connection(is_writer=commit)
try:
yield conn
finally:
self.pool.put_connection(conn)
def close(self) -> None:
self.pool.close()
def __del__(self) -> None:
self.close()
class SQLiteRecorder(Recorder):
def __init__(
self,
datastore: SQLiteDatastore,
):
assert isinstance(datastore, SQLiteDatastore)
self.datastore = datastore
self.create_table_statements = self.construct_create_table_statements()
def construct_create_table_statements(self) -> list[str]:
return []
def create_table(self) -> None:
with self.datastore.transaction(commit=True) as c:
self._create_table(c)
def _create_table(self, c: SQLiteCursor) -> None:
for statement in self.create_table_statements:
c.execute(statement)
def convert_originator_id(self, originator_id: str) -> UUID | str:
return (
UUID(originator_id)
if self.datastore.originator_id_type == "uuid"
else originator_id
)
[docs]
class SQLiteAggregateRecorder(SQLiteRecorder, AggregateRecorder):
[docs]
def __init__(
self,
datastore: SQLiteDatastore,
events_table_name: str = "stored_events",
):
self.events_table_name = events_table_name
super().__init__(datastore)
self.insert_events_statement = (
f"INSERT INTO {self.events_table_name} VALUES (?,?,?,?)"
)
self.select_events_statement = (
f"SELECT * FROM {self.events_table_name} WHERE originator_id=? "
)
def construct_create_table_statements(self) -> list[str]:
statements = super().construct_create_table_statements()
statements.append(
"CREATE TABLE IF NOT EXISTS "
f"{self.events_table_name} ("
"originator_id TEXT, "
"originator_version INTEGER, "
"topic TEXT, "
"state BLOB, "
"PRIMARY KEY "
"(originator_id, originator_version)) "
"WITHOUT ROWID"
)
return statements
[docs]
def insert_events(
self, stored_events: Sequence[StoredEvent], **kwargs: Any
) -> Sequence[int] | None:
with self.datastore.transaction(commit=True) as c:
return self._insert_events(c, stored_events, **kwargs)
def _insert_events(
self,
c: SQLiteCursor,
stored_events: Sequence[StoredEvent],
**_: Any,
) -> Sequence[int] | None:
params = [
(
(
s.originator_id.hex
if isinstance(s.originator_id, UUID)
else s.originator_id
),
s.originator_version,
s.topic,
s.state,
)
for s in stored_events
]
c.executemany(self.insert_events_statement, params)
return None
[docs]
def select_events(
self,
originator_id: UUID | str,
*,
gt: int | None = None,
lte: int | None = None,
desc: bool = False,
limit: int | None = None,
) -> Sequence[StoredEvent]:
statement = self.select_events_statement
params: list[Any] = [
originator_id.hex if isinstance(originator_id, UUID) else originator_id
]
if gt is not None:
statement += "AND originator_version>? "
params.append(gt)
if lte is not None:
statement += "AND originator_version<=? "
params.append(lte)
statement += "ORDER BY originator_version "
if desc is False:
statement += "ASC "
else:
statement += "DESC "
if limit is not None:
statement += "LIMIT ? "
params.append(limit)
with self.datastore.transaction(commit=False) as c:
c.execute(statement, params)
return [
StoredEvent(
originator_id=self.convert_originator_id(row["originator_id"]),
originator_version=row["originator_version"],
topic=row["topic"],
state=row["state"],
)
for row in c.fetchall()
]
[docs]
class SQLiteApplicationRecorder(
SQLiteAggregateRecorder,
ApplicationRecorder,
):
[docs]
def __init__(
self,
datastore: SQLiteDatastore,
events_table_name: str = "stored_events",
):
super().__init__(datastore, events_table_name)
self.select_max_notification_id_statement = (
f"SELECT MAX(rowid) FROM {self.events_table_name}"
)
def construct_create_table_statements(self) -> list[str]:
statement = (
"CREATE TABLE IF NOT EXISTS "
f"{self.events_table_name} ("
"originator_id TEXT, "
"originator_version INTEGER, "
"topic TEXT, "
"state BLOB, "
"PRIMARY KEY "
"(originator_id, originator_version))"
)
return [statement]
def _insert_events(
self,
c: SQLiteCursor,
stored_events: Sequence[StoredEvent],
**_: Any,
) -> Sequence[int] | None:
returning = []
for s in stored_events:
c.execute(
self.insert_events_statement,
(
(
s.originator_id.hex
if isinstance(s.originator_id, UUID)
else s.originator_id
),
s.originator_version,
s.topic,
s.state,
),
)
returning.append(c.lastrowid)
return returning
[docs]
def select_notifications(
self,
start: int | None,
limit: int,
stop: int | None = None,
topics: Sequence[str] = (),
*,
inclusive_of_start: bool = True,
) -> Sequence[Notification]:
"""Returns a list of event notifications
from 'start', limited by 'limit'.
"""
params: list[int | str] = []
statement = f"SELECT rowid, * FROM {self.events_table_name} "
has_where = False
if start is not None:
has_where = True
statement += "WHERE "
params.append(start)
if inclusive_of_start:
statement += "rowid>=? "
else:
statement += "rowid>? "
if stop is not None:
if not has_where:
has_where = True
statement += "WHERE "
else:
statement += "AND "
params.append(stop)
statement += "rowid<=? "
if topics:
if not has_where:
statement += "WHERE "
else:
statement += "AND "
params += list(topics)
statement += f"topic IN ({','.join('?' * len(topics))}) "
params.append(limit)
statement += "ORDER BY rowid LIMIT ?"
with self.datastore.transaction(commit=False) as c:
c.execute(statement, params)
return [
Notification(
id=row["rowid"],
originator_id=self.convert_originator_id(row["originator_id"]),
originator_version=row["originator_version"],
topic=row["topic"],
state=row["state"],
)
for row in c.fetchall()
]
[docs]
def max_notification_id(self) -> int:
"""Returns the maximum notification ID."""
with self.datastore.transaction(commit=False) as c:
return self._max_notification_id(c)
def _max_notification_id(self, c: SQLiteCursor) -> int:
c.execute(self.select_max_notification_id_statement)
return c.fetchone()[0]
[docs]
def subscribe(
self, gt: int | None = None, topics: Sequence[str] = ()
) -> Subscription[ApplicationRecorder]:
"""This method is not implemented on this class."""
msg = f"The {type(self).__qualname__} recorder does not support subscriptions"
raise NotImplementedError(msg)
[docs]
class SQLiteTrackingRecorder(SQLiteRecorder, TrackingRecorder):
[docs]
def __init__(
self,
datastore: SQLiteDatastore,
**kwargs: Any,
):
super().__init__(datastore, **kwargs)
self.tracking_table_exists: bool = False
self.tracking_migration_previous: int | None = None
self.tracking_migration_current: int | None = None
self.table_migration_identifier = "__migration__"
self.has_checked_for_multi_row_tracking_table: bool = False
if self.datastore.single_row_tracking:
self.insert_tracking_statement = (
"INSERT INTO tracking "
"VALUES (:application_name, :notification_id) "
"ON CONFLICT (application_name) DO UPDATE "
"SET notification_id = :notification_id "
"WHERE tracking.notification_id < :notification_id "
"RETURNING notification_id"
)
else:
self.insert_tracking_statement = (
"INSERT INTO tracking VALUES (:application_name, :notification_id)"
)
self.select_max_tracking_id_statement = (
"SELECT MAX(notification_id) FROM tracking WHERE application_name=?"
)
def construct_create_table_statements(self) -> list[str]:
statements = super().construct_create_table_statements()
if self.datastore.single_row_tracking:
statements.append(
"CREATE TABLE IF NOT EXISTS tracking ("
"application_name TEXT, "
"notification_id INTEGER, "
"PRIMARY KEY "
"(application_name)) "
"WITHOUT ROWID"
)
else:
statements.append(
"CREATE TABLE IF NOT EXISTS tracking ("
"application_name TEXT, "
"notification_id INTEGER, "
"PRIMARY KEY "
"(application_name, notification_id)) "
"WITHOUT ROWID"
)
return statements
def create_table(self) -> None:
# Get the migration version.
try:
self.tracking_migration_current = self.tracking_migration_previous = (
self.max_tracking_id(self.table_migration_identifier)
)
except OperationalError:
pass
else:
self.tracking_table_exists = True
super().create_table()
if (
not self.datastore.single_row_tracking
and self.tracking_migration_current is not None
):
msg = "Can't do multi-row tracking with single-row tracking table"
raise OperationalError(msg)
def _create_table(self, c: SQLiteCursor) -> None:
max_tracking_ids: dict[str, int] = {}
if (
self.datastore.single_row_tracking
and self.tracking_table_exists
and not self.tracking_migration_previous
):
# Migrate tracking to use single-row per application name.
# - Get all application names.
c.execute("SELECT DISTINCT application_name FROM tracking")
application_names: list[str] = [
select_row["application_name"] for select_row in c.fetchall()
]
# - Get max tracking ID for each application name.
for application_name in application_names:
c.execute(self.select_max_tracking_id_statement, (application_name,))
max_tracking_id_row = c.fetchone()
assert max_tracking_id_row is not None
max_tracking_ids[application_name] = max_tracking_id_row[0]
# - Rename the table.
drop_table_statement = "ALTER TABLE tracking RENAME TO old1_tracking"
c.execute(drop_table_statement)
# Create the table.
super()._create_table(c)
# - Maybe insert migration tracking record and application tracking records.
if self.datastore.single_row_tracking and (
not self.tracking_table_exists
or (self.tracking_table_exists and not self.tracking_migration_previous)
):
# - Assume we just created a table for single-row tracking.
self._insert_tracking(c, Tracking(self.table_migration_identifier, 1))
self.tracking_migration_current = 1
for application_name, max_tracking_id in max_tracking_ids.items():
self._insert_tracking(c, Tracking(application_name, max_tracking_id))
[docs]
def insert_tracking(self, tracking: Tracking) -> None:
with self.datastore.transaction(commit=True) as c:
self._insert_tracking(c, tracking)
def _insert_tracking(
self,
c: SQLiteCursor,
tracking: Tracking,
) -> None:
self._check_has_multi_row_tracking_table(c)
c.execute(
self.insert_tracking_statement,
{
"application_name": tracking.application_name,
"notification_id": tracking.notification_id,
},
)
if self.datastore.single_row_tracking:
fetchone = c.fetchone()
if fetchone is None:
msg = (
"Failed to record tracking for "
f"{tracking.application_name} {tracking.notification_id}"
)
raise IntegrityError(msg)
def _check_has_multi_row_tracking_table(self, c: SQLiteCursor) -> None:
if (
not self.datastore.single_row_tracking
and not self.has_checked_for_multi_row_tracking_table
and self._max_tracking_id(self.table_migration_identifier, c)
):
msg = "Can't do multi-row tracking with single-row tracking table"
raise OperationalError(msg)
self.has_checked_for_multi_row_tracking_table = True
[docs]
def max_tracking_id(self, application_name: str) -> int | None:
with self.datastore.transaction(commit=False) as c:
return self._max_tracking_id(application_name, c)
def _max_tracking_id(self, application_name: str, c: SQLiteCursor) -> int | None:
params = [application_name]
c.execute(self.select_max_tracking_id_statement, params)
return c.fetchone()[0]
[docs]
class SQLiteProcessRecorder(
SQLiteTrackingRecorder,
SQLiteApplicationRecorder,
ProcessRecorder,
):
[docs]
def __init__(
self,
datastore: SQLiteDatastore,
*,
events_table_name: str = "stored_events",
):
super().__init__(datastore, events_table_name=events_table_name)
def _insert_events(
self,
c: SQLiteCursor,
stored_events: Sequence[StoredEvent],
**kwargs: Any,
) -> Sequence[int] | None:
returning = super()._insert_events(c, stored_events, **kwargs)
tracking: Tracking | None = kwargs.get("tracking")
if tracking is not None:
self._insert_tracking(c, tracking)
return returning
[docs]
class SQLiteFactory(InfrastructureFactory[SQLiteTrackingRecorder]):
SQLITE_DBNAME = "SQLITE_DBNAME"
SQLITE_LOCK_TIMEOUT = "SQLITE_LOCK_TIMEOUT"
SQLITE_SINGLE_ROW_TRACKING = "SINGLE_ROW_TRACKING"
ORIGINATOR_ID_TYPE = "ORIGINATOR_ID_TYPE"
CREATE_TABLE = "CREATE_TABLE"
aggregate_recorder_class = SQLiteAggregateRecorder
application_recorder_class = SQLiteApplicationRecorder
tracking_recorder_class = SQLiteTrackingRecorder
process_recorder_class = SQLiteProcessRecorder
[docs]
def __init__(self, env: Environment | EnvType | None):
super().__init__(env)
db_name = self.env.get(self.SQLITE_DBNAME)
if not db_name:
msg = (
"SQLite database name not found "
"in environment with keys: "
f"{', '.join(self.env.create_keys(self.SQLITE_DBNAME))}"
)
raise OSError(msg)
lock_timeout_str = (
self.env.get(self.SQLITE_LOCK_TIMEOUT) or ""
).strip() or None
lock_timeout: int | None = None
if lock_timeout_str is not None:
try:
lock_timeout = int(lock_timeout_str)
except ValueError:
msg = (
"SQLite environment value for key "
f"'{self.SQLITE_LOCK_TIMEOUT}' is invalid. "
"If set, an int or empty string is expected: "
f"'{lock_timeout_str}'"
)
raise OSError(msg) from None
single_row_tracking = strtobool(
self.env.get(self.SQLITE_SINGLE_ROW_TRACKING, "t")
)
originator_id_type = cast(
Literal["uuid", "text"],
self.env.get(self.ORIGINATOR_ID_TYPE, "uuid"),
)
if originator_id_type.lower() not in ("uuid", "text"):
msg = (
f"Invalid {self.ORIGINATOR_ID_TYPE} '{originator_id_type}', "
f"must be 'uuid' or 'text'"
)
raise OSError(msg)
self.datastore = SQLiteDatastore(
db_name=db_name,
lock_timeout=lock_timeout,
single_row_tracking=single_row_tracking,
originator_id_type=originator_id_type,
)
[docs]
def aggregate_recorder(self, purpose: str = "events") -> AggregateRecorder:
events_table_name = "stored_" + purpose
recorder = self.aggregate_recorder_class(
datastore=self.datastore,
events_table_name=events_table_name,
)
if self.env_create_table():
recorder.create_table()
return recorder
[docs]
def application_recorder(self) -> ApplicationRecorder:
application_recorder_topic = self.env.get(self.APPLICATION_RECORDER_TOPIC)
if application_recorder_topic:
application_recorder_class: type[SQLiteApplicationRecorder] = resolve_topic(
application_recorder_topic
)
assert issubclass(application_recorder_class, SQLiteApplicationRecorder)
else:
application_recorder_class = self.application_recorder_class
recorder = application_recorder_class(datastore=self.datastore)
if self.env_create_table():
recorder.create_table()
return recorder
[docs]
def tracking_recorder(
self, tracking_recorder_class: type[SQLiteTrackingRecorder] | None = None
) -> SQLiteTrackingRecorder:
if tracking_recorder_class is None:
tracking_recorder_topic = self.env.get(self.TRACKING_RECORDER_TOPIC)
if tracking_recorder_topic:
tracking_recorder_class = resolve_topic(tracking_recorder_topic)
else:
tracking_recorder_class = self.tracking_recorder_class
assert tracking_recorder_class is not None
assert issubclass(tracking_recorder_class, SQLiteTrackingRecorder)
recorder = tracking_recorder_class(datastore=self.datastore)
if self.env_create_table():
recorder.create_table()
return recorder
[docs]
def process_recorder(self) -> ProcessRecorder:
process_recorder_topic = self.env.get(self.PROCESS_RECORDER_TOPIC)
if process_recorder_topic:
process_recorder_class: type[SQLiteProcessRecorder] = resolve_topic(
process_recorder_topic
)
assert issubclass(process_recorder_class, SQLiteProcessRecorder)
else:
process_recorder_class = self.process_recorder_class
recorder = process_recorder_class(datastore=self.datastore)
if self.env_create_table():
recorder.create_table()
return recorder
def env_create_table(self) -> bool:
default = "yes"
return bool(strtobool(self.env.get(self.CREATE_TABLE, default) or default))
[docs]
def close(self) -> None:
self.datastore.close()
Factory = SQLiteFactory