1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import logging
15- from typing import TYPE_CHECKING
15+ from typing import TYPE_CHECKING , Awaitable , Callable
1616
1717from aiohttp import web
1818
2626
2727logger = logging .getLogger (__name__ )
2828
29- _media_path_regexp = r"/{media_path:.+}"
29+ _MEDIA_PATH_REGEXP = r"/{media_path:.+}"
30+
31+ _CORS_HEADERS = {
32+ "Access-Control-Allow-Origin" : "*" ,
33+ "Access-Control-Allow-Methods" : "GET, POST, OPTIONS" ,
34+ "Access-Control-Allow-Headers" : "Origin, X-Requested-With, Content-Type, Accept, Authorization" ,
35+ }
36+
37+
38+ @web .middleware
39+ async def simple_cors_middleware (
40+ request : web .Request ,
41+ handler : Callable [[web .Request ], Awaitable [web .StreamResponse ]],
42+ ) -> web .StreamResponse :
43+ """A simple aiohttp middleware that adds CORS headers to responses, and handles
44+ OPTIONS requests.
45+
46+ Args:
47+ request: The request to handle.
48+ handler: The handler for this request.
49+
50+ Returns:
51+ A response with CORS headers.
52+ """
53+ if request .method == "OPTIONS" :
54+ # We don't register routes for OPTIONS requests, therefore the handler we're given
55+ # in this case just raises a 405 Method Not Allowed status using an exception.
56+ # Because we actually want to return a 200 OK with additional headers, we ignore
57+ # the handler and just return a new response.
58+ response = web .StreamResponse (
59+ status = 200 ,
60+ headers = _CORS_HEADERS ,
61+ )
62+ return response
63+
64+ # Run the request's handler and append CORS headers to it.
65+ response = await handler (request )
66+ response .headers .update (_CORS_HEADERS )
67+ return response
3068
3169
3270class HTTPServer :
@@ -53,14 +91,14 @@ def _build_app(self) -> web.Application:
5391
5492 app .add_routes (
5593 [
56- web .get ("/scan" + _media_path_regexp , scan_handler .handle_plain ),
94+ web .get ("/scan" + _MEDIA_PATH_REGEXP , scan_handler .handle_plain ),
5795 web .post ("/scan_encrypted" , scan_handler .handle_encrypted ),
5896 web .get (
59- "/download" + _media_path_regexp , download_handler .handle_plain
97+ "/download" + _MEDIA_PATH_REGEXP , download_handler .handle_plain
6098 ),
6199 web .post ("/download_encrypted" , download_handler .handle_encrypted ),
62100 web .get (
63- "/thumbnail" + _media_path_regexp ,
101+ "/thumbnail" + _MEDIA_PATH_REGEXP ,
64102 thumbnail_handler .handle_thumbnail ,
65103 ),
66104 web .get (
@@ -73,9 +111,13 @@ def _build_app(self) -> web.Application:
73111 # Then we create a root application, and define the app we previously created as
74112 # a subapp on the base path for the content scanner API.
75113 root = web .Application (
76- # Apply the "normalize path" middleware to handle trailing slashes. This will
77- # also apply the middleware to subapps.
78- middlewares = [web .normalize_path_middleware ()],
114+ # Apply middlewares. This will also apply to subapps.
115+ middlewares = [
116+ # Handle trailing slashes.
117+ web .normalize_path_middleware (),
118+ # Handler CORS.
119+ simple_cors_middleware ,
120+ ],
79121 )
80122 root .add_subapp ("/_matrix/media_proxy/unstable" , app )
81123
0 commit comments