393 lines
12 KiB
Python
393 lines
12 KiB
Python
# -*- test-case-name: twisted.test.test_persisted -*-
|
|
# Copyright (c) Twisted Matrix Laboratories.
|
|
# See LICENSE for details.
|
|
|
|
"""
|
|
Different styles of persisted objects.
|
|
"""
|
|
|
|
import copy
|
|
import copyreg as copy_reg
|
|
import inspect
|
|
import pickle
|
|
import types
|
|
from io import StringIO as _cStringIO
|
|
from typing import Dict
|
|
|
|
from twisted.python import log, reflect
|
|
from twisted.python.compat import _PYPY
|
|
|
|
oldModules: Dict[str, types.ModuleType] = {}
|
|
|
|
|
|
_UniversalPicklingError = pickle.PicklingError
|
|
|
|
|
|
def pickleMethod(method):
|
|
"support function for copy_reg to pickle method refs"
|
|
return (
|
|
unpickleMethod,
|
|
(method.__name__, method.__self__, method.__self__.__class__),
|
|
)
|
|
|
|
|
|
def _methodFunction(classObject, methodName):
|
|
"""
|
|
Retrieve the function object implementing a method name given the class
|
|
it's on and a method name.
|
|
|
|
@param classObject: A class to retrieve the method's function from.
|
|
@type classObject: L{type}
|
|
|
|
@param methodName: The name of the method whose function to retrieve.
|
|
@type methodName: native L{str}
|
|
|
|
@return: the function object corresponding to the given method name.
|
|
@rtype: L{types.FunctionType}
|
|
"""
|
|
methodObject = getattr(classObject, methodName)
|
|
return methodObject
|
|
|
|
|
|
def unpickleMethod(im_name, im_self, im_class):
|
|
"""
|
|
Support function for copy_reg to unpickle method refs.
|
|
|
|
@param im_name: The name of the method.
|
|
@type im_name: native L{str}
|
|
|
|
@param im_self: The instance that the method was present on.
|
|
@type im_self: L{object}
|
|
|
|
@param im_class: The class where the method was declared.
|
|
@type im_class: L{type} or L{None}
|
|
"""
|
|
if im_self is None:
|
|
return getattr(im_class, im_name)
|
|
try:
|
|
methodFunction = _methodFunction(im_class, im_name)
|
|
except AttributeError:
|
|
log.msg("Method", im_name, "not on class", im_class)
|
|
assert im_self is not None, "No recourse: no instance to guess from."
|
|
# Attempt a last-ditch fix before giving up. If classes have changed
|
|
# around since we pickled this method, we may still be able to get it
|
|
# by looking on the instance's current class.
|
|
if im_self.__class__ is im_class:
|
|
raise
|
|
return unpickleMethod(im_name, im_self, im_self.__class__)
|
|
else:
|
|
maybeClass = ()
|
|
bound = types.MethodType(methodFunction, im_self, *maybeClass)
|
|
return bound
|
|
|
|
|
|
copy_reg.pickle(types.MethodType, pickleMethod)
|
|
|
|
|
|
def _pickleFunction(f):
|
|
"""
|
|
Reduce, in the sense of L{pickle}'s C{object.__reduce__} special method, a
|
|
function object into its constituent parts.
|
|
|
|
@param f: The function to reduce.
|
|
@type f: L{types.FunctionType}
|
|
|
|
@return: a 2-tuple of a reference to L{_unpickleFunction} and a tuple of
|
|
its arguments, a 1-tuple of the function's fully qualified name.
|
|
@rtype: 2-tuple of C{callable, native string}
|
|
"""
|
|
if f.__name__ == "<lambda>":
|
|
raise _UniversalPicklingError(f"Cannot pickle lambda function: {f}")
|
|
return (_unpickleFunction, tuple([".".join([f.__module__, f.__qualname__])]))
|
|
|
|
|
|
def _unpickleFunction(fullyQualifiedName):
|
|
"""
|
|
Convert a function name into a function by importing it.
|
|
|
|
This is a synonym for L{twisted.python.reflect.namedAny}, but imported
|
|
locally to avoid circular imports, and also to provide a persistent name
|
|
that can be stored (and deprecated) independently of C{namedAny}.
|
|
|
|
@param fullyQualifiedName: The fully qualified name of a function.
|
|
@type fullyQualifiedName: native C{str}
|
|
|
|
@return: A function object imported from the given location.
|
|
@rtype: L{types.FunctionType}
|
|
"""
|
|
from twisted.python.reflect import namedAny
|
|
|
|
return namedAny(fullyQualifiedName)
|
|
|
|
|
|
copy_reg.pickle(types.FunctionType, _pickleFunction)
|
|
|
|
|
|
def pickleModule(module):
|
|
"support function for copy_reg to pickle module refs"
|
|
return unpickleModule, (module.__name__,)
|
|
|
|
|
|
def unpickleModule(name):
|
|
"support function for copy_reg to unpickle module refs"
|
|
if name in oldModules:
|
|
log.msg("Module has moved: %s" % name)
|
|
name = oldModules[name]
|
|
log.msg(name)
|
|
return __import__(name, {}, {}, "x")
|
|
|
|
|
|
copy_reg.pickle(types.ModuleType, pickleModule)
|
|
|
|
|
|
def pickleStringO(stringo):
|
|
"""
|
|
Reduce the given cStringO.
|
|
|
|
This is only called on Python 2, because the cStringIO module only exists
|
|
on Python 2.
|
|
|
|
@param stringo: The string output to pickle.
|
|
@type stringo: C{cStringIO.OutputType}
|
|
"""
|
|
"support function for copy_reg to pickle StringIO.OutputTypes"
|
|
return unpickleStringO, (stringo.getvalue(), stringo.tell())
|
|
|
|
|
|
def unpickleStringO(val, sek):
|
|
"""
|
|
Convert the output of L{pickleStringO} into an appropriate type for the
|
|
current python version. This may be called on Python 3 and will convert a
|
|
cStringIO into an L{io.StringIO}.
|
|
|
|
@param val: The content of the file.
|
|
@type val: L{bytes}
|
|
|
|
@param sek: The seek position of the file.
|
|
@type sek: L{int}
|
|
|
|
@return: a file-like object which you can write bytes to.
|
|
@rtype: C{cStringIO.OutputType} on Python 2, L{io.StringIO} on Python 3.
|
|
"""
|
|
x = _cStringIO()
|
|
x.write(val)
|
|
x.seek(sek)
|
|
return x
|
|
|
|
|
|
def pickleStringI(stringi):
|
|
"""
|
|
Reduce the given cStringI.
|
|
|
|
This is only called on Python 2, because the cStringIO module only exists
|
|
on Python 2.
|
|
|
|
@param stringi: The string input to pickle.
|
|
@type stringi: C{cStringIO.InputType}
|
|
|
|
@return: a 2-tuple of (C{unpickleStringI}, (bytes, pointer))
|
|
@rtype: 2-tuple of (function, (bytes, int))
|
|
"""
|
|
return unpickleStringI, (stringi.getvalue(), stringi.tell())
|
|
|
|
|
|
def unpickleStringI(val, sek):
|
|
"""
|
|
Convert the output of L{pickleStringI} into an appropriate type for the
|
|
current Python version.
|
|
|
|
This may be called on Python 3 and will convert a cStringIO into an
|
|
L{io.StringIO}.
|
|
|
|
@param val: The content of the file.
|
|
@type val: L{bytes}
|
|
|
|
@param sek: The seek position of the file.
|
|
@type sek: L{int}
|
|
|
|
@return: a file-like object which you can read bytes from.
|
|
@rtype: C{cStringIO.OutputType} on Python 2, L{io.StringIO} on Python 3.
|
|
"""
|
|
x = _cStringIO(val)
|
|
x.seek(sek)
|
|
return x
|
|
|
|
|
|
class Ephemeral:
|
|
"""
|
|
This type of object is never persisted; if possible, even references to it
|
|
are eliminated.
|
|
"""
|
|
|
|
def __reduce__(self):
|
|
"""
|
|
Serialize any subclass of L{Ephemeral} in a way which replaces it with
|
|
L{Ephemeral} itself.
|
|
"""
|
|
return (Ephemeral, ())
|
|
|
|
def __getstate__(self):
|
|
log.msg("WARNING: serializing ephemeral %s" % self)
|
|
if not _PYPY:
|
|
import gc
|
|
|
|
if getattr(gc, "get_referrers", None):
|
|
for r in gc.get_referrers(self):
|
|
log.msg(f" referred to by {r}")
|
|
return None
|
|
|
|
def __setstate__(self, state):
|
|
log.msg("WARNING: unserializing ephemeral %s" % self.__class__)
|
|
self.__class__ = Ephemeral
|
|
|
|
|
|
versionedsToUpgrade: Dict[int, "Versioned"] = {}
|
|
upgraded = {}
|
|
|
|
|
|
def doUpgrade():
|
|
global versionedsToUpgrade, upgraded
|
|
for versioned in list(versionedsToUpgrade.values()):
|
|
requireUpgrade(versioned)
|
|
versionedsToUpgrade = {}
|
|
upgraded = {}
|
|
|
|
|
|
def requireUpgrade(obj):
|
|
"""Require that a Versioned instance be upgraded completely first."""
|
|
objID = id(obj)
|
|
if objID in versionedsToUpgrade and objID not in upgraded:
|
|
upgraded[objID] = 1
|
|
obj.versionUpgrade()
|
|
return obj
|
|
|
|
|
|
def _aybabtu(c):
|
|
"""
|
|
Get all of the parent classes of C{c}, not including C{c} itself, which are
|
|
strict subclasses of L{Versioned}.
|
|
|
|
@param c: a class
|
|
@returns: list of classes
|
|
"""
|
|
# begin with two classes that should *not* be included in the
|
|
# final result
|
|
l = [c, Versioned]
|
|
for b in inspect.getmro(c):
|
|
if b not in l and issubclass(b, Versioned):
|
|
l.append(b)
|
|
# return all except the unwanted classes
|
|
return l[2:]
|
|
|
|
|
|
class Versioned:
|
|
"""
|
|
This type of object is persisted with versioning information.
|
|
|
|
I have a single class attribute, the int persistenceVersion. After I am
|
|
unserialized (and styles.doUpgrade() is called), self.upgradeToVersionX()
|
|
will be called for each version upgrade I must undergo.
|
|
|
|
For example, if I serialize an instance of a Foo(Versioned) at version 4
|
|
and then unserialize it when the code is at version 9, the calls::
|
|
|
|
self.upgradeToVersion5()
|
|
self.upgradeToVersion6()
|
|
self.upgradeToVersion7()
|
|
self.upgradeToVersion8()
|
|
self.upgradeToVersion9()
|
|
|
|
will be made. If any of these methods are undefined, a warning message
|
|
will be printed.
|
|
"""
|
|
|
|
persistenceVersion = 0
|
|
persistenceForgets = ()
|
|
|
|
def __setstate__(self, state):
|
|
versionedsToUpgrade[id(self)] = self
|
|
self.__dict__ = state
|
|
|
|
def __getstate__(self, dict=None):
|
|
"""Get state, adding a version number to it on its way out."""
|
|
dct = copy.copy(dict or self.__dict__)
|
|
bases = _aybabtu(self.__class__)
|
|
bases.reverse()
|
|
bases.append(self.__class__) # don't forget me!!
|
|
for base in bases:
|
|
if "persistenceForgets" in base.__dict__:
|
|
for slot in base.persistenceForgets:
|
|
if slot in dct:
|
|
del dct[slot]
|
|
if "persistenceVersion" in base.__dict__:
|
|
dct[
|
|
f"{reflect.qual(base)}.persistenceVersion"
|
|
] = base.persistenceVersion
|
|
return dct
|
|
|
|
def versionUpgrade(self):
|
|
"""(internal) Do a version upgrade."""
|
|
bases = _aybabtu(self.__class__)
|
|
# put the bases in order so superclasses' persistenceVersion methods
|
|
# will be called first.
|
|
bases.reverse()
|
|
bases.append(self.__class__) # don't forget me!!
|
|
# first let's look for old-skool versioned's
|
|
if "persistenceVersion" in self.__dict__:
|
|
|
|
# Hacky heuristic: if more than one class subclasses Versioned,
|
|
# we'll assume that the higher version number wins for the older
|
|
# class, so we'll consider the attribute the version of the older
|
|
# class. There are obviously possibly times when this will
|
|
# eventually be an incorrect assumption, but hopefully old-school
|
|
# persistenceVersion stuff won't make it that far into multiple
|
|
# classes inheriting from Versioned.
|
|
|
|
pver = self.__dict__["persistenceVersion"]
|
|
del self.__dict__["persistenceVersion"]
|
|
highestVersion = 0
|
|
highestBase = None
|
|
for base in bases:
|
|
if "persistenceVersion" not in base.__dict__:
|
|
continue
|
|
if base.persistenceVersion > highestVersion:
|
|
highestBase = base
|
|
highestVersion = base.persistenceVersion
|
|
if highestBase:
|
|
self.__dict__[
|
|
"%s.persistenceVersion" % reflect.qual(highestBase)
|
|
] = pver
|
|
for base in bases:
|
|
# ugly hack, but it's what the user expects, really
|
|
if (
|
|
Versioned not in base.__bases__
|
|
and "persistenceVersion" not in base.__dict__
|
|
):
|
|
continue
|
|
currentVers = base.persistenceVersion
|
|
pverName = "%s.persistenceVersion" % reflect.qual(base)
|
|
persistVers = self.__dict__.get(pverName) or 0
|
|
if persistVers:
|
|
del self.__dict__[pverName]
|
|
assert persistVers <= currentVers, "Sorry, can't go backwards in time."
|
|
while persistVers < currentVers:
|
|
persistVers = persistVers + 1
|
|
method = base.__dict__.get("upgradeToVersion%s" % persistVers, None)
|
|
if method:
|
|
log.msg(
|
|
"Upgrading %s (of %s @ %s) to version %s"
|
|
% (
|
|
reflect.qual(base),
|
|
reflect.qual(self.__class__),
|
|
id(self),
|
|
persistVers,
|
|
)
|
|
)
|
|
method(self)
|
|
else:
|
|
log.msg(
|
|
"Warning: cannot upgrade {} to version {}".format(
|
|
base, persistVers
|
|
)
|
|
)
|