From 2ab2bcce9467e15e0d67fb19d613b378ec2faa08 Mon Sep 17 00:00:00 2001 From: Nathanael Mortensen Date: Tue, 18 Nov 2025 10:01:04 -0800 Subject: [PATCH] Support gRPC Headers from ClientOptions and correct Scope Handling Allow additional headers to be added via ClientOptions, similar to TChannel. Additionally support removing headers by setting built-in headers values to null or an empty string. Refactor OpenTracingInterceptor to a separate file and correct Scope handling to resolve Context mismatches. Leaving the Scope open during initialization and then closing it during onClose is incorrect and can cause issues with other interceptors. It's difficult to test this change in isolation. --- .../proto/serviceclient/GrpcServiceStubs.java | 138 ++--------- .../serviceclient/OpenTracingInterceptor.java | 231 ++++++++++++++++++ .../Thrift2ProtoAdapterTest.java | 52 +++- 3 files changed, 301 insertions(+), 120 deletions(-) create mode 100644 src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/OpenTracingInterceptor.java diff --git a/src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/GrpcServiceStubs.java b/src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/GrpcServiceStubs.java index 451d8868a..fdf5a2bab 100644 --- a/src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/GrpcServiceStubs.java +++ b/src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/GrpcServiceStubs.java @@ -16,8 +16,6 @@ package com.uber.cadence.internal.compatibility.proto.serviceclient; import com.google.common.base.Strings; -import com.google.protobuf.ByteString; -import com.uber.cadence.api.v1.*; import com.uber.cadence.api.v1.DomainAPIGrpc; import com.uber.cadence.api.v1.MetaAPIGrpc; import com.uber.cadence.api.v1.MetaAPIGrpc.MetaAPIBlockingStub; @@ -32,7 +30,6 @@ import com.uber.cadence.api.v1.WorkflowAPIGrpc.WorkflowAPIBlockingStub; import com.uber.cadence.api.v1.WorkflowAPIGrpc.WorkflowAPIFutureStub; import com.uber.cadence.internal.Version; -import com.uber.cadence.internal.tracing.TracingPropagator; import com.uber.cadence.serviceclient.ClientOptions; import com.uber.cadence.serviceclient.auth.IAuthorizationProvider; import io.grpc.*; @@ -41,13 +38,9 @@ import io.opentelemetry.context.Context; import io.opentelemetry.context.propagation.TextMapPropagator; import io.opentelemetry.context.propagation.TextMapSetter; -import io.opentracing.Scope; -import io.opentracing.Span; import io.opentracing.Tracer; import java.nio.charset.StandardCharsets; -import java.util.HashMap; import java.util.Map; -import java.util.Objects; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import org.slf4j.Logger; @@ -116,6 +109,7 @@ final class GrpcServiceStubs implements IGrpcServiceStubs { if (!Strings.isNullOrEmpty(options.getIsolationGroup())) { headers.put(ISOLATION_GROUP_HEADER_KEY, options.getIsolationGroup()); } + mergeHeaders(headers, options.getHeaders()); Channel interceptedChannel = ClientInterceptors.intercept( @@ -205,117 +199,7 @@ public void start(Listener responseListener, Metadata headers) { } private ClientInterceptor newOpenTracingInterceptor(Tracer tracer) { - return new ClientInterceptor() { - private final TracingPropagator tracingPropagator = new TracingPropagator(tracer); - private final String OPERATIONFORMAT = "cadence-%s"; - - @Override - public ClientCall interceptCall( - MethodDescriptor method, CallOptions callOptions, Channel next) { - return new ForwardingClientCall.SimpleForwardingClientCall( - next.newCall(method, callOptions)) { - - @Override - public void start(Listener responseListener, Metadata headers) { - Span span = - tracingPropagator.spanByServiceMethod( - String.format(OPERATIONFORMAT, method.getBareMethodName())); - Scope scope = tracer.activateSpan(span); - super.start( - new ForwardingClientCallListener.SimpleForwardingClientCallListener( - responseListener) { - @Override - public void onClose(Status status, Metadata trailers) { - try { - super.onClose(status, trailers); - } finally { - span.finish(); - scope.close(); - } - } - }, - headers); - } - - @SuppressWarnings("unchecked") - @Override - public void sendMessage(ReqT message) { - if (Objects.equals(method.getBareMethodName(), "StartWorkflowExecution") - && message instanceof StartWorkflowExecutionRequest) { - StartWorkflowExecutionRequest request = (StartWorkflowExecutionRequest) message; - Header newHeader = addTracingHeaders(request.getHeader()); - - // cast should not throw error as we are using the builder - message = (ReqT) request.toBuilder().setHeader(newHeader).build(); - } else if (Objects.equals(method.getBareMethodName(), "StartWorkflowExecutionAsync") - && message instanceof StartWorkflowExecutionAsyncRequest) { - StartWorkflowExecutionAsyncRequest request = - (StartWorkflowExecutionAsyncRequest) message; - Header newHeader = addTracingHeaders(request.getRequest().getHeader()); - - // cast should not throw error as we are using the builder - message = - (ReqT) - request - .toBuilder() - .setRequest(request.getRequest().toBuilder().setHeader(newHeader)) - .build(); - } else if (Objects.equals( - method.getBareMethodName(), "SignalWithStartWorkflowExecution") - && message instanceof SignalWithStartWorkflowExecutionRequest) { - SignalWithStartWorkflowExecutionRequest request = - (SignalWithStartWorkflowExecutionRequest) message; - Header newHeader = addTracingHeaders(request.getStartRequest().getHeader()); - - // cast should not throw error as we are using the builder - message = - (ReqT) - request - .toBuilder() - .setStartRequest( - request.getStartRequest().toBuilder().setHeader(newHeader)) - .build(); - } else if (Objects.equals( - method.getBareMethodName(), "SignalWithStartWorkflowExecutionAsync") - && message instanceof SignalWithStartWorkflowExecutionAsyncRequest) { - SignalWithStartWorkflowExecutionAsyncRequest request = - (SignalWithStartWorkflowExecutionAsyncRequest) message; - Header newHeader = - addTracingHeaders(request.getRequest().getStartRequest().getHeader()); - - // cast should not throw error as we are using the builder - message = - (ReqT) - request - .toBuilder() - .setRequest( - request - .getRequest() - .toBuilder() - .setStartRequest( - request - .getRequest() - .getStartRequest() - .toBuilder() - .setHeader(newHeader))) - .build(); - } - super.sendMessage(message); - } - - private Header addTracingHeaders(Header header) { - Map headers = new HashMap<>(); - tracingPropagator.inject(headers); - Header.Builder headerBuilder = header.toBuilder(); - headers.forEach( - (k, v) -> - headerBuilder.putFields( - k, Payload.newBuilder().setData(ByteString.copyFrom(v)).build())); - return headerBuilder.build(); - } - }; - } - }; + return new OpenTracingInterceptor(tracer); } private ClientInterceptor newTracingInterceptor() { @@ -488,4 +372,22 @@ public ClientCall interceptCall( return next.newCall(method, callOptions.withDeadlineAfter(duration, TimeUnit.MILLISECONDS)); } } + + private static void mergeHeaders(Metadata metadata, Map headers) { + if (headers == null) { + return; + } + for (Map.Entry entry : headers.entrySet()) { + Metadata.Key key = Metadata.Key.of(entry.getKey(), Metadata.ASCII_STRING_MARSHALLER); + // Allow headers to overwrite any defaults + if (metadata.containsKey(key)) { + metadata.removeAll(key); + } + // Only replace it if they specify a value. + // This allows for removing headers + if (!Strings.isNullOrEmpty(entry.getValue())) { + metadata.put(key, entry.getValue()); + } + } + } } diff --git a/src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/OpenTracingInterceptor.java b/src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/OpenTracingInterceptor.java new file mode 100644 index 000000000..553e49b03 --- /dev/null +++ b/src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/OpenTracingInterceptor.java @@ -0,0 +1,231 @@ +/* + * Modifications Copyright (c) 2017-2021 Uber Technologies Inc. + * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License is + * located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package com.uber.cadence.internal.compatibility.proto.serviceclient; + +import com.google.protobuf.ByteString; +import com.uber.cadence.api.v1.Header; +import com.uber.cadence.api.v1.Payload; +import com.uber.cadence.api.v1.SignalWithStartWorkflowExecutionAsyncRequest; +import com.uber.cadence.api.v1.SignalWithStartWorkflowExecutionRequest; +import com.uber.cadence.api.v1.StartWorkflowExecutionAsyncRequest; +import com.uber.cadence.api.v1.StartWorkflowExecutionRequest; +import com.uber.cadence.internal.tracing.TracingPropagator; +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall; +import io.grpc.ForwardingClientCallListener; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.opentracing.Scope; +import io.opentracing.Span; +import io.opentracing.Tracer; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; +import javax.annotation.Nullable; + +final class OpenTracingInterceptor implements ClientInterceptor { + private static final String OPERATION_FORMAT = "cadence-%s"; + private final Tracer tracer; + private final TracingPropagator tracingPropagator; + + OpenTracingInterceptor(Tracer tracer) { + this.tracer = tracer; + this.tracingPropagator = new TracingPropagator(tracer); + } + + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + + Span span = + tracingPropagator.spanByServiceMethod( + String.format(OPERATION_FORMAT, method.getBareMethodName())); + try (Scope ignored = tracer.activateSpan(span)) { + return new OpenTracingClientCall<>(next, method, callOptions, span); + } + } + + private class OpenTracingClientCall + extends ForwardingClientCall.SimpleForwardingClientCall { + + private final AtomicBoolean finished = new AtomicBoolean(); + private final MethodDescriptor method; + private final Span span; + + public OpenTracingClientCall( + Channel next, MethodDescriptor method, CallOptions callOptions, Span span) { + super(next.newCall(method, callOptions)); + this.method = method; + this.span = span; + } + + @Override + public void start(Listener responseListener, Metadata headers) { + try (Scope ignored = tracer.activateSpan(span)) { + super.start( + new ForwardingClientCallListener.SimpleForwardingClientCallListener( + responseListener) { + @Override + public void onClose(Status status, Metadata trailers) { + try { + super.onClose(status, trailers); + } finally { + finishSpan(); + } + } + }, + headers); + } + } + + @Override + public void request(int numMessages) { + try (Scope ignored = tracer.activateSpan(span)) { + super.request(numMessages); + } + } + + @Override + public void setMessageCompression(boolean enabled) { + try (Scope ignored = tracer.activateSpan(span)) { + super.setMessageCompression(enabled); + } + } + + @Override + public boolean isReady() { + try (Scope ignored = tracer.activateSpan(span)) { + return super.isReady(); + } + } + + @Override + public Attributes getAttributes() { + try (Scope ignored = tracer.activateSpan(span)) { + return super.getAttributes(); + } + } + + @Override + public void cancel(@Nullable String message, @Nullable Throwable cause) { + try (Scope ignored = tracer.activateSpan(span)) { + super.cancel(message, cause); + } finally { + finishSpan(); + } + } + + @Override + public void halfClose() { + try (Scope ignored = tracer.activateSpan(span)) { + super.halfClose(); + } + } + + @Override + public void sendMessage(ReqT message) { + try (Scope ignored = tracer.activateSpan(span)) { + message = replaceMessage(message); + super.sendMessage(message); + } + } + + private void finishSpan() { + // Some combination of cancel and onClose can be called so ensure we only finish once + if (finished.compareAndSet(false, true)) { + span.finish(); + } + } + + @SuppressWarnings("unchecked") + private ReqT replaceMessage(ReqT message) { + if (Objects.equals(method.getBareMethodName(), "StartWorkflowExecution") + && message instanceof StartWorkflowExecutionRequest) { + StartWorkflowExecutionRequest request = (StartWorkflowExecutionRequest) message; + Header newHeader = addTracingHeaders(request.getHeader()); + + // cast should not throw error as we are using the builder + message = (ReqT) request.toBuilder().setHeader(newHeader).build(); + } else if (Objects.equals(method.getBareMethodName(), "StartWorkflowExecutionAsync") + && message instanceof StartWorkflowExecutionAsyncRequest) { + StartWorkflowExecutionAsyncRequest request = (StartWorkflowExecutionAsyncRequest) message; + Header newHeader = addTracingHeaders(request.getRequest().getHeader()); + + // cast should not throw error as we are using the builder + message = + (ReqT) + request + .toBuilder() + .setRequest(request.getRequest().toBuilder().setHeader(newHeader)) + .build(); + } else if (Objects.equals(method.getBareMethodName(), "SignalWithStartWorkflowExecution") + && message instanceof SignalWithStartWorkflowExecutionRequest) { + SignalWithStartWorkflowExecutionRequest request = + (SignalWithStartWorkflowExecutionRequest) message; + Header newHeader = addTracingHeaders(request.getStartRequest().getHeader()); + + // cast should not throw error as we are using the builder + message = + (ReqT) + request + .toBuilder() + .setStartRequest(request.getStartRequest().toBuilder().setHeader(newHeader)) + .build(); + } else if (Objects.equals(method.getBareMethodName(), "SignalWithStartWorkflowExecutionAsync") + && message instanceof SignalWithStartWorkflowExecutionAsyncRequest) { + SignalWithStartWorkflowExecutionAsyncRequest request = + (SignalWithStartWorkflowExecutionAsyncRequest) message; + Header newHeader = addTracingHeaders(request.getRequest().getStartRequest().getHeader()); + + // cast should not throw error as we are using the builder + message = + (ReqT) + request + .toBuilder() + .setRequest( + request + .getRequest() + .toBuilder() + .setStartRequest( + request + .getRequest() + .getStartRequest() + .toBuilder() + .setHeader(newHeader))) + .build(); + } + + return message; + } + + private Header addTracingHeaders(Header header) { + Map headers = new HashMap<>(); + tracingPropagator.inject(headers); + Header.Builder headerBuilder = header.toBuilder(); + headers.forEach( + (k, v) -> + headerBuilder.putFields( + k, Payload.newBuilder().setData(ByteString.copyFrom(v)).build())); + return headerBuilder.build(); + } + }; +} diff --git a/src/test/java/com/uber/cadence/internal/compatibility/Thrift2ProtoAdapterTest.java b/src/test/java/com/uber/cadence/internal/compatibility/Thrift2ProtoAdapterTest.java index d28b69804..5098145f8 100644 --- a/src/test/java/com/uber/cadence/internal/compatibility/Thrift2ProtoAdapterTest.java +++ b/src/test/java/com/uber/cadence/internal/compatibility/Thrift2ProtoAdapterTest.java @@ -25,6 +25,7 @@ import ch.qos.logback.classic.Level; import ch.qos.logback.classic.Logger; +import com.google.common.collect.ImmutableMap; import com.uber.cadence.AccessDeniedError; import com.uber.cadence.RefreshWorkflowTasksRequest; import com.uber.cadence.SignalWithStartWorkflowExecutionAsyncRequest; @@ -84,6 +85,13 @@ public class Thrift2ProtoAdapterTest { private static final Metadata.Key AUTHORIZATION_HEADER_KEY = Metadata.Key.of("cadence-authorization", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key EXPECTED_HEADER_KEY = + Metadata.Key.of("rpc-service", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key REMOVED_HEADER_KEY = + Metadata.Key.of("rpc-caller", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key ADDED_HEADER_KEY = + Metadata.Key.of("from-options", Metadata.ASCII_STRING_MARSHALLER); + private static final String ADDED_HEADER_VALUE = "added-value"; private static final StatusRuntimeException GRPC_ACCESS_DENIED = new StatusRuntimeException(Status.PERMISSION_DENIED); @@ -107,11 +115,14 @@ public void setup() { (Logger) LoggerFactory.getLogger( "com.uber.cadence.internal.compatibility.proto.serviceclient.GrpcServiceStubs"); + Map headers = + ImmutableMap.of(REMOVED_HEADER_KEY.name(), "", ADDED_HEADER_KEY.name(), ADDED_HEADER_VALUE); logger.setLevel(Level.TRACE); client = new Thrift2ProtoAdapter( IGrpcServiceStubs.newInstance( ClientOptions.newBuilder() + .setHeaders(headers) .setAuthorizationProvider("foo"::getBytes) .setGRPCChannel(clientChannel) .build())); @@ -119,6 +130,7 @@ public void setup() { new Thrift2ProtoAdapter( IGrpcServiceStubs.newInstance( ClientOptions.newBuilder() + .setHeaders(headers) .setAuthorizationProvider("foo"::getBytes) .setTracer(tracer) .setGRPCChannel(clientChannel) @@ -1020,7 +1032,9 @@ public Server createServer(ServiceDescriptor... descriptors) { } serverBuilder.addService( ServerInterceptors.intercept( - serviceDefinition.build(), new AuthHeaderValidatingInterceptor())); + serviceDefinition.build(), + new AuthHeaderValidatingInterceptor(), + new HeaderValidatingInterceptor())); } return serverBuilder.build().start(); } catch (IOException e) { @@ -1058,7 +1072,41 @@ private static class AuthHeaderValidatingInterceptor implements ServerIntercepto public ServerCall.Listener interceptCall( ServerCall call, Metadata headers, ServerCallHandler next) { if (!headers.containsKey(AUTHORIZATION_HEADER_KEY)) { - call.close(Status.INVALID_ARGUMENT, new Metadata()); + call.close( + Status.INVALID_ARGUMENT.withDescription( + "Missing auth header: " + AUTHORIZATION_HEADER_KEY.name()), + new Metadata()); + } + return next.startCall(call, headers); + } + } + + private static class HeaderValidatingInterceptor implements ServerInterceptor { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + if (!headers.containsKey(EXPECTED_HEADER_KEY)) { + call.close( + Status.INVALID_ARGUMENT.withDescription("Missing " + EXPECTED_HEADER_KEY.name()), + new Metadata()); + } + String addedHeaderValue = headers.get(ADDED_HEADER_KEY); + if (!ADDED_HEADER_VALUE.equals(addedHeaderValue)) { + call.close( + Status.INVALID_ARGUMENT.withDescription( + "Incorrect value for " + + ADDED_HEADER_KEY.name() + + "; got " + + addedHeaderValue + + " instead of " + + ADDED_HEADER_VALUE), + new Metadata()); + } + if (headers.containsKey(REMOVED_HEADER_KEY)) { + call.close( + Status.INVALID_ARGUMENT.withDescription( + "Unexpected header " + REMOVED_HEADER_KEY.name()), + new Metadata()); } return next.startCall(call, headers); }