from sqlalchemy import asc, bindparam, desc, select, text
from sqlalchemy.exc import DBAPIError, IntegrityError
from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound
from sqlalchemy.sql import func
from eventsourcing.exceptions import ProgrammingError
from eventsourcing.infrastructure.base import SQLRecordManager
[docs]class SQLAlchemyRecordManager(SQLRecordManager):
_where_application_name_tmpl = (
" WHERE application_name=:application_name AND pipeline_id=:pipeline_id"
)
[docs] def __init__(self, session, *args, **kwargs):
super(SQLAlchemyRecordManager, self).__init__(*args, **kwargs)
self.session = session
[docs] def _prepare_insert(
self, tmpl, record_class, field_names, placeholder_for_id=False
):
"""
With transaction isolation level of "read committed" this should
generate records with a contiguous sequence of integer IDs, assumes
an indexed ID column, the database-side SQL max function, the
insert-select-from form, and optimistic concurrency control.
"""
field_names = list(field_names)
if (
hasattr(record_class, "application_name")
and "application_name" not in field_names
):
field_names.append("application_name")
if hasattr(record_class, "pipeline_id") and "pipeline_id" not in field_names:
field_names.append("pipeline_id")
if (
hasattr(record_class, "causal_dependencies")
and "causal_dependencies" not in field_names
):
field_names.append("causal_dependencies")
if self.notification_id_name:
if placeholder_for_id:
if self.notification_id_name not in field_names:
field_names.append(self.notification_id_name)
statement = text(
tmpl.format(
tablename=self.get_record_table_name(record_class),
columns=", ".join(field_names),
placeholders=", ".join([":{}".format(f) for f in field_names]),
notification_id=self.notification_id_name,
)
)
# Define bind parameters with explicit types taken from record column types.
bindparams = []
for col_name in field_names:
column_type = getattr(record_class, col_name).type
bindparams.append(bindparam(col_name, type_=column_type))
# Redefine statement with explicitly typed bind parameters.
statement = statement.bindparams(*bindparams)
return statement
[docs] def write_records(
self,
records,
tracking_kwargs=None,
orm_objs_pending_save=None,
orm_objs_pending_delete=None,
):
all_params = []
statement = None
if records:
# Prepare to insert event and notification records.
statement = self.insert_values
if self.notification_id_name:
all_ids = set((getattr(r, self.notification_id_name) for r in records))
if None in all_ids:
if len(all_ids) > 1:
# Either all or zero records must have IDs.
raise ProgrammingError("Only some records have IDs")
elif self.contiguous_record_ids:
# Do an "insert select max" from existing.
statement = self.insert_select_max
elif hasattr(self.record_class, "application_name"):
# Can't allow auto-incrementing ID if table has field
# application_name. We need values and don't have them.
raise ProgrammingError("record ID not set when required")
for record in records:
# Params for stored item itself (e.g. event).
params = {name: getattr(record, name) for name in self.field_names}
# Params for application partition (bounded context).
if hasattr(self.record_class, "application_name"):
params["application_name"] = self.application_name
# Params for notification log.
if self.notification_id_name:
notification_id = getattr(record, self.notification_id_name)
if notification_id == "event-not-notifiable":
params[self.notification_id_name] = None
else:
params[self.notification_id_name] = notification_id
if hasattr(self.record_class, "pipeline_id"):
if notification_id == "event-not-notifiable":
params["pipeline_id"] = None
else:
params["pipeline_id"] = self.pipeline_id
if hasattr(record, "causal_dependencies"):
params["causal_dependencies"] = record.causal_dependencies
all_params.append(params)
try:
nothing_to_commit = True
# Commit custom ORM objects.
if orm_objs_pending_save:
for orm_obj in orm_objs_pending_save:
self.session.add(orm_obj)
nothing_to_commit = False
if orm_objs_pending_delete:
for orm_obj in orm_objs_pending_delete:
self.session.delete(orm_obj)
nothing_to_commit = False
# Insert tracking record.
if tracking_kwargs:
self.session.execute(self.insert_tracking_record, tracking_kwargs)
nothing_to_commit = False
# Bulk insert event records.
if all_params:
self.session.execute(statement, all_params)
nothing_to_commit = False
if nothing_to_commit:
return
self.session.commit()
except IntegrityError as e:
self.session.rollback()
self.raise_record_integrity_error(e)
except DBAPIError as e:
self.session.rollback()
self.raise_operational_error(e)
except:
self.session.rollback()
raise
finally:
self.session.close()
[docs] def get_records(
self,
sequence_id,
gt=None,
gte=None,
lt=None,
lte=None,
limit=None,
query_ascending=True,
results_ascending=True,
):
assert limit is None or limit >= 1, limit
try:
# Filter by sequence_id.
filter_kwargs = {self.field_names.sequence_id: sequence_id}
# Optionally, filter by application_name.
if hasattr(self.record_class, "application_name"):
filter_kwargs["application_name"] = self.application_name
query = self.filter_by(**filter_kwargs)
# Filter and order by position.
position_field = getattr(self.record_class, self.field_names.position)
if query_ascending:
query = query.order_by(asc(position_field))
else:
query = query.order_by(desc(position_field))
if gt is not None:
query = query.filter(position_field > gt)
if gte is not None:
query = query.filter(position_field >= gte)
if lt is not None:
query = query.filter(position_field < lt)
if lte is not None:
query = query.filter(position_field <= lte)
# Limit the results.
if limit is not None:
query = query.limit(limit)
# Get the result set.
results = query.all()
finally:
self.session.close()
# Reverse if necessary.
if results_ascending != query_ascending:
# This code path is under test, but not otherwise used ATM.
results.reverse()
return results
[docs] def get_notifications(self, start=None, stop=None, *args, **kwargs):
try:
query = self.orm_query()
if hasattr(self.record_class, "application_name"):
query = query.filter(
self.record_class.application_name == self.application_name
)
if hasattr(self.record_class, "pipeline_id"):
query = query.filter(self.record_class.pipeline_id == self.pipeline_id)
if self.notification_id_name:
query = query.order_by(asc(self.notification_id_name))
# NB '+1' because record IDs start from 1.
notification_id_col = getattr(
self.record_class, self.notification_id_name
)
if start is not None:
query = query.filter(notification_id_col >= start + 1)
if stop is not None:
query = query.filter(notification_id_col < stop + 1)
# Todo: Should some tables with an ID not be ordered by ID?
# Todo: Which order do other tables have?
return query.all()
finally:
self.session.close()
[docs] def get_record(self, sequence_id, position):
try:
filter_args = {self.field_names.sequence_id: sequence_id}
query = self.filter_by(**filter_args)
if hasattr(self.record_class, "application_name"):
query = query.filter(
self.record_class.application_name == self.application_name
)
position_field = getattr(self.record_class, self.field_names.position)
query = query.filter(position_field == position)
return query.one()
except (NoResultFound, MultipleResultsFound):
raise IndexError(self.application_name, sequence_id, position)
def filter_by(self, **kwargs):
return self.orm_query().filter_by(**kwargs)
def orm_query(self):
return self.session.query(self.record_class)
[docs] def get_max_record_id(self):
try:
notification_id_col = getattr(self.record_class, self.notification_id_name)
query = self.session.query(func.max(notification_id_col))
if hasattr(self.record_class, "application_name"):
query = query.filter(
self.record_class.application_name == self.application_name
)
if hasattr(self.record_class, "pipeline_id"):
query = query.filter(self.record_class.pipeline_id == self.pipeline_id)
return query.scalar()
finally:
self.session.close()
[docs] def get_max_tracking_record_id(self, upstream_application_name):
query = self.session.query(func.max(self.tracking_record_class.notification_id))
query = query.filter(
self.tracking_record_class.application_name == self.application_name
)
query = query.filter(
self.tracking_record_class.upstream_application_name
== upstream_application_name
)
query = query.filter(self.tracking_record_class.pipeline_id == self.pipeline_id)
return query.scalar()
[docs] def has_tracking_record(
self, upstream_application_name, pipeline_id, notification_id
):
query = self.session.query(self.tracking_record_class)
query = query.filter(
self.tracking_record_class.application_name == self.application_name
)
query = query.filter(
self.tracking_record_class.upstream_application_name
== upstream_application_name
)
query = query.filter(self.tracking_record_class.pipeline_id == pipeline_id)
query = query.filter(
self.tracking_record_class.notification_id == notification_id
)
try:
query.one()
except (MultipleResultsFound, NoResultFound):
return False
else:
return True
[docs] def all_sequence_ids(self):
c = self.record_class.__table__.c
sequence_id_col = getattr(c, self.field_names.sequence_id)
expr = select([sequence_id_col], distinct=True)
if hasattr(self.record_class, "application_name"):
expr = expr.where(c.application_name == self.application_name)
try:
for row in self.session.query(expr):
yield row[0]
finally:
self.session.close()
[docs] def delete_record(self, record):
"""
Permanently removes record from table.
"""
try:
self.session.delete(record)
self.session.commit()
except Exception as e:
self.session.rollback()
raise ProgrammingError(e)
finally:
self.session.close()
[docs] def get_record_table_name(self, record_class):
return record_class.__table__.name
def clone(self, **kwargs):
return super(SQLAlchemyRecordManager, self).clone(
session=self.session, **kwargs
)