@@ -329,6 +329,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
329329 def __init__ (self , hs : "HomeServer" ):
330330 super ().__init__ ()
331331 self .hs = hs
332+ self ._auth = hs .get_auth ()
332333 self .server_name = hs .hostname
333334 self .registration_handler = hs .get_registration_handler ()
334335 self .ratelimiter = FederationRateLimiter (
@@ -361,7 +362,7 @@ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
361362 if self .inhibit_user_in_use_error :
362363 return 200 , {"available" : True }
363364
364- ip = request . getClientAddress (). host
365+ ip = self . _auth . get_ip_address_from_request ( request )
365366 with self .ratelimiter .ratelimit (ip ) as wait_deferred :
366367 await wait_deferred
367368
@@ -395,6 +396,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
395396 def __init__ (self , hs : "HomeServer" ):
396397 super ().__init__ ()
397398 self .hs = hs
399+ self ._auth = hs .get_auth ()
398400 self .store = hs .get_datastores ().main
399401 self .ratelimiter = Ratelimiter (
400402 store = self .store ,
@@ -403,7 +405,8 @@ def __init__(self, hs: "HomeServer"):
403405 )
404406
405407 async def on_GET (self , request : Request ) -> Tuple [int , JsonDict ]:
406- await self .ratelimiter .ratelimit (None , (request .getClientAddress ().host ,))
408+ ip_address = self ._auth .get_ip_address_from_request (request )
409+ await self .ratelimiter .ratelimit (None , (ip_address ,))
407410
408411 if not self .hs .config .registration .enable_registration :
409412 raise SynapseError (
@@ -456,7 +459,7 @@ def __init__(self, hs: "HomeServer"):
456459 async def on_POST (self , request : SynapseRequest ) -> Tuple [int , JsonDict ]:
457460 body = parse_json_object_from_request (request )
458461
459- client_addr = request . getClientAddress (). host
462+ client_addr = self . auth . get_ip_address_from_request ( request )
460463
461464 await self .ratelimiter .ratelimit (None , client_addr , update = False )
462465
@@ -916,7 +919,7 @@ def __init__(self, hs: "HomeServer"):
916919 async def on_POST (self , request : SynapseRequest ) -> Tuple [int , JsonDict ]:
917920 body = parse_json_object_from_request (request )
918921
919- client_addr = request . getClientAddress (). host
922+ client_addr = self . auth . get_ip_address_from_request ( request )
920923
921924 await self .ratelimiter .ratelimit (None , client_addr , update = False )
922925
0 commit comments