Source code for eventsourcing.infrastructure.django.manager

from typing import Any, Dict, Iterable, Optional, Sequence
from uuid import UUID

import django.db
from django.db import connection, transaction

from eventsourcing.exceptions import ProgrammingError
from eventsourcing.infrastructure.base import (
    EVENT_NOT_NOTIFIABLE,
    SQLRecordManager,
    TrackingKwargs,
)


[docs]class DjangoRecordManager(SQLRecordManager): _where_application_name_tmpl = " WHERE application_name = %s AND pipeline_id = %s"
[docs] def write_records( self, records: Iterable[Any], tracking_kwargs: Optional[TrackingKwargs] = None, orm_objs_pending_save: Optional[Sequence[Any]] = None, orm_objs_pending_delete: Optional[Sequence[Any]] = None, ) -> None: if not isinstance(records, list): records = list(records) # Prepare tracking params. if tracking_kwargs: tracking_params = [ tracking_kwargs[c] for c in self.tracking_record_field_names ] else: tracking_params = None use_insert_select_max_statement = False event_insert_statement = self.insert_values if self.notification_id_name: # This is a bit complicated, but basically the idea # is to support two alternatives: # - either use the "insert select max" statement # - or use the "insert values" statement # The "insert values" statement depends on the # notification IDs being provided by the application # and the "insert select max" creates these IDs in # the query. # The "insert values" statement provides the opportunity # to have some notification IDs being null, but the # "insert select max" statement doesn't. # Therefore, using "insert select max" should only be # used when all the given records have null value # for the notification IDs. And so the alternative # usage needs to provide true values for all the # notification IDs. These true values can involve # a reserved "event-not-notifiable" value which # indicates there shouldn't be a notification ID # for this record (so that it doesn't appear in the # "get notifications" query). all_ids = set((getattr(r, self.notification_id_name) for r in records)) if None in all_ids: if len(all_ids) > 1: # Either all or zero records must have IDs. raise ProgrammingError("Only some records have IDs") elif self.contiguous_record_ids: # Do an "insert select max" from existing. use_insert_select_max_statement = True event_insert_statement = self.insert_select_max elif hasattr(self.record_class, "application_name"): # Can't allow auto-incrementing ID if table has field # application_name. We need values and don't have them. raise ProgrammingError("record ID not set when required") if self.contiguous_record_ids: all_event_record_params = [] for record in records: # Get values from record obj. # List of params, because dict doesn't work with Django # and SQLite. params = [] for col_name in self.field_names: col_value = getattr(record, col_name) meta = self.record_class._meta # type: ignore col_type = meta.get_field(col_name) # Prepare value for database. param = col_type.get_db_prep_value(col_value, connection) params.append(param) # Notification logs fields, to be inserted with event # fields. index_of_pipeline_id_param = None if hasattr(self.record_class, "application_name"): params.append(self.application_name) if hasattr(self.record_class, "pipeline_id"): params.append(self.pipeline_id) index_of_pipeline_id_param = len(params) - 1 if hasattr(record, "causal_dependencies"): params.append(record.causal_dependencies) if use_insert_select_max_statement: # Where clause fields. if hasattr(self.record_class, "application_name"): params.append(self.application_name) if hasattr(self.record_class, "pipeline_id"): params.append(self.pipeline_id) elif self.notification_id_name: if hasattr(self.record_class, self.notification_id_name): notification_id = getattr(record, self.notification_id_name) if notification_id == EVENT_NOT_NOTIFIABLE: notification_id = None params[index_of_pipeline_id_param] = None elif notification_id is not None: if not isinstance(notification_id, int): raise ProgrammingError( "%s must be an %s not %s: %s" % ( self.notification_id_name, int, type(notification_id), record.__dict__, ) ) params.append(notification_id) all_event_record_params.append(params) else: all_event_record_params = None try: with transaction.atomic(self.record_class.objects.db): # type: ignore with connection.cursor() as cursor: # Insert tracking record. if tracking_params is not None: cursor.execute(self.insert_tracking_record, tracking_params) if all_event_record_params is not None: for event_record_params in all_event_record_params: # Use cursor to execute event insert statement. cursor.execute(event_insert_statement, event_record_params) else: # This can only work for simple models, without application_name # and pipeline_id, because it relies on the auto-incrementing # ID. # Todo: If it's faster, change to use an "insert_values" raw # query. # Save record objects. for record in records: record.save() # Call 'save()' on each of the ORM objects pending save. if orm_objs_pending_save: for orm_obj in orm_objs_pending_save: orm_obj.save() # Call 'delete()' on each of the ORM objects pending delete. if orm_objs_pending_delete: for orm_obj in orm_objs_pending_delete: orm_obj.delete() except django.db.IntegrityError as e: self.raise_record_integrity_error(e)
[docs] def make_placeholder(self, _: str) -> str: return "%s" # doesn't involve field name
[docs] def get_record_table_name(self, record_class: type) -> str: """Returns table name from SQLAlchemy record class.""" return record_class._meta.db_table # type: ignore
[docs] def get_record(self, sequence_id: UUID, position: int) -> Any: kwargs = { self.field_names.sequence_id: sequence_id, self.field_names.position: position, } records = self.record_class.objects.filter(**kwargs) # type: ignore # Todo: try/except for native error here, call self.raise_index_error() return records.all()[0]
[docs] def get_records( self, sequence_id: UUID, gt: Optional[int] = None, gte: Optional[int] = None, lt: Optional[int] = None, lte: Optional[int] = None, limit: Optional[int] = None, query_ascending: bool = True, results_ascending: bool = True, ) -> Sequence[Any]: assert limit is None or limit >= 1, limit filter_kwargs = {self.field_names.sequence_id: sequence_id} objects = self.record_class.objects.filter(**filter_kwargs) # type: ignore if hasattr(self.record_class, "application_name"): objects = objects.filter(application_name=self.application_name) position_field_name = self.field_names.position if query_ascending: objects = objects.order_by(position_field_name) else: objects = objects.order_by("-" + position_field_name) if gt is not None: arg = "{}__gt".format(position_field_name) objects = objects.filter(**{arg: gt}) if gte is not None: arg = "{}__gte".format(position_field_name) objects = objects.filter(**{arg: gte}) if lt is not None: arg = "{}__lt".format(position_field_name) objects = objects.filter(**{arg: lt}) if lte is not None: arg = "{}__lte".format(position_field_name) objects = objects.filter(**{arg: lte}) if limit is not None: objects = objects[:limit] records = objects.all() if results_ascending != query_ascending: # This code path is under test, but not otherwise used ATM. records = list(records) records.reverse() for record in records: # Django returns memoryview objects from PostgreSQL, so need to cast. state = getattr(record, self.field_names.state) setattr(record, self.field_names.state, bytes(state)) return records
[docs] def get_notification_records( self, start: Optional[int] = None, stop: Optional[int] = None, *args: Any, **kwargs: Any ) -> Iterable: """ Returns all records in the table. """ filter_kwargs = {} # Todo: Also support sequencing by 'position' if items are sequenced by # timestamp? if start is not None: filter_kwargs["%s__gte" % self.notification_id_name] = start + 1 if stop is not None: filter_kwargs["%s__lt" % self.notification_id_name] = stop + 1 objects = self.record_class.objects.filter(**filter_kwargs) # type: ignore if hasattr(self.record_class, "application_name"): objects = objects.filter(application_name=self.application_name) if hasattr(self.record_class, "pipeline_id"): objects = objects.filter(pipeline_id=self.pipeline_id) objects = objects.order_by("%s" % self.notification_id_name) for record in objects.all(): # Django returns memoryview objects from PostgreSQL, so need to cast. state = getattr(record, self.field_names.state) setattr(record, self.field_names.state, bytes(state)) yield record
[docs] def delete_record(self, record: Any) -> None: """ Permanently removes record from table. """ record.delete()
[docs] def get_max_notification_id(self) -> int: assert self.notification_id_name try: objects = self.record_class.objects # type: ignore if hasattr(self.record_class, "application_name"): objects = objects.filter(application_name=self.application_name) if hasattr(self.record_class, "pipeline_id"): objects = objects.filter(pipeline_id=self.pipeline_id) latest = objects.latest(self.notification_id_name) return getattr(latest, self.notification_id_name) except self.record_class.DoesNotExist: return 0
[docs] def get_max_tracking_record_id(self, upstream_application_name: str) -> int: notification_id = 0 assert self.tracking_record_class is not None try: objects = self.tracking_record_class.objects # type: ignore objects = objects.filter(application_name=self.application_name) objects = objects.filter( upstream_application_name=upstream_application_name ) objects = objects.filter(pipeline_id=self.pipeline_id) notification_id = objects.latest("notification_id").notification_id except self.tracking_record_class.DoesNotExist: # type: ignore pass return notification_id
[docs] def has_tracking_record( self, upstream_application_name: str, pipeline_id: int, notification_id: int ) -> bool: objects = self.tracking_record_class.objects # type: ignore objects = objects.filter(application_name=self.application_name) objects = objects.filter(upstream_application_name=upstream_application_name) objects = objects.filter(pipeline_id=pipeline_id) objects = objects.filter(notification_id=notification_id) return bool(objects.count())
[docs] def all_sequence_ids(self) -> Iterable[UUID]: sequence_id_fieldname = self.field_names.sequence_id values_queryset = self.record_class.objects.values( # type: ignore sequence_id_fieldname ).distinct() for values in values_queryset: yield values[sequence_id_fieldname]
def get_field_kwargs(self, item: object) -> Dict[str, Any]: kwargs = super().get_field_kwargs(item) state_fieldname = self.field_names.state state = kwargs[state_fieldname] if isinstance(state, memoryview): state = bytes(state) kwargs[state_fieldname] = state return kwargs