Is it possible to run custom code before Swagger validations in a python/flask server stub?

Question:

I’m using the swagger editor (OpenApi 2) for creating flask apis in python. When you define a model in swagger and use it as a schema for the body of a request, swagger validates the body before handing the control to you in the X_controller.py files.

I want to add some code before that validation happens (for printing logs for debugging purposes). Swagger just prints to stdout errors like the following and they are not useful when you have a lot of fields (I need the key that isn’t valid).

https://host/path validation error: False is not of type 'string'
10.255.0.2 - - [20/May/2020:20:20:20 +0000] "POST /path HTTP/1.1" 400 116 "-" "GuzzleHttp/7"

I know tecnically you can remove the validations in swagger and do them manually in your code but I want to keep using this feature, when it works it’s awesome.

Any ideas on how to do this or any alternative to be able to log the requests are welcome.

Asked By: Samuel O.D.

||

Answers:

After some time studying the matter this is what I learnt.

First let’s take a look at how a python-flask server made with Swagger Editor works.

Swagger Editor generates the server stub through Swagger Codegen using the definition written in Swagger Editor. This server stub returned by codegen uses the framework Connexion on top of flask to handle all the HTTP requests and responses, including the validation against the swagger definition (swagger.yaml).

Connexion is a framework that makes it easy to develop python-flask servers because it has a lot of functionality you’d have to make yourself already built in, like parameter validation. All we need to do is replace (in this case modify) these connexion validators.

There are three validators:

  • ParameterValidator
  • RequestBodyValidator
  • ResponseValidator

They get mapped to flask by default but we can replace them easily in the __main__.py file as we’ll see.

Our goal is to replace the default logs and default error response to some custom ones. I’m using a custom Error model and a function called error_response() for preparing error responses, and Loguru for logging the errors (not mandatory, you can keep the original one).

To make the changes needed, looking at the connexion validators code, we can see that most of it can be reused, we only need to modify:

  • RequestBodyValidator: __call__() and validate_schema()
  • ParameterValidator: __call__()

So we only need to create two new classes that extend the original ones, and copy and modify those functions.

Be careful when copying and pasting. This code is based on connexion==1.1.15. If your are on a different version you should base your classes on it.

In a new file custom_validators.py we need:

import json
import functools
from flask import Flask
from loguru import logger
from requests import Response
from jsonschema import ValidationError
from connexion.utils import all_json, is_null
from connexion.exceptions import ExtraParameterProblem
from swagger_server.models import Error
from connexion.decorators.validation import ParameterValidator, RequestBodyValidator

app = Flask(__name__)


def error_response(response: Error) -> Response:
    return app.response_class(
        response=json.dumps(response.to_dict(), default=str),
        status=response.status,
        mimetype='application/json')


class CustomParameterValidator(ParameterValidator):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, function):
        """
        :type function: types.FunctionType
        :rtype: types.FunctionType
        """

        @functools.wraps(function)
        def wrapper(request):

            if self.strict_validation:
                query_errors = self.validate_query_parameter_list(request)
                formdata_errors = self.validate_formdata_parameter_list(request)

                if formdata_errors or query_errors:
                    raise ExtraParameterProblem(formdata_errors, query_errors)

            for param in self.parameters.get('query', []):
                error = self.validate_query_parameter(param, request)
                if error:
                    response = error_response(Error(status=400, description=f'Error: {error}'))
                    return self.api.get_response(response)

            for param in self.parameters.get('path', []):
                error = self.validate_path_parameter(param, request)
                if error:
                    response = error_response(Error(status=400, description=f'Error: {error}'))
                    return self.api.get_response(response)

            for param in self.parameters.get('header', []):
                error = self.validate_header_parameter(param, request)
                if error:
                    response = error_response(Error(status=400, description=f'Error: {error}'))
                    return self.api.get_response(response)

            for param in self.parameters.get('formData', []):
                error = self.validate_formdata_parameter(param, request)
                if error:
                    response = error_response(Error(status=400, description=f'Error: {error}'))
                    return self.api.get_response(response)

            return function(request)

        return wrapper


class CustomRequestBodyValidator(RequestBodyValidator):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, function):
        """
        :type function: types.FunctionType
        :rtype: types.FunctionType
        """

        @functools.wraps(function)
        def wrapper(request):
            if all_json(self.consumes):
                data = request.json

                if data is None and len(request.body) > 0 and not self.is_null_value_valid:
                    # the body has contents that were not parsed as JSON
                    return error_response(Error(
                        status=415,
                        description="Invalid Content-type ({content_type}), JSON data was expected".format(content_type=request.headers.get("Content-Type", ""))
                    ))
                
                error = self.validate_schema(data, request.url)
                if error and not self.has_default:
                    return error

            response = function(request)
            return response

        return wrapper

    def validate_schema(self, data, url):
        if self.is_null_value_valid and is_null(data):
            return None

        try:
            self.validator.validate(data)
        except ValidationError as exception:
            description = f'Validation error. Attribute "{exception.validator_value}" return this error: "{exception.message}"'
            logger.error(description)
            return error_response(Error(
                status=400,
                description=description
            ))

        return None

Once we have our validators, we have to map them to the flask app (__main__.py) using validator_map:

validator_map = {
    'parameter': CustomParameterValidator,
    'body': CustomRequestBodyValidator,
    'response': ResponseValidator,
}

app = connexion.App(__name__, specification_dir='./swagger/', validator_map=validator_map)
app.app.json_encoder = encoder.JSONEncoder
app.add_api(Path('swagger.yaml'), arguments={'title': 'MyApp'})

If you also need to replace the validator I didn’t use in this example, just create a custom child class of ResponseValidator and replace it on the validator_map dictionary in __main__.py.

Connexion docs:
https://connexion.readthedocs.io/en/latest/request.html

Answered By: Samuel O.D.

Forgive me for repeating an answer first posted at https://stackoverflow.com/a/73051652/1630244

Have you tried the Connexion before_request feature? Here’s an example that logs the headers and content before Connexion validates the body:

import connexion
import logging
from flask import request

logger = logging.getLogger(__name__)
conn_app = connexion.FlaskApp(__name__)

@conn_app.app.before_request
def before_request():
    for h in request.headers:
        logger.debug('header %s', h)
   logger.debug('data %s', request.get_data())
Answered By: chrisinmtown
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.