def test_if_encode_raises_exception_with_invalid_data_and_strict_schema():
class StrictSchema(marshmallow.Schema):
uuid_field = fields.UUID(required=True)
class Meta:
strict = True
class Event(structures.Model):
class Meta:
schema = StrictSchema
type_name = 'Event'
data = Event(uuid_field='not an uuid')
with pytest.raises(exceptions.EncodingError) as excinfo:
encoding.encode(data)
assert str(excinfo.value) == (
"({'uuid_field': ['Not a valid UUID.']}, '')")
python类Schema()的实例源码
def test_if_raises_exception_with_invalid_data_and_strict_schema(self):
class StrictSchema(marshmallow.Schema):
uuid_field = fields.UUID(required=True)
class Meta:
strict = True
class Event(structures.Model):
class Meta:
schema = StrictSchema
type_name = 'Event'
data = '{"uuid_field": "not an uuid"}'
with pytest.raises(exceptions.DecodingError) as excinfo:
encoding.decode(type=Event, encoded_data=data)
assert str(excinfo.value) == (
"({'uuid_field': ['Not a valid UUID.']}, '')")
def test_generate_unmarshall_method_bodies_with_load_from():
class OneFieldSchema(Schema):
foo = fields.Integer(load_from='bar', allow_none=True)
context = JitContext(is_serializing=False, use_inliners=False)
result = str(generate_transform_method_body(OneFieldSchema(),
DictSerializer(context),
context))
expected = '''\
def DictSerializer(obj):
res = {}
__res_get = res.get
if "foo" in obj:
res["foo"] = _field_foo__deserialize(obj["foo"], "bar", obj)
if "foo" not in res:
if "bar" in obj:
res["foo"] = _field_foo__deserialize(obj["bar"], "bar", obj)
return res'''
assert expected == result
def test_generate_unmarshall_method_bodies_required():
class OneFieldSchema(Schema):
foo = fields.Integer(required=True)
context = JitContext(is_serializing=False, use_inliners=False)
result = str(generate_transform_method_body(OneFieldSchema(),
DictSerializer(context),
context))
expected = '''\
def DictSerializer(obj):
res = {}
__res_get = res.get
res["foo"] = _field_foo__deserialize(obj["foo"], "foo", obj)
if "foo" not in res:
raise ValueError()
if __res_get("foo", res) is None:
raise ValueError()
return res'''
assert expected == result
def test_secure_field(app):
class SecureSchema(Schema):
token = SecureField()
schema = SecureSchema()
# case 1: plaintext
data = {'token': 'abc'}
result = schema.load(data)
assert result.data['token'] == 'abc'
# case 2: valid secure token
data = {'token': {'secure': SecureToken.encrypt('def')}}
result = schema.load(data)
assert result.data['token'] == 'def'
# case 3: invalid secure token
data = {'token': {'secure': 'gAAAAABYmoldCp-EQGUKCppiqmVOu2jLrAKUz6E2e4aOMMD8Vu0VKswmJexHX6vUEoxVYKFUlSonPb91QKXZBEZdBezHzJMCHg=='}} # NOQA
result = schema.load(data)
assert result.data['token'] == ''
def __call__(self, info):
"""
If a schema is present, replace value with output from schema.dump(..).
"""
original_render = super().__call__(info)
def schema_render(value, system):
request = system.get('request')
if (request is not None and isinstance(getattr(request, 'render_schema', None), Schema)):
try:
value, errors = request.render_schema.dump(value)
except Exception:
errors = True
if errors:
raise HTTPInternalServerError(body="Serialization failed.")
return original_render(value, system)
return schema_render
def jsonify(self, obj, many=sentinel, *args, **kwargs):
"""Return a JSON response containing the serialized data.
:param obj: Object to serialize.
:param bool many: Whether `obj` should be serialized as an instance
or as a collection. If unset, defaults to the value of the
`many` attribute on this Schema.
:param kwargs: Additional keyword arguments passed to `flask.jsonify`.
.. versionchanged:: 0.6.0
Takes the same arguments as `marshmallow.Schema.dump`. Additional
keyword arguments are passed to `flask.jsonify`.
.. versionchanged:: 0.6.3
The `many` argument for this method defaults to the value of
the `many` attribute on the Schema. Previously, the `many`
argument of this method defaulted to False, regardless of the
value of `Schema.many`.
"""
if many is sentinel:
many = self.many
data = self.dump(obj, many=many).data
return flask.jsonify(data, *args, **kwargs)
def _register_deduced_schemas(Base):
def setup_schema_fn():
# Generate missing schemas
for class_ in Base._decl_class_registry.values():
if hasattr(class_, '__tablename__') and not hasattr(class_, '__marshmallow__'):
if class_.__name__.endswith('Schema'):
raise ModelConversionError(
"For safety, setup_schema can not be used when a"
"Model class ends with 'Schema'"
)
class Meta(BaseSchema.Meta):
model = class_
schema_class_name = '%sSchema' % class_.__name__
schema_class = type(
schema_class_name,
(BaseSchema,),
{'Meta': Meta}
)
setattr(class_, '__marshmallow__', schema_class)
return setup_schema_fn
def test_custom_base_schema(self):
class MyBaseSchema(marshmallow.Schema):
name = marshmallow.fields.Int()
age = marshmallow.fields.Int()
ma_schema_cls = self.User.schema.as_marshmallow_schema(base_schema_cls=MyBaseSchema)
assert issubclass(ma_schema_cls, MyBaseSchema)
schema = ma_schema_cls()
ret = schema.dump({'name': "42", 'age': 42, 'dummy': False})
assert not ret.errors
assert ret.data == {'name': "42", 'age': 42}
ret = schema.load({'name': "42", 'age': 42, 'dummy': False})
assert ret.errors == {'_schema': ['Unknown field name dummy.']}
ret = schema.load({'name': "42", 'age': 42})
assert not ret.errors
assert ret.data == {'name': "42", 'age': 42}
def generate_json_schema(cls, schema, context=DEFAULT_DICT):
"""Generate a JSON Schema from a Marshmallow schema.
Args:
schema (marshmallow.Schema|str): The Marshmallow schema, or the
Python path to one, to create the JSON schema for.
Keyword Args:
file_pointer (file, optional): The path or pointer to the file
to write this schema to. If not provided, the schema will be
dumped to ``sys.stdout``.
Returns:
dict: The JSON schema in dictionary form.
"""
schema = cls._get_schema(schema)
# Generate the JSON Schema
return cls(context=context).dump(schema).data
def _get_schema(cls, schema):
"""Method that will fetch a Marshmallow schema flexibly.
Args:
schema (marshmallow.Schema|str): Either the schema class, an
instance of a schema, or a Python path to a schema.
Returns:
marshmallow.Schema: The desired schema.
Raises:
TypeError: This is raised if the provided object isn't
a Marshmallow schema.
"""
if isinstance(schema, string_types):
schema = cls._get_object_from_python_path(schema)
if isclass(schema):
schema = schema()
if not isinstance(schema, Schema):
raise TypeError("The schema must be a path to a Marshmallow "
"schema or a Marshmallow schema.")
return schema
def __init__(self, cls):
"""Initialize the resource."""
self._collection = None
super(MongoOptions, self).__init__(cls)
self.name = self.meta and getattr(self.meta, 'name', None)
if not self.collection:
return
self.name = self.name or str(self.collection.name)
if not cls.Schema:
meta = type('Meta', (object,), self.schema_meta)
cls.Schema = type(
self.name.title() + 'Schema', (MongoSchema,), dict({'Meta': meta}, **self.schema))
def get_schema(self, resource=None, **kwargs):
"""Create the resource schema."""
return self.Schema(instance=resource) # noqa
def test_if_encode_raises_exception_with_invalid_data_and_not_strict_schema():
class NotStrictSchema(marshmallow.Schema):
uuid_field = fields.UUID(required=True)
class Event(structures.Model):
class Meta:
schema = NotStrictSchema
type_name = 'Event'
data = Event(uuid_field='not an uuid')
with pytest.raises(exceptions.EncodingError) as excinfo:
encoding.encode(data)
assert str(excinfo.value) == (
"({'uuid_field': ['Not a valid UUID.']}, '')")
def schema_class(self):
class Schema(marshmallow.Schema):
uuid_field = fields.UUID(required=True)
string_field = fields.String(required=False)
return Schema
def test_if_raises_exception_with_invalid_data_and_not_strict_schema(self):
class NotStrictSchema(marshmallow.Schema):
uuid_field = fields.UUID(required=True)
class Event(structures.Model):
class Meta:
schema = NotStrictSchema
type_name = 'Event'
data = '{"uuid_field": "not an uuid"}'
with pytest.raises(exceptions.DecodingError) as excinfo:
encoding.decode(type=Event, encoded_data=data)
assert str(excinfo.value) == (
"({'uuid_field': ['Not a valid UUID.']}, '')")
def configure_retrieve(self, ns, definition):
"""
Register a retrieve endpoint.
The definition's func should be a retrieve function, which must:
- accept kwargs for path data
- return an item or falsey
:param ns: the namespace
:param definition: the endpoint definition
"""
request_schema = definition.request_schema or Schema()
@self.add_route(ns.instance_path, Operation.Retrieve, ns)
@qs(request_schema)
@response(definition.response_schema)
@wraps(definition.func)
def retrieve(**path_data):
headers = dict()
request_data = load_query_string_data(request_schema)
response_data = require_response_data(definition.func(**merge_data(path_data, request_data)))
definition.header_func(headers)
response_format = self.negotiate_response_content(definition.response_formats)
return dump_response_data(
definition.response_schema,
response_data,
headers=headers,
response_format=response_format,
)
retrieve.__doc__ = "Retrieve a {} by id".format(ns.subject_name)
def configure_retrievefor(self, ns, definition):
"""
Register a relation endpoint.
The definition's func should be a retrieve function, which must:
- accept kwargs for path data and optional request data
- return an item
The definition's request_schema will be used to process query string arguments, if any.
:param ns: the namespace
:param definition: the endpoint definition
"""
request_schema = definition.request_schema or Schema()
@self.add_route(ns.relation_path, Operation.RetrieveFor, ns)
@qs(request_schema)
@response(definition.response_schema)
@wraps(definition.func)
def retrieve(**path_data):
headers = dict()
request_data = load_query_string_data(request_schema)
response_data = require_response_data(definition.func(**merge_data(path_data, request_data)))
definition.header_func(headers)
response_format = self.negotiate_response_content(definition.response_formats)
return dump_response_data(
definition.response_schema,
response_data,
headers=headers,
response_format=response_format,
)
retrieve.__doc__ = "Retrieve {} relative to a {}".format(pluralize(ns.object_name), ns.subject_name)
def create_upload_func(self, ns, definition, path, operation):
request_schema = definition.request_schema or Schema()
response_schema = definition.response_schema or Schema()
@self.add_route(path, operation, ns)
@wraps(definition.func)
def upload(**path_data):
request_data = load_query_string_data(request_schema)
if not request.files:
raise BadRequest("No files were uploaded")
uploads = [
temporary_upload(name, fileobj)
for name, fileobj
in request.files.items()
if not self.exclude_func(name, fileobj)
]
with nested(*uploads) as files:
response_data = definition.func(files, **merge_data(path_data, request_data))
if response_data is None:
return "", 204
return dump_response_data(response_schema, response_data, operation.value.default_code)
if definition.request_schema:
upload = qs(definition.request_schema)(upload)
if definition.response_schema:
upload = response(definition.response_schema)(upload)
return upload
def make_paginated_list_schema_class(cls, ns, item_schema):
"""
Generate a schema class that represents a paginted list of items.
"""
class PaginatedListSchema(Schema):
__alias__ = "{}_list".format(ns.subject_name)
items = fields.List(fields.Nested(item_schema), required=True)
_links = fields.Raw()
return PaginatedListSchema
def make_paginated_list_schema_class(cls, ns, item_schema):
class PaginatedListSchema(Schema):
__alias__ = "{}_list".format(ns.subject_name)
offset = fields.Integer(required=True)
limit = fields.Integer(required=True)
count = fields.Integer(required=True)
items = fields.List(fields.Nested(item_schema), required=True)
_links = fields.Raw()
@property
def csv_column_order(self):
return getattr(item_schema, "csv_column_order", None)
return PaginatedListSchema
def test_offset_limit_page_to_paginated_list():
graph = create_object_graph(name="example", testing=True)
ns = Namespace("foo")
@graph.flask.route("/", methods=["GET"], endpoint="foo.search.v1")
def search():
pass
with graph.flask.test_request_context():
page = OffsetLimitPage(
offset=10,
limit=10,
foo="bar",
)
result = [], 0
paginated_list, headers = page.to_paginated_list(result, _ns=ns, _operation=Operation.Search)
schema_cls = page.make_paginated_list_schema_class(ns, Schema())
data = schema_cls().dump(paginated_list).data
assert_that(
data,
is_(equal_to(dict(
offset=10,
limit=10,
count=0,
items=[],
_links=dict(
self=dict(
href="http://localhost/?offset=10&limit=10&foo=bar",
),
prev=dict(
href="http://localhost/?offset=0&limit=10&foo=bar",
),
),
))))
def _to_jsonschema(type_):
if isinstance(type_, marshmallow.Schema):
return _jsonschema.dump_schema(type_)
elif type_ in six.integer_types:
return {'type': 'number', 'format': 'integer'}
elif type_ == float:
return {'type': 'number', 'format': 'float'}
elif type_ == decimal.Decimal:
return {'type': 'string', 'format': 'decimal'}
elif type_ == uuid.UUID:
return {'type': 'string', 'format': 'uuid'}
elif type_ == datetime.datetime:
return {'type': 'string', 'format': 'date-time'}
elif type_ == datetime.date:
return {'type': 'string', 'format': 'date'}
elif type_ == datetime.time:
return {'type': 'string', 'format': 'time'}
elif type_ == dict:
return {'type': 'object'}
elif type_ == six.text_type or type_ == six.binary_type:
return {'type': 'string'}
elif type_ is None:
return {'type': 'null'}
elif type_ == list:
return {'type': 'array'}
elif type_ == bool:
return {'type': 'boolean'}
elif issubclass(type_, typing.MutableSequence[typing.T]):
items_type = type_.__parameters__[0]
if issubclass(items_type, marshmallow.Schema):
items_type = items_type()
return {
'type': 'array',
'items': _to_jsonschema(items_type),
}
else:
raise ValueError('unsupported return type: %s' % type_)
def load_data(self, data):
"""
Deserialize data to an object defined by its Schema and raises a
ValidationError if there are eny errors.
:param data:
:return:
"""
data, errors = self.load(data)
if errors:
raise ApiValidationError(errors)
return data
def simple_schema():
class InstanceSchema(Schema):
key = fields.String()
value = fields.Integer(default=0)
return InstanceSchema()
def nested_circular_ref_schema():
class NestedStringSchema(Schema):
key = fields.String()
me = fields.Nested('NestedStringSchema')
return NestedStringSchema()
def nested_schema():
class GrandChildSchema(Schema):
bar = fields.String()
raz = fields.String()
class SubSchema(Schema):
name = fields.String()
value = fields.Nested(GrandChildSchema)
class NestedSchema(Schema):
key = fields.String()
value = fields.Nested(SubSchema, only=('name', 'value.bar'))
values = fields.Nested(SubSchema, exclude=('value', ), many=True)
return NestedSchema()
def optimized_schema():
class OptimizedSchema(Schema):
class Meta:
jit_options = {
'no_callable_fields': True,
'expected_marshal_type': 'object'
}
key = fields.String()
value = fields.Integer(default=0, as_string=True)
return OptimizedSchema()
def schema():
class BasicSchema(Schema):
class Meta:
ordered = True
foo = fields.Integer(attribute='@#')
bar = fields.String()
raz = fields.Method('raz_')
meh = fields.String(load_only=True)
blargh = fields.Boolean()
def raz_(self, obj):
return 'Hello!'
return BasicSchema()
def test_generate_unmarshall_method_bodies():
class OneFieldSchema(Schema):
foo = fields.Integer()
context = JitContext(is_serializing=False, use_inliners=False)
result = generate_method_bodies(OneFieldSchema(), context)
expected = '''\
def InstanceSerializer(obj):
res = {}
__res_get = res.get
res["foo"] = _field_foo__deserialize(obj.foo, "foo", obj)
if __res_get("foo", res) is None:
raise ValueError()
return res
def DictSerializer(obj):
res = {}
__res_get = res.get
if "foo" in obj:
res["foo"] = _field_foo__deserialize(obj["foo"], "foo", obj)
if __res_get("foo", res) is None:
raise ValueError()
return res
def HybridSerializer(obj):
res = {}
__res_get = res.get
try:
value = obj["foo"]
except (KeyError, AttributeError, IndexError, TypeError):
value = obj.foo
res["foo"] = _field_foo__deserialize(value, "foo", obj)
if __res_get("foo", res) is None:
raise ValueError()
return res'''
assert expected == result