How to secure fastapi API endpoint with JWT Token based authorization?

Question:

I am a little new to FastAPI in python. I am building an API backend framework that needs to have JWT token based authorization. Now, I know how to generate JWT tokens, but not sure how to integrate that with API methods in fast api in Python. Any pointers will be really appreciated.

Answers:

Integrating it to API methods is easy with Depends and Response Model

So let me provide an example, imagine you are deploying your ML Model, and you are going to add some security, in your case you already created the Token Part:

TL DR

class User(BaseModel):
    pass
...
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
...
async def get_current_user(token: str = Depends(oauth2_scheme)): # You created a function that depends on oauth2_scheme
    pass
...
@app.get("/users/me/models/")
async def read_own_items(current_user: User = Depends(get_current_active_user)):
   pass

Some example

Pydantic schemas

class Url(BaseModel):
    url: str

class AuthorizationResponse(BaseModel):
    pass

class User(BaseModel):
    pass

class AuthUser(BaseModel):
    pass

class Token(BaseModel):
    pass

Your app

LOGIN_URL = "https://example.com/login/oauth/authorize"
REDIRECT_URL = f"{app}/auth/app"
...
@app.get("/login")
def get_login_url() -> Url:
    return Url(url=f"{LOGIN_URL}?{urlencode(some_params_here)}")

@app.post("/authorize")
async def verify_authorization(body: AuthorizationResponse, db: Session = Depends(some_database_fetch)) -> Token:
    return Token(access_token=access_token, token_type="bearer", user=User)

def create_access_token(*, data: User, expire_time: int = None) -> bytes:
    return encoded_jwt

def get_user_from_header(*, authorization: str = Header(None)) -> User: # from fastapi import Header
    return token_data   #Token data = User(**payload)

@app.get("/me", response_model=User)
def read_profile(user: User = Depends(get_user_from_header), db: Session = Depends(some_database_fetch),) -> DbUser:
    return db_user

Summary of example above

  1. We created a LOGIN_URL, then a Pydantic schema for that URL
  2. Then we created /authorize endpoint for the backend to check it and get all it needs from the User API
  3. Function for creating a simple JWT token which is create_access_token
  4. Through JWT token we just created, we can create a dependency get_user_from_header to use in some private endpoints

Sebastian Ramirez(Creator of FastAPI) has a great video that shows how you can add a basic auth to your app FastAPI – Basic HTTP Auth

FastAPI has a great documentation about, oauth2-jwt:

For some real world example, fastapi-users has a perfect JWT authentication backend.

Answered By: Yagiz Degirmenci

With some help from my friend and colleague, I was able to solve this problem, and wanted to share this solution with the community. This is how it looks like now:

Python Code —-

import json

import os

import datetime

from fastapi import HTTPException, Header

from urllib.request import urlopen

from jose import jwt

from jose import exceptions as JoseExceptions

from utils import logger

AUTH0_DOMAIN = os.environ.get(
    'AUTH0_DOMAIN', 'https://<domain>/<tenant-id>/')

AUTH0_ISSUER = os.environ.get(
    'AUTO0_ISSUER', 'https://sts.windows.net/<tenant>/')

AUTH0_API_AUDIENCE = os.environ.get(
    'AUTH0_API_AUDIENCE', '<audience url>')

AZURE_OPENID_CONFIG = os.environ.get(
    'AZURE_OPENID_CONFIG', 'https://login.microsoftonline.com/common/.well-known/openid-configuration')


def get_token_auth_header(authorization):
    parts = authorization.split()

    if parts[0].lower() != "bearer":
        raise HTTPException(
            status_code=401, 
            detail='Authorization header must start with Bearer')
    elif len(parts) == 1:
        raise HTTPException(
            status_code=401, 
            detail='Authorization token not found')
    elif len(parts) > 2:
        raise HTTPException(
            status_code=401, 
            detail='Authorization header be Bearer token')
    
    token = parts[1]
    return token


def get_payload(unverified_header, token, jwks_properties):
    try:
        payload = jwt.decode(
            token,
            key=jwks_properties["jwks"],
            algorithms=jwks_properties["algorithms"],  # ["RS256"] typically
            audience=AUTH0_API_AUDIENCE,
            issuer=AUTH0_ISSUER
        )
    except jwt.ExpiredSignatureError:
        raise HTTPException(
            status_code=401, 
            detail='Authorization token expired')
    except jwt.JWTClaimsError:
        raise HTTPException(
            status_code=401, 
            detail='Incorrect claims, check the audience and issuer.')
    except Exception:
        raise HTTPException(
            status_code=401, 
            detail='Unable to parse authentication token')

    return payload


class AzureJWKS:
    def __init__(self, openid_config: str=AZURE_OPENID_CONFIG):
        self.openid_url = openid_config
        self._jwks = None
        self._signing_algorithms = []
        self._last_updated = datetime.datetime(2000, 1, 1, 12, 0, 0)
    
    def _refresh_cache(self):
        openid_reader = urlopen(self.openid_url)
        azure_config = json.loads(openid_reader.read())
        self._signing_algorithms = azure_config["id_token_signing_alg_values_supported"]
        jwks_url = azure_config["jwks_uri"]

        jwks_reader = urlopen(jwks_url)
        self._jwks = json.loads(jwks_reader.read())

        logger.info(f"Refreshed jwks config from {jwks_url}.")
        logger.info("Supported token signing algorithms: {}".format(str(self._signing_algorithms)))
        self._last_updated = datetime.datetime.now()

    def get_jwks(self, cache_hours: int=24):
        
            logger.info("jwks config is out of date (last updated at {})".format(str(self._last_updated)))
            self._refresh_cache()
        return {'jwks': self._jwks, 'algorithms': self._signing_algorithms}

jwks_config = AzureJWKS()


async def require_auth(token: str = Header(...)):
    token = get_token_auth_header(token)
   

    try:
        unverified_header = jwt.get_unverified_header(token)
    except JoseExceptions.JWTError:
        raise HTTPException(
                    status_code=401, 
                    detail='Unable to decode authorization token headers')

    payload = get_payload(unverified_header, token, jwks_config.get_jwks())
    if not payload:
        raise HTTPException(
                    status_code=401, 
                    detail='Invalid authorization token')

    return payload

I hope the community gets benefited from this!

Answered By: Aditya Bhattacharya

I found certain improvements that could be made to the accepted answer:

  • If you choose to use the HTTPBearer security schema, the format of the Authorization header content is automatically validated, and there is no need to have a function like the one in the accepted answer, get_token_auth_header. Moreover, the generated docs end up being super clear and explanatory, with regards to authentication:

enter image description here

  • When you decode the token, you can catch all exceptions that are descendants of the class JOSEError, and print their message, avoiding catching specific exceptions, and writing custom messages
  • Bonus: in the jwt decode method, you can specify what claims you want to ignore, given the fact you don’t wanna validate them

Sample snippet:
Where …

/endpoints
          - hello.py
          - __init__.p
dependency.py
main.py
# dependency.py script
from jose import jwt
from jose.exceptions import JOSEError
from fastapi import HTTPException, Depends
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

security = HTTPBearer()

async def has_access(credentials: HTTPAuthorizationCredentials= Depends(security)):
    """
        Function that is used to validate the token in the case that it requires it
    """
    token = credentials.credentials

    try:
        payload = jwt.decode(token, key='secret', options={"verify_signature": False,
                                                           "verify_aud": False,
                                                           "verify_iss": False})
        print("payload => ", payload)
    except JOSEError as e:  # catches any exception
        raise HTTPException(
            status_code=401,
            detail=str(e))
# main.py script
from fastapi import FastAPI, Depends
from endpoints import hello
from dependency import has_access

app = FastAPI()

# routes
PROTECTED = [Depends(has_access)]

app.include_router(
    hello.router,
    prefix="/hello",
    dependencies=PROTECTED
)
# hello.py script
from fastapi import APIRouter

router = APIRouter()

@router.get("")
async def say_hi(name: str):
    return "Hi " + name

By taking advantage of all the mentioned features, you end up building an API with security super fast 🙂

Answered By: onofricamila
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.