From b1564d52e3d0e543f25bd63a6e46de660382f59c Mon Sep 17 00:00:00 2001 From: Alexandr Mansurov Date: Sat, 16 Oct 2021 14:58:48 +0200 Subject: [PATCH] Key management --- src/cryptbase/__init__.py | 9 ++-- src/cryptbase/cryptbase.py | 46 +++++++++++++++--- tests/test_cryptbase.py | 99 ++++++++++++++++++++++++++------------ 3 files changed, 112 insertions(+), 42 deletions(-) diff --git a/src/cryptbase/__init__.py b/src/cryptbase/__init__.py index 3c7f8a2..229263d 100644 --- a/src/cryptbase/__init__.py +++ b/src/cryptbase/__init__.py @@ -1,5 +1,6 @@ __version__ = '0.0.0' +import gc from .cryptbase import AES256CTREncryptor from .cryptbase import CryptoContainer @@ -21,9 +22,9 @@ try: def __init__(self, *args, **kwargs): if 'key' in kwargs: - self.__key = kwargs['key'] + self.encryptor = AES256CTREncryptor(bytearray.fromhex(kwargs['key'])) del kwargs['key'] - self.encryptor = AES256CTREncryptor(bytes.fromhex(self.__key)) + gc.collect() super().__init__(*args, **kwargs) def get_db_prep_value(self, value, *args, **kwargs): @@ -64,9 +65,9 @@ try: impl = types.TEXT def __init__(self, *args, **kwargs): - self.__key = kwargs['key'] + self.encryptor = AES256CTREncryptor(bytearray.fromhex(kwargs['key'])) del kwargs['key'] - self.encryptor = AES256CTREncryptor(bytes.fromhex(self.__key)) + gc.collect() super().__init__(*args, **kwargs) def process_bind_param(self, value, dialect): diff --git a/src/cryptbase/cryptbase.py b/src/cryptbase/cryptbase.py index af4ff62..2ce86c7 100644 --- a/src/cryptbase/cryptbase.py +++ b/src/cryptbase/cryptbase.py @@ -26,22 +26,52 @@ class Encryptor(ABC): class AES256CTREncryptor(Encryptor): """Encryptor using AES with 256-bit long keys using CTR mode.""" - def __init__(self, key: bytes): - if not isinstance(key, bytes): + KEY_LENGTH = 32 + IV_LENGTH = 16 + + def __init__(self, key: bytearray): + if not isinstance(key, bytearray): raise ValueError() - if not len(key) == 32: + if not len(key) == AES256CTREncryptor.KEY_LENGTH: raise ValueError() - self.__key = key + self.__protect = bytearray(secrets.token_bytes(AES256CTREncryptor.KEY_LENGTH)) + self.__key = AES256CTREncryptor.__obfuscate(key, self.__protect) + + def __del__(self): + self.clear() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.clear() + + def clear(self): + if '_AES256CTREncryptor__key' in self.__dict__.keys(): + for i in range(len(self.__key)): + self.__key[i] = 0 + if '_AES256CTREncryptor__protect' in self.__dict__.keys(): + for i in range(len(self.__protect)): + self.__protect[i] = 0 + + @staticmethod + def __obfuscate(secret: bytearray, key: bytearray) -> bytearray: + if not len(secret) == len(key): + raise ValueError() + result = bytearray(len(key)) + for i in range(len(secret)): + result[i] = secret[i] ^ key[i] + return result def encrypt(self, plaintext: bytes) -> Tuple[bytes, bytes]: - iv = secrets.token_bytes(16) - cipher = Cipher(algorithms.AES(self.__key), modes.CTR(iv)) + iv = secrets.token_bytes(AES256CTREncryptor.IV_LENGTH) + cipher = Cipher(algorithms.AES(AES256CTREncryptor.__obfuscate(self.__key, self.__protect)), modes.CTR(iv)) encryptor = cipher.encryptor() ciphertext = encryptor.update(plaintext) + encryptor.finalize() return ciphertext, iv def decrypt(self, ciphertext: bytes, iv: bytes) -> bytes: - cipher = Cipher(algorithms.AES(self.__key), modes.CTR(iv)) + cipher = Cipher(algorithms.AES(AES256CTREncryptor.__obfuscate(self.__key, self.__protect)), modes.CTR(iv)) decryptor = cipher.decryptor() return decryptor.update(ciphertext) + decryptor.finalize() @@ -63,7 +93,7 @@ class CryptoContainer: if 'encryptor' not in kwargs: raise ValueError() if not isinstance(kwargs['encryptor'], Encryptor): - raise ValueError() + raise ValueError(type(kwargs['encryptor'])) if 'plaintext' in kwargs: self.__plaintext = kwargs['plaintext'] diff --git a/tests/test_cryptbase.py b/tests/test_cryptbase.py index 32fbc16..3b4e4c5 100644 --- a/tests/test_cryptbase.py +++ b/tests/test_cryptbase.py @@ -6,20 +6,47 @@ import string import cryptbase +aes_n = 256 +key_bytes = int(aes_n / 8) +population = string.ascii_letters + string.digits + string.punctuation +length = 2048 + def test_main(): assert cryptbase.__version__ == '0.0.0' -def test_crypto(): - aes_n = 256 - key_bytes = int(aes_n / 8) - population = string.ascii_letters + string.digits + string.punctuation - length = 2048 +def test_string_key(): + try: + cryptbase.cryptbase.AES256CTREncryptor(secrets.token_bytes(key_bytes)) + assert False + except ValueError: + pass + +def test_clear(): + encryptor = cryptbase.cryptbase.AES256CTREncryptor(bytearray(secrets.token_bytes(key_bytes))) + assert '_AES256CTREncryptor__key' in encryptor.__dict__.keys() + assert '_AES256CTREncryptor__protect' in encryptor.__dict__.keys() + encryptor.clear() + for i in range(key_bytes): + assert encryptor.__dict__['_AES256CTREncryptor__key'][i] == 0 + assert encryptor.__dict__['_AES256CTREncryptor__protect'][i] == 0 + + +def test_contextmgr(): + with cryptbase.cryptbase.AES256CTREncryptor(bytearray(secrets.token_bytes(key_bytes))) as encryptor: + e = encryptor + assert '_AES256CTREncryptor__key' in encryptor.__dict__.keys() + assert '_AES256CTREncryptor__protect' in encryptor.__dict__.keys() + for i in range(key_bytes): + assert e.__dict__['_AES256CTREncryptor__key'][i] == 0 + assert e.__dict__['_AES256CTREncryptor__protect'][i] == 0 + + +def test_crypto(): for _ in range(16): - key = secrets.token_bytes(key_bytes) - encryptor = cryptbase.cryptbase.AES256CTREncryptor(key) + encryptor = cryptbase.cryptbase.AES256CTREncryptor(bytearray(secrets.token_bytes(key_bytes))) for _ in range(32): plaintext = ''.join(random.choices(population, k=length)).encode('ascii') assert plaintext == encryptor.decrypt(*encryptor.encrypt(plaintext)) @@ -46,36 +73,48 @@ def test_encryptor(): pass -def test_containter_encrypt(): - aes_n = 256 +def test_container_bad_key_length(): + aes_n = 128 key_bytes = int(aes_n / 8) - population = string.ascii_letters + string.digits + string.punctuation - length = 2048 + try: + cryptbase.cryptbase.AES256CTREncryptor(bytearray(secrets.token_bytes(key_bytes))) + assert False + except ValueError: + pass + aes_n = 512 + key_bytes = int(aes_n / 8) + try: + cryptbase.cryptbase.AES256CTREncryptor(bytearray(secrets.token_bytes(key_bytes))) + assert False + except ValueError: + pass + + +def test_containter_encrypt(): for _ in range(16): - key = secrets.token_bytes(key_bytes) - encryptor = cryptbase.cryptbase.AES256CTREncryptor(key) - for _ in range(32): - plaintext = ''.join(random.choices(population, k=length)) - container = cryptbase.cryptbase.CryptoContainer(encryptor=encryptor, plaintext=plaintext) + with cryptbase.cryptbase.AES256CTREncryptor(bytearray(secrets.token_bytes(key_bytes))) as encryptor: + for _ in range(32): + plaintext = ''.join(random.choices(population, k=length)) + container = cryptbase.cryptbase.CryptoContainer(encryptor=encryptor, plaintext=plaintext) - assert container.plaintext == plaintext - assert str(container) == '@'.join([container.ciphertext, container.iv]) + assert container.plaintext == plaintext + assert str(container) == '@'.join([container.ciphertext, container.iv]) - cipher_b64 = container.ciphertext - iv_b64 = container.iv + cipher_b64 = container.ciphertext + iv_b64 = container.iv - ciphertext = base64.b64decode(cipher_b64) - iv = base64.b64decode(iv_b64) - decrypted = encryptor.decrypt(ciphertext, iv).decode('utf-8') - assert type(plaintext) == type(decrypted) - assert len(plaintext) == len(decrypted) - assert plaintext == decrypted + ciphertext = base64.b64decode(cipher_b64) + iv = base64.b64decode(iv_b64) + decrypted = encryptor.decrypt(ciphertext, iv).decode('utf-8') + assert type(plaintext) == type(decrypted) + assert len(plaintext) == len(decrypted) + assert plaintext == decrypted - decontainer = cryptbase.cryptbase.CryptoContainer.from_encrypted(str(container), encryptor) - assert decontainer.plaintext == plaintext - assert decontainer.iv == iv_b64 - assert decontainer.ciphertext == cipher_b64 + decontainer = cryptbase.cryptbase.CryptoContainer.from_encrypted(str(container), encryptor) + assert decontainer.plaintext == plaintext + assert decontainer.iv == iv_b64 + assert decontainer.ciphertext == cipher_b64 def test_container_init():