Skip to content

Instantly share code, notes, and snippets.

@jrepp
Created October 2, 2024 01:43
Show Gist options
  • Select an option

  • Save jrepp/fbe9a7c3207e255718afabd662f1c3c3 to your computer and use it in GitHub Desktop.

Select an option

Save jrepp/fbe9a7c3207e255718afabd662f1c3c3 to your computer and use it in GitHub Desktop.
# derived from https://pypi.org/project/fastapi-proxiedheadersmiddleware/
#
# this modified version solves an important problem when operating behind a TLS
# terminating gateway - convert the X-Forwarded-Proto to the internal scheme.
#
# If you are using CORS middleware you will want to install this as the redirects from
# CORS will redirect your location header to a http:// endpoint (as it's not respecting
# the scheme of the upstream proxy but the fastapi instance itself
#
# Just use this module in your project and add it ahead of CORS.
#
# This is how you would install this middleware:
#
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"],
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
# app.add_middleware(
# ProxiedHeadersMiddleware
# )
from typing import List, Tuple
from starlette.types import ASGIApp, Receive, Scope, Send
Headers = List[Tuple[bytes, bytes]]
class ProxiedHeadersMiddleware:
"""
A middleware that modifies the request to ensure that FastAPI uses the
X-Forwarded-* headers when creating URLs used to reference this application.
We are very permissive in allowing all X-Forwarded-* headers to be used, as
we know that this API will be published behind the API Gateway, and is
therefore not prone to redirect hijacking.
"""
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
headers = dict(scope.get("headers", {}))
remap_headers(headers)
# replace the scheme if behind TLS termination
if b'x-forwarded-proto' in headers:
print("re-writing scheme")
scope["scheme"] = headers[b'x-forwarded-proto'].decode('ascii')
# rewrite to tuple based headers format
scope["headers"] = [(k, v) for k, v in headers.items()]
await self.app(scope, receive, send)
return
def remap_headers(source_headers: dict) -> None:
"""
Map X-Forwarded-Host to host and X-Forwarded-Prefix to prefix.
"""
if b'x-forwarded-host' in source_headers:
source_headers.update({b'host': source_headers[b'x-forwarded-host']})
source_headers.pop(b'x-forwarded-host')
if b'x-forwarded-prefix' in source_headers:
source_headers.update({
b'host': source_headers[b'host'] + source_headers[b'x-forwarded-prefix']
})
source_headers.pop(b'x-forwarded-prefix')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment