Source code for eventsourcing.postgres

from __future__ import annotations

import contextlib
import logging
from asyncio import CancelledError
from contextlib import contextmanager
from threading import Thread
from typing import TYPE_CHECKING, Any, Callable, cast

import psycopg
import psycopg.errors
import psycopg_pool
from psycopg import Connection, Cursor, Error
from psycopg.generators import notifies
from psycopg.rows import DictRow, dict_row
from psycopg.sql import SQL, Composed, Identifier
from typing_extensions import TypeVar

from eventsourcing.persistence import (
    AggregateRecorder,
    ApplicationRecorder,
    DatabaseError,
    DataError,
    InfrastructureFactory,
    IntegrityError,
    InterfaceError,
    InternalError,
    ListenNotifySubscription,
    Notification,
    NotSupportedError,
    OperationalError,
    PersistenceError,
    ProcessRecorder,
    ProgrammingError,
    StoredEvent,
    Subscription,
    Tracking,
    TrackingRecorder,
)
from eventsourcing.utils import Environment, resolve_topic, retry, strtobool

if TYPE_CHECKING:
    from collections.abc import Iterator, Sequence
    from uuid import UUID

    from psycopg.abc import Query
    from typing_extensions import Self

logging.getLogger("psycopg.pool").setLevel(logging.CRITICAL)
logging.getLogger("psycopg").setLevel(logging.CRITICAL)

# Copy of "private" psycopg.errors._NO_TRACEBACK (in case it changes)
# From psycopg: "Don't show a complete traceback upon raising these exception.
# Usually the traceback starts from internal functions (for instance in the
# server communication callbacks) but, for the end user, it's more important
# to get the high level information about where the exception was raised, for
# instance in a certain `Cursor.execute()`."
NO_TRACEBACK = (Error, KeyboardInterrupt, CancelledError)


[docs] class ConnectionPool(psycopg_pool.ConnectionPool[Any]):
[docs] def __init__( self, *args: Any, get_password_func: Callable[[], str] | None = None, **kwargs: Any, ) -> None: self.get_password_func = get_password_func super().__init__(*args, **kwargs)
def _connect(self, timeout: float | None = None) -> Connection[Any]: if self.get_password_func: self.kwargs["password"] = self.get_password_func() return super()._connect(timeout=timeout)
[docs] class PostgresDatastore:
[docs] def __init__( # noqa: PLR0913 self, dbname: str, host: str, port: str | int, user: str, password: str, *, connect_timeout: int = 30, idle_in_transaction_session_timeout: int = 0, pool_size: int = 2, max_overflow: int = 2, max_waiting: int = 0, conn_max_age: float = 60 * 60.0, pre_ping: bool = False, lock_timeout: int = 0, schema: str = "", pool_open_timeout: int | None = None, get_password_func: Callable[[], str] | None = None, ): self.idle_in_transaction_session_timeout = idle_in_transaction_session_timeout self.pre_ping = pre_ping self.pool_open_timeout = pool_open_timeout check = ConnectionPool.check_connection if pre_ping else None self.pool = ConnectionPool( get_password_func=get_password_func, connection_class=Connection[DictRow], kwargs={ "dbname": dbname, "host": host, "port": port, "user": user, "password": password, "row_factory": dict_row, }, min_size=pool_size, max_size=pool_size + max_overflow, open=False, configure=self.after_connect_func(), timeout=connect_timeout, max_waiting=max_waiting, max_lifetime=conn_max_age, check=check, ) self.lock_timeout = lock_timeout self.schema = schema.strip() or "public"
[docs] def after_connect_func(self) -> Callable[[Connection[Any]], None]: statement = SQL("SET idle_in_transaction_session_timeout = '{0}s'").format( self.idle_in_transaction_session_timeout ) def after_connect(conn: Connection[DictRow]) -> None: conn.autocommit = True conn.cursor().execute(statement) return after_connect
[docs] @contextmanager def get_connection(self) -> Iterator[Connection[DictRow]]: try: wait = self.pool_open_timeout is not None timeout = self.pool_open_timeout or 30.0 self.pool.open(wait, timeout) with self.pool.connection() as conn: yield conn except psycopg.InterfaceError as e: # conn.close() raise InterfaceError(str(e)) from e except psycopg.OperationalError as e: # conn.close() raise OperationalError(str(e)) from e except psycopg.DataError as e: raise DataError(str(e)) from e except psycopg.IntegrityError as e: raise IntegrityError(str(e)) from e except psycopg.InternalError as e: raise InternalError(str(e)) from e except psycopg.ProgrammingError as e: raise ProgrammingError(str(e)) from e except psycopg.NotSupportedError as e: raise NotSupportedError(str(e)) from e except psycopg.DatabaseError as e: raise DatabaseError(str(e)) from e except psycopg.Error as e: # conn.close() raise PersistenceError(str(e)) from e except Exception: # conn.close() raise
[docs] @contextmanager def transaction(self, *, commit: bool = False) -> Iterator[Cursor[DictRow]]: with self.get_connection() as conn, conn.transaction(force_rollback=not commit): yield conn.cursor()
[docs] def close(self) -> None: self.pool.close()
def __enter__(self) -> Self: return self def __exit__(self, *args: object, **kwargs: Any) -> None: self.close() def __del__(self) -> None: self.close()
[docs] class PostgresRecorder: """Base class for recorders that use PostgreSQL."""
[docs] def __init__( self, datastore: PostgresDatastore, ): self.datastore = datastore self.create_table_statements = self.construct_create_table_statements()
[docs] def construct_create_table_statements(self) -> list[Composed]: return []
[docs] def check_table_name_length(self, table_name: str) -> None: if len(table_name) > 63: msg = f"Table name too long: {table_name}" raise ProgrammingError(msg)
[docs] def create_table(self) -> None: with self.datastore.transaction(commit=True) as curs: for statement in self.create_table_statements: curs.execute(statement, prepare=False)
[docs] class PostgresAggregateRecorder(PostgresRecorder, AggregateRecorder):
[docs] def __init__( self, datastore: PostgresDatastore, *, events_table_name: str = "stored_events", ): super().__init__(datastore) self.check_table_name_length(events_table_name) self.events_table_name = events_table_name # Index names can't be qualified names, but # are created in the same schema as the table. self.notification_id_index_name = ( f"{self.events_table_name}_notification_id_idx" ) self.create_table_statements.append( SQL( "CREATE TABLE IF NOT EXISTS {0}.{1} (" "originator_id uuid NOT NULL, " "originator_version bigint NOT NULL, " "topic text, " "state bytea, " "PRIMARY KEY " "(originator_id, originator_version)) " "WITH (autovacuum_enabled=false)" ).format( Identifier(self.datastore.schema), Identifier(self.events_table_name), ) ) self.insert_events_statement = SQL( "INSERT INTO {0}.{1} VALUES (%s, %s, %s, %s)" ).format( Identifier(self.datastore.schema), Identifier(self.events_table_name), ) self.select_events_statement = SQL( "SELECT * FROM {0}.{1} WHERE originator_id = %s" ).format( Identifier(self.datastore.schema), Identifier(self.events_table_name), ) self.lock_table_statements: list[Query] = []
[docs] @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2) def insert_events( self, stored_events: list[StoredEvent], **kwargs: Any ) -> Sequence[int] | None: exc: Exception | None = None notification_ids: Sequence[int] | None = None with self.datastore.get_connection() as conn: with conn.pipeline() as pipeline, conn.transaction(): # Do other things first, so they can be pipelined too. with conn.cursor() as curs: self._insert_events(curs, stored_events, **kwargs) # Then use a different cursor for the executemany() call. with conn.cursor() as curs: try: self._insert_stored_events(curs, stored_events, **kwargs) # Sync now, so any uniqueness constraint violation causes an # IntegrityError to be raised here, rather an InternalError # being raised sometime later e.g. when commit() is called. pipeline.sync() notification_ids = self._fetch_ids_after_insert_events( curs, stored_events, **kwargs ) except Exception as e: # Avoid psycopg emitting a pipeline warning. exc = e if exc: # Reraise exception after pipeline context manager has exited. raise exc return notification_ids
def _insert_events( self, curs: Cursor[DictRow], stored_events: list[StoredEvent], **_: Any, ) -> None: pass def _insert_stored_events( self, curs: Cursor[DictRow], stored_events: list[StoredEvent], **_: Any, ) -> None: # Only do something if there is something to do. if len(stored_events) > 0: self._lock_table(curs) self._notify_channel(curs) # Insert events. curs.executemany( query=self.insert_events_statement, params_seq=[ ( stored_event.originator_id, stored_event.originator_version, stored_event.topic, stored_event.state, ) for stored_event in stored_events ], returning="RETURNING" in self.insert_events_statement.as_string(), ) def _lock_table(self, curs: Cursor[DictRow]) -> None: pass def _notify_channel(self, curs: Cursor[DictRow]) -> None: pass def _fetch_ids_after_insert_events( self, curs: Cursor[DictRow], stored_events: list[StoredEvent], **kwargs: Any, ) -> Sequence[int] | None: return None
[docs] @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2) def select_events( self, originator_id: UUID, *, gt: int | None = None, lte: int | None = None, desc: bool = False, limit: int | None = None, ) -> list[StoredEvent]: statement = self.select_events_statement params: list[Any] = [originator_id] if gt is not None: params.append(gt) statement += SQL(" AND originator_version > %s") if lte is not None: params.append(lte) statement += SQL(" AND originator_version <= %s") statement += SQL(" ORDER BY originator_version") if desc is False: statement += SQL(" ASC") else: statement += SQL(" DESC") if limit is not None: params.append(limit) statement += SQL(" LIMIT %s") with self.datastore.get_connection() as conn, conn.cursor() as curs: curs.execute(statement, params, prepare=True) return [ StoredEvent( originator_id=row["originator_id"], originator_version=row["originator_version"], topic=row["topic"], state=bytes(row["state"]), ) for row in curs.fetchall() ]
[docs] class PostgresApplicationRecorder(PostgresAggregateRecorder, ApplicationRecorder):
[docs] def __init__( self, datastore: PostgresDatastore, *, events_table_name: str = "stored_events", ): super().__init__(datastore, events_table_name=events_table_name) self.create_table_statements[-1] = SQL( "CREATE TABLE IF NOT EXISTS {0}.{1} (" "originator_id uuid NOT NULL, " "originator_version bigint NOT NULL, " "topic text, " "state bytea, " "notification_id bigserial, " "PRIMARY KEY " "(originator_id, originator_version)) " "WITH (autovacuum_enabled=false)" ).format( Identifier(self.datastore.schema), Identifier(self.events_table_name), ) self.create_table_statements.append( SQL( "CREATE UNIQUE INDEX IF NOT EXISTS {0} " "ON {1}.{2} (notification_id ASC);" ).format( Identifier(self.notification_id_index_name), Identifier(self.datastore.schema), Identifier(self.events_table_name), ) ) self.channel_name = self.events_table_name.replace(".", "_") self.insert_events_statement = self.insert_events_statement + SQL( " RETURNING notification_id" ) self.max_notification_id_statement = SQL( "SELECT MAX(notification_id) FROM {0}.{1}" ).format( Identifier(self.datastore.schema), Identifier(self.events_table_name), ) self.lock_table_statements = [ SQL("SET LOCAL lock_timeout = '{0}s'").format(self.datastore.lock_timeout), SQL("LOCK TABLE {0}.{1} IN EXCLUSIVE MODE").format( Identifier(self.datastore.schema), Identifier(self.events_table_name), ), ]
[docs] @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2) def select_notifications( self, start: int | None, limit: int, stop: int | None = None, topics: Sequence[str] = (), *, inclusive_of_start: bool = True, ) -> list[Notification]: """Returns a list of event notifications from 'start', limited by 'limit'. """ params: list[int | str | Sequence[str]] = [] statement = SQL("SELECT * FROM {0}.{1}").format( Identifier(self.datastore.schema), Identifier(self.events_table_name), ) has_where = False if start is not None: statement += SQL(" WHERE") has_where = True params.append(start) if inclusive_of_start: statement += SQL(" notification_id>=%s") else: statement += SQL(" notification_id>%s") if stop is not None: if not has_where: has_where = True statement += SQL(" WHERE") else: statement += SQL(" AND") params.append(stop) statement += SQL(" notification_id <= %s") if topics: # Check sequence and ensure list of strings. assert isinstance(topics, (tuple, list)), topics topics = list(topics) if isinstance(topics, tuple) else topics assert all(isinstance(t, str) for t in topics), topics if not has_where: statement += SQL(" WHERE") else: statement += SQL(" AND") params.append(topics) statement += SQL(" topic = ANY(%s)") params.append(limit) statement += SQL(" ORDER BY notification_id LIMIT %s") connection = self.datastore.get_connection() with connection as conn, conn.cursor() as curs: curs.execute(statement, params, prepare=True) return [ Notification( id=row["notification_id"], originator_id=row["originator_id"], originator_version=row["originator_version"], topic=row["topic"], state=bytes(row["state"]), ) for row in curs.fetchall() ]
[docs] @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2) def max_notification_id(self) -> int | None: """Returns the maximum notification ID.""" with self.datastore.get_connection() as conn, conn.cursor() as curs: curs.execute(self.max_notification_id_statement) fetchone = curs.fetchone() assert fetchone is not None return fetchone["max"]
def _lock_table(self, curs: Cursor[DictRow]) -> None: # Acquire "EXCLUSIVE" table lock, to serialize transactions that insert # stored events, so that readers don't pass over gaps that are filled in # later. We want each transaction that will be issued with notifications # IDs by the notification ID sequence to receive all its notification IDs # and then commit, before another transaction is issued with any notification # IDs. In other words, we want the insert order to be the same as the commit # order. We can accomplish this by locking the table for writes. The # EXCLUSIVE lock mode does not block SELECT statements, which acquire an # ACCESS SHARE lock, so the stored events table can be read concurrently # with writes and other reads. However, INSERT statements normally just # acquires ROW EXCLUSIVE locks, which risks the interleaving (within the # recorded sequence of notification IDs) of stored events from one transaction # with those of another transaction. And since one transaction will always # commit before another, the possibility arises when using ROW EXCLUSIVE locks # for readers that are tailing a notification log to miss items inserted later # but issued with lower notification IDs. # https://www.postgresql.org/docs/current/explicit-locking.html#LOCKING-TABLES # https://www.postgresql.org/docs/9.1/sql-lock.html # https://stackoverflow.com/questions/45866187/guarantee-monotonicity-of # -postgresql-serial-column-values-by-commit-order for lock_statement in self.lock_table_statements: curs.execute(lock_statement, prepare=True) def _notify_channel(self, curs: Cursor[DictRow]) -> None: curs.execute(SQL("NOTIFY {0}").format(Identifier(self.channel_name))) def _fetch_ids_after_insert_events( self, curs: Cursor[DictRow], stored_events: list[StoredEvent], **kwargs: Any, ) -> Sequence[int] | None: notification_ids: list[int] = [] len_events = len(stored_events) if len_events: while curs.nextset() and len(notification_ids) != len_events: if curs.statusmessage and curs.statusmessage.startswith("INSERT"): row = curs.fetchone() assert row is not None notification_ids.append(row["notification_id"]) if len(notification_ids) != len(stored_events): msg = "Couldn't get all notification IDs " msg += f"(got {len(notification_ids)}, expected {len(stored_events)})" raise ProgrammingError(msg) return notification_ids
[docs] def subscribe( self, gt: int | None = None, topics: Sequence[str] = () ) -> Subscription[ApplicationRecorder]: return PostgresSubscription(recorder=self, gt=gt, topics=topics)
[docs] class PostgresSubscription(ListenNotifySubscription[PostgresApplicationRecorder]):
[docs] def __init__( self, recorder: PostgresApplicationRecorder, gt: int | None = None, topics: Sequence[str] = (), ) -> None: assert isinstance(recorder, PostgresApplicationRecorder) super().__init__(recorder=recorder, gt=gt, topics=topics) self._listen_thread = Thread(target=self._listen) self._listen_thread.start()
def __exit__(self, *args: object, **kwargs: Any) -> None: super().__exit__(*args, **kwargs) self._listen_thread.join() def _listen(self) -> None: try: with self._recorder.datastore.get_connection() as conn: conn.execute( SQL("LISTEN {0}").format(Identifier(self._recorder.channel_name)) ) while not self._has_been_stopped and not self._thread_error: # This block simplifies psycopg's conn.notifies(), because # we aren't interested in the actual notify messages, and # also we want to stop consuming notify messages when the # subscription has an error or is otherwise stopped. with conn.lock: try: if conn.wait(notifies(conn.pgconn), interval=0.1): self._has_been_notified.set() except NO_TRACEBACK as ex: # pragma: no cover raise ex.with_traceback(None) from None except BaseException as e: if self._thread_error is None: self._thread_error = e self.stop()
[docs] class PostgresTrackingRecorder(PostgresRecorder, TrackingRecorder):
[docs] def __init__( self, datastore: PostgresDatastore, *, tracking_table_name: str = "notification_tracking", **kwargs: Any, ): super().__init__(datastore, **kwargs) self.check_table_name_length(tracking_table_name) self.tracking_table_name = tracking_table_name self.create_table_statements.append( SQL( "CREATE TABLE IF NOT EXISTS {0}.{1} (" "application_name text, " "notification_id bigint, " "PRIMARY KEY " "(application_name, notification_id))" ).format( Identifier(self.datastore.schema), Identifier(self.tracking_table_name), ) ) self.insert_tracking_statement = SQL( "INSERT INTO {0}.{1} VALUES (%s, %s)" ).format( Identifier(self.datastore.schema), Identifier(self.tracking_table_name), ) self.max_tracking_id_statement = SQL( "SELECT MAX(notification_id) FROM {0}.{1} WHERE application_name=%s" ).format( Identifier(self.datastore.schema), Identifier(self.tracking_table_name), ) self.count_tracking_id_statement = SQL( "SELECT COUNT(*) FROM {0}.{1} " "WHERE application_name=%s AND notification_id=%s" ).format( Identifier(self.datastore.schema), Identifier(self.tracking_table_name), )
[docs] @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2) def insert_tracking(self, tracking: Tracking) -> None: with ( self.datastore.get_connection() as conn, conn.transaction(), conn.cursor() as curs, ): self._insert_tracking(curs, tracking)
[docs] def _insert_tracking( self, curs: Cursor[DictRow], tracking: Tracking, ) -> None: curs.execute( query=self.insert_tracking_statement, params=( tracking.application_name, tracking.notification_id, ), prepare=True, )
[docs] @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2) def max_tracking_id(self, application_name: str) -> int | None: with self.datastore.get_connection() as conn, conn.cursor() as curs: curs.execute( query=self.max_tracking_id_statement, params=(application_name,), prepare=True, ) fetchone = curs.fetchone() assert fetchone is not None return fetchone["max"]
[docs] @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2) def has_tracking_id( self, application_name: str, notification_id: int | None ) -> bool: if notification_id is None: return True with self.datastore.get_connection() as conn, conn.cursor() as curs: curs.execute( query=self.count_tracking_id_statement, params=(application_name, notification_id), prepare=True, ) fetchone = curs.fetchone() assert fetchone is not None return bool(fetchone["count"])
TPostgresTrackingRecorder = TypeVar( "TPostgresTrackingRecorder", bound=PostgresTrackingRecorder, default=PostgresTrackingRecorder, )
[docs] class PostgresProcessRecorder( PostgresTrackingRecorder, PostgresApplicationRecorder, ProcessRecorder ):
[docs] def __init__( self, datastore: PostgresDatastore, *, events_table_name: str = "stored_events", tracking_table_name: str = "notification_tracking", ): super().__init__( datastore, tracking_table_name=tracking_table_name, events_table_name=events_table_name, )
def _insert_events( self, curs: Cursor[DictRow], stored_events: list[StoredEvent], **kwargs: Any, ) -> None: tracking: Tracking | None = kwargs.get("tracking") if tracking is not None: self._insert_tracking(curs, tracking=tracking) super()._insert_events(curs, stored_events, **kwargs)
[docs] class PostgresFactory(InfrastructureFactory[PostgresTrackingRecorder]): POSTGRES_DBNAME = "POSTGRES_DBNAME" POSTGRES_HOST = "POSTGRES_HOST" POSTGRES_PORT = "POSTGRES_PORT" POSTGRES_USER = "POSTGRES_USER" POSTGRES_PASSWORD = "POSTGRES_PASSWORD" # noqa: S105 POSTGRES_GET_PASSWORD_TOPIC = "POSTGRES_GET_PASSWORD_TOPIC" # noqa: S105 POSTGRES_CONNECT_TIMEOUT = "POSTGRES_CONNECT_TIMEOUT" POSTGRES_CONN_MAX_AGE = "POSTGRES_CONN_MAX_AGE" POSTGRES_PRE_PING = "POSTGRES_PRE_PING" POSTGRES_MAX_WAITING = "POSTGRES_MAX_WAITING" POSTGRES_LOCK_TIMEOUT = "POSTGRES_LOCK_TIMEOUT" POSTGRES_POOL_SIZE = "POSTGRES_POOL_SIZE" POSTGRES_MAX_OVERFLOW = "POSTGRES_MAX_OVERFLOW" POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT = ( "POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT" ) POSTGRES_SCHEMA = "POSTGRES_SCHEMA" CREATE_TABLE = "CREATE_TABLE" aggregate_recorder_class = PostgresAggregateRecorder application_recorder_class = PostgresApplicationRecorder tracking_recorder_class = PostgresTrackingRecorder process_recorder_class = PostgresProcessRecorder
[docs] def __init__(self, env: Environment): super().__init__(env) dbname = self.env.get(self.POSTGRES_DBNAME) if dbname is None: msg = ( "Postgres database name not found " "in environment with key " f"'{self.POSTGRES_DBNAME}'" ) # TODO: Indicate both keys here, also for other environment variables. # ) + " or ".join( # [f"'{key}'" for key in self.env.create_keys(self.POSTGRES_DBNAME)] # ) raise OSError(msg) host = self.env.get(self.POSTGRES_HOST) if host is None: msg = ( "Postgres host not found " "in environment with key " f"'{self.POSTGRES_HOST}'" ) raise OSError(msg) port = self.env.get(self.POSTGRES_PORT) or "5432" user = self.env.get(self.POSTGRES_USER) if user is None: msg = ( "Postgres user not found " "in environment with key " f"'{self.POSTGRES_USER}'" ) raise OSError(msg) get_password_func = None get_password_topic = self.env.get(self.POSTGRES_GET_PASSWORD_TOPIC) if not get_password_topic: password = self.env.get(self.POSTGRES_PASSWORD) if password is None: msg = ( "Postgres password not found " "in environment with key " f"'{self.POSTGRES_PASSWORD}'" ) raise OSError(msg) else: get_password_func = resolve_topic(get_password_topic) password = "" connect_timeout = 30 connect_timeout_str = self.env.get(self.POSTGRES_CONNECT_TIMEOUT) if connect_timeout_str: try: connect_timeout = int(connect_timeout_str) except ValueError: msg = ( "Postgres environment value for key " f"'{self.POSTGRES_CONNECT_TIMEOUT}' is invalid. " "If set, an integer or empty string is expected: " f"'{connect_timeout_str}'" ) raise OSError(msg) from None idle_in_transaction_session_timeout_str = ( self.env.get(self.POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT) or "5" ) try: idle_in_transaction_session_timeout = int( idle_in_transaction_session_timeout_str ) except ValueError: msg = ( "Postgres environment value for key " f"'{self.POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT}' is invalid. " "If set, an integer or empty string is expected: " f"'{idle_in_transaction_session_timeout_str}'" ) raise OSError(msg) from None pool_size = 5 pool_size_str = self.env.get(self.POSTGRES_POOL_SIZE) if pool_size_str: try: pool_size = int(pool_size_str) except ValueError: msg = ( "Postgres environment value for key " f"'{self.POSTGRES_POOL_SIZE}' is invalid. " "If set, an integer or empty string is expected: " f"'{pool_size_str}'" ) raise OSError(msg) from None pool_max_overflow = 10 pool_max_overflow_str = self.env.get(self.POSTGRES_MAX_OVERFLOW) if pool_max_overflow_str: try: pool_max_overflow = int(pool_max_overflow_str) except ValueError: msg = ( "Postgres environment value for key " f"'{self.POSTGRES_MAX_OVERFLOW}' is invalid. " "If set, an integer or empty string is expected: " f"'{pool_max_overflow_str}'" ) raise OSError(msg) from None max_waiting = 0 max_waiting_str = self.env.get(self.POSTGRES_MAX_WAITING) if max_waiting_str: try: max_waiting = int(max_waiting_str) except ValueError: msg = ( "Postgres environment value for key " f"'{self.POSTGRES_MAX_WAITING}' is invalid. " "If set, an integer or empty string is expected: " f"'{max_waiting_str}'" ) raise OSError(msg) from None conn_max_age = 60 * 60.0 conn_max_age_str = self.env.get(self.POSTGRES_CONN_MAX_AGE) if conn_max_age_str: try: conn_max_age = float(conn_max_age_str) except ValueError: msg = ( "Postgres environment value for key " f"'{self.POSTGRES_CONN_MAX_AGE}' is invalid. " "If set, a float or empty string is expected: " f"'{conn_max_age_str}'" ) raise OSError(msg) from None pre_ping = strtobool(self.env.get(self.POSTGRES_PRE_PING) or "no") lock_timeout_str = self.env.get(self.POSTGRES_LOCK_TIMEOUT) or "0" try: lock_timeout = int(lock_timeout_str) except ValueError: msg = ( "Postgres environment value for key " f"'{self.POSTGRES_LOCK_TIMEOUT}' is invalid. " "If set, an integer or empty string is expected: " f"'{lock_timeout_str}'" ) raise OSError(msg) from None schema = self.env.get(self.POSTGRES_SCHEMA) or "" self.datastore = PostgresDatastore( dbname=dbname, host=host, port=port, user=user, password=password, get_password_func=get_password_func, connect_timeout=connect_timeout, idle_in_transaction_session_timeout=idle_in_transaction_session_timeout, pool_size=pool_size, max_overflow=pool_max_overflow, max_waiting=max_waiting, conn_max_age=conn_max_age, pre_ping=pre_ping, lock_timeout=lock_timeout, schema=schema, )
[docs] def env_create_table(self) -> bool: return strtobool(self.env.get(self.CREATE_TABLE) or "yes")
[docs] def aggregate_recorder(self, purpose: str = "events") -> AggregateRecorder: prefix = self.env.name.lower() or "stored" events_table_name = prefix + "_" + purpose recorder = type(self).aggregate_recorder_class( datastore=self.datastore, events_table_name=events_table_name, ) if self.env_create_table(): recorder.create_table() return recorder
[docs] def application_recorder(self) -> ApplicationRecorder: prefix = self.env.name.lower() or "stored" events_table_name = prefix + "_events" application_recorder_topic = self.env.get(self.APPLICATION_RECORDER_TOPIC) if application_recorder_topic: application_recorder_class: type[PostgresApplicationRecorder] = ( resolve_topic(application_recorder_topic) ) assert issubclass(application_recorder_class, PostgresApplicationRecorder) else: application_recorder_class = type(self).application_recorder_class recorder = application_recorder_class( datastore=self.datastore, events_table_name=events_table_name, ) if self.env_create_table(): recorder.create_table() return recorder
[docs] def tracking_recorder( self, tracking_recorder_class: type[TPostgresTrackingRecorder] | None = None ) -> TPostgresTrackingRecorder: prefix = self.env.name.lower() or "notification" tracking_table_name = prefix + "_tracking" if tracking_recorder_class is None: tracking_recorder_topic = self.env.get(self.TRACKING_RECORDER_TOPIC) if tracking_recorder_topic: tracking_recorder_class = resolve_topic(tracking_recorder_topic) else: tracking_recorder_class = cast( "type[TPostgresTrackingRecorder]", type(self).tracking_recorder_class, ) assert tracking_recorder_class is not None assert issubclass(tracking_recorder_class, PostgresTrackingRecorder) recorder = tracking_recorder_class( datastore=self.datastore, tracking_table_name=tracking_table_name, ) if self.env_create_table(): recorder.create_table() return recorder
[docs] def process_recorder(self) -> ProcessRecorder: prefix = self.env.name.lower() or "stored" events_table_name = prefix + "_events" prefix = self.env.name.lower() or "notification" tracking_table_name = prefix + "_tracking" process_recorder_topic = self.env.get(self.PROCESS_RECORDER_TOPIC) if process_recorder_topic: process_recorder_class: type[PostgresTrackingRecorder] = resolve_topic( process_recorder_topic ) assert issubclass(process_recorder_class, PostgresProcessRecorder) else: process_recorder_class = type(self).process_recorder_class recorder = process_recorder_class( datastore=self.datastore, events_table_name=events_table_name, tracking_table_name=tracking_table_name, ) if self.env_create_table(): recorder.create_table() return recorder
[docs] def close(self) -> None: with contextlib.suppress(AttributeError): self.datastore.close()
Factory = PostgresFactory