from __future__ import annotations
import os
from base64 import b64decode, b64encode
from typing import TYPE_CHECKING
from cryptography.exceptions import InvalidTag
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from eventsourcing.persistence import Cipher
if TYPE_CHECKING:
from eventsourcing.utils import Environment
[docs]
class AESCipher(Cipher):
"""Cipher strategy that uses AES cipher (in GCM mode)
from the Python cryptography package.
"""
CIPHER_KEY = "CIPHER_KEY"
KEY_SIZES = (16, 24, 32)
[docs]
@staticmethod
def create_key(num_bytes: int) -> str:
"""Creates AES cipher key, with length num_bytes.
:param num_bytes: An int value, either 16, 24, or 32.
"""
AESCipher.check_key_size(num_bytes)
key = AESGCM.generate_key(num_bytes * 8)
return b64encode(key).decode("utf8")
@staticmethod
def check_key_size(num_bytes: int) -> None:
if num_bytes not in AESCipher.KEY_SIZES:
msg = f"Invalid key size: {num_bytes} not in {AESCipher.KEY_SIZES}"
raise ValueError(msg)
@staticmethod
def random_bytes(num_bytes: int) -> bytes:
return os.urandom(num_bytes)
[docs]
def __init__(self, environment: Environment):
"""Initialises AES cipher with ``cipher_key``.
:param str cipher_key: 16, 24, or 32 bytes encoded as base64
"""
cipher_key = environment.get(self.CIPHER_KEY)
if not cipher_key:
msg = f"'{self.CIPHER_KEY}' not in env"
raise OSError(msg)
key = b64decode(cipher_key.encode("utf8"))
AESCipher.check_key_size(len(key))
self.key = key
[docs]
def encrypt(self, plaintext: bytes) -> bytes:
"""Return ciphertext for given plaintext."""
# Construct AES-GCM cipher, with 96-bit nonce.
aesgcm = AESGCM(self.key)
nonce = AESCipher.random_bytes(12)
res = aesgcm.encrypt(nonce, plaintext, None)
# Put tag at the front for compatibility with eventsourcing.crypto.AESCipher.
tag = res[-16:]
encrypted = res[:-16]
return nonce + tag + encrypted
[docs]
def decrypt(self, ciphertext: bytes) -> bytes:
"""Return plaintext for given ciphertext."""
# Split out the nonce, tag, and encrypted data.
nonce = ciphertext[:12]
if len(nonce) != 12:
msg = "Damaged cipher text: invalid nonce length"
raise ValueError(msg)
# Expect tag at the front.
tag = ciphertext[12:28]
if len(tag) != 16:
msg = "Damaged cipher text: invalid tag length"
raise ValueError(msg)
encrypted = ciphertext[28:]
aesgcm = AESGCM(self.key)
try:
plaintext = aesgcm.decrypt(nonce, encrypted + tag, None)
except InvalidTag as e:
msg = "Invalid cipher tag"
raise ValueError(msg) from e
# Decrypt and verify.
return plaintext