diff --git a/src/pyop/provider.py b/src/pyop/provider.py
index 67f9c2d..4ff703c 100644
--- a/src/pyop/provider.py
+++ b/src/pyop/provider.py
@@ -3,6 +3,7 @@
import logging
import time
import uuid
+from typing import Any, Callable, Dict, List, Mapping, Optional, Union
from urllib.parse import parse_qsl
from urllib.parse import urlparse
@@ -49,14 +50,14 @@ class Provider(object):
def __init__(self, signing_key, configuration_information, authz_state, clients, userinfo, *,
id_token_lifetime=3600, extra_scopes=None):
# type: (jwkest.jwk.Key, Dict[str, Union[str, Sequence[str]]], se_leg_op.authz_state.AuthorizationState,
- # Mapping[str, Mapping[str, Any]], se_leg_op.userinfo.Userinfo, int) -> None
+ # Union[Mapping[str, Mapping[str, Any]], Callable[[], Mapping[str, Mapping[str, Any]]]], se_leg_op.userinfo.Userinfo, int) -> None
"""
Creates a new provider instance.
:param configuration_information: see
"OpenID Connect Discovery 1.0", Section 3
:param clients: see
- "OpenID Connect Dynamic Client Registration 1.0", Section 2
+ "OpenID Connect Dynamic Client Registration 1.0", Section 2 or a callable that returns such a mapping
:param userinfo: read-only interface for user info
:param id_token_lifetime: how long the signed ID Tokens should be valid (in seconds), defaults to 1 hour
"""
@@ -81,7 +82,7 @@ def __init__(self, signing_key, configuration_information, authz_state, clients,
self.authz_state = authz_state
self.stateless = self.authz_state and self.authz_state.stateless
- self.clients = clients
+ self._clients = clients # type: Union[Dict[str, Dict[str, Any]], Callable[[], Dict[str, Dict[str, Any]]]]
self.userinfo = userinfo
self.id_token_lifetime = id_token_lifetime
@@ -101,6 +102,16 @@ def __init__(self, signing_key, configuration_information, authz_state, clients,
self.registration_request_validators.append(
functools.partial(client_preferences_match_provider_capabilities, self))
+ @property
+ def clients(self) -> Dict[str, Dict[str, Any]]:
+ if callable(self._clients):
+ return self._clients()
+ return self._clients
+
+ @clients.setter
+ def clients(self, value): # for backwards compatibility
+ self._clients = value
+
@property
def provider_configuration(self):
"""
diff --git a/tests/pyop/test_provider_clients.py b/tests/pyop/test_provider_clients.py
new file mode 100644
index 0000000..1f57b2c
--- /dev/null
+++ b/tests/pyop/test_provider_clients.py
@@ -0,0 +1,63 @@
+import pytest
+
+from pyop.provider import Provider
+
+
+class MockAuthzState:
+ def __init__(self):
+ self.stateless = False
+
+
+class MockUserinfo:
+ pass
+
+
+class ClientRepo: # A basic ClientRepo with fixed initial clients
+ clients: dict
+
+ def get_all(self):
+ return self.clients
+
+
+class TestProviderClients:
+ def setup_method(self):
+ self.config = {
+ 'issuer': 'https://test',
+ 'authorization_endpoint': 'https://test/auth',
+ 'token_endpoint': 'https://test/token',
+ 'userinfo_endpoint': 'https://test/userinfo',
+ 'jwks_uri': 'https://test/jwks'
+ }
+ self.mock_authz = MockAuthzState()
+ self.mock_userinfo = MockUserinfo()
+
+ def test_clients_dict_direct_modification_and_set(self):
+ clients_dict = {'client1': {'name': 'test'}}
+ provider = Provider(None, self.config, self.mock_authz, clients_dict, self.mock_userinfo)
+
+ provider.clients['client3'] = {'name': 'added'}
+ assert 'client3' in provider.clients
+ assert provider.clients['client3'] == {'name': 'added'}
+
+ new_clients = {'client4': {'name': 'new'}, 'client5': {'name': 'new2'}}
+ provider.clients = new_clients
+ assert provider.clients == new_clients
+
+ def test_clients_callable_indirect_modification(self):
+ repo = ClientRepo()
+ expected_initial = {'client2': {'name': 'callable_test'}}
+ repo.clients = expected_initial
+
+ provider = Provider(None, self.config, self.mock_authz, repo.get_all, self.mock_userinfo)
+ assert provider.clients == expected_initial
+
+ repo.clients['client3'] = {'name': 'added'}
+ assert 'client2' in provider.clients
+ assert provider.clients['client3'] == {'name': 'added'}
+ assert len(provider.clients) == 2
+
+ del repo.clients['client3']
+ assert 'client2' in provider.clients
+ assert 'client3' not in provider.clients
+ assert len(provider.clients) == 1
+