from __future__ import annotations
import contextlib
import logging
from asyncio import CancelledError
from contextlib import contextmanager
from threading import Thread
from typing import TYPE_CHECKING, Any, Callable, cast
import psycopg
import psycopg.errors
import psycopg_pool
from psycopg import Connection, Cursor, Error
from psycopg.generators import notifies
from psycopg.rows import DictRow, dict_row
from psycopg.sql import SQL, Composed, Identifier
from typing_extensions import TypeVar
from eventsourcing.persistence import (
AggregateRecorder,
ApplicationRecorder,
DatabaseError,
DataError,
InfrastructureFactory,
IntegrityError,
InterfaceError,
InternalError,
ListenNotifySubscription,
Notification,
NotSupportedError,
OperationalError,
PersistenceError,
ProcessRecorder,
ProgrammingError,
StoredEvent,
Subscription,
Tracking,
TrackingRecorder,
)
from eventsourcing.utils import Environment, resolve_topic, retry, strtobool
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from uuid import UUID
from psycopg.abc import Query
from typing_extensions import Self
logging.getLogger("psycopg.pool").setLevel(logging.CRITICAL)
logging.getLogger("psycopg").setLevel(logging.CRITICAL)
# Copy of "private" psycopg.errors._NO_TRACEBACK (in case it changes)
# From psycopg: "Don't show a complete traceback upon raising these exception.
# Usually the traceback starts from internal functions (for instance in the
# server communication callbacks) but, for the end user, it's more important
# to get the high level information about where the exception was raised, for
# instance in a certain `Cursor.execute()`."
NO_TRACEBACK = (Error, KeyboardInterrupt, CancelledError)
[docs]
class ConnectionPool(psycopg_pool.ConnectionPool[Any]):
[docs]
def __init__(
self,
*args: Any,
get_password_func: Callable[[], str] | None = None,
**kwargs: Any,
) -> None:
self.get_password_func = get_password_func
super().__init__(*args, **kwargs)
def _connect(self, timeout: float | None = None) -> Connection[Any]:
if self.get_password_func:
self.kwargs["password"] = self.get_password_func()
return super()._connect(timeout=timeout)
[docs]
class PostgresDatastore:
[docs]
def __init__( # noqa: PLR0913
self,
dbname: str,
host: str,
port: str | int,
user: str,
password: str,
*,
connect_timeout: int = 30,
idle_in_transaction_session_timeout: int = 0,
pool_size: int = 2,
max_overflow: int = 2,
max_waiting: int = 0,
conn_max_age: float = 60 * 60.0,
pre_ping: bool = False,
lock_timeout: int = 0,
schema: str = "",
pool_open_timeout: int | None = None,
get_password_func: Callable[[], str] | None = None,
):
self.idle_in_transaction_session_timeout = idle_in_transaction_session_timeout
self.pre_ping = pre_ping
self.pool_open_timeout = pool_open_timeout
check = ConnectionPool.check_connection if pre_ping else None
self.pool = ConnectionPool(
get_password_func=get_password_func,
connection_class=Connection[DictRow],
kwargs={
"dbname": dbname,
"host": host,
"port": port,
"user": user,
"password": password,
"row_factory": dict_row,
},
min_size=pool_size,
max_size=pool_size + max_overflow,
open=False,
configure=self.after_connect_func(),
timeout=connect_timeout,
max_waiting=max_waiting,
max_lifetime=conn_max_age,
check=check,
)
self.lock_timeout = lock_timeout
self.schema = schema.strip() or "public"
[docs]
def after_connect_func(self) -> Callable[[Connection[Any]], None]:
statement = SQL("SET idle_in_transaction_session_timeout = '{0}s'").format(
self.idle_in_transaction_session_timeout
)
def after_connect(conn: Connection[DictRow]) -> None:
conn.autocommit = True
conn.cursor().execute(statement)
return after_connect
[docs]
@contextmanager
def get_connection(self) -> Iterator[Connection[DictRow]]:
try:
wait = self.pool_open_timeout is not None
timeout = self.pool_open_timeout or 30.0
self.pool.open(wait, timeout)
with self.pool.connection() as conn:
yield conn
except psycopg.InterfaceError as e:
# conn.close()
raise InterfaceError(str(e)) from e
except psycopg.OperationalError as e:
# conn.close()
raise OperationalError(str(e)) from e
except psycopg.DataError as e:
raise DataError(str(e)) from e
except psycopg.IntegrityError as e:
raise IntegrityError(str(e)) from e
except psycopg.InternalError as e:
raise InternalError(str(e)) from e
except psycopg.ProgrammingError as e:
raise ProgrammingError(str(e)) from e
except psycopg.NotSupportedError as e:
raise NotSupportedError(str(e)) from e
except psycopg.DatabaseError as e:
raise DatabaseError(str(e)) from e
except psycopg.Error as e:
# conn.close()
raise PersistenceError(str(e)) from e
except Exception:
# conn.close()
raise
[docs]
@contextmanager
def transaction(self, *, commit: bool = False) -> Iterator[Cursor[DictRow]]:
with self.get_connection() as conn, conn.transaction(force_rollback=not commit):
yield conn.cursor()
[docs]
def close(self) -> None:
self.pool.close()
def __enter__(self) -> Self:
return self
def __exit__(self, *args: object, **kwargs: Any) -> None:
self.close()
def __del__(self) -> None:
self.close()
[docs]
class PostgresRecorder:
"""Base class for recorders that use PostgreSQL."""
[docs]
def __init__(
self,
datastore: PostgresDatastore,
):
self.datastore = datastore
self.create_table_statements = self.construct_create_table_statements()
[docs]
def construct_create_table_statements(self) -> list[Composed]:
return []
[docs]
def check_table_name_length(self, table_name: str) -> None:
if len(table_name) > 63:
msg = f"Table name too long: {table_name}"
raise ProgrammingError(msg)
[docs]
def create_table(self) -> None:
with self.datastore.transaction(commit=True) as curs:
for statement in self.create_table_statements:
curs.execute(statement, prepare=False)
[docs]
class PostgresAggregateRecorder(PostgresRecorder, AggregateRecorder):
[docs]
def __init__(
self,
datastore: PostgresDatastore,
*,
events_table_name: str = "stored_events",
):
super().__init__(datastore)
self.check_table_name_length(events_table_name)
self.events_table_name = events_table_name
# Index names can't be qualified names, but
# are created in the same schema as the table.
self.notification_id_index_name = (
f"{self.events_table_name}_notification_id_idx"
)
self.create_table_statements.append(
SQL(
"CREATE TABLE IF NOT EXISTS {0}.{1} ("
"originator_id uuid NOT NULL, "
"originator_version bigint NOT NULL, "
"topic text, "
"state bytea, "
"PRIMARY KEY "
"(originator_id, originator_version)) "
"WITH (autovacuum_enabled=false)"
).format(
Identifier(self.datastore.schema),
Identifier(self.events_table_name),
)
)
self.insert_events_statement = SQL(
"INSERT INTO {0}.{1} VALUES (%s, %s, %s, %s)"
).format(
Identifier(self.datastore.schema),
Identifier(self.events_table_name),
)
self.select_events_statement = SQL(
"SELECT * FROM {0}.{1} WHERE originator_id = %s"
).format(
Identifier(self.datastore.schema),
Identifier(self.events_table_name),
)
self.lock_table_statements: list[Query] = []
[docs]
@retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
def insert_events(
self, stored_events: list[StoredEvent], **kwargs: Any
) -> Sequence[int] | None:
exc: Exception | None = None
notification_ids: Sequence[int] | None = None
with self.datastore.get_connection() as conn:
with conn.pipeline() as pipeline, conn.transaction():
# Do other things first, so they can be pipelined too.
with conn.cursor() as curs:
self._insert_events(curs, stored_events, **kwargs)
# Then use a different cursor for the executemany() call.
with conn.cursor() as curs:
try:
self._insert_stored_events(curs, stored_events, **kwargs)
# Sync now, so any uniqueness constraint violation causes an
# IntegrityError to be raised here, rather an InternalError
# being raised sometime later e.g. when commit() is called.
pipeline.sync()
notification_ids = self._fetch_ids_after_insert_events(
curs, stored_events, **kwargs
)
except Exception as e:
# Avoid psycopg emitting a pipeline warning.
exc = e
if exc:
# Reraise exception after pipeline context manager has exited.
raise exc
return notification_ids
def _insert_events(
self,
curs: Cursor[DictRow],
stored_events: list[StoredEvent],
**_: Any,
) -> None:
pass
def _insert_stored_events(
self,
curs: Cursor[DictRow],
stored_events: list[StoredEvent],
**_: Any,
) -> None:
# Only do something if there is something to do.
if len(stored_events) > 0:
self._lock_table(curs)
self._notify_channel(curs)
# Insert events.
curs.executemany(
query=self.insert_events_statement,
params_seq=[
(
stored_event.originator_id,
stored_event.originator_version,
stored_event.topic,
stored_event.state,
)
for stored_event in stored_events
],
returning="RETURNING" in self.insert_events_statement.as_string(),
)
def _lock_table(self, curs: Cursor[DictRow]) -> None:
pass
def _notify_channel(self, curs: Cursor[DictRow]) -> None:
pass
def _fetch_ids_after_insert_events(
self,
curs: Cursor[DictRow],
stored_events: list[StoredEvent],
**kwargs: Any,
) -> Sequence[int] | None:
return None
[docs]
@retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
def select_events(
self,
originator_id: UUID,
*,
gt: int | None = None,
lte: int | None = None,
desc: bool = False,
limit: int | None = None,
) -> list[StoredEvent]:
statement = self.select_events_statement
params: list[Any] = [originator_id]
if gt is not None:
params.append(gt)
statement += SQL(" AND originator_version > %s")
if lte is not None:
params.append(lte)
statement += SQL(" AND originator_version <= %s")
statement += SQL(" ORDER BY originator_version")
if desc is False:
statement += SQL(" ASC")
else:
statement += SQL(" DESC")
if limit is not None:
params.append(limit)
statement += SQL(" LIMIT %s")
with self.datastore.get_connection() as conn, conn.cursor() as curs:
curs.execute(statement, params, prepare=True)
return [
StoredEvent(
originator_id=row["originator_id"],
originator_version=row["originator_version"],
topic=row["topic"],
state=bytes(row["state"]),
)
for row in curs.fetchall()
]
[docs]
class PostgresApplicationRecorder(PostgresAggregateRecorder, ApplicationRecorder):
[docs]
def __init__(
self,
datastore: PostgresDatastore,
*,
events_table_name: str = "stored_events",
):
super().__init__(datastore, events_table_name=events_table_name)
self.create_table_statements[-1] = SQL(
"CREATE TABLE IF NOT EXISTS {0}.{1} ("
"originator_id uuid NOT NULL, "
"originator_version bigint NOT NULL, "
"topic text, "
"state bytea, "
"notification_id bigserial, "
"PRIMARY KEY "
"(originator_id, originator_version)) "
"WITH (autovacuum_enabled=false)"
).format(
Identifier(self.datastore.schema),
Identifier(self.events_table_name),
)
self.create_table_statements.append(
SQL(
"CREATE UNIQUE INDEX IF NOT EXISTS {0} "
"ON {1}.{2} (notification_id ASC);"
).format(
Identifier(self.notification_id_index_name),
Identifier(self.datastore.schema),
Identifier(self.events_table_name),
)
)
self.channel_name = self.events_table_name.replace(".", "_")
self.insert_events_statement = self.insert_events_statement + SQL(
" RETURNING notification_id"
)
self.max_notification_id_statement = SQL(
"SELECT MAX(notification_id) FROM {0}.{1}"
).format(
Identifier(self.datastore.schema),
Identifier(self.events_table_name),
)
self.lock_table_statements = [
SQL("SET LOCAL lock_timeout = '{0}s'").format(self.datastore.lock_timeout),
SQL("LOCK TABLE {0}.{1} IN EXCLUSIVE MODE").format(
Identifier(self.datastore.schema),
Identifier(self.events_table_name),
),
]
[docs]
@retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
def select_notifications(
self,
start: int | None,
limit: int,
stop: int | None = None,
topics: Sequence[str] = (),
*,
inclusive_of_start: bool = True,
) -> list[Notification]:
"""Returns a list of event notifications
from 'start', limited by 'limit'.
"""
params: list[int | str | Sequence[str]] = []
statement = SQL("SELECT * FROM {0}.{1}").format(
Identifier(self.datastore.schema),
Identifier(self.events_table_name),
)
has_where = False
if start is not None:
statement += SQL(" WHERE")
has_where = True
params.append(start)
if inclusive_of_start:
statement += SQL(" notification_id>=%s")
else:
statement += SQL(" notification_id>%s")
if stop is not None:
if not has_where:
has_where = True
statement += SQL(" WHERE")
else:
statement += SQL(" AND")
params.append(stop)
statement += SQL(" notification_id <= %s")
if topics:
# Check sequence and ensure list of strings.
assert isinstance(topics, (tuple, list)), topics
topics = list(topics) if isinstance(topics, tuple) else topics
assert all(isinstance(t, str) for t in topics), topics
if not has_where:
statement += SQL(" WHERE")
else:
statement += SQL(" AND")
params.append(topics)
statement += SQL(" topic = ANY(%s)")
params.append(limit)
statement += SQL(" ORDER BY notification_id LIMIT %s")
connection = self.datastore.get_connection()
with connection as conn, conn.cursor() as curs:
curs.execute(statement, params, prepare=True)
return [
Notification(
id=row["notification_id"],
originator_id=row["originator_id"],
originator_version=row["originator_version"],
topic=row["topic"],
state=bytes(row["state"]),
)
for row in curs.fetchall()
]
[docs]
@retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
def max_notification_id(self) -> int | None:
"""Returns the maximum notification ID."""
with self.datastore.get_connection() as conn, conn.cursor() as curs:
curs.execute(self.max_notification_id_statement)
fetchone = curs.fetchone()
assert fetchone is not None
return fetchone["max"]
def _lock_table(self, curs: Cursor[DictRow]) -> None:
# Acquire "EXCLUSIVE" table lock, to serialize transactions that insert
# stored events, so that readers don't pass over gaps that are filled in
# later. We want each transaction that will be issued with notifications
# IDs by the notification ID sequence to receive all its notification IDs
# and then commit, before another transaction is issued with any notification
# IDs. In other words, we want the insert order to be the same as the commit
# order. We can accomplish this by locking the table for writes. The
# EXCLUSIVE lock mode does not block SELECT statements, which acquire an
# ACCESS SHARE lock, so the stored events table can be read concurrently
# with writes and other reads. However, INSERT statements normally just
# acquires ROW EXCLUSIVE locks, which risks the interleaving (within the
# recorded sequence of notification IDs) of stored events from one transaction
# with those of another transaction. And since one transaction will always
# commit before another, the possibility arises when using ROW EXCLUSIVE locks
# for readers that are tailing a notification log to miss items inserted later
# but issued with lower notification IDs.
# https://www.postgresql.org/docs/current/explicit-locking.html#LOCKING-TABLES
# https://www.postgresql.org/docs/9.1/sql-lock.html
# https://stackoverflow.com/questions/45866187/guarantee-monotonicity-of
# -postgresql-serial-column-values-by-commit-order
for lock_statement in self.lock_table_statements:
curs.execute(lock_statement, prepare=True)
def _notify_channel(self, curs: Cursor[DictRow]) -> None:
curs.execute(SQL("NOTIFY {0}").format(Identifier(self.channel_name)))
def _fetch_ids_after_insert_events(
self,
curs: Cursor[DictRow],
stored_events: list[StoredEvent],
**kwargs: Any,
) -> Sequence[int] | None:
notification_ids: list[int] = []
len_events = len(stored_events)
if len_events:
while curs.nextset() and len(notification_ids) != len_events:
if curs.statusmessage and curs.statusmessage.startswith("INSERT"):
row = curs.fetchone()
assert row is not None
notification_ids.append(row["notification_id"])
if len(notification_ids) != len(stored_events):
msg = "Couldn't get all notification IDs "
msg += f"(got {len(notification_ids)}, expected {len(stored_events)})"
raise ProgrammingError(msg)
return notification_ids
[docs]
def subscribe(
self, gt: int | None = None, topics: Sequence[str] = ()
) -> Subscription[ApplicationRecorder]:
return PostgresSubscription(recorder=self, gt=gt, topics=topics)
[docs]
class PostgresSubscription(ListenNotifySubscription[PostgresApplicationRecorder]):
[docs]
def __init__(
self,
recorder: PostgresApplicationRecorder,
gt: int | None = None,
topics: Sequence[str] = (),
) -> None:
assert isinstance(recorder, PostgresApplicationRecorder)
super().__init__(recorder=recorder, gt=gt, topics=topics)
self._listen_thread = Thread(target=self._listen)
self._listen_thread.start()
def __exit__(self, *args: object, **kwargs: Any) -> None:
super().__exit__(*args, **kwargs)
self._listen_thread.join()
def _listen(self) -> None:
try:
with self._recorder.datastore.get_connection() as conn:
conn.execute(
SQL("LISTEN {0}").format(Identifier(self._recorder.channel_name))
)
while not self._has_been_stopped and not self._thread_error:
# This block simplifies psycopg's conn.notifies(), because
# we aren't interested in the actual notify messages, and
# also we want to stop consuming notify messages when the
# subscription has an error or is otherwise stopped.
with conn.lock:
try:
if conn.wait(notifies(conn.pgconn), interval=0.1):
self._has_been_notified.set()
except NO_TRACEBACK as ex: # pragma: no cover
raise ex.with_traceback(None) from None
except BaseException as e:
if self._thread_error is None:
self._thread_error = e
self.stop()
[docs]
class PostgresTrackingRecorder(PostgresRecorder, TrackingRecorder):
[docs]
def __init__(
self,
datastore: PostgresDatastore,
*,
tracking_table_name: str = "notification_tracking",
**kwargs: Any,
):
super().__init__(datastore, **kwargs)
self.check_table_name_length(tracking_table_name)
self.tracking_table_name = tracking_table_name
self.create_table_statements.append(
SQL(
"CREATE TABLE IF NOT EXISTS {0}.{1} ("
"application_name text, "
"notification_id bigint, "
"PRIMARY KEY "
"(application_name, notification_id))"
).format(
Identifier(self.datastore.schema),
Identifier(self.tracking_table_name),
)
)
self.insert_tracking_statement = SQL(
"INSERT INTO {0}.{1} VALUES (%s, %s)"
).format(
Identifier(self.datastore.schema),
Identifier(self.tracking_table_name),
)
self.max_tracking_id_statement = SQL(
"SELECT MAX(notification_id) FROM {0}.{1} WHERE application_name=%s"
).format(
Identifier(self.datastore.schema),
Identifier(self.tracking_table_name),
)
self.count_tracking_id_statement = SQL(
"SELECT COUNT(*) FROM {0}.{1} "
"WHERE application_name=%s AND notification_id=%s"
).format(
Identifier(self.datastore.schema),
Identifier(self.tracking_table_name),
)
[docs]
@retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
def insert_tracking(self, tracking: Tracking) -> None:
with (
self.datastore.get_connection() as conn,
conn.transaction(),
conn.cursor() as curs,
):
self._insert_tracking(curs, tracking)
[docs]
def _insert_tracking(
self,
curs: Cursor[DictRow],
tracking: Tracking,
) -> None:
curs.execute(
query=self.insert_tracking_statement,
params=(
tracking.application_name,
tracking.notification_id,
),
prepare=True,
)
[docs]
@retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
def max_tracking_id(self, application_name: str) -> int | None:
with self.datastore.get_connection() as conn, conn.cursor() as curs:
curs.execute(
query=self.max_tracking_id_statement,
params=(application_name,),
prepare=True,
)
fetchone = curs.fetchone()
assert fetchone is not None
return fetchone["max"]
[docs]
@retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
def has_tracking_id(
self, application_name: str, notification_id: int | None
) -> bool:
if notification_id is None:
return True
with self.datastore.get_connection() as conn, conn.cursor() as curs:
curs.execute(
query=self.count_tracking_id_statement,
params=(application_name, notification_id),
prepare=True,
)
fetchone = curs.fetchone()
assert fetchone is not None
return bool(fetchone["count"])
TPostgresTrackingRecorder = TypeVar(
"TPostgresTrackingRecorder",
bound=PostgresTrackingRecorder,
default=PostgresTrackingRecorder,
)
[docs]
class PostgresProcessRecorder(
PostgresTrackingRecorder, PostgresApplicationRecorder, ProcessRecorder
):
[docs]
def __init__(
self,
datastore: PostgresDatastore,
*,
events_table_name: str = "stored_events",
tracking_table_name: str = "notification_tracking",
):
super().__init__(
datastore,
tracking_table_name=tracking_table_name,
events_table_name=events_table_name,
)
def _insert_events(
self,
curs: Cursor[DictRow],
stored_events: list[StoredEvent],
**kwargs: Any,
) -> None:
tracking: Tracking | None = kwargs.get("tracking")
if tracking is not None:
self._insert_tracking(curs, tracking=tracking)
super()._insert_events(curs, stored_events, **kwargs)
[docs]
class PostgresFactory(InfrastructureFactory[PostgresTrackingRecorder]):
POSTGRES_DBNAME = "POSTGRES_DBNAME"
POSTGRES_HOST = "POSTGRES_HOST"
POSTGRES_PORT = "POSTGRES_PORT"
POSTGRES_USER = "POSTGRES_USER"
POSTGRES_PASSWORD = "POSTGRES_PASSWORD" # noqa: S105
POSTGRES_GET_PASSWORD_TOPIC = "POSTGRES_GET_PASSWORD_TOPIC" # noqa: S105
POSTGRES_CONNECT_TIMEOUT = "POSTGRES_CONNECT_TIMEOUT"
POSTGRES_CONN_MAX_AGE = "POSTGRES_CONN_MAX_AGE"
POSTGRES_PRE_PING = "POSTGRES_PRE_PING"
POSTGRES_MAX_WAITING = "POSTGRES_MAX_WAITING"
POSTGRES_LOCK_TIMEOUT = "POSTGRES_LOCK_TIMEOUT"
POSTGRES_POOL_SIZE = "POSTGRES_POOL_SIZE"
POSTGRES_MAX_OVERFLOW = "POSTGRES_MAX_OVERFLOW"
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT = (
"POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT"
)
POSTGRES_SCHEMA = "POSTGRES_SCHEMA"
CREATE_TABLE = "CREATE_TABLE"
aggregate_recorder_class = PostgresAggregateRecorder
application_recorder_class = PostgresApplicationRecorder
tracking_recorder_class = PostgresTrackingRecorder
process_recorder_class = PostgresProcessRecorder
[docs]
def __init__(self, env: Environment):
super().__init__(env)
dbname = self.env.get(self.POSTGRES_DBNAME)
if dbname is None:
msg = (
"Postgres database name not found "
"in environment with key "
f"'{self.POSTGRES_DBNAME}'"
)
# TODO: Indicate both keys here, also for other environment variables.
# ) + " or ".join(
# [f"'{key}'" for key in self.env.create_keys(self.POSTGRES_DBNAME)]
# )
raise OSError(msg)
host = self.env.get(self.POSTGRES_HOST)
if host is None:
msg = (
"Postgres host not found "
"in environment with key "
f"'{self.POSTGRES_HOST}'"
)
raise OSError(msg)
port = self.env.get(self.POSTGRES_PORT) or "5432"
user = self.env.get(self.POSTGRES_USER)
if user is None:
msg = (
"Postgres user not found "
"in environment with key "
f"'{self.POSTGRES_USER}'"
)
raise OSError(msg)
get_password_func = None
get_password_topic = self.env.get(self.POSTGRES_GET_PASSWORD_TOPIC)
if not get_password_topic:
password = self.env.get(self.POSTGRES_PASSWORD)
if password is None:
msg = (
"Postgres password not found "
"in environment with key "
f"'{self.POSTGRES_PASSWORD}'"
)
raise OSError(msg)
else:
get_password_func = resolve_topic(get_password_topic)
password = ""
connect_timeout = 30
connect_timeout_str = self.env.get(self.POSTGRES_CONNECT_TIMEOUT)
if connect_timeout_str:
try:
connect_timeout = int(connect_timeout_str)
except ValueError:
msg = (
"Postgres environment value for key "
f"'{self.POSTGRES_CONNECT_TIMEOUT}' is invalid. "
"If set, an integer or empty string is expected: "
f"'{connect_timeout_str}'"
)
raise OSError(msg) from None
idle_in_transaction_session_timeout_str = (
self.env.get(self.POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT) or "5"
)
try:
idle_in_transaction_session_timeout = int(
idle_in_transaction_session_timeout_str
)
except ValueError:
msg = (
"Postgres environment value for key "
f"'{self.POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT}' is invalid. "
"If set, an integer or empty string is expected: "
f"'{idle_in_transaction_session_timeout_str}'"
)
raise OSError(msg) from None
pool_size = 5
pool_size_str = self.env.get(self.POSTGRES_POOL_SIZE)
if pool_size_str:
try:
pool_size = int(pool_size_str)
except ValueError:
msg = (
"Postgres environment value for key "
f"'{self.POSTGRES_POOL_SIZE}' is invalid. "
"If set, an integer or empty string is expected: "
f"'{pool_size_str}'"
)
raise OSError(msg) from None
pool_max_overflow = 10
pool_max_overflow_str = self.env.get(self.POSTGRES_MAX_OVERFLOW)
if pool_max_overflow_str:
try:
pool_max_overflow = int(pool_max_overflow_str)
except ValueError:
msg = (
"Postgres environment value for key "
f"'{self.POSTGRES_MAX_OVERFLOW}' is invalid. "
"If set, an integer or empty string is expected: "
f"'{pool_max_overflow_str}'"
)
raise OSError(msg) from None
max_waiting = 0
max_waiting_str = self.env.get(self.POSTGRES_MAX_WAITING)
if max_waiting_str:
try:
max_waiting = int(max_waiting_str)
except ValueError:
msg = (
"Postgres environment value for key "
f"'{self.POSTGRES_MAX_WAITING}' is invalid. "
"If set, an integer or empty string is expected: "
f"'{max_waiting_str}'"
)
raise OSError(msg) from None
conn_max_age = 60 * 60.0
conn_max_age_str = self.env.get(self.POSTGRES_CONN_MAX_AGE)
if conn_max_age_str:
try:
conn_max_age = float(conn_max_age_str)
except ValueError:
msg = (
"Postgres environment value for key "
f"'{self.POSTGRES_CONN_MAX_AGE}' is invalid. "
"If set, a float or empty string is expected: "
f"'{conn_max_age_str}'"
)
raise OSError(msg) from None
pre_ping = strtobool(self.env.get(self.POSTGRES_PRE_PING) or "no")
lock_timeout_str = self.env.get(self.POSTGRES_LOCK_TIMEOUT) or "0"
try:
lock_timeout = int(lock_timeout_str)
except ValueError:
msg = (
"Postgres environment value for key "
f"'{self.POSTGRES_LOCK_TIMEOUT}' is invalid. "
"If set, an integer or empty string is expected: "
f"'{lock_timeout_str}'"
)
raise OSError(msg) from None
schema = self.env.get(self.POSTGRES_SCHEMA) or ""
self.datastore = PostgresDatastore(
dbname=dbname,
host=host,
port=port,
user=user,
password=password,
get_password_func=get_password_func,
connect_timeout=connect_timeout,
idle_in_transaction_session_timeout=idle_in_transaction_session_timeout,
pool_size=pool_size,
max_overflow=pool_max_overflow,
max_waiting=max_waiting,
conn_max_age=conn_max_age,
pre_ping=pre_ping,
lock_timeout=lock_timeout,
schema=schema,
)
[docs]
def env_create_table(self) -> bool:
return strtobool(self.env.get(self.CREATE_TABLE) or "yes")
[docs]
def aggregate_recorder(self, purpose: str = "events") -> AggregateRecorder:
prefix = self.env.name.lower() or "stored"
events_table_name = prefix + "_" + purpose
recorder = type(self).aggregate_recorder_class(
datastore=self.datastore,
events_table_name=events_table_name,
)
if self.env_create_table():
recorder.create_table()
return recorder
[docs]
def application_recorder(self) -> ApplicationRecorder:
prefix = self.env.name.lower() or "stored"
events_table_name = prefix + "_events"
application_recorder_topic = self.env.get(self.APPLICATION_RECORDER_TOPIC)
if application_recorder_topic:
application_recorder_class: type[PostgresApplicationRecorder] = (
resolve_topic(application_recorder_topic)
)
assert issubclass(application_recorder_class, PostgresApplicationRecorder)
else:
application_recorder_class = type(self).application_recorder_class
recorder = application_recorder_class(
datastore=self.datastore,
events_table_name=events_table_name,
)
if self.env_create_table():
recorder.create_table()
return recorder
[docs]
def tracking_recorder(
self, tracking_recorder_class: type[TPostgresTrackingRecorder] | None = None
) -> TPostgresTrackingRecorder:
prefix = self.env.name.lower() or "notification"
tracking_table_name = prefix + "_tracking"
if tracking_recorder_class is None:
tracking_recorder_topic = self.env.get(self.TRACKING_RECORDER_TOPIC)
if tracking_recorder_topic:
tracking_recorder_class = resolve_topic(tracking_recorder_topic)
else:
tracking_recorder_class = cast(
"type[TPostgresTrackingRecorder]",
type(self).tracking_recorder_class,
)
assert tracking_recorder_class is not None
assert issubclass(tracking_recorder_class, PostgresTrackingRecorder)
recorder = tracking_recorder_class(
datastore=self.datastore,
tracking_table_name=tracking_table_name,
)
if self.env_create_table():
recorder.create_table()
return recorder
[docs]
def process_recorder(self) -> ProcessRecorder:
prefix = self.env.name.lower() or "stored"
events_table_name = prefix + "_events"
prefix = self.env.name.lower() or "notification"
tracking_table_name = prefix + "_tracking"
process_recorder_topic = self.env.get(self.PROCESS_RECORDER_TOPIC)
if process_recorder_topic:
process_recorder_class: type[PostgresTrackingRecorder] = resolve_topic(
process_recorder_topic
)
assert issubclass(process_recorder_class, PostgresProcessRecorder)
else:
process_recorder_class = type(self).process_recorder_class
recorder = process_recorder_class(
datastore=self.datastore,
events_table_name=events_table_name,
tracking_table_name=tracking_table_name,
)
if self.env_create_table():
recorder.create_table()
return recorder
[docs]
def close(self) -> None:
with contextlib.suppress(AttributeError):
self.datastore.close()
Factory = PostgresFactory