diff --git a/src/pyop/provider.py b/src/pyop/provider.py index 67f9c2d..4ff703c 100644 --- a/src/pyop/provider.py +++ b/src/pyop/provider.py @@ -3,6 +3,7 @@ import logging import time import uuid +from typing import Any, Callable, Dict, List, Mapping, Optional, Union from urllib.parse import parse_qsl from urllib.parse import urlparse @@ -49,14 +50,14 @@ class Provider(object): def __init__(self, signing_key, configuration_information, authz_state, clients, userinfo, *, id_token_lifetime=3600, extra_scopes=None): # type: (jwkest.jwk.Key, Dict[str, Union[str, Sequence[str]]], se_leg_op.authz_state.AuthorizationState, - # Mapping[str, Mapping[str, Any]], se_leg_op.userinfo.Userinfo, int) -> None + # Union[Mapping[str, Mapping[str, Any]], Callable[[], Mapping[str, Mapping[str, Any]]]], se_leg_op.userinfo.Userinfo, int) -> None """ Creates a new provider instance. :param configuration_information: see "OpenID Connect Discovery 1.0", Section 3 :param clients: see - "OpenID Connect Dynamic Client Registration 1.0", Section 2 + "OpenID Connect Dynamic Client Registration 1.0", Section 2 or a callable that returns such a mapping :param userinfo: read-only interface for user info :param id_token_lifetime: how long the signed ID Tokens should be valid (in seconds), defaults to 1 hour """ @@ -81,7 +82,7 @@ def __init__(self, signing_key, configuration_information, authz_state, clients, self.authz_state = authz_state self.stateless = self.authz_state and self.authz_state.stateless - self.clients = clients + self._clients = clients # type: Union[Dict[str, Dict[str, Any]], Callable[[], Dict[str, Dict[str, Any]]]] self.userinfo = userinfo self.id_token_lifetime = id_token_lifetime @@ -101,6 +102,16 @@ def __init__(self, signing_key, configuration_information, authz_state, clients, self.registration_request_validators.append( functools.partial(client_preferences_match_provider_capabilities, self)) + @property + def clients(self) -> Dict[str, Dict[str, Any]]: + if callable(self._clients): + return self._clients() + return self._clients + + @clients.setter + def clients(self, value): # for backwards compatibility + self._clients = value + @property def provider_configuration(self): """ diff --git a/tests/pyop/test_provider_clients.py b/tests/pyop/test_provider_clients.py new file mode 100644 index 0000000..1f57b2c --- /dev/null +++ b/tests/pyop/test_provider_clients.py @@ -0,0 +1,63 @@ +import pytest + +from pyop.provider import Provider + + +class MockAuthzState: + def __init__(self): + self.stateless = False + + +class MockUserinfo: + pass + + +class ClientRepo: # A basic ClientRepo with fixed initial clients + clients: dict + + def get_all(self): + return self.clients + + +class TestProviderClients: + def setup_method(self): + self.config = { + 'issuer': 'https://test', + 'authorization_endpoint': 'https://test/auth', + 'token_endpoint': 'https://test/token', + 'userinfo_endpoint': 'https://test/userinfo', + 'jwks_uri': 'https://test/jwks' + } + self.mock_authz = MockAuthzState() + self.mock_userinfo = MockUserinfo() + + def test_clients_dict_direct_modification_and_set(self): + clients_dict = {'client1': {'name': 'test'}} + provider = Provider(None, self.config, self.mock_authz, clients_dict, self.mock_userinfo) + + provider.clients['client3'] = {'name': 'added'} + assert 'client3' in provider.clients + assert provider.clients['client3'] == {'name': 'added'} + + new_clients = {'client4': {'name': 'new'}, 'client5': {'name': 'new2'}} + provider.clients = new_clients + assert provider.clients == new_clients + + def test_clients_callable_indirect_modification(self): + repo = ClientRepo() + expected_initial = {'client2': {'name': 'callable_test'}} + repo.clients = expected_initial + + provider = Provider(None, self.config, self.mock_authz, repo.get_all, self.mock_userinfo) + assert provider.clients == expected_initial + + repo.clients['client3'] = {'name': 'added'} + assert 'client2' in provider.clients + assert provider.clients['client3'] == {'name': 'added'} + assert len(provider.clients) == 2 + + del repo.clients['client3'] + assert 'client2' in provider.clients + assert 'client3' not in provider.clients + assert len(provider.clients) == 1 +