import pickle
from base64 import b64decode, b64encode
from django.core import checks
from django.db import models
from django.utils.encoding import force_bytes
from django.utils.translation import ugettext_lazy as _
from django_cryptography.core.signing import SignatureExpired
from django_cryptography.utils.crypto import FernetBytes
FIELD_CACHE = {}
Expired = object()
"""Represents an expired encryption value."""
[docs]class PickledField(models.Field):
"""
A field for storing pickled objects
"""
description = _("Pickled data")
empty_values = [None, b'']
supported_lookups = ('exact', 'in', 'isnull')
def __init__(self, *args, **kwargs):
kwargs['editable'] = False
super(PickledField, self).__init__(*args, **kwargs)
def _dump(self, value):
return pickle.dumps(value)
def _load(self, value):
return pickle.loads(value)
[docs] def deconstruct(self):
name, path, args, kwargs = super(PickledField, self).deconstruct()
del kwargs['editable']
return name, path, args, kwargs
def get_internal_type(self):
return "BinaryField"
[docs] def get_default(self):
default = super(PickledField, self).get_default()
if default == '':
return b''
return default
def get_lookup(self, lookup_name):
if lookup_name not in self.supported_lookups:
return
return super(PickledField, self).get_lookup(lookup_name)
def get_transform(self, lookup_name):
if lookup_name not in self.supported_lookups:
return
return super(PickledField, self).get_transform(lookup_name)
[docs] def get_db_prep_value(self, value, connection, prepared=False):
value = super(PickledField, self).get_db_prep_value(
value, connection, prepared)
if value is not None:
return connection.Database.Binary(self._dump(value))
return value
def from_db_value(self, value, *args, **kwargs):
if value is not None:
return self._load(force_bytes(value))
return value
[docs] def value_to_string(self, obj):
"""Pickled data is serialized as base64"""
value = self.value_from_object(obj)
return b64encode(self._dump(value)).decode('ascii')
[docs] def to_python(self, value):
# If it's a string, it should be base64-encoded data
if isinstance(value, str):
return self._load(b64decode(force_bytes(value)))
return value
[docs]class EncryptedMixin:
"""
A field mixin storing encrypted data
:param bytes key: This is an optional argument.
Allows for specifying an instance specific encryption key.
:param int ttl: This is an optional argument.
The amount of time in seconds that a value can be stored for. If the
time to live of the data has passed, it will become unreadable.
The expired value will return an :class:`Expired` object.
"""
supported_lookups = ('isnull', )
def __init__(self, *args, **kwargs):
self.key = kwargs.pop('key', None)
self.ttl = kwargs.pop('ttl', None)
self._fernet = FernetBytes(self.key)
super(EncryptedMixin, self).__init__(*args, **kwargs)
@property
def description(self):
return _('Encrypted %s') % super(EncryptedMixin, self).description
def _dump(self, value):
return self._fernet.encrypt(pickle.dumps(value))
def _load(self, value):
try:
return pickle.loads(self._fernet.decrypt(value, self.ttl))
except SignatureExpired:
return Expired
def check(self, **kwargs):
errors = super(EncryptedMixin, self).check(**kwargs)
if getattr(self, 'remote_field', None):
errors.append(
checks.Error(
'Base field for encrypted cannot be a related field.',
hint=None,
obj=self,
id='encrypted.E002'))
return errors
def clone(self):
name, path, args, kwargs = super(EncryptedMixin, self).deconstruct()
# Determine if the class that subclassed us has been subclassed.
if not self.__class__.__mro__.index(EncryptedMixin) > 1:
return encrypt(
self.base_class(*args, **kwargs), self.key, self.ttl)
return self.__class__(*args, **kwargs)
def deconstruct(self):
name, path, args, kwargs = super(EncryptedMixin, self).deconstruct()
# Determine if the class that subclassed us has been subclassed.
if not self.__class__.__mro__.index(EncryptedMixin) > 1:
path = "%s.%s" % (encrypt.__module__, encrypt.__name__)
args = [self.base_class(*args, **kwargs)]
kwargs = {}
if self.ttl is not None:
kwargs['ttl'] = self.ttl
return name, path, args, kwargs
def get_lookup(self, lookup_name):
if lookup_name not in self.supported_lookups:
return
return super(EncryptedMixin, self).get_lookup(lookup_name)
def get_transform(self, lookup_name):
if lookup_name not in self.supported_lookups:
return
return super(EncryptedMixin, self).get_transform(lookup_name)
def get_internal_type(self):
return "BinaryField"
def get_db_prep_value(self, value, connection, prepared=False):
value = models.Field.get_db_prep_value(self, value, connection,
prepared)
if value is not None:
return connection.Database.Binary(self._dump(value))
return value
get_db_prep_save = models.Field.get_db_prep_save
def from_db_value(self, value, *args, **kwargs):
if value is not None:
return self._load(force_bytes(value))
return value
[docs]def get_encrypted_field(base_class):
"""
A get or create method for encrypted fields, we cache the field in
the module to avoid recreation. This also allows us to always return
the same class reference for a field.
:type base_class: models.Field[T]
:rtype: models.Field[EncryptedMixin, T]
"""
assert not isinstance(base_class, models.Field)
field_name = 'Encrypted' + base_class.__name__
if base_class not in FIELD_CACHE:
FIELD_CACHE[base_class] = type(field_name,
(EncryptedMixin, base_class), {
'base_class': base_class,
})
return FIELD_CACHE[base_class]
[docs]def encrypt(base_field, key=None, ttl=None):
"""
A decorator for creating encrypted model fields.
:type base_field: models.Field[T]
:param bytes key: This is an optional argument.
Allows for specifying an instance specific encryption key.
:param int ttl: This is an optional argument.
The amount of time in seconds that a value can be stored for. If the
time to live of the data has passed, it will become unreadable.
The expired value will return an :class:`Expired` object.
:rtype: models.Field[EncryptedMixin, T]
"""
if not isinstance(base_field, models.Field):
assert key is None
assert ttl is None
return get_encrypted_field(base_field)
name, path, args, kwargs = base_field.deconstruct()
kwargs.update({'key': key, 'ttl': ttl})
return get_encrypted_field(type(base_field))(*args, **kwargs)