[APIFlask] Add flag to allow disabling of validation on input

认领

1 Like

虽然不是我认领的任务,但是作为flask的重度使用者,并且想给apiflask做一点贡献,自己看了以下原 issue , 兵把 apiflaskinput 逻辑梳理了以下:

flowchart
A[APIScaffold.input] --> B[FlaskParser.use_args]
B --> C{is_async_function}
C --> |True| D[webargs.core.Parser.async_parse]
C --> |False| E[webargs.core.Parser.parse]
D --> F[webargs.core.Parser._process_location_data]
E --> F[webargs.core.Parser._process_location_data]
F --> G[webargs.core.Parser._validate_arguments]

然后下面是相关类的继承关系:

classDiagram
    `webargs.core.Parser` <|-- `webargs.core.FlaskParser`
    `webargs.core.FlaskParser` <|-- `apiflask.scaffold.FlaskParser`

而上一个流程图中,进行参数验证的,也就是最后一个关键方法: _validate_argumentswebargs.core.Parser 中,如果我们需要修改这个的话,避免不了一些重复的重写。

我想不到其他的方式,不知道大家有没有其他不一样的思路和实现方式,可以相互讨论一下吗?

3 Likes

看起来可以先贡献给上游(webargs),添加关闭验证的功能。或许先在 webargs 创建 issue 讨论一下。

@ywang 这个任务开始了没?如果暂时没有时间的话或许可以交给 @jennier0107

@greyli,目前为止我仅针对validation组件做了初步的调试以探索改造方案,尚未开始编码。后续1个月也没时间继续完成(上班11-11-6,周末经常有其他事 :cold_sweat:),若@jennier0107 有空自愿移交给他完成,感谢为开源做贡献~

1 Like

没事,保重身体。万恶的资本主义 :stuck_out_tongue:

之前我漏掉了一步:

//webargs.core.Parser
    def parse(
        self,
        argmap: ArgMap,
        req: Request | None = None,
        *,
        location: str | None = None,
        unknown: str | None = _UNKNOWN_DEFAULT_PARAM,
        validate: ValidateArg = None,
        error_status_code: int | None = None,
        error_headers: typing.Mapping[str, str] | None = None,
    ) -> typing.Any:
       
        data, req, location, validators, schema = self._prepare_for_parse(
            argmap, req, location, unknown, validate
        )
        try:
            location_data = self._load_location_data(
                schema=schema, req=req, location=location
            )
            data = self._process_location_data(
                location_data, schema, req, location, unknown, validators
            )
        except ma.exceptions.ValidationError as error:
            self._on_validation_error(
                error,
                req,
                schema,
                location,
                error_status_code=error_status_code,
                error_headers=error_headers,
            )
            raise ValueError(
                "_on_validation_error hook did not raise an exception"
            ) from error
        return data

async_parse 是 parse 方法的一个变体,专门处理async 函数
这一行 data = self._process_location_data( location_data, schema, req, location, unknown, validators ) 是数据解析和验证的核心

//webargs.core.Parser
    def _process_location_data(
        self,
        location_data: typing.Any,
        schema: ma.Schema,
        req: Request,
        location: str,
        unknown: str | None,
        validators: CallableList,
    ) -> typing.Any:
        unknown = (
            unknown
            if unknown != _UNKNOWN_DEFAULT_PARAM
            else (
                self.unknown
                if self.unknown != _UNKNOWN_DEFAULT_PARAM
                else self.DEFAULT_UNKNOWN_BY_LOCATION.get(location)
            )
        )
        load_kwargs: dict[str, typing.Any] = {"unknown": unknown} if unknown else {}
        preprocessed_data = self.pre_load(
            location_data, schema=schema, req=req, location=location
        )
        data = schema.load(preprocessed_data, **load_kwargs)
        self._validate_arguments(data, validators)
        return data

数据的解析和验证主要进行在下面两行代码:

data = schema.load(preprocessed_data, **load_kwargs)
self._validate_arguments(data, validators)

schema.loadmarshmallow 提供的方法, 而 self._validate_arguments(data, validators) 是对 marshmallow 中再进一步对数据的验证,比如

class UserSchema(Schema):
    name = fields.Str(validate=validate.Length(min=1))
    permission = fields.Str(validate=validate.OneOf(["read", "write", "admin"]))
    age = fields.Int(validate=validate.Range(min=18, max=40))

检查字段数据的大小,还有范围等等。
而 marshmallow 并没有提供相关的参数在进行 .loadvalidate 时,只拿到数据,而跳过验证。不过提供了一个错误类型:

class ValidationError(MarshmallowError):
    """Raised when validation fails on a field or schema.

    Validators and custom fields should raise this exception.

    :param message: An error message, list of error messages, or dict of
        error messages. If a dict, the keys are subitems and the values are error messages.
    :param field_name: Field name to store the error on.
        If `None`, the error is stored as schema-level error.
    :param data: Raw input data.
    :param valid_data: Valid (de)serialized data.
    """

    def __init__(
        self,
        message: str | list | dict,
        field_name: str = SCHEMA,
        data: typing.Mapping[str, typing.Any]
        | typing.Iterable[typing.Mapping[str, typing.Any]]
        | None = None,
        valid_data: list[dict[str, typing.Any]] | dict[str, typing.Any] | None = None,
        **kwargs,
    ):

data 是原数据(在docstring表明是 Raw input data
也就是说只能通过该错误进一步拿到原数据了
所以我最终考虑的是,在apiflask.scaffold.FlaskParser(BaseFlaskParser) 中重写

以下是个简单的dmeo :

app.input() 处理逻辑

class APIScaffold:
    ...
     def input(
        self,
        schema: SchemaType,
        location: str = 'json',
        arg_name: str | None = None,
        schema_name: str | None = None,
        example: t.Any | None = None,
        examples: dict[str, t.Any] | None = None,
        skip_validation: bool = False,
        **kwargs: t.Any,
    ) -> t.Callable[[DecoratedType], DecoratedType]:

        ...
        return use_args(
                schema, location=location, arg_name=arg_name or f'{location}_data',
                skip_validation=skip_validation, **kwargs
            )(f)

`FlaskParser(BaseFlaskParser)` 的处理逻辑:

```python
class FlaskParser(BaseFlaskParser):
    """Overwrite the default `webargs.FlaskParser.handle_error`.

    Update the default status code and the error description from related
    configuration variables.
    """

    USE_ARGS_POSITIONAL = False
    SKIP_VALIDATION = False

   ...

    def use_args(
        self,
        argmap: ArgMap,
        req: Request | None = None,
        *,
        location: str | None = None,
        unknown: str | None = _UNKNOWN_DEFAULT_PARAM,
        as_kwargs: bool = False,
        arg_name: str | None = None,
        validate: ValidateArg = None,
        error_status_code: int | None = None,
        error_headers: typing.Mapping[str, str] | None = None,
        skip_validation: bool = False,
        **kwargs
    ) -> typing.Callable[..., typing.Callable]:
        self.SKIP_VALIDATION = skip_validation
        return super().use_args(
            argmap, req, location=location, unknown=unknown, as_kwargs=as_kwargs,
            arg_name=arg_name, validate=validate, error_status_code=error_status_code,
            error_headers=error_headers
        )

    def _process_location_data(
        self,
        location_data: typing.Any,
        schema: ma.Schema,
        req: Request,
        location: str,
        unknown: str | None,
        validators: CallableList,
    ) -> typing.Any:
        if self.SKIP_VALIDATION:
            try:
                super()._process_location_data(
                    location_data,
                    schema,
                    req, location, unknown, validators
                )
            except MarshmallowValidationError as me:
                return me.data
        else:
            return super()._process_location_data(
                location_data,
                schema,
                req, location, unknown, validators
            )

希望大家能够提提意见
还有可以的话 @greyli, 可以把任务分配给我

2 Likes

Sorry 这两天在打黑神话,刚刚才认真看你的代码……我觉得实现没问题,不过可以尝试先贡献到 webargs。

好的,我再考虑一下。
我代码差不多实现好了,昨天在编写测试的时候,注意到漏了一个点。
就是添加的flag, 应该是只跳过标记 validate=False对应的视图函数的输入参数的验证,其他的视图函数参数验证应该不会被影响到

@app.patch('/pets/<int:pet_id>')
@app.input(PetIn)  # -> json_data
def update_pet(pet_id, json_data):
    return json_data

@app.patch('/pets/no_validation/<int:pet_id>')
@app.input(PetIn, skip_validation=True)  # -> json_data
def update_pet_with_no_validation(pet_id, json_data):
    return json_data

对于上面的代码,按理来说 update_pet 参数验证失败时应该给出对应的信息,但是会拿到正常的原始输入数据,因为 update_pet_with_no_validation 这一段标记为 skip_validation=True 之后,初始化之后的 FlaskParser 中的 self.SKIP_VALIDATIONTrue ,所以所有的没有被标记为 skip_validation=True 的视图函数也会受到影响

parser: FlaskParser = FlaskParser()
use_args: t.Callable = parser.use_args

是因为这段代码的存在,parser 因为是单例,所以我在某个视图函数中的标记,会扩散到整个视图函数的参数验证中。

还有最后一个想法,就是我看了 marshmallow 的文档,代码以及相关 api, 也在网上看了不少文章。其实对于一个 schema ,如果只想拿到原始数据,而不进行 validation 或者 deserialization 有一个更加直接的方法:

class PetIn(Schema):
    name = String(required=True, validate=Length(0, 10))
    category = String(required=True, validate=OneOf(['dog', 'cat']))
    
    def load(self, data, *args, **kwargs): 
        return data

就是定义 schema 时,顺便重写一下 load 方法,并直接返回 data 。数据验证是通过 marshmallow 来进行的,跳过验证这一个步骤,是不是放到定义 Schema 这一步更好更高效?

skip_validation 可以作为参数直接传到 _process_location_data 吗?作为 self.SKIP_VALIDATION 传递确实不合适。

为 schema 跳过验证是另一个维度了。可以先看看以视图函数作为入口来控制是不是能实现。

FlaskParser(BaseFlaskParser)APIScaffold 依赖的是 :

parser: FlaskParser = FlaskParser()
use_args: t.Callable = parser.use_args

没法直接传递过去,先通过重写 use_args 方法,接受额外的变量,然后在 _process_location_data 中判断 self.SKIP_VALIDATION 的值,再进行其他的操作

突然想起来,有没有可能这样实现:

  • 创建一个 load 函数来根据 location 从 request 对象获取对应的数据
  • 如果用户设置了 skip_validation,我们直接从 load 函数获取数据传给视图函数,不进入 use_args 的调用逻辑

这一点我确实没能想得到,不过我明天要加班,我周日仔细的看一下 :ok_hand:

1 Like

最近加班稍多,有点拖久了。
我把input这一块的代码,重新看了一遍,然后基于辉哥你给出的的想法,把主要的逻辑弄了一下:

def raw_load(location: str = 'json', arg_name: str | None = None) -> dict[t.Any, t.Any]:
    if location == 'json':
        return request.get_json()
    if location == 'form':
        return request.form.to_dict()
    if location in ('files', 'form_and_files'):
        files = request.files.to_dict()
        form = request.form.to_dict()
        return {**form, **files}
    if location == 'json_or_form':
        pass
    if location == 'path':
        return {
            arg_name: request.view_args.get(arg_name)
        } if arg_name is not None else {}
    if location == 'query':
        return request.args.to_dict()
    raise RuntimeError(f'Unsupported location: {location}')

我想确定一下,自己有没有跑偏

没事,不着急。实现差不多是这样。

已提交 pr.

1 Like

上一个pr我漏掉了一处的type checking error, 重新修改之后更新了pr