Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ def response_hook(span, instance, response):
_format_command_args,
)
from opentelemetry.instrumentation.redis.version import __version__
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.instrumentation.utils import (
is_instrumentation_enabled,
unwrap,
)
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span

Expand Down Expand Up @@ -179,9 +182,12 @@ def _instrument(
response_hook: _ResponseHookT = None,
):
def _traced_execute_command(func, instance, args, kwargs):

if not is_instrumentation_enabled():
return func(*args, **kwargs)

query = _format_command_args(args)
name = _build_span_name(instance, args)

with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT
) as span:
Expand All @@ -197,6 +203,10 @@ def _traced_execute_command(func, instance, args, kwargs):
return response

def _traced_execute_pipeline(func, instance, args, kwargs):

if not is_instrumentation_enabled():
return func(*args, **kwargs)

(
command_stack,
resource,
Expand Down Expand Up @@ -248,6 +258,10 @@ def _traced_execute_pipeline(func, instance, args, kwargs):
)

async def _async_traced_execute_command(func, instance, args, kwargs):

if not is_instrumentation_enabled():
return await func(*args, **kwargs)

query = _format_command_args(args)
name = _build_span_name(instance, args)

Expand All @@ -266,6 +280,10 @@ async def _async_traced_execute_command(func, instance, args, kwargs):
return response

async def _async_traced_execute_pipeline(func, instance, args, kwargs):

if not is_instrumentation_enabled():
return await func(*args, **kwargs)

(
command_stack,
resource,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from opentelemetry import trace
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.instrumentation.utils import suppress_instrumentation
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import SpanKind

Expand Down Expand Up @@ -61,6 +62,40 @@ def test_not_recording(self):
self.assertFalse(mock_span.set_attribute.called)
self.assertFalse(mock_span.set_status.called)

def test_suppress_instrumentation_no_span(self):
redis_client = redis.Redis()

with mock.patch.object(redis_client, "connection"):
redis_client.get("key")
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 1)
self.memory_exporter.clear()

with suppress_instrumentation():
with mock.patch.object(redis_client, "connection"):
redis_client.get("key")
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 0)

def test_suppress_async_instrumentation_no_span(self):
redis_client = redis.asyncio.Redis()

with mock.patch.object(redis_client, "connection", AsyncMock()):
asyncio.run(redis_client.get("key"))
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 1)
self.memory_exporter.clear()

with suppress_instrumentation():
with mock.patch.object(redis_client, "connection", AsyncMock()):
asyncio.run(redis_client.get("key"))
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 0)

def test_instrument_uninstrument(self):
redis_client = redis.Redis()

Expand Down