322 lines
11 KiB
Python
322 lines
11 KiB
Python
|
# Copyright: See the LICENSE file.
|
|||
|
|
|||
|
|
|||
|
"""factory_boy extensions for use with the Django framework."""
|
|||
|
|
|||
|
|
|||
|
import functools
|
|||
|
import io
|
|||
|
import logging
|
|||
|
import os
|
|||
|
|
|||
|
from django.core import files as django_files
|
|||
|
from django.db import IntegrityError
|
|||
|
|
|||
|
from . import base, declarations, errors
|
|||
|
|
|||
|
logger = logging.getLogger('factory.generate')
|
|||
|
|
|||
|
|
|||
|
DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS
|
|||
|
|
|||
|
|
|||
|
_LAZY_LOADS = {}
|
|||
|
|
|||
|
|
|||
|
def get_model(app, model):
|
|||
|
"""Wrapper around django's get_model."""
|
|||
|
if 'get_model' not in _LAZY_LOADS:
|
|||
|
_lazy_load_get_model()
|
|||
|
|
|||
|
_get_model = _LAZY_LOADS['get_model']
|
|||
|
return _get_model(app, model)
|
|||
|
|
|||
|
|
|||
|
def _lazy_load_get_model():
|
|||
|
"""Lazy loading of get_model.
|
|||
|
|
|||
|
get_model loads django.conf.settings, which may fail if
|
|||
|
the settings haven't been configured yet.
|
|||
|
"""
|
|||
|
from django import apps as django_apps
|
|||
|
_LAZY_LOADS['get_model'] = django_apps.apps.get_model
|
|||
|
|
|||
|
|
|||
|
class DjangoOptions(base.FactoryOptions):
|
|||
|
def _build_default_options(self):
|
|||
|
return super()._build_default_options() + [
|
|||
|
base.OptionDefault('django_get_or_create', (), inherit=True),
|
|||
|
base.OptionDefault('database', DEFAULT_DB_ALIAS, inherit=True),
|
|||
|
]
|
|||
|
|
|||
|
def _get_counter_reference(self):
|
|||
|
counter_reference = super()._get_counter_reference()
|
|||
|
if (counter_reference == self.base_factory
|
|||
|
and self.base_factory._meta.model is not None
|
|||
|
and self.base_factory._meta.model._meta.abstract
|
|||
|
and self.model is not None
|
|||
|
and not self.model._meta.abstract):
|
|||
|
# Target factory is for an abstract model, yet we're for another,
|
|||
|
# concrete subclass => don't reuse the counter.
|
|||
|
return self.factory
|
|||
|
return counter_reference
|
|||
|
|
|||
|
def get_model_class(self):
|
|||
|
if isinstance(self.model, str) and '.' in self.model:
|
|||
|
app, model_name = self.model.split('.', 1)
|
|||
|
self.model = get_model(app, model_name)
|
|||
|
|
|||
|
return self.model
|
|||
|
|
|||
|
|
|||
|
class DjangoModelFactory(base.Factory):
|
|||
|
"""Factory for Django models.
|
|||
|
|
|||
|
This makes sure that the 'sequence' field of created objects is a new id.
|
|||
|
|
|||
|
Possible improvement: define a new 'attribute' type, AutoField, which would
|
|||
|
handle those for non-numerical primary keys.
|
|||
|
"""
|
|||
|
|
|||
|
_options_class = DjangoOptions
|
|||
|
|
|||
|
class Meta:
|
|||
|
abstract = True # Optional, but explicit.
|
|||
|
|
|||
|
@classmethod
|
|||
|
def _load_model_class(cls, definition):
|
|||
|
|
|||
|
if isinstance(definition, str) and '.' in definition:
|
|||
|
app, model = definition.split('.', 1)
|
|||
|
return get_model(app, model)
|
|||
|
|
|||
|
return definition
|
|||
|
|
|||
|
@classmethod
|
|||
|
def _get_manager(cls, model_class):
|
|||
|
if model_class is None:
|
|||
|
raise errors.AssociatedClassError(
|
|||
|
f"No model set on {cls.__module__}.{cls.__name__}.Meta")
|
|||
|
|
|||
|
try:
|
|||
|
manager = model_class.objects
|
|||
|
except AttributeError:
|
|||
|
# When inheriting from an abstract model with a custom
|
|||
|
# manager, the class has no 'objects' field.
|
|||
|
manager = model_class._default_manager
|
|||
|
|
|||
|
if cls._meta.database != DEFAULT_DB_ALIAS:
|
|||
|
manager = manager.using(cls._meta.database)
|
|||
|
return manager
|
|||
|
|
|||
|
@classmethod
|
|||
|
def _generate(cls, strategy, params):
|
|||
|
# Original params are used in _get_or_create if it cannot build an
|
|||
|
# object initially due to an IntegrityError being raised
|
|||
|
cls._original_params = params
|
|||
|
return super()._generate(strategy, params)
|
|||
|
|
|||
|
@classmethod
|
|||
|
def _get_or_create(cls, model_class, *args, **kwargs):
|
|||
|
"""Create an instance of the model through objects.get_or_create."""
|
|||
|
manager = cls._get_manager(model_class)
|
|||
|
|
|||
|
assert 'defaults' not in cls._meta.django_get_or_create, (
|
|||
|
"'defaults' is a reserved keyword for get_or_create "
|
|||
|
"(in %s._meta.django_get_or_create=%r)"
|
|||
|
% (cls, cls._meta.django_get_or_create))
|
|||
|
|
|||
|
key_fields = {}
|
|||
|
for field in cls._meta.django_get_or_create:
|
|||
|
if field not in kwargs:
|
|||
|
raise errors.FactoryError(
|
|||
|
"django_get_or_create - "
|
|||
|
"Unable to find initialization value for '%s' in factory %s" %
|
|||
|
(field, cls.__name__))
|
|||
|
key_fields[field] = kwargs.pop(field)
|
|||
|
key_fields['defaults'] = kwargs
|
|||
|
|
|||
|
try:
|
|||
|
instance, _created = manager.get_or_create(*args, **key_fields)
|
|||
|
except IntegrityError as e:
|
|||
|
get_or_create_params = {
|
|||
|
lookup: value
|
|||
|
for lookup, value in cls._original_params.items()
|
|||
|
if lookup in cls._meta.django_get_or_create
|
|||
|
}
|
|||
|
if get_or_create_params:
|
|||
|
try:
|
|||
|
instance = manager.get(**get_or_create_params)
|
|||
|
except manager.model.DoesNotExist:
|
|||
|
# Original params are not a valid lookup and triggered a create(),
|
|||
|
# that resulted in an IntegrityError. Follow Django’s behavior.
|
|||
|
raise e
|
|||
|
else:
|
|||
|
raise e
|
|||
|
|
|||
|
return instance
|
|||
|
|
|||
|
@classmethod
|
|||
|
def _create(cls, model_class, *args, **kwargs):
|
|||
|
"""Create an instance of the model, and save it to the database."""
|
|||
|
if cls._meta.django_get_or_create:
|
|||
|
return cls._get_or_create(model_class, *args, **kwargs)
|
|||
|
|
|||
|
manager = cls._get_manager(model_class)
|
|||
|
return manager.create(*args, **kwargs)
|
|||
|
|
|||
|
@classmethod
|
|||
|
def _after_postgeneration(cls, instance, create, results=None):
|
|||
|
"""Save again the instance if creating and at least one hook ran."""
|
|||
|
if create and results:
|
|||
|
# Some post-generation hooks ran, and may have modified us.
|
|||
|
instance.save()
|
|||
|
|
|||
|
|
|||
|
class FileField(declarations.BaseDeclaration):
|
|||
|
"""Helper to fill in django.db.models.FileField from a Factory."""
|
|||
|
|
|||
|
DEFAULT_FILENAME = 'example.dat'
|
|||
|
|
|||
|
def _make_data(self, params):
|
|||
|
"""Create data for the field."""
|
|||
|
return params.get('data', b'')
|
|||
|
|
|||
|
def _make_content(self, params):
|
|||
|
path = ''
|
|||
|
|
|||
|
_content_params = [params.get('from_path'), params.get('from_file'), params.get('from_func')]
|
|||
|
if len([p for p in _content_params if p]) > 1:
|
|||
|
raise ValueError(
|
|||
|
"At most one argument from 'from_file', 'from_path', and 'from_func' should "
|
|||
|
"be non-empty when calling factory.django.FileField."
|
|||
|
)
|
|||
|
|
|||
|
if params.get('from_path'):
|
|||
|
path = params['from_path']
|
|||
|
with open(path, 'rb') as f:
|
|||
|
content = django_files.base.ContentFile(f.read())
|
|||
|
|
|||
|
elif params.get('from_file'):
|
|||
|
f = params['from_file']
|
|||
|
content = django_files.File(f)
|
|||
|
path = content.name
|
|||
|
|
|||
|
elif params.get('from_func'):
|
|||
|
func = params['from_func']
|
|||
|
content = django_files.File(func())
|
|||
|
path = content.name
|
|||
|
|
|||
|
else:
|
|||
|
data = self._make_data(params)
|
|||
|
content = django_files.base.ContentFile(data)
|
|||
|
|
|||
|
if path:
|
|||
|
default_filename = os.path.basename(path)
|
|||
|
else:
|
|||
|
default_filename = self.DEFAULT_FILENAME
|
|||
|
|
|||
|
filename = params.get('filename', default_filename)
|
|||
|
return filename, content
|
|||
|
|
|||
|
def evaluate(self, instance, step, extra):
|
|||
|
"""Fill in the field."""
|
|||
|
filename, content = self._make_content(extra)
|
|||
|
return django_files.File(content.file, filename)
|
|||
|
|
|||
|
|
|||
|
class ImageField(FileField):
|
|||
|
DEFAULT_FILENAME = 'example.jpg'
|
|||
|
|
|||
|
def _make_data(self, params):
|
|||
|
# ImageField (both django's and factory_boy's) require PIL.
|
|||
|
# Try to import it along one of its known installation paths.
|
|||
|
from PIL import Image
|
|||
|
|
|||
|
width = params.get('width', 100)
|
|||
|
height = params.get('height', width)
|
|||
|
color = params.get('color', 'blue')
|
|||
|
image_format = params.get('format', 'JPEG')
|
|||
|
image_palette = params.get('palette', 'RGB')
|
|||
|
|
|||
|
thumb_io = io.BytesIO()
|
|||
|
with Image.new(image_palette, (width, height), color) as thumb:
|
|||
|
thumb.save(thumb_io, format=image_format)
|
|||
|
return thumb_io.getvalue()
|
|||
|
|
|||
|
|
|||
|
class mute_signals:
|
|||
|
"""Temporarily disables and then restores any django signals.
|
|||
|
|
|||
|
Args:
|
|||
|
*signals (django.dispatch.dispatcher.Signal): any django signals
|
|||
|
|
|||
|
Examples:
|
|||
|
with mute_signals(pre_init):
|
|||
|
user = UserFactory.build()
|
|||
|
...
|
|||
|
|
|||
|
@mute_signals(pre_save, post_save)
|
|||
|
class UserFactory(factory.Factory):
|
|||
|
...
|
|||
|
|
|||
|
@mute_signals(post_save)
|
|||
|
def generate_users():
|
|||
|
UserFactory.create_batch(10)
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self, *signals):
|
|||
|
self.signals = signals
|
|||
|
self.paused = {}
|
|||
|
|
|||
|
def __enter__(self):
|
|||
|
for signal in self.signals:
|
|||
|
logger.debug('mute_signals: Disabling signal handlers %r',
|
|||
|
signal.receivers)
|
|||
|
|
|||
|
# Note that we're using implementation details of
|
|||
|
# django.signals, since arguments to signal.connect()
|
|||
|
# are lost in signal.receivers
|
|||
|
self.paused[signal] = signal.receivers
|
|||
|
signal.receivers = []
|
|||
|
|
|||
|
def __exit__(self, exc_type, exc_value, traceback):
|
|||
|
for signal, receivers in self.paused.items():
|
|||
|
logger.debug('mute_signals: Restoring signal handlers %r',
|
|||
|
receivers)
|
|||
|
|
|||
|
signal.receivers += receivers
|
|||
|
with signal.lock:
|
|||
|
# Django uses some caching for its signals.
|
|||
|
# Since we're bypassing signal.connect and signal.disconnect,
|
|||
|
# we have to keep messing with django's internals.
|
|||
|
signal.sender_receivers_cache.clear()
|
|||
|
self.paused = {}
|
|||
|
|
|||
|
def copy(self):
|
|||
|
return mute_signals(*self.signals)
|
|||
|
|
|||
|
def __call__(self, callable_obj):
|
|||
|
if isinstance(callable_obj, base.FactoryMetaClass):
|
|||
|
# Retrieve __func__, the *actual* callable object.
|
|||
|
callable_obj._create = self.wrap_method(callable_obj._create.__func__)
|
|||
|
callable_obj._generate = self.wrap_method(callable_obj._generate.__func__)
|
|||
|
return callable_obj
|
|||
|
|
|||
|
else:
|
|||
|
@functools.wraps(callable_obj)
|
|||
|
def wrapper(*args, **kwargs):
|
|||
|
# A mute_signals() object is not reentrant; use a copy every time.
|
|||
|
with self.copy():
|
|||
|
return callable_obj(*args, **kwargs)
|
|||
|
return wrapper
|
|||
|
|
|||
|
def wrap_method(self, method):
|
|||
|
@classmethod
|
|||
|
@functools.wraps(method)
|
|||
|
def wrapped_method(*args, **kwargs):
|
|||
|
# A mute_signals() object is not reentrant; use a copy every time.
|
|||
|
with self.copy():
|
|||
|
return method(*args, **kwargs)
|
|||
|
return wrapped_method
|