From 203a03392d2ef524657a91cd50dd520c888be738 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 7 Nov 2025 15:34:08 -0500 Subject: [PATCH 1/3] chore: Support x-ld-envid in updates --- ldclient/impl/datasourcev2/polling.py | 35 ++++++++++++-------- ldclient/impl/datasourcev2/streaming.py | 44 +++++++++++++++---------- ldclient/impl/datasystem/config.py | 15 --------- ldclient/impl/util.py | 15 ++++----- ldclient/integrations/test_datav2.py | 12 ++++--- 5 files changed, 64 insertions(+), 57 deletions(-) diff --git a/ldclient/impl/datasourcev2/polling.py b/ldclient/impl/datasourcev2/polling.py index a1a67702..307d8276 100644 --- a/ldclient/impl/datasourcev2/polling.py +++ b/ldclient/impl/datasourcev2/polling.py @@ -32,6 +32,8 @@ from ldclient.impl.http import _http_factory from ldclient.impl.repeating_task import RepeatingTask from ldclient.impl.util import ( + _LD_ENVID_HEADER, + _LD_FD_FALLBACK_HEADER, UnsuccessfulResponseException, _Fail, _headers, @@ -117,6 +119,13 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: while self._stop.is_set() is False: result = self._requester.fetch(ss.selector()) if isinstance(result, _Fail): + fallback = None + envid = None + + if result.headers is not None: + fallback = result.headers.get(_LD_FD_FALLBACK_HEADER) == 'true' + envid = result.headers.get(_LD_ENVID_HEADER) + if isinstance(result.exception, UnsuccessfulResponseException): error_info = DataSourceErrorInfo( kind=DataSourceErrorKind.ERROR_RESPONSE, @@ -127,28 +136,28 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: ), ) - fallback = result.exception.headers.get("X-LD-FD-Fallback") == 'true' if fallback: yield Update( state=DataSourceState.OFF, error=error_info, - revert_to_fdv1=True + revert_to_fdv1=True, + environment_id=envid, ) break status_code = result.exception.status if is_http_error_recoverable(status_code): - # TODO(fdv2): Add support for environment ID yield Update( state=DataSourceState.INTERRUPTED, error=error_info, + environment_id=envid, ) continue - # TODO(fdv2): Add support for environment ID yield Update( state=DataSourceState.OFF, error=error_info, + environment_id=envid, ) break @@ -159,19 +168,18 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: message=result.error, ) - # TODO(fdv2): Go has a designation here to handle JSON decoding separately. - # TODO(fdv2): Add support for environment ID yield Update( state=DataSourceState.INTERRUPTED, error=error_info, + environment_id=envid, ) else: (change_set, headers) = result.value yield Update( state=DataSourceState.VALID, change_set=change_set, - environment_id=headers.get("X-LD-EnvID"), - revert_to_fdv1=headers.get('X-LD-FD-Fallback') == 'true' + environment_id=headers.get(_LD_ENVID_HEADER), + revert_to_fdv1=headers.get(_LD_FD_FALLBACK_HEADER) == 'true' ) if self._event.wait(self._poll_interval): @@ -208,7 +216,7 @@ def _poll(self, ss: SelectorStore) -> BasisResult: (change_set, headers) = result.value - env_id = headers.get("X-LD-EnvID") + env_id = headers.get(_LD_ENVID_HEADER) if not isinstance(env_id, str): env_id = None @@ -273,14 +281,14 @@ def fetch(self, selector: Optional[Selector]) -> PollingResult: ), retries=1, ) + headers = response.headers if response.status >= 400: return _Fail( - f"HTTP error {response}", UnsuccessfulResponseException(response.status, response.headers) + f"HTTP error {response}", UnsuccessfulResponseException(response.status), + headers=headers, ) - headers = response.headers - if response.status == 304: return _Success(value=(ChangeSetBuilder.no_changes(), headers)) @@ -304,6 +312,7 @@ def fetch(self, selector: Optional[Selector]) -> PollingResult: return _Fail( error=changeset_result.error, exception=changeset_result.exception, + headers=headers, # type: ignore ) @@ -438,7 +447,7 @@ def fetch(self, selector: Optional[Selector]) -> PollingResult: if response.status >= 400: return _Fail( - f"HTTP error {response}", UnsuccessfulResponseException(response.status, response.headers) + f"HTTP error {response}", UnsuccessfulResponseException(response.status) ) headers = response.headers diff --git a/ldclient/impl/datasourcev2/streaming.py b/ldclient/impl/datasourcev2/streaming.py index e8637174..2c83bda3 100644 --- a/ldclient/impl/datasourcev2/streaming.py +++ b/ldclient/impl/datasourcev2/streaming.py @@ -38,6 +38,8 @@ ) from ldclient.impl.http import HTTPFactory, _http_factory from ldclient.impl.util import ( + _LD_ENVID_HEADER, + _LD_FD_FALLBACK_HEADER, http_error_message, is_http_error_recoverable, log @@ -58,7 +60,6 @@ STREAMING_ENDPOINT = "/sdk/stream" - SseClientBuilder = Callable[[Config, SelectorStore], SSEClient] @@ -154,7 +155,9 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: if action.error is None: continue - (update, should_continue) = self._handle_error(action.error) + envid = action.headers.get(_LD_ENVID_HEADER) if action.headers is not None else None + + (update, should_continue) = self._handle_error(action.error, envid) if update is not None: yield update @@ -162,13 +165,17 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: break continue + envid = None if isinstance(action, Start) and action.headers is not None: - fallback = action.headers.get('X-LD-FD-Fallback') == 'true' + fallback = action.headers.get(_LD_FD_FALLBACK_HEADER) == 'true' + envid = action.headers.get(_LD_ENVID_HEADER) + if fallback: self._record_stream_init(True) yield Update( state=DataSourceState.OFF, - revert_to_fdv1=True + revert_to_fdv1=True, + environment_id=envid, ) break @@ -176,7 +183,7 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: continue try: - update = self._process_message(action, change_set_builder) + update = self._process_message(action, change_set_builder, envid) if update is not None: self._record_stream_init(False) self._connection_attempt_start_time = None @@ -187,7 +194,7 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: ) self._sse.interrupt() - (update, should_continue) = self._handle_error(e) + (update, should_continue) = self._handle_error(e, envid) if update is not None: yield update if not should_continue: @@ -204,7 +211,7 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: DataSourceErrorKind.UNKNOWN, 0, time(), str(e) ), revert_to_fdv1=False, - environment_id=None, # TODO(sdk-1410) + environment_id=envid, ) self._sse.close() @@ -226,7 +233,7 @@ def _record_stream_init(self, failed: bool): # pylint: disable=too-many-return-statements def _process_message( - self, msg: Event, change_set_builder: ChangeSetBuilder + self, msg: Event, change_set_builder: ChangeSetBuilder, envid: Optional[str] ) -> Optional[Update]: """ Processes a single message from the SSE stream and returns an Update @@ -247,7 +254,7 @@ def _process_message( change_set_builder.expect_changes() return Update( state=DataSourceState.VALID, - environment_id=None, # TODO(sdk-1410) + environment_id=envid, ) return None @@ -293,13 +300,13 @@ def _process_message( return Update( state=DataSourceState.VALID, change_set=change_set, - environment_id=None, # TODO(sdk-1410) + environment_id=envid, ) log.info("Unexpected event found in stream: %s", msg.event) return None - def _handle_error(self, error: Exception) -> Tuple[Optional[Update], bool]: + def _handle_error(self, error: Exception, envid: Optional[str]) -> Tuple[Optional[Update], bool]: """ This method handles errors that occur during the streaming process. @@ -328,7 +335,7 @@ def _handle_error(self, error: Exception) -> Tuple[Optional[Update], bool]: DataSourceErrorKind.INVALID_DATA, 0, time(), str(error) ), revert_to_fdv1=False, - environment_id=None, # TODO(sdk-1410) + environment_id=envid, ) return (update, True) @@ -344,11 +351,15 @@ def _handle_error(self, error: Exception) -> Tuple[Optional[Update], bool]: str(error), ) - if error.headers is not None and error.headers.get("X-LD-FD-Fallback") == 'true': + if envid is None and error.headers is not None: + envid = error.headers.get(_LD_ENVID_HEADER) + + if error.headers is not None and error.headers.get(_LD_FD_FALLBACK_HEADER) == 'true': update = Update( state=DataSourceState.OFF, error=error_info, - revert_to_fdv1=True + revert_to_fdv1=True, + environment_id=envid, ) return (update, False) @@ -364,7 +375,7 @@ def _handle_error(self, error: Exception) -> Tuple[Optional[Update], bool]: ), error=error_info, revert_to_fdv1=False, - environment_id=None, # TODO(sdk-1410) + environment_id=envid, ) if not is_recoverable: @@ -386,7 +397,7 @@ def _handle_error(self, error: Exception) -> Tuple[Optional[Update], bool]: DataSourceErrorKind.UNKNOWN, 0, time(), str(error) ), revert_to_fdv1=False, - environment_id=None, # TODO(sdk-1410) + environment_id=envid, ) # no stacktrace here because, for a typical connection error, it'll # just be a lengthy tour of urllib3 internals @@ -411,5 +422,4 @@ def __init__(self, config: Config): def build(self) -> StreamingDataSource: """Builds a StreamingDataSource instance with the configured parameters.""" - # TODO(fdv2): Add in the other controls here. return StreamingDataSource(self._config) diff --git a/ldclient/impl/datasystem/config.py b/ldclient/impl/datasystem/config.py index d3b34a7a..eadc6f0e 100644 --- a/ldclient/impl/datasystem/config.py +++ b/ldclient/impl/datasystem/config.py @@ -210,18 +210,3 @@ def persistent_store(store: FeatureStore) -> ConfigBuilder: although it will keep it up-to-date. """ return default().data_store(store, DataStoreMode.READ_WRITE) - - -# TODO(fdv2): Implement these methods -# -# WithEndpoints configures the data system with custom endpoints for -# LaunchDarkly's streaming and polling synchronizers. This method is not -# necessary for most use-cases, but can be useful for testing or custom -# network configurations. -# -# Any endpoint that is not specified (empty string) will be treated as the -# default LaunchDarkly SaaS endpoint for that service. - -# WithRelayProxyEndpoints configures the data system with a single endpoint -# for LaunchDarkly's streaming and polling synchronizers. The endpoint -# should be Relay Proxy's base URI, for example http://localhost:8123. diff --git a/ldclient/impl/util.py b/ldclient/impl/util.py index 81054f4b..54caf9de 100644 --- a/ldclient/impl/util.py +++ b/ldclient/impl/util.py @@ -4,7 +4,7 @@ import time from dataclasses import dataclass from datetime import timedelta -from typing import Any, Dict, Generic, Optional, TypeVar, Union +from typing import Any, Dict, Generic, Mapping, Optional, TypeVar, Union from urllib.parse import urlparse, urlunparse from ldclient.impl.http import _base_headers @@ -35,6 +35,9 @@ def timedelta_millis(delta: timedelta) -> float: # Compiled regex pattern for valid characters in application values and SDK keys _VALID_CHARACTERS_REGEX = re.compile(r"[^a-zA-Z0-9._-]") +_LD_ENVID_HEADER = 'X-LD-EnvID' +_LD_FD_FALLBACK_HEADER = 'X-LD-FD-Fallback' + def validate_application_info(application: dict, logger: logging.Logger) -> dict: return { @@ -117,23 +120,18 @@ def __str__(self, *args, **kwargs): class UnsuccessfulResponseException(Exception): - def __init__(self, status, headers={}): + def __init__(self, status): super(UnsuccessfulResponseException, self).__init__("HTTP error %d" % status) self._status = status - self._headers = headers @property def status(self): return self._status - @property - def headers(self): - return self._headers - def throw_if_unsuccessful_response(resp): if resp.status >= 400: - raise UnsuccessfulResponseException(resp.status, resp.headers) + raise UnsuccessfulResponseException(resp.status) def is_http_error_recoverable(status): @@ -290,6 +288,7 @@ class _Success(Generic[T]): class _Fail(Generic[E]): error: E exception: Optional[Exception] = None + headers: Optional[Mapping[str, Any]] = None # TODO(breaking): Replace the above Result class with an improved generic diff --git a/ldclient/integrations/test_datav2.py b/ldclient/integrations/test_datav2.py index 744264f2..a2da52db 100644 --- a/ldclient/integrations/test_datav2.py +++ b/ldclient/integrations/test_datav2.py @@ -551,17 +551,21 @@ class TestDataV2: :: from ldclient.impl.datasystem import config as datasystem_config + from ldclient.integrations.test_datav2 import TestDataV2 + td = TestDataV2.data_source() td.update(td.flag('flag-key-1').variation_for_all(True)) # Configure the data system with TestDataV2 as both initializer and synchronizer data_config = datasystem_config.custom() - data_config.initializers([lambda: td.build_initializer()]) - data_config.synchronizers(lambda: td.build_synchronizer()) + data_config.initializers([td.build_initializer]) + data_config.synchronizers(td.build_synchronizer) - # TODO(fdv2): This will be integrated with the main Config in a future version - # For now, TestDataV2 is primarily intended for unit testing scenarios + config = Config( + sdk_key, + datasystem_config=data_config.build(), + ) # flags can be updated at any time: td.update(td.flag('flag-key-1'). From cc6e742ce300c0fadb0c9249ecfd7ae8da2fff0b Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 7 Nov 2025 15:46:12 -0500 Subject: [PATCH 2/3] Add headers to fdv1 polling implementation --- ldclient/impl/datasourcev2/polling.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ldclient/impl/datasourcev2/polling.py b/ldclient/impl/datasourcev2/polling.py index 307d8276..e5415039 100644 --- a/ldclient/impl/datasourcev2/polling.py +++ b/ldclient/impl/datasourcev2/polling.py @@ -445,13 +445,13 @@ def fetch(self, selector: Optional[Selector]) -> PollingResult: retries=1, ) + headers = response.headers if response.status >= 400: return _Fail( - f"HTTP error {response}", UnsuccessfulResponseException(response.status) + f"HTTP error {response}", UnsuccessfulResponseException(response.status), + headers=headers ) - headers = response.headers - if response.status == 304: return _Success(value=(ChangeSetBuilder.no_changes(), headers)) @@ -475,6 +475,7 @@ def fetch(self, selector: Optional[Selector]) -> PollingResult: return _Fail( error=changeset_result.error, exception=changeset_result.exception, + headers=headers, ) From d3cb296437bdfc28749ebca262f574ad6874941f Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 14 Nov 2025 15:05:19 -0500 Subject: [PATCH 3/3] Address review feedback from cursor --- ldclient/impl/datasourcev2/streaming.py | 2 +- .../datasourcev2/test_polling_synchronizer.py | 168 ++++++++++++++- .../test_streaming_synchronizer.py | 193 +++++++++++++++++- 3 files changed, 357 insertions(+), 6 deletions(-) diff --git a/ldclient/impl/datasourcev2/streaming.py b/ldclient/impl/datasourcev2/streaming.py index 2c83bda3..57359f16 100644 --- a/ldclient/impl/datasourcev2/streaming.py +++ b/ldclient/impl/datasourcev2/streaming.py @@ -147,6 +147,7 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: self._running = True self._connection_attempt_start_time = time() + envid = None for action in self._sse.all: if isinstance(action, Fault): # If the SSE client detects the stream has closed, then it will @@ -165,7 +166,6 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: break continue - envid = None if isinstance(action, Start) and action.headers is not None: fallback = action.headers.get(_LD_FD_FALLBACK_HEADER) == 'true' envid = action.headers.get(_LD_ENVID_HEADER) diff --git a/ldclient/testing/impl/datasourcev2/test_polling_synchronizer.py b/ldclient/testing/impl/datasourcev2/test_polling_synchronizer.py index 3410a1e6..471a0e2b 100644 --- a/ldclient/testing/impl/datasourcev2/test_polling_synchronizer.py +++ b/ldclient/testing/impl/datasourcev2/test_polling_synchronizer.py @@ -20,7 +20,7 @@ Selector, ServerIntent ) -from ldclient.impl.util import UnsuccessfulResponseException, _Fail, _Success +from ldclient.impl.util import UnsuccessfulResponseException, _Fail, _Success, _LD_ENVID_HEADER, _LD_FD_FALLBACK_HEADER from ldclient.interfaces import DataSourceErrorKind, DataSourceState from ldclient.testing.mock_components import MockSelectorStore @@ -304,3 +304,169 @@ def test_unrecoverable_error_shuts_down(): assert False, "Expected StopIteration" except StopIteration: pass + + +def test_envid_from_success_headers(): + """Test that environment ID is captured from successful polling response headers""" + change_set = ChangeSetBuilder.no_changes() + headers = {_LD_ENVID_HEADER: 'test-env-polling-123'} + polling_result: PollingResult = _Success(value=(change_set, headers)) + + synchronizer = PollingDataSource( + poll_interval=0.01, requester=ListBasedRequester(results=iter([polling_result])) + ) + + valid = next(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert valid.state == DataSourceState.VALID + assert valid.error is None + assert valid.revert_to_fdv1 is False + assert valid.environment_id == 'test-env-polling-123' + + +def test_envid_from_success_with_changeset(): + """Test that environment ID is captured from polling response with actual changes""" + builder = ChangeSetBuilder() + builder.start(intent=IntentCode.TRANSFER_FULL) + builder.add_put( + version=100, kind=ObjectKind.FLAG, key="flag-key", obj={"key": "flag-key"} + ) + change_set = builder.finish(selector=Selector(state="p:SOMETHING:300", version=300)) + headers = {_LD_ENVID_HEADER: 'test-env-456'} + polling_result: PollingResult = _Success(value=(change_set, headers)) + + synchronizer = PollingDataSource( + poll_interval=0.01, requester=ListBasedRequester(results=iter([polling_result])) + ) + valid = next(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert valid.state == DataSourceState.VALID + assert valid.environment_id == 'test-env-456' + assert valid.change_set is not None + assert len(valid.change_set.changes) == 1 + + +def test_envid_from_fallback_headers(): + """Test that environment ID is captured when fallback header is present on success""" + change_set = ChangeSetBuilder.no_changes() + headers = { + _LD_ENVID_HEADER: 'test-env-fallback', + _LD_FD_FALLBACK_HEADER: 'true' + } + polling_result: PollingResult = _Success(value=(change_set, headers)) + + synchronizer = PollingDataSource( + poll_interval=0.01, requester=ListBasedRequester(results=iter([polling_result])) + ) + + valid = next(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert valid.state == DataSourceState.VALID + assert valid.revert_to_fdv1 is True + assert valid.environment_id == 'test-env-fallback' + + +def test_envid_from_error_headers_recoverable(): + """Test that environment ID is captured from error response headers for recoverable errors""" + builder = ChangeSetBuilder() + builder.start(intent=IntentCode.TRANSFER_FULL) + builder.add_delete(version=101, kind=ObjectKind.FLAG, key="flag-key") + change_set = builder.finish(selector=Selector(state="p:SOMETHING:300", version=300)) + headers_success = {_LD_ENVID_HEADER: 'test-env-success'} + polling_result: PollingResult = _Success(value=(change_set, headers_success)) + + headers_error = {_LD_ENVID_HEADER: 'test-env-408'} + _failure = _Fail( + error="error for test", + exception=UnsuccessfulResponseException(status=408), + headers=headers_error + ) + + synchronizer = PollingDataSource( + poll_interval=0.01, + requester=ListBasedRequester(results=iter([_failure, polling_result])), + ) + sync = synchronizer.sync(MockSelectorStore(Selector.no_selector())) + interrupted = next(sync) + valid = next(sync) + + assert interrupted.state == DataSourceState.INTERRUPTED + assert interrupted.environment_id == 'test-env-408' + assert interrupted.error is not None + assert interrupted.error.status_code == 408 + + assert valid.state == DataSourceState.VALID + assert valid.environment_id == 'test-env-success' + + +def test_envid_from_error_headers_unrecoverable(): + """Test that environment ID is captured from error response headers for unrecoverable errors""" + headers_error = {_LD_ENVID_HEADER: 'test-env-401'} + _failure = _Fail( + error="error for test", + exception=UnsuccessfulResponseException(status=401), + headers=headers_error + ) + + synchronizer = PollingDataSource( + poll_interval=0.01, + requester=ListBasedRequester(results=iter([_failure])), + ) + sync = synchronizer.sync(MockSelectorStore(Selector.no_selector())) + off = next(sync) + + assert off.state == DataSourceState.OFF + assert off.environment_id == 'test-env-401' + assert off.error is not None + assert off.error.status_code == 401 + + +def test_envid_from_error_with_fallback(): + """Test that environment ID and fallback are captured from error response""" + headers_error = { + _LD_ENVID_HEADER: 'test-env-503', + _LD_FD_FALLBACK_HEADER: 'true' + } + _failure = _Fail( + error="error for test", + exception=UnsuccessfulResponseException(status=503), + headers=headers_error + ) + + synchronizer = PollingDataSource( + poll_interval=0.01, + requester=ListBasedRequester(results=iter([_failure])), + ) + sync = synchronizer.sync(MockSelectorStore(Selector.no_selector())) + off = next(sync) + + assert off.state == DataSourceState.OFF + assert off.revert_to_fdv1 is True + assert off.environment_id == 'test-env-503' + + +def test_envid_from_generic_error_with_headers(): + """Test that environment ID is captured from generic errors with headers""" + builder = ChangeSetBuilder() + builder.start(intent=IntentCode.TRANSFER_FULL) + change_set = builder.finish(selector=Selector(state="p:SOMETHING:300", version=300)) + headers_success = {} + polling_result: PollingResult = _Success(value=(change_set, headers_success)) + + headers_error = {_LD_ENVID_HEADER: 'test-env-generic'} + _failure = _Fail(error="generic error for test", headers=headers_error) + + synchronizer = PollingDataSource( + poll_interval=0.01, + requester=ListBasedRequester(results=iter([_failure, polling_result])), + ) + sync = synchronizer.sync(MockSelectorStore(Selector.no_selector())) + interrupted = next(sync) + valid = next(sync) + + assert interrupted.state == DataSourceState.INTERRUPTED + assert interrupted.environment_id == 'test-env-generic' + assert interrupted.error is not None + assert interrupted.error.kind == DataSourceErrorKind.NETWORK_ERROR + + assert valid.state == DataSourceState.VALID diff --git a/ldclient/testing/impl/datasourcev2/test_streaming_synchronizer.py b/ldclient/testing/impl/datasourcev2/test_streaming_synchronizer.py index 90c7037e..c90d293c 100644 --- a/ldclient/testing/impl/datasourcev2/test_streaming_synchronizer.py +++ b/ldclient/testing/impl/datasourcev2/test_streaming_synchronizer.py @@ -6,7 +6,7 @@ from typing import Iterable, List, Optional import pytest -from ld_eventsource.actions import Action +from ld_eventsource.actions import Action, Start from ld_eventsource.http import HTTPStatusError from ld_eventsource.sse_client import Event, Fault @@ -30,6 +30,7 @@ Selector, ServerIntent ) +from ldclient.impl.util import _LD_ENVID_HEADER, _LD_FD_FALLBACK_HEADER from ldclient.interfaces import DataSourceErrorKind, DataSourceState from ldclient.testing.mock_components import MockSelectorStore @@ -416,10 +417,14 @@ def test_invalid_json_decoding(events): # pylint: disable=redefined-outer-name def test_stops_on_unrecoverable_status_code( events, ): # pylint: disable=redefined-outer-name + error = HTTPStatusError(401) + error.headers = None + fault = Fault(error=error) + fault.headers = None builder = list_sse_client( [ # This will generate an error but the stream should continue - Fault(error=HTTPStatusError(401)), + fault, # We send these valid combinations to ensure the stream is NOT # being processed after the 401. events[EventName.SERVER_INTENT], @@ -445,12 +450,22 @@ def test_stops_on_unrecoverable_status_code( def test_continues_on_recoverable_status_code( events, ): # pylint: disable=redefined-outer-name + error1 = HTTPStatusError(400) + error1.headers = None + fault1 = Fault(error=error1) + fault1.headers = None + + error2 = HTTPStatusError(408) + error2.headers = None + fault2 = Fault(error=error2) + fault2.headers = None + builder = list_sse_client( [ # This will generate an error but the stream should continue - Fault(error=HTTPStatusError(400)), + fault1, events[EventName.SERVER_INTENT], - Fault(error=HTTPStatusError(408)), + fault2, # We send these valid combinations to ensure the stream will # continue to be processed. events[EventName.SERVER_INTENT], @@ -478,3 +493,173 @@ def test_continues_on_recoverable_status_code( assert updates[2].change_set.selector.version == 300 assert updates[2].change_set.selector.state == "p:SOMETHING:300" assert updates[2].change_set.intent_code == IntentCode.TRANSFER_FULL + + +def test_envid_from_start_action(events): # pylint: disable=redefined-outer-name + """Test that environment ID is captured from Start action headers""" + start_action = Start() + start_action.headers = {_LD_ENVID_HEADER: 'test-env-123'} + + builder = list_sse_client( + [ + start_action, + events[EventName.SERVER_INTENT], + events[EventName.PAYLOAD_TRANSFERRED], + ] + ) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 1 + assert updates[0].state == DataSourceState.VALID + assert updates[0].environment_id == 'test-env-123' + + +def test_envid_preserved_across_events(events): # pylint: disable=redefined-outer-name + """Test that environment ID is preserved across multiple events after being set on Start""" + start_action = Start() + start_action.headers = {_LD_ENVID_HEADER: 'test-env-456'} + + builder = list_sse_client( + [ + start_action, + events[EventName.SERVER_INTENT], + events[EventName.PUT_OBJECT], + events[EventName.PAYLOAD_TRANSFERRED], + ] + ) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 1 + assert updates[0].state == DataSourceState.VALID + assert updates[0].environment_id == 'test-env-456' + assert updates[0].change_set is not None + assert len(updates[0].change_set.changes) == 1 + + +def test_envid_from_fallback_header(): + """Test that environment ID is captured when fallback header is present""" + start_action = Start() + start_action.headers = { + _LD_ENVID_HEADER: 'test-env-fallback', + _LD_FD_FALLBACK_HEADER: 'true' + } + + builder = list_sse_client([start_action]) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 1 + assert updates[0].state == DataSourceState.OFF + assert updates[0].revert_to_fdv1 is True + assert updates[0].environment_id == 'test-env-fallback' + + +def test_envid_from_fault_action(): + """Test that environment ID is captured from Fault action headers""" + error = HTTPStatusError(401) + error.headers = {_LD_ENVID_HEADER: 'test-env-fault'} + fault_action = Fault(error=error) + fault_action.headers = {_LD_ENVID_HEADER: 'test-env-fault'} + + builder = list_sse_client([fault_action]) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 1 + assert updates[0].state == DataSourceState.OFF + assert updates[0].environment_id == 'test-env-fault' + assert updates[0].error is not None + assert updates[0].error.status_code == 401 + + +def test_envid_from_fault_with_fallback(): + """Test that environment ID and fallback are captured from Fault action""" + error = HTTPStatusError(503) + error.headers = { + _LD_ENVID_HEADER: 'test-env-503', + _LD_FD_FALLBACK_HEADER: 'true' + } + fault_action = Fault(error=error) + fault_action.headers = { + _LD_ENVID_HEADER: 'test-env-503', + _LD_FD_FALLBACK_HEADER: 'true' + } + + builder = list_sse_client([fault_action]) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 1 + assert updates[0].state == DataSourceState.OFF + assert updates[0].revert_to_fdv1 is True + assert updates[0].environment_id == 'test-env-503' + + +def test_envid_from_recoverable_fault(events): # pylint: disable=redefined-outer-name + """Test that environment ID is captured from recoverable Fault and preserved in subsequent events""" + error = HTTPStatusError(400) + error.headers = {_LD_ENVID_HEADER: 'test-env-400'} + fault_action = Fault(error=error) + fault_action.headers = {_LD_ENVID_HEADER: 'test-env-400'} + + builder = list_sse_client( + [ + fault_action, + events[EventName.SERVER_INTENT], + events[EventName.PAYLOAD_TRANSFERRED], + ] + ) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 2 + # First update from the fault + assert updates[0].state == DataSourceState.INTERRUPTED + assert updates[0].environment_id == 'test-env-400' + + # Second update should preserve the envid + assert updates[1].state == DataSourceState.VALID + assert updates[1].environment_id == 'test-env-400' + + +def test_envid_missing_when_no_headers(): + """Test that environment ID is None when no headers are present""" + start_action = Start() + start_action.headers = None + + server_intent = ServerIntent( + payload=Payload( + id="id", + target=300, + code=IntentCode.TRANSFER_NONE, + reason="up-to-date", + ) + ) + intent_event = Event( + event=EventName.SERVER_INTENT, + data=json.dumps(server_intent.to_dict()), + ) + + builder = list_sse_client([start_action, intent_event]) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 1 + assert updates[0].state == DataSourceState.VALID + assert updates[0].environment_id is None