from functools import reduce
from typing import Any, Callable, Dict, Iterable, Optional, Type
from uuid import UUID
from eventsourcing.domain.model.entity import TVersionedEntity, TVersionedEvent
from eventsourcing.exceptions import RepositoryKeyError
from eventsourcing.infrastructure.base import (
AbstractEntityRepository,
AbstractEventStore,
AbstractRecordManager,
AbstractSnapshop,
)
from eventsourcing.infrastructure.snapshotting import (
AbstractSnapshotStrategy,
entity_from_snapshot,
)
from eventsourcing.whitehead import SEntity
[docs]class EventSourcedRepository(
AbstractEntityRepository[TVersionedEntity, TVersionedEvent],
):
# The page size by which events are retrieved. If this
# value is set to a positive integer, the events of
# the entity will be retrieved in pages, using a series
# of queries, rather than with one potentially large query.
__page_size__: Optional[int] = None
[docs] def __init__(
self,
event_store: AbstractEventStore,
use_cache: bool = False,
snapshot_strategy: Optional[AbstractSnapshotStrategy] = None,
mutator_func: Optional[
Callable[
[Optional[TVersionedEntity], TVersionedEvent],
Optional[TVersionedEntity],
]
] = None,
**kwargs: Any
):
self._event_store: AbstractEventStore = event_store
self._snapshot_strategy = snapshot_strategy
self._mutator_func = mutator_func or self.mutate
super(EventSourcedRepository, self).__init__()
# NB If you use the cache, make sure to del entities
# when records fail to write otherwise the cache will
# give an entity that is ahead of the event records,
# and writing more records will give a broken sequence.
self._cache: Dict[UUID, Optional[TVersionedEntity]] = {}
self._use_cache = use_cache
@property
def event_store(self) -> AbstractEventStore[TVersionedEvent, AbstractRecordManager]:
"""
Returns event store object used by this repository.
"""
return self._event_store
@property
def use_cache(self) -> bool:
return self._use_cache
@use_cache.setter
def use_cache(self, value: bool) -> None:
self._use_cache = value
if not self._use_cache:
self._cache.clear()
[docs] def __contains__(self, entity_id: UUID) -> bool:
"""
Returns a boolean value according to whether entity with given ID exists.
"""
return self.get_entity(entity_id) is not None
[docs] def __getitem__(self, entity_id: UUID) -> TVersionedEntity:
"""
Returns entity with given ID.
:param entity_id: ID of entity in the repository.
:raises RepositoryKeyError: If the entity is not found.
"""
if self._use_cache:
try:
# Get entity from the cache.
entity: Optional[TVersionedEntity] = self._cache[entity_id]
except KeyError:
# Reconstitute the entity.
entity = self.get_entity(entity_id)
# Put entity in the cache.
self._cache[entity_id] = entity
else:
entity = self.get_entity(entity_id)
# Never created or already discarded?
if entity is None:
raise RepositoryKeyError(entity_id)
# Return entity.
assert entity is not None
return entity
[docs] def get_entity(
self, entity_id: UUID, at: Optional[int] = None
) -> Optional[TVersionedEntity]:
"""
Returns entity with given ID, optionally at a version.
Returns None if entity not found.
"""
# Get a snapshot (None if none exist).
if self._snapshot_strategy is not None:
snapshot = self._snapshot_strategy.get_snapshot(entity_id, lte=at)
else:
snapshot = None
# Decide the initial state of the entity, and the
# version of the last item applied to the entity.
if snapshot is None:
initial_state = None
gt = None
else:
initial_state = entity_from_snapshot(snapshot)
gt = snapshot.originator_version
# Obtain and return current state.
return self.get_and_project_events(
entity_id, gt=gt, lte=at, initial_state=initial_state
)
[docs] def get_and_project_events(
self,
entity_id: UUID,
gt: Optional[int] = None,
gte: Optional[int] = None,
lt: Optional[int] = None,
lte: Optional[int] = None,
limit: Optional[int] = None,
initial_state: Optional[TVersionedEntity] = None,
query_descending: bool = False,
) -> Optional[TVersionedEntity]:
"""
Reconstitutes requested domain entity from domain events found in event store.
"""
# Decide if query is in ascending order.
# - A "speed up" for when events are stored in descending order (e.g.
# in Cassandra) and it is faster to get them in that order.
# - This isn't useful when 'until' or 'after' or 'limit' are set,
# because the inclusiveness or exclusiveness of until and after
# and the end of the stream that is truncated by limit both depend on
# the direction of the query. Also paging backwards isn't useful, because
# all the events are needed eventually, so it would probably slow things
# down. Paging is intended to support replaying longer event streams, and
# only makes sense to work in ascending order.
if (
gt is None
and gte is None
and lt is None
and lte is None
and self.__page_size__ is None
):
is_ascending = False
else:
is_ascending = not query_descending
# Get entity's domain events from the event store.
domain_events = self.event_store.iter_events(
originator_id=entity_id,
gt=gt,
gte=gte,
lt=lt,
lte=lte,
limit=limit,
is_ascending=is_ascending,
page_size=self.__page_size__,
)
# The events must be replayed in ascending order.
if not is_ascending:
domain_events = reversed(list(domain_events))
# Project the domain events onto the initial state.
return self.project_events(initial_state, domain_events)
[docs] def project_events(
self,
initial_state: Optional[TVersionedEntity],
domain_events: Iterable[TVersionedEvent],
) -> Optional[TVersionedEntity]:
"""
Evolves initial_state using the domain_events and a mutator function.
Applies a mutator function cumulatively to a sequence of domain
events, so as to mutate the initial value to a mutated value.
This class's mutate() method is used as the default mutator function, but
custom behaviour can be introduced by passing in a 'mutator_func' argument
when constructing this class, or by overridding the mutate() method.
"""
return reduce(self._mutator_func, domain_events, initial_state)
[docs] @staticmethod
def mutate(
initial: Optional[TVersionedEntity], event: TVersionedEvent
) -> Optional[TVersionedEntity]:
"""
Default mutator function, which uses __mutate__()
method on event object to mutate initial state.
:param initial: Initial state to be mutated by this function.
:param event: Event that causes the initial state to be mutated.
:return: Returns the mutated state.
"""
# Check obj is not None.
if initial is not None:
event.__check_obj__(initial)
return event.__mutate__(initial)
# Todo: Does this method belong on this class?
[docs] def take_snapshot(
self, entity_id: UUID, lt: Optional[int] = None, lte: Optional[int] = None
) -> Optional[AbstractSnapshop]:
"""
Takes a snapshot of the entity as it existed after the most recent
event, optionally less than, or less than or equal to, a particular position.
"""
snapshot = None
if self._snapshot_strategy:
# Get the latest event (optionally until a particular position).
latest_event = self.event_store.get_most_recent_event(
entity_id, lt=lt, lte=lte
)
# If there is something to snapshot, then look for a snapshot
# taken before or at the entity version of the latest event. Please
# note, the snapshot might have a smaller version number than
# the latest event if events occurred since the latest snapshot was taken.
if latest_event is not None:
latest_snapshot = self._snapshot_strategy.get_snapshot(
entity_id, lt=lt, lte=lte
)
latest_version = latest_event.originator_version
if (
latest_snapshot
and latest_snapshot.originator_version == latest_version
):
# If up-to-date snapshot exists, there's nothing to do.
snapshot = latest_snapshot
else:
# Otherwise recover entity state from latest snapshot.
if latest_snapshot:
initial_state = entity_from_snapshot(latest_snapshot)
gt: Optional[int] = latest_snapshot.originator_version
else:
initial_state = None
gt = None
# Fast-forward entity state to latest version.
entity = self.get_and_project_events(
entity_id=entity_id,
gt=gt,
lte=latest_version,
initial_state=initial_state,
)
# Take snapshot from entity.
snapshot = self._snapshot_strategy.take_snapshot(
entity_id, entity, latest_version
)
return snapshot
def get_instance_of(
self, instance_class: Type[SEntity], entity_id: UUID, at: Optional[int] = None
) -> Optional[SEntity]:
entity = self.get_entity(entity_id, at=at)
if isinstance(entity, instance_class):
return entity
else:
return None