Key management

This commit is contained in:
2021-10-16 14:58:48 +02:00
parent 1d39c27690
commit b1564d52e3
3 changed files with 112 additions and 42 deletions

View File

@@ -1,5 +1,6 @@
__version__ = '0.0.0' __version__ = '0.0.0'
import gc
from .cryptbase import AES256CTREncryptor from .cryptbase import AES256CTREncryptor
from .cryptbase import CryptoContainer from .cryptbase import CryptoContainer
@@ -21,9 +22,9 @@ try:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if 'key' in kwargs: if 'key' in kwargs:
self.__key = kwargs['key'] self.encryptor = AES256CTREncryptor(bytearray.fromhex(kwargs['key']))
del kwargs['key'] del kwargs['key']
self.encryptor = AES256CTREncryptor(bytes.fromhex(self.__key)) gc.collect()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def get_db_prep_value(self, value, *args, **kwargs): def get_db_prep_value(self, value, *args, **kwargs):
@@ -64,9 +65,9 @@ try:
impl = types.TEXT impl = types.TEXT
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.__key = kwargs['key'] self.encryptor = AES256CTREncryptor(bytearray.fromhex(kwargs['key']))
del kwargs['key'] del kwargs['key']
self.encryptor = AES256CTREncryptor(bytes.fromhex(self.__key)) gc.collect()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def process_bind_param(self, value, dialect): def process_bind_param(self, value, dialect):

View File

@@ -26,22 +26,52 @@ class Encryptor(ABC):
class AES256CTREncryptor(Encryptor): class AES256CTREncryptor(Encryptor):
"""Encryptor using AES with 256-bit long keys using CTR mode.""" """Encryptor using AES with 256-bit long keys using CTR mode."""
def __init__(self, key: bytes): KEY_LENGTH = 32
if not isinstance(key, bytes): IV_LENGTH = 16
def __init__(self, key: bytearray):
if not isinstance(key, bytearray):
raise ValueError() raise ValueError()
if not len(key) == 32: if not len(key) == AES256CTREncryptor.KEY_LENGTH:
raise ValueError() raise ValueError()
self.__key = key self.__protect = bytearray(secrets.token_bytes(AES256CTREncryptor.KEY_LENGTH))
self.__key = AES256CTREncryptor.__obfuscate(key, self.__protect)
def __del__(self):
self.clear()
def __enter__(self):
return self
def __exit__(self, *args):
self.clear()
def clear(self):
if '_AES256CTREncryptor__key' in self.__dict__.keys():
for i in range(len(self.__key)):
self.__key[i] = 0
if '_AES256CTREncryptor__protect' in self.__dict__.keys():
for i in range(len(self.__protect)):
self.__protect[i] = 0
@staticmethod
def __obfuscate(secret: bytearray, key: bytearray) -> bytearray:
if not len(secret) == len(key):
raise ValueError()
result = bytearray(len(key))
for i in range(len(secret)):
result[i] = secret[i] ^ key[i]
return result
def encrypt(self, plaintext: bytes) -> Tuple[bytes, bytes]: def encrypt(self, plaintext: bytes) -> Tuple[bytes, bytes]:
iv = secrets.token_bytes(16) iv = secrets.token_bytes(AES256CTREncryptor.IV_LENGTH)
cipher = Cipher(algorithms.AES(self.__key), modes.CTR(iv)) cipher = Cipher(algorithms.AES(AES256CTREncryptor.__obfuscate(self.__key, self.__protect)), modes.CTR(iv))
encryptor = cipher.encryptor() encryptor = cipher.encryptor()
ciphertext = encryptor.update(plaintext) + encryptor.finalize() ciphertext = encryptor.update(plaintext) + encryptor.finalize()
return ciphertext, iv return ciphertext, iv
def decrypt(self, ciphertext: bytes, iv: bytes) -> bytes: def decrypt(self, ciphertext: bytes, iv: bytes) -> bytes:
cipher = Cipher(algorithms.AES(self.__key), modes.CTR(iv)) cipher = Cipher(algorithms.AES(AES256CTREncryptor.__obfuscate(self.__key, self.__protect)), modes.CTR(iv))
decryptor = cipher.decryptor() decryptor = cipher.decryptor()
return decryptor.update(ciphertext) + decryptor.finalize() return decryptor.update(ciphertext) + decryptor.finalize()
@@ -63,7 +93,7 @@ class CryptoContainer:
if 'encryptor' not in kwargs: if 'encryptor' not in kwargs:
raise ValueError() raise ValueError()
if not isinstance(kwargs['encryptor'], Encryptor): if not isinstance(kwargs['encryptor'], Encryptor):
raise ValueError() raise ValueError(type(kwargs['encryptor']))
if 'plaintext' in kwargs: if 'plaintext' in kwargs:
self.__plaintext = kwargs['plaintext'] self.__plaintext = kwargs['plaintext']

View File

@@ -6,20 +6,47 @@ import string
import cryptbase import cryptbase
aes_n = 256
key_bytes = int(aes_n / 8)
population = string.ascii_letters + string.digits + string.punctuation
length = 2048
def test_main(): def test_main():
assert cryptbase.__version__ == '0.0.0' assert cryptbase.__version__ == '0.0.0'
def test_crypto(): def test_string_key():
aes_n = 256 try:
key_bytes = int(aes_n / 8) cryptbase.cryptbase.AES256CTREncryptor(secrets.token_bytes(key_bytes))
population = string.ascii_letters + string.digits + string.punctuation assert False
length = 2048 except ValueError:
pass
def test_clear():
encryptor = cryptbase.cryptbase.AES256CTREncryptor(bytearray(secrets.token_bytes(key_bytes)))
assert '_AES256CTREncryptor__key' in encryptor.__dict__.keys()
assert '_AES256CTREncryptor__protect' in encryptor.__dict__.keys()
encryptor.clear()
for i in range(key_bytes):
assert encryptor.__dict__['_AES256CTREncryptor__key'][i] == 0
assert encryptor.__dict__['_AES256CTREncryptor__protect'][i] == 0
def test_contextmgr():
with cryptbase.cryptbase.AES256CTREncryptor(bytearray(secrets.token_bytes(key_bytes))) as encryptor:
e = encryptor
assert '_AES256CTREncryptor__key' in encryptor.__dict__.keys()
assert '_AES256CTREncryptor__protect' in encryptor.__dict__.keys()
for i in range(key_bytes):
assert e.__dict__['_AES256CTREncryptor__key'][i] == 0
assert e.__dict__['_AES256CTREncryptor__protect'][i] == 0
def test_crypto():
for _ in range(16): for _ in range(16):
key = secrets.token_bytes(key_bytes) encryptor = cryptbase.cryptbase.AES256CTREncryptor(bytearray(secrets.token_bytes(key_bytes)))
encryptor = cryptbase.cryptbase.AES256CTREncryptor(key)
for _ in range(32): for _ in range(32):
plaintext = ''.join(random.choices(population, k=length)).encode('ascii') plaintext = ''.join(random.choices(population, k=length)).encode('ascii')
assert plaintext == encryptor.decrypt(*encryptor.encrypt(plaintext)) assert plaintext == encryptor.decrypt(*encryptor.encrypt(plaintext))
@@ -46,15 +73,27 @@ def test_encryptor():
pass pass
def test_containter_encrypt(): def test_container_bad_key_length():
aes_n = 256 aes_n = 128
key_bytes = int(aes_n / 8) key_bytes = int(aes_n / 8)
population = string.ascii_letters + string.digits + string.punctuation try:
length = 2048 cryptbase.cryptbase.AES256CTREncryptor(bytearray(secrets.token_bytes(key_bytes)))
assert False
except ValueError:
pass
aes_n = 512
key_bytes = int(aes_n / 8)
try:
cryptbase.cryptbase.AES256CTREncryptor(bytearray(secrets.token_bytes(key_bytes)))
assert False
except ValueError:
pass
def test_containter_encrypt():
for _ in range(16): for _ in range(16):
key = secrets.token_bytes(key_bytes) with cryptbase.cryptbase.AES256CTREncryptor(bytearray(secrets.token_bytes(key_bytes))) as encryptor:
encryptor = cryptbase.cryptbase.AES256CTREncryptor(key)
for _ in range(32): for _ in range(32):
plaintext = ''.join(random.choices(population, k=length)) plaintext = ''.join(random.choices(population, k=length))
container = cryptbase.cryptbase.CryptoContainer(encryptor=encryptor, plaintext=plaintext) container = cryptbase.cryptbase.CryptoContainer(encryptor=encryptor, plaintext=plaintext)