Source code for eventsourcing.cryptography

from __future__ import annotations

import os
from base64 import b64decode, b64encode
from typing import TYPE_CHECKING

from cryptography.exceptions import InvalidTag
from cryptography.hazmat.primitives.ciphers.aead import AESGCM

from eventsourcing.persistence import Cipher

if TYPE_CHECKING:
    from eventsourcing.utils import Environment


[docs] class AESCipher(Cipher): """Cipher strategy that uses AES cipher (in GCM mode) from the Python cryptography package. """ CIPHER_KEY = "CIPHER_KEY" KEY_SIZES = (16, 24, 32)
[docs] @staticmethod def create_key(num_bytes: int) -> str: """Creates AES cipher key, with length num_bytes. :param num_bytes: An int value, either 16, 24, or 32. """ AESCipher.check_key_size(num_bytes) key = AESGCM.generate_key(num_bytes * 8) return b64encode(key).decode("utf8")
@staticmethod def check_key_size(num_bytes: int) -> None: if num_bytes not in AESCipher.KEY_SIZES: msg = f"Invalid key size: {num_bytes} not in {AESCipher.KEY_SIZES}" raise ValueError(msg) @staticmethod def random_bytes(num_bytes: int) -> bytes: return os.urandom(num_bytes)
[docs] def __init__(self, environment: Environment): """Initialises AES cipher with ``cipher_key``. :param str cipher_key: 16, 24, or 32 bytes encoded as base64 """ cipher_key = environment.get(self.CIPHER_KEY) if not cipher_key: msg = f"'{self.CIPHER_KEY}' not in env" raise OSError(msg) key = b64decode(cipher_key.encode("utf8")) AESCipher.check_key_size(len(key)) self.key = key
[docs] def encrypt(self, plaintext: bytes) -> bytes: """Return ciphertext for given plaintext.""" # Construct AES-GCM cipher, with 96-bit nonce. aesgcm = AESGCM(self.key) nonce = AESCipher.random_bytes(12) res = aesgcm.encrypt(nonce, plaintext, None) # Put tag at the front for compatibility with eventsourcing.crypto.AESCipher. tag = res[-16:] encrypted = res[:-16] return nonce + tag + encrypted
[docs] def decrypt(self, ciphertext: bytes) -> bytes: """Return plaintext for given ciphertext.""" # Split out the nonce, tag, and encrypted data. nonce = ciphertext[:12] if len(nonce) != 12: msg = "Damaged cipher text: invalid nonce length" raise ValueError(msg) # Expect tag at the front. tag = ciphertext[12:28] if len(tag) != 16: msg = "Damaged cipher text: invalid tag length" raise ValueError(msg) encrypted = ciphertext[28:] aesgcm = AESGCM(self.key) try: plaintext = aesgcm.decrypt(nonce, encrypted + tag, None) except InvalidTag as e: msg = "Invalid cipher tag" raise ValueError(msg) from e # Decrypt and verify. return plaintext