fix: handle deepcopy of openapi objects (#9735) (#9735)

- Add __deepcopy__ and __copy__ to OpenApiModel
- pass discriminator inside deepcopy if exists
- add test cases for deepcopy of models
This commit is contained in:
Mike Marchetti
2021-09-09 12:16:59 -04:00
committed by GitHub
parent 477e2365c7
commit 9464999d9c
8 changed files with 202 additions and 13 deletions

View File

@@ -13,3 +13,22 @@
def __getattr__(self, attr):
"""get the value of an attribute using dot notation: `instance.attr`"""
return self.{{#attrNoneIfUnset}}get{{/attrNoneIfUnset}}{{^attrNoneIfUnset}}__getitem__{{/attrNoneIfUnset}}(attr)
def __copy__(self):
cls = self.__class__
if self.get("_spec_property_naming", False):
return cls._new_from_openapi_data(**self.__dict__)
else:
return new_cls.__new__(cls, **self.__dict__)
def __deepcopy__(self, memo):
cls = self.__class__
if self.get("_spec_property_naming", False):
new_inst = cls._new_from_openapi_data()
else:
new_inst = cls.__new__(cls)
for k, v in self.__dict__.items():
setattr(new_inst, k, deepcopy(v, memo))
return new_inst

View File

@@ -1,6 +1,7 @@
{{>partial_header}}
from datetime import date, datetime # noqa: F401
from copy import deepcopy
import inspect
import io
import os
@@ -223,8 +224,13 @@ class OpenApiModel(object):
self_inst = super(OpenApiModel, cls).__new__(cls)
self_inst.__init__(*args, **kwargs)
if kwargs.get("_spec_property_naming", False):
# when true, implies new is from deserialization
new_inst = new_cls._new_from_openapi_data(*args, **kwargs)
else:
new_inst = new_cls.__new__(new_cls, *args, **kwargs)
new_inst.__init__(*args, **kwargs)
return new_inst

View File

@@ -9,6 +9,7 @@
from datetime import date, datetime # noqa: F401
from copy import deepcopy
import inspect
import io
import os
@@ -186,6 +187,26 @@ class OpenApiModel(object):
"""get the value of an attribute using dot notation: `instance.attr`"""
return self.__getitem__(attr)
def __copy__(self):
cls = self.__class__
if self.get("_spec_property_naming", False):
return cls._new_from_openapi_data(**self.__dict__)
else:
return new_cls.__new__(cls, **self.__dict__)
def __deepcopy__(self, memo):
cls = self.__class__
if self.get("_spec_property_naming", False):
new_inst = cls._new_from_openapi_data()
else:
new_inst = cls.__new__(cls)
for k, v in self.__dict__.items():
setattr(new_inst, k, deepcopy(v, memo))
return new_inst
def __new__(cls, *args, **kwargs):
# this function uses the discriminator to
# pick a new schema/class to instantiate because a discriminator
@@ -295,8 +316,13 @@ class OpenApiModel(object):
self_inst = super(OpenApiModel, cls).__new__(cls)
self_inst.__init__(*args, **kwargs)
if kwargs.get("_spec_property_naming", False):
# when true, implies new is from deserialization
new_inst = new_cls._new_from_openapi_data(*args, **kwargs)
else:
new_inst = new_cls.__new__(new_cls, *args, **kwargs)
new_inst.__init__(*args, **kwargs)
return new_inst

View File

@@ -9,6 +9,7 @@
from datetime import date, datetime # noqa: F401
from copy import deepcopy
import inspect
import io
import os
@@ -186,6 +187,26 @@ class OpenApiModel(object):
"""get the value of an attribute using dot notation: `instance.attr`"""
return self.__getitem__(attr)
def __copy__(self):
cls = self.__class__
if self.get("_spec_property_naming", False):
return cls._new_from_openapi_data(**self.__dict__)
else:
return new_cls.__new__(cls, **self.__dict__)
def __deepcopy__(self, memo):
cls = self.__class__
if self.get("_spec_property_naming", False):
new_inst = cls._new_from_openapi_data()
else:
new_inst = cls.__new__(cls)
for k, v in self.__dict__.items():
setattr(new_inst, k, deepcopy(v, memo))
return new_inst
def __new__(cls, *args, **kwargs):
# this function uses the discriminator to
# pick a new schema/class to instantiate because a discriminator
@@ -295,8 +316,13 @@ class OpenApiModel(object):
self_inst = super(OpenApiModel, cls).__new__(cls)
self_inst.__init__(*args, **kwargs)
if kwargs.get("_spec_property_naming", False):
# when true, implies new is from deserialization
new_inst = new_cls._new_from_openapi_data(*args, **kwargs)
else:
new_inst = new_cls.__new__(new_cls, *args, **kwargs)
new_inst.__init__(*args, **kwargs)
return new_inst

View File

@@ -9,6 +9,7 @@
from datetime import date, datetime # noqa: F401
from copy import deepcopy
import inspect
import io
import os
@@ -186,6 +187,26 @@ class OpenApiModel(object):
"""get the value of an attribute using dot notation: `instance.attr`"""
return self.__getitem__(attr)
def __copy__(self):
cls = self.__class__
if self.get("_spec_property_naming", False):
return cls._new_from_openapi_data(**self.__dict__)
else:
return new_cls.__new__(cls, **self.__dict__)
def __deepcopy__(self, memo):
cls = self.__class__
if self.get("_spec_property_naming", False):
new_inst = cls._new_from_openapi_data()
else:
new_inst = cls.__new__(cls)
for k, v in self.__dict__.items():
setattr(new_inst, k, deepcopy(v, memo))
return new_inst
def __new__(cls, *args, **kwargs):
# this function uses the discriminator to
# pick a new schema/class to instantiate because a discriminator
@@ -295,8 +316,13 @@ class OpenApiModel(object):
self_inst = super(OpenApiModel, cls).__new__(cls)
self_inst.__init__(*args, **kwargs)
if kwargs.get("_spec_property_naming", False):
# when true, implies new is from deserialization
new_inst = new_cls._new_from_openapi_data(*args, **kwargs)
else:
new_inst = new_cls.__new__(new_cls, *args, **kwargs)
new_inst.__init__(*args, **kwargs)
return new_inst

View File

@@ -9,6 +9,7 @@
from datetime import date, datetime # noqa: F401
from copy import deepcopy
import inspect
import io
import os
@@ -186,6 +187,26 @@ class OpenApiModel(object):
"""get the value of an attribute using dot notation: `instance.attr`"""
return self.__getitem__(attr)
def __copy__(self):
cls = self.__class__
if self.get("_spec_property_naming", False):
return cls._new_from_openapi_data(**self.__dict__)
else:
return new_cls.__new__(cls, **self.__dict__)
def __deepcopy__(self, memo):
cls = self.__class__
if self.get("_spec_property_naming", False):
new_inst = cls._new_from_openapi_data()
else:
new_inst = cls.__new__(cls)
for k, v in self.__dict__.items():
setattr(new_inst, k, deepcopy(v, memo))
return new_inst
def __new__(cls, *args, **kwargs):
# this function uses the discriminator to
# pick a new schema/class to instantiate because a discriminator
@@ -295,8 +316,13 @@ class OpenApiModel(object):
self_inst = super(OpenApiModel, cls).__new__(cls)
self_inst.__init__(*args, **kwargs)
if kwargs.get("_spec_property_naming", False):
# when true, implies new is from deserialization
new_inst = new_cls._new_from_openapi_data(*args, **kwargs)
else:
new_inst = new_cls.__new__(new_cls, *args, **kwargs)
new_inst.__init__(*args, **kwargs)
return new_inst

View File

@@ -9,6 +9,7 @@
from datetime import date, datetime # noqa: F401
from copy import deepcopy
import inspect
import io
import os
@@ -186,6 +187,26 @@ class OpenApiModel(object):
"""get the value of an attribute using dot notation: `instance.attr`"""
return self.__getitem__(attr)
def __copy__(self):
cls = self.__class__
if self.get("_spec_property_naming", False):
return cls._new_from_openapi_data(**self.__dict__)
else:
return new_cls.__new__(cls, **self.__dict__)
def __deepcopy__(self, memo):
cls = self.__class__
if self.get("_spec_property_naming", False):
new_inst = cls._new_from_openapi_data()
else:
new_inst = cls.__new__(cls)
for k, v in self.__dict__.items():
setattr(new_inst, k, deepcopy(v, memo))
return new_inst
def __new__(cls, *args, **kwargs):
# this function uses the discriminator to
# pick a new schema/class to instantiate because a discriminator
@@ -295,8 +316,13 @@ class OpenApiModel(object):
self_inst = super(OpenApiModel, cls).__new__(cls)
self_inst.__init__(*args, **kwargs)
if kwargs.get("_spec_property_naming", False):
# when true, implies new is from deserialization
new_inst = new_cls._new_from_openapi_data(*args, **kwargs)
else:
new_inst = new_cls.__new__(new_cls, *args, **kwargs)
new_inst.__init__(*args, **kwargs)
return new_inst

View File

@@ -0,0 +1,34 @@
from copy import deepcopy
import unittest
from petstore_api.model.mammal import Mammal
from petstore_api.model.triangle import Triangle
class TestCopy(unittest.TestCase):
"""TestCopy unit test stubs"""
def setUp(self):
pass
def tearDown(self):
pass
def testDeepCopyOneOf(self):
"""test deepcopy"""
obj = deepcopy(Mammal(class_name="whale"))
assert id(deepcopy(obj)) != id(obj)
assert deepcopy(obj) == obj
def testDeepCopyAllOf(self):
"""test deepcopy"""
obj = Triangle(shape_type="Triangle", triangle_type="EquilateralTriangle", foo="blah")
assert id(deepcopy(obj)) != id(obj)
assert deepcopy(obj) == obj
obj = Triangle._new_from_openapi_data(shape_type="Triangle", triangle_type="EquilateralTriangle", foo="blah")
assert id(deepcopy(obj)) != id(obj)
assert deepcopy(obj) == obj
if __name__ == '__main__':
unittest.main()