Source code for eventsourcing.infrastructure.sqlalchemy.manager

import six
from sqlalchemy import asc, bindparam, desc, text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound
from sqlalchemy.sql import func

from eventsourcing.exceptions import ProgrammingError
from eventsourcing.infrastructure.base import RelationalRecordManager


[docs]class SQLAlchemyRecordManager(RelationalRecordManager): def __init__(self, session, *args, **kwargs): super(SQLAlchemyRecordManager, self).__init__(*args, **kwargs) self.session = session def _write_records(self, records): try: if self.contiguous_record_ids: for record in records: # Execute "insert select max" statement with values from record obj. params = {c: getattr(record, c) for c in self.field_names} self.session.bind.execute(self.insert_select_max, **params) else: for record in records: # Execute "insert values" statement with values from record obj. params = {c: getattr(record, c) for c in self.field_names} self.session.bind.execute(self.insert_values, **params) # Old way: # # Add record obj to session. # self.session.add(record) self.session.commit() except IntegrityError as e: self.session.rollback() self.raise_after_integrity_error(e) finally: self.session.close() @property def record_table_name(self): return self.record_class.__table__.name
[docs] def _prepare_insert(self, tmpl): """ With transaction isolation level of "read committed" this should generate records with a contiguous sequence of integer IDs, assumes an indexed ID column, the database-side SQL max function, the insert-select-from form, and optimistic concurrency control. """ statement = text(tmpl.format( tablename=self.record_table_name, columns=", ".join(self.field_names), placeholders=", ".join([":{}".format(f) for f in self.field_names]), )) # Define bind parameters with explicit types taken from record column types. bindparams = [] for col_name in self.field_names: column_type = getattr(self.record_class, col_name).type bindparams.append(bindparam(col_name, type_=column_type)) # Redefine statement with explicitly typed bind parameters. statement = statement.bindparams(*bindparams) # Compile the statement with the session dialect. compiled = statement.compile(dialect=self.session.bind.dialect) return compiled
[docs] def get_item(self, sequence_id, eq): try: filter_args = {self.field_names.sequence_id: sequence_id} query = self.filter(**filter_args) position_field = getattr(self.record_class, self.field_names.position) query = query.filter(position_field == eq) result = query.one() except (NoResultFound, MultipleResultsFound): raise IndexError finally: self.session.close() return self.from_record(result)
[docs] def get_items(self, sequence_id, gt=None, gte=None, lt=None, lte=None, limit=None, query_ascending=True, results_ascending=True): records = self.get_records( sequence_id=sequence_id, gt=gt, gte=gte, lt=lt, lte=lte, limit=limit, query_ascending=query_ascending, results_ascending=results_ascending, ) for item in six.moves.map(self.from_record, records): yield item
[docs] def get_records(self, sequence_id, gt=None, gte=None, lt=None, lte=None, limit=None, query_ascending=True, results_ascending=True): assert limit is None or limit >= 1, limit try: filter_kwargs = {self.field_names.sequence_id: sequence_id} query = self.filter(**filter_kwargs) position_field = getattr(self.record_class, self.field_names.position) if query_ascending: query = query.order_by(asc(position_field)) else: query = query.order_by(desc(position_field)) if gt is not None: query = query.filter(position_field > gt) if gte is not None: query = query.filter(position_field >= gte) if lt is not None: query = query.filter(position_field < lt) if lte is not None: query = query.filter(position_field <= lte) if limit is not None: query = query.limit(limit) results = query.all() finally: self.session.close() if results_ascending != query_ascending: # This code path is under test, but not otherwise used ATM. results.reverse() return results
[docs] def filter(self, **kwargs): return self.query.filter_by(**kwargs)
@property def query(self): return self.session.query(self.record_class)
[docs] def all_records(self, start=None, stop=None, *args, **kwargs): """ Returns all records in the table. Intended to support getting all application domain events in order, especially if the records have contiguous IDs. """ # query = self.filter(**kwargs) # if resume is not None: # query = query.offset(resume + 1) # else: # resume = 0 # query = query.limit(100) # for i, record in enumerate(query): # yield record, i + resume try: query = self.query if hasattr(self.record_class, 'id'): query = query.order_by(asc('id')) # NB '+1' because record IDs start from 1. if start is not None: query = query.filter(self.record_class.id >= start + 1) if stop is not None: query = query.filter(self.record_class.id < stop + 1) # Todo: Should some tables with an ID not be ordered by ID? # Todo: Which order do other tables have? return query.all() finally: self.session.close()
[docs] def get_max_record_id(self): return self.session.query(func.max(self.record_class.id)).scalar()
[docs] def delete_record(self, record): """ Permanently removes record from table. """ try: self.session.delete(record) self.session.commit() except Exception as e: self.session.rollback() raise ProgrammingError(e) finally: self.session.close()