366 lines
12 KiB
Python
366 lines
12 KiB
Python
"""Build factory instances."""
|
|
|
|
import collections
|
|
|
|
from . import enums, errors, utils
|
|
|
|
DeclarationWithContext = collections.namedtuple(
|
|
'DeclarationWithContext',
|
|
['name', 'declaration', 'context'],
|
|
)
|
|
|
|
|
|
class DeclarationSet:
|
|
"""A set of declarations, including the recursive parameters.
|
|
|
|
Attributes:
|
|
declarations (dict(name => declaration)): the top-level declarations
|
|
contexts (dict(name => dict(subfield => value))): the nested parameters related
|
|
to a given top-level declaration
|
|
|
|
This object behaves similarly to a dict mapping a top-level declaration name to a
|
|
DeclarationWithContext, containing field name, declaration object and extra context.
|
|
"""
|
|
|
|
def __init__(self, initial=None):
|
|
self.declarations = {}
|
|
self.contexts = collections.defaultdict(dict)
|
|
self.update(initial or {})
|
|
|
|
@classmethod
|
|
def split(cls, entry):
|
|
"""Split a declaration name into a (declaration, subpath) tuple.
|
|
|
|
Examples:
|
|
>>> DeclarationSet.split('foo__bar')
|
|
('foo', 'bar')
|
|
>>> DeclarationSet.split('foo')
|
|
('foo', None)
|
|
>>> DeclarationSet.split('foo__bar__baz')
|
|
('foo', 'bar__baz')
|
|
"""
|
|
if enums.SPLITTER in entry:
|
|
return entry.split(enums.SPLITTER, 1)
|
|
else:
|
|
return (entry, None)
|
|
|
|
@classmethod
|
|
def join(cls, root, subkey):
|
|
"""Rebuild a full declaration name from its components.
|
|
|
|
for every string x, we have `join(split(x)) == x`.
|
|
"""
|
|
if subkey is None:
|
|
return root
|
|
return enums.SPLITTER.join((root, subkey))
|
|
|
|
def copy(self):
|
|
return self.__class__(self.as_dict())
|
|
|
|
def update(self, values):
|
|
"""Add new declarations to this set/
|
|
|
|
Args:
|
|
values (dict(name, declaration)): the declarations to ingest.
|
|
"""
|
|
for k, v in values.items():
|
|
root, sub = self.split(k)
|
|
if sub is None:
|
|
self.declarations[root] = v
|
|
else:
|
|
self.contexts[root][sub] = v
|
|
|
|
extra_context_keys = set(self.contexts) - set(self.declarations)
|
|
if extra_context_keys:
|
|
raise errors.InvalidDeclarationError(
|
|
"Received deep context for unknown fields: %r (known=%r)" % (
|
|
{
|
|
self.join(root, sub): v
|
|
for root in extra_context_keys
|
|
for sub, v in self.contexts[root].items()
|
|
},
|
|
sorted(self.declarations),
|
|
)
|
|
)
|
|
|
|
def filter(self, entries):
|
|
"""Filter a set of declarations: keep only those related to this object.
|
|
|
|
This will keep:
|
|
- Declarations that 'override' the current ones
|
|
- Declarations that are parameters to current ones
|
|
"""
|
|
return [
|
|
entry for entry in entries
|
|
if self.split(entry)[0] in self.declarations
|
|
]
|
|
|
|
def sorted(self):
|
|
return utils.sort_ordered_objects(
|
|
self.declarations,
|
|
getter=lambda entry: self.declarations[entry],
|
|
)
|
|
|
|
def __contains__(self, key):
|
|
return key in self.declarations
|
|
|
|
def __getitem__(self, key):
|
|
return DeclarationWithContext(
|
|
name=key,
|
|
declaration=self.declarations[key],
|
|
context=self.contexts[key],
|
|
)
|
|
|
|
def __iter__(self):
|
|
return iter(self.declarations)
|
|
|
|
def values(self):
|
|
"""Retrieve the list of declarations, with their context."""
|
|
for name in self:
|
|
yield self[name]
|
|
|
|
def _items(self):
|
|
"""Extract a list of (key, value) pairs, suitable for our __init__."""
|
|
for name in self.declarations:
|
|
yield name, self.declarations[name]
|
|
for subkey, value in self.contexts[name].items():
|
|
yield self.join(name, subkey), value
|
|
|
|
def as_dict(self):
|
|
"""Return a dict() suitable for our __init__."""
|
|
return dict(self._items())
|
|
|
|
def __repr__(self):
|
|
return '<DeclarationSet: %r>' % self.as_dict()
|
|
|
|
|
|
def parse_declarations(decls, base_pre=None, base_post=None):
|
|
pre_declarations = base_pre.copy() if base_pre else DeclarationSet()
|
|
post_declarations = base_post.copy() if base_post else DeclarationSet()
|
|
|
|
# Inject extra declarations, splitting between known-to-be-post and undetermined
|
|
extra_post = {}
|
|
extra_maybenonpost = {}
|
|
for k, v in decls.items():
|
|
if enums.get_builder_phase(v) == enums.BuilderPhase.POST_INSTANTIATION:
|
|
if k in pre_declarations:
|
|
# Conflict: PostGenerationDeclaration with the same
|
|
# name as a BaseDeclaration
|
|
raise errors.InvalidDeclarationError(
|
|
"PostGenerationDeclaration %s=%r shadows declaration %r"
|
|
% (k, v, pre_declarations[k])
|
|
)
|
|
extra_post[k] = v
|
|
elif k in post_declarations:
|
|
# Passing in a scalar value to a PostGenerationDeclaration
|
|
# Set it as `key__`
|
|
magic_key = post_declarations.join(k, '')
|
|
extra_post[magic_key] = v
|
|
else:
|
|
extra_maybenonpost[k] = v
|
|
|
|
# Start with adding new post-declarations
|
|
post_declarations.update(extra_post)
|
|
|
|
# Fill in extra post-declaration context
|
|
post_overrides = post_declarations.filter(extra_maybenonpost)
|
|
post_declarations.update({
|
|
k: v
|
|
for k, v in extra_maybenonpost.items()
|
|
if k in post_overrides
|
|
})
|
|
|
|
# Anything else is pre_declarations
|
|
pre_declarations.update({
|
|
k: v
|
|
for k, v in extra_maybenonpost.items()
|
|
if k not in post_overrides
|
|
})
|
|
|
|
return pre_declarations, post_declarations
|
|
|
|
|
|
class BuildStep:
|
|
def __init__(self, builder, sequence, parent_step=None):
|
|
self.builder = builder
|
|
self.sequence = sequence
|
|
self.attributes = {}
|
|
self.parent_step = parent_step
|
|
self.stub = None
|
|
|
|
def resolve(self, declarations):
|
|
self.stub = Resolver(
|
|
declarations=declarations,
|
|
step=self,
|
|
sequence=self.sequence,
|
|
)
|
|
|
|
for field_name in declarations:
|
|
self.attributes[field_name] = getattr(self.stub, field_name)
|
|
|
|
@property
|
|
def chain(self):
|
|
if self.parent_step:
|
|
parent_chain = self.parent_step.chain
|
|
else:
|
|
parent_chain = ()
|
|
return (self.stub,) + parent_chain
|
|
|
|
def recurse(self, factory, declarations, force_sequence=None):
|
|
from . import base
|
|
if not issubclass(factory, base.BaseFactory):
|
|
raise errors.AssociatedClassError(
|
|
"%r: Attempting to recursing into a non-factory object %r"
|
|
% (self, factory))
|
|
builder = self.builder.recurse(factory._meta, declarations)
|
|
return builder.build(parent_step=self, force_sequence=force_sequence)
|
|
|
|
def __repr__(self):
|
|
return f"<BuildStep for {self.builder!r}>"
|
|
|
|
|
|
class StepBuilder:
|
|
"""A factory instantiation step.
|
|
|
|
Attributes:
|
|
- parent: the parent StepBuilder, or None for the root step
|
|
- extras: the passed-in kwargs for this branch
|
|
- factory: the factory class being built
|
|
- strategy: the strategy to use
|
|
"""
|
|
def __init__(self, factory_meta, extras, strategy):
|
|
self.factory_meta = factory_meta
|
|
self.strategy = strategy
|
|
self.extras = extras
|
|
self.force_init_sequence = extras.pop('__sequence', None)
|
|
|
|
def build(self, parent_step=None, force_sequence=None):
|
|
"""Build a factory instance."""
|
|
# TODO: Handle "batch build" natively
|
|
pre, post = parse_declarations(
|
|
self.extras,
|
|
base_pre=self.factory_meta.pre_declarations,
|
|
base_post=self.factory_meta.post_declarations,
|
|
)
|
|
|
|
if force_sequence is not None:
|
|
sequence = force_sequence
|
|
elif self.force_init_sequence is not None:
|
|
sequence = self.force_init_sequence
|
|
else:
|
|
sequence = self.factory_meta.next_sequence()
|
|
|
|
step = BuildStep(
|
|
builder=self,
|
|
sequence=sequence,
|
|
parent_step=parent_step,
|
|
)
|
|
step.resolve(pre)
|
|
|
|
args, kwargs = self.factory_meta.prepare_arguments(step.attributes)
|
|
|
|
instance = self.factory_meta.instantiate(
|
|
step=step,
|
|
args=args,
|
|
kwargs=kwargs,
|
|
)
|
|
|
|
postgen_results = {}
|
|
for declaration_name in post.sorted():
|
|
declaration = post[declaration_name]
|
|
postgen_results[declaration_name] = declaration.declaration.evaluate_post(
|
|
instance=instance,
|
|
step=step,
|
|
overrides=declaration.context,
|
|
)
|
|
self.factory_meta.use_postgeneration_results(
|
|
instance=instance,
|
|
step=step,
|
|
results=postgen_results,
|
|
)
|
|
return instance
|
|
|
|
def recurse(self, factory_meta, extras):
|
|
"""Recurse into a sub-factory call."""
|
|
return self.__class__(factory_meta, extras, strategy=self.strategy)
|
|
|
|
def __repr__(self):
|
|
return f"<StepBuilder({self.factory_meta!r}, strategy={self.strategy!r})>"
|
|
|
|
|
|
class Resolver:
|
|
"""Resolve a set of declarations.
|
|
|
|
Attributes are set at instantiation time, values are computed lazily.
|
|
|
|
Attributes:
|
|
__initialized (bool): whether this object's __init__ as run. If set,
|
|
setting any attribute will be prevented.
|
|
__declarations (dict): maps attribute name to their declaration
|
|
__values (dict): maps attribute name to computed value
|
|
__pending (str list): names of the attributes whose value is being
|
|
computed. This allows to detect cyclic lazy attribute definition.
|
|
__step (BuildStep): the BuildStep related to this resolver.
|
|
This allows to have the value of a field depend on the value of
|
|
another field
|
|
"""
|
|
|
|
__initialized = False
|
|
|
|
def __init__(self, declarations, step, sequence):
|
|
self.__declarations = declarations
|
|
self.__step = step
|
|
|
|
self.__values = {}
|
|
self.__pending = []
|
|
|
|
self.__initialized = True
|
|
|
|
@property
|
|
def factory_parent(self):
|
|
return self.__step.parent_step.stub if self.__step.parent_step else None
|
|
|
|
def __repr__(self):
|
|
return '<Resolver for %r>' % self.__step
|
|
|
|
def __getattr__(self, name):
|
|
"""Retrieve an attribute's value.
|
|
|
|
This will compute it if needed, unless it is already on the list of
|
|
attributes being computed.
|
|
"""
|
|
if name in self.__pending:
|
|
raise errors.CyclicDefinitionError(
|
|
"Cyclic lazy attribute definition for %r; cycle found in %r." %
|
|
(name, self.__pending))
|
|
elif name in self.__values:
|
|
return self.__values[name]
|
|
elif name in self.__declarations:
|
|
declaration = self.__declarations[name]
|
|
value = declaration.declaration
|
|
if enums.get_builder_phase(value) == enums.BuilderPhase.ATTRIBUTE_RESOLUTION:
|
|
self.__pending.append(name)
|
|
try:
|
|
value = value.evaluate_pre(
|
|
instance=self,
|
|
step=self.__step,
|
|
overrides=declaration.context,
|
|
)
|
|
finally:
|
|
last = self.__pending.pop()
|
|
assert name == last
|
|
|
|
self.__values[name] = value
|
|
return value
|
|
else:
|
|
raise AttributeError(
|
|
"The parameter %r is unknown. Evaluated attributes are %r, "
|
|
"definitions are %r." % (name, self.__values, self.__declarations))
|
|
|
|
def __setattr__(self, name, value):
|
|
"""Prevent setting attributes once __init__ is done."""
|
|
if not self.__initialized:
|
|
return super().__setattr__(name, value)
|
|
else:
|
|
raise AttributeError('Setting of object attributes is not allowed')
|