diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ActiveClusterMonitor.java b/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ActiveClusterMonitor.java index b1f8884f9..01dd4168d 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ActiveClusterMonitor.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ActiveClusterMonitor.java @@ -62,7 +62,8 @@ public ActiveClusterMonitor( public void start() { log.info("Running cluster monitor with connection task delay of %s", taskDelay); - scheduledExecutor.scheduleAtFixedRate(() -> { + @SuppressWarnings("unused") + var unused = scheduledExecutor.scheduleAtFixedRate(() -> { try { log.info("Getting stats for all active clusters"); List activeClusters = diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ClusterMetricsStatsExporter.java b/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ClusterMetricsStatsExporter.java index 23d81e0fc..1a68ca934 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ClusterMetricsStatsExporter.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ClusterMetricsStatsExporter.java @@ -61,7 +61,8 @@ public ClusterMetricsStatsExporter(GatewayBackendManager gatewayBackendManager, public void start() { log.debug("Running periodic metric refresh with interval of %s", refreshInterval); - scheduledExecutor.scheduleAtFixedRate(() -> { + @SuppressWarnings("unused") + var unused = scheduledExecutor.scheduleAtFixedRate(() -> { try { updateClustersMetricRegistry(); } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ClusterStatsJdbcMonitor.java b/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ClusterStatsJdbcMonitor.java index c5b827318..3e81e7a11 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ClusterStatsJdbcMonitor.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ClusterStatsJdbcMonitor.java @@ -21,6 +21,7 @@ import io.trino.gateway.ha.config.ProxyBackendConfiguration; import java.net.MalformedURLException; +import java.net.URI; import java.net.URL; import java.sql.Connection; import java.sql.DriverManager; @@ -69,7 +70,7 @@ public ClusterStats monitor(ProxyBackendConfiguration backend) ClusterStats.Builder clusterStats = ClusterStatsMonitor.getClusterStatsBuilder(backend); String jdbcUrl; try { - URL parsedUrl = new URL(url); + URL parsedUrl = URI.create(url).toURL(); jdbcUrl = String .format("jdbc:trino://%s:%s/system", parsedUrl.getHost(), diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ClusterStatsMetricsMonitor.java b/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ClusterStatsMetricsMonitor.java index 3064b78b0..70c0c56d1 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ClusterStatsMetricsMonitor.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/clustermonitor/ClusterStatsMetricsMonitor.java @@ -13,6 +13,7 @@ */ package io.trino.gateway.ha.clustermonitor; +import com.google.common.base.Splitter; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.http.client.HttpClient; @@ -190,7 +191,7 @@ public Map handle(Request request, Response response) String responseBody = new String(response.getInputStream().readAllBytes(), UTF_8); Map metrics = Arrays.stream(responseBody.split("\n")) .filter(line -> !line.startsWith("#")) - .collect(toImmutableMap(s -> s.split(" ")[0], s -> s.split(" ")[1])); + .collect(toImmutableMap(s -> Splitter.on(' ').splitToList(s).get(0), s -> Splitter.on(' ').splitToList(s).get(1))); if (!metrics.keySet().containsAll(requiredKeys)) { throw new UnexpectedResponseException( format("Request is missing required keys: \n%s\nin response: '%s'", String.join("\n", requiredKeys), responseBody), diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java index afd314a48..600f1e50f 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java @@ -13,6 +13,7 @@ */ package io.trino.gateway.ha.handler; +import com.google.common.base.Splitter; import com.google.common.collect.ImmutableSet; import com.google.common.io.CharStreams; import io.airlift.log.Logger; @@ -103,15 +104,15 @@ public static Optional extractQueryIdIfPresent(String path, String query } if (matchingStatementPath.isPresent() || path.startsWith(V1_QUERY_PATH)) { path = path.replace(matchingStatementPath.orElse(V1_QUERY_PATH), ""); - String[] tokens = path.split("/"); - if (tokens.length >= 2) { - if (tokens.length >= 3 && QUERY_STATE_PATH.contains(tokens[1])) { - if (tokens.length >= 4 && tokens[2].equals(PARTIAL_CANCEL_PATH)) { - return Optional.of(tokens[3]); + List tokens = Splitter.on('/').splitToList(path); + if (tokens.size() >= 2) { + if (tokens.size() >= 3 && QUERY_STATE_PATH.contains(tokens.get(1))) { + if (tokens.size() >= 4 && tokens.get(2).equals(PARTIAL_CANCEL_PATH)) { + return Optional.of(tokens.get(3)); } - return Optional.of(tokens[2]); + return Optional.of(tokens.get(2)); } - return Optional.of(tokens[1]); + return Optional.of(tokens.get(1)); } } else if (path.startsWith(TRINO_UI_PATH)) { diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/persistence/JdbcConnectionManager.java b/gateway-ha/src/main/java/io/trino/gateway/ha/persistence/JdbcConnectionManager.java index 2506b3f2a..fda05a017 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/persistence/JdbcConnectionManager.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/persistence/JdbcConnectionManager.java @@ -104,7 +104,8 @@ private static URI getUriWithRoutingGroupDatabase(String routingGroupDatabase, i private void startCleanUps() { - executorService.scheduleWithFixedDelay( + @SuppressWarnings("unused") + var unused = executorService.scheduleWithFixedDelay( () -> { log.info("Performing query history cleanup task"); long created = System.currentTimeMillis() - TimeUnit.HOURS.toMillis(this.configuration.getQueryHistoryHoursRetention()); diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/resource/GatewayWebAppResource.java b/gateway-ha/src/main/java/io/trino/gateway/ha/resource/GatewayWebAppResource.java index a0d83ad03..eb2fb972b 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/resource/GatewayWebAppResource.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/resource/GatewayWebAppResource.java @@ -70,7 +70,7 @@ @Path("/webapp") public class GatewayWebAppResource { - private static final LocalDateTime START_TIME = LocalDateTime.now(ZoneId.systemDefault()); + private final LocalDateTime startTime = LocalDateTime.now(ZoneId.systemDefault()); private static final DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSSXXX"); private final GatewayBackendManager gatewayBackendManager; private final QueryHistoryManager queryHistoryManager; @@ -167,7 +167,7 @@ public Response getDistribution(QueryDistributionRequest query) state -> state.trinoStatus() == TrinoStatus.HEALTHY, Collectors.collectingAndThen(Collectors.counting(), Long::intValue))); Integer latestHour = query.latestHour(); - Long ts = System.currentTimeMillis() - (latestHour * 60 * 60 * 1000); + Long ts = System.currentTimeMillis() - (latestHour * 60 * 60 * 1000L); List lineChart = queryHistoryManager.findDistribution(ts); lineChart.forEach(qh -> qh.setName(urlToNameMap.get(qh.getBackendUrl()))); Map> lineChartMap = lineChart.stream().collect(Collectors.groupingBy(DistributionResponse.LineChart::getName)); @@ -192,7 +192,7 @@ public Response getDistribution(QueryDistributionRequest query) distributionResponse.setTotalQueryCount(totalQueryCount); distributionResponse.setAverageQueryCountSecond(totalQueryCount / (latestHour * 60d * 60d)); distributionResponse.setAverageQueryCountMinute(totalQueryCount / (latestHour * 60d)); - ZonedDateTime zonedLocalTime = START_TIME.atZone(ZoneId.systemDefault()); + ZonedDateTime zonedLocalTime = startTime.atZone(ZoneId.systemDefault()); ZonedDateTime utcTime = zonedLocalTime.withZoneSameInstant(ZoneOffset.UTC); distributionResponse.setStartTime(utcTime.format(formatter)); return Response.ok(Result.ok(distributionResponse)).build(); diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/GatewayCookie.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/GatewayCookie.java index 971d6fb32..d2483e166 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/GatewayCookie.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/GatewayCookie.java @@ -152,7 +152,7 @@ private String computeSignature() public int compareTo(GatewayCookie o) { int priorityDelta = unsignedGatewayCookie.getPriority() - o.getPriority(); - return priorityDelta != 0 ? priorityDelta : (int) (unsignedGatewayCookie.getTs() - o.getTs()); + return priorityDelta != 0 ? priorityDelta : unsignedGatewayCookie.getTs().compareTo(o.getTs()); } public Cookie toCookie() diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/StatementUtils.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/StatementUtils.java index 066fb6750..6bca7db7f 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/StatementUtils.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/StatementUtils.java @@ -195,8 +195,8 @@ private StatementUtils() {} public static String getResourceGroupQueryType(Statement statement) { - if (statement instanceof ExplainAnalyze) { - return getResourceGroupQueryType(((ExplainAnalyze) statement).getStatement()); + if (statement instanceof ExplainAnalyze explainAnalyze) { + return getResourceGroupQueryType(explainAnalyze.getStatement()); } StatementTypeInfo statementTypeInfo = STATEMENT_QUERY_TYPES.get(statement.getClass()); if (statementTypeInfo != null) { diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/StochasticRoutingManager.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/StochasticRoutingManager.java index e82401015..9f6d03523 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/StochasticRoutingManager.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/StochasticRoutingManager.java @@ -19,13 +19,11 @@ import java.util.List; import java.util.Optional; -import java.util.Random; +import java.util.concurrent.ThreadLocalRandom; public class StochasticRoutingManager extends BaseRoutingManager { - private static final Random RANDOM = new Random(); - @Inject public StochasticRoutingManager( GatewayBackendManager gatewayBackendManager, @@ -41,7 +39,7 @@ protected Optional selectBackend(List getPreparedStatements(Enumeration headers) return preparedStatementsMapBuilder.build(); } while (headers.hasMoreElements()) { - String[] preparedStatementsArray = headers.nextElement().split(","); + Iterable preparedStatementsArray = Splitter.on(',').split(headers.nextElement()); for (String preparedStatement : preparedStatementsArray) { - String[] nameValue = preparedStatement.split("="); - if (nameValue.length != 2) { + List nameValue = Splitter.on('=').splitToList(preparedStatement); + if (nameValue.size() != 2) { throw new RequestParsingException(format("preparedStatement must be formatted as name=value, but is %s", preparedStatement)); } - preparedStatementsMapBuilder.put(URLDecoder.decode(nameValue[0], UTF_8), URLDecoder.decode(decodePreparedStatementFromHeader(nameValue[1]), UTF_8)); + preparedStatementsMapBuilder.put(URLDecoder.decode(nameValue.get(0), UTF_8), URLDecoder.decode(decodePreparedStatementFromHeader(nameValue.get(1)), UTF_8)); } } return preparedStatementsMapBuilder.build(); diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoRequestUser.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoRequestUser.java index 449a198f1..1153a855a 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoRequestUser.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoRequestUser.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.databind.SerializerProvider; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.ser.std.StdSerializer; +import com.google.common.base.Splitter; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; @@ -54,13 +55,10 @@ public class TrinoRequestUser public static final String TRINO_USER_HEADER_NAME = "X-Trino-User"; public static final String TRINO_UI_TOKEN_NAME = "Trino-UI-Token"; public static final String TRINO_SECURE_UI_TOKEN_NAME = "__Secure-Trino-ID-Token"; - - private Optional user = Optional.empty(); - private Optional userInfo = Optional.empty(); - private static final Logger log = Logger.get(TrinoRequestUser.class); - private final Optional> userInfoCache; + private Optional user = Optional.empty(); + private Optional userInfo = Optional.empty(); private TrinoRequestUser(ContainerRequestContext request, String userField, Optional> userInfoCache) { @@ -155,11 +153,11 @@ private Optional extractUserFromAuthorizationHeader(String header, Strin if (header.contains("Basic")) { try { - return Optional.of(new String(Base64.getDecoder().decode(header.split(" ")[1]), StandardCharsets.UTF_8).split(":")[0]); + return Optional.of(Splitter.on(':').splitToStream(new String(Base64.getDecoder().decode(Splitter.on(' ').splitToStream(header).skip(1).findFirst().get()), StandardCharsets.UTF_8)).findFirst().get()); } catch (IllegalArgumentException e) { log.error(e, "Authorization: Basic header contains invalid base64"); - log.debug("Invalid header value: " + header.split(" ")[1]); + log.debug("Invalid header value: " + Splitter.on(' ').splitToStream(header).skip(1).findFirst().get()); return Optional.empty(); } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbKeyProvider.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbKeyProvider.java index b4df89031..bc636b6bb 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbKeyProvider.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbKeyProvider.java @@ -78,13 +78,13 @@ public LbKeyProvider(SelfSignKeyPairConfiguration keypairConfig) RSAPrivateKey getRsaPrivateKey() { - return (this.privateKey instanceof RSAPrivateKey) - ? (RSAPrivateKey) this.privateKey : null; + return (this.privateKey instanceof RSAPrivateKey rSAPrivateKey) + ? rSAPrivateKey : null; } RSAPublicKey getRsaPublicKey() { - return (this.publicKey instanceof RSAPublicKey) - ? (RSAPublicKey) this.publicKey : null; + return (this.publicKey instanceof RSAPublicKey rSAPublicKey) + ? rSAPublicKey : null; } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbTokenUtil.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbTokenUtil.java index b753fadb4..fe9a8d66d 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbTokenUtil.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/LbTokenUtil.java @@ -47,7 +47,8 @@ public static boolean validateToken(String idToken, RSAPublicKey publicKey, Stri audiences.ifPresent(auds -> verification.withAnyOfAudience(auds.toArray(new String[0]))); - verification.build().verify(idToken); + // Add clock skew tolerance for containerized environments + verification.acceptLeeway(10).build().verify(idToken); } catch (Exception exc) { log.error(exc, "Could not validate token."); diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/clustermonitor/TestClusterMetricsStatsExporter.java b/gateway-ha/src/test/java/io/trino/gateway/ha/clustermonitor/TestClusterMetricsStatsExporter.java index 50532e2a3..848bbb8be 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/clustermonitor/TestClusterMetricsStatsExporter.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/clustermonitor/TestClusterMetricsStatsExporter.java @@ -48,14 +48,14 @@ void testMetricsRegistrationForNewCluster() sleepUninterruptibly(2, SECONDS); verify(statsExporter.exporter()).exportWithGeneratedName( - argThat(stats -> stats instanceof ClusterMetricsStats && ((ClusterMetricsStats) stats).getClusterName().equals(clusterName1)), + argThat(stats -> stats instanceof ClusterMetricsStats clusterMetricsStats && clusterMetricsStats.getClusterName().equals(clusterName1)), eq(ClusterMetricsStats.class), eq(clusterName1)); // Wait for next update where cluster is added sleepUninterruptibly(2, SECONDS); verify(statsExporter.exporter()).exportWithGeneratedName( - argThat(stats -> stats instanceof ClusterMetricsStats && ((ClusterMetricsStats) stats).getClusterName().equals(clusterName2)), + argThat(stats -> stats instanceof ClusterMetricsStats clusterMetricsStats && clusterMetricsStats.getClusterName().equals(clusterName2)), eq(ClusterMetricsStats.class), eq(clusterName2)); } } @@ -74,7 +74,7 @@ public void testMetricsUnregistrationForRemovedCluster() sleepUninterruptibly(2, SECONDS); verify(statsExporter.exporter()).exportWithGeneratedName( - argThat(stats -> stats instanceof ClusterMetricsStats && ((ClusterMetricsStats) stats).getClusterName().equals(clusterName)), + argThat(stats -> stats instanceof ClusterMetricsStats clusterMetricsStats && clusterMetricsStats.getClusterName().equals(clusterName)), eq(ClusterMetricsStats.class), eq(clusterName)); // Wait for next update where cluster is removed diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestResourceGroupsManager.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestResourceGroupsManager.java index 1e1716358..4c27d3982 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestResourceGroupsManager.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestResourceGroupsManager.java @@ -80,7 +80,7 @@ void testReadResourceGroup() assertThat(resourceGroups.get(0).getName()).isEqualTo("admin"); assertThat(resourceGroups.get(0).getHardConcurrencyLimit()).isEqualTo(20); assertThat(resourceGroups.get(0).getMaxQueued()).isEqualTo(200); - assertThat(resourceGroups.get(0).getJmxExport()).isEqualTo(Boolean.TRUE); + assertThat(resourceGroups.get(0).getJmxExport()).isEqualTo(true); assertThat(resourceGroups.get(0).getSoftMemoryLimit()).isEqualTo("80%"); } @@ -128,21 +128,21 @@ void testUpdateResourceGroup() assertThat(resourceGroups.get(0).getName()).isEqualTo("admin"); assertThat(resourceGroups.get(0).getHardConcurrencyLimit()).isEqualTo(50); assertThat(resourceGroups.get(0).getMaxQueued()).isEqualTo(50); - assertThat(resourceGroups.get(0).getJmxExport()).isEqualTo(Boolean.FALSE); + assertThat(resourceGroups.get(0).getJmxExport()).isEqualTo(false); assertThat(resourceGroups.get(0).getSoftMemoryLimit()).isEqualTo("20%"); assertThat(resourceGroups.get(1).getResourceGroupId()).isEqualTo(2L); assertThat(resourceGroups.get(1).getName()).isEqualTo("user"); assertThat(resourceGroups.get(1).getHardConcurrencyLimit()).isEqualTo(10); assertThat(resourceGroups.get(1).getMaxQueued()).isEqualTo(100); - assertThat(resourceGroups.get(1).getJmxExport()).isEqualTo(Boolean.TRUE); + assertThat(resourceGroups.get(1).getJmxExport()).isEqualTo(true); assertThat(resourceGroups.get(1).getSoftMemoryLimit()).isEqualTo("50%"); assertThat(resourceGroups.get(2).getResourceGroupId()).isEqualTo(3L); assertThat(resourceGroups.get(2).getName()).isEqualTo("localization-eng"); assertThat(resourceGroups.get(2).getHardConcurrencyLimit()).isEqualTo(50); assertThat(resourceGroups.get(2).getMaxQueued()).isEqualTo(70); - assertThat(resourceGroups.get(2).getJmxExport()).isEqualTo(Boolean.TRUE); + assertThat(resourceGroups.get(2).getJmxExport()).isEqualTo(true); assertThat(resourceGroups.get(2).getSoftMemoryLimit()).isEqualTo("20%"); assertThat(resourceGroups.get(2).getSoftConcurrencyLimit()).isEqualTo(Integer.valueOf(20)); } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingRulesManager.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingRulesManager.java index 668374c17..3419a3a65 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingRulesManager.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingRulesManager.java @@ -128,7 +128,8 @@ void testConcurrentUpdateRoutingRule() ExecutorService executorService = Executors.newFixedThreadPool(2); - executorService.submit(() -> + @SuppressWarnings("unused") + var unused1 = executorService.submit(() -> { try { routingRulesManager.updateRoutingRule(routingRule1); @@ -138,7 +139,8 @@ void testConcurrentUpdateRoutingRule() } }); - executorService.submit(() -> + @SuppressWarnings("unused") + var unused2 = executorService.submit(() -> { try { routingRulesManager.updateRoutingRule(routingRule2); diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestOIDC.java b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestOIDC.java index 2ba71405e..436b80d48 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestOIDC.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestOIDC.java @@ -70,6 +70,59 @@ final class TestOIDC private static final String DSN = "postgres://hydra:mysecretpassword@hydra-db:5432/hydra?sslmode=disable"; private static final int ROUTER_PORT = 21001 + (int) (Math.random() * 1000); + public static void setupInsecureSsl(OkHttpClient.Builder clientBuilder) + throws Exception + { + X509TrustManager trustAllCerts = new X509TrustManager() + { + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + { + throw new UnsupportedOperationException("checkClientTrusted should not be called"); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + { + // skip validation of server certificate + } + + @Override + public X509Certificate[] getAcceptedIssuers() + { + return new X509Certificate[0]; + } + }; + + SSLContext sslContext = SSLContext.getInstance("SSL"); + sslContext.init(null, new TrustManager[] {trustAllCerts}, new SecureRandom()); + + clientBuilder.sslSocketFactory(sslContext.getSocketFactory(), trustAllCerts); + clientBuilder.hostnameVerifier((hostname, session) -> true); + } + + private static String extractRedirectURL(String body) + throws JsonProcessingException + { + ObjectMapper objectMapper = new ObjectMapper(); + JsonNode jsonNode = objectMapper.readTree(body); + return jsonNode.get("data").asText(); + } + + private static OkHttpClient createOkHttpClient(Optional cookieJar) + throws Exception + { + OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder() + .followRedirects(true) + .cookieJar(cookieJar.orElseGet(() -> { + CookieManager cookieManager = new CookieManager(); + cookieManager.setCookiePolicy(CookiePolicy.ACCEPT_ALL); + return new JavaNetCookieJar(cookieManager); + })); + setupInsecureSsl(httpClientBuilder); + return httpClientBuilder.build(); + } + @BeforeAll void setup() throws Exception @@ -91,6 +144,7 @@ void setup() .withStartupCheckStrategy(new OneShotStartupCheckStrategy()); migrationContainer.start(); + @SuppressWarnings("deprecation") FixedHostPortGenericContainer hydraConsent = new FixedHostPortGenericContainer<>("python:3.10.1-alpine") .withFixedExposedPort(3000, 3000) .withNetwork(network) @@ -101,6 +155,7 @@ void setup() .waitingFor(Wait.forHttp("/healthz").forPort(3000).forStatusCode(200)); hydraConsent.start(); + @SuppressWarnings("deprecation") FixedHostPortGenericContainer hydra = new FixedHostPortGenericContainer<>(HYDRA_IMAGE) .withFixedExposedPort(4444, 4444) .withFixedExposedPort(4445, 4445) @@ -124,13 +179,14 @@ void setup() .withStrategy(Wait.forLogMessage(".*Setting up http server on :4444.*", 1)) .withStrategy(Wait.forLogMessage(".*Setting up http server on :4445.*", 1))) .withStartupTimeout(java.time.Duration.ofMinutes(3)); + hydra.start(); String clientId = "trino_client_id"; String clientSecret = "trino_client_secret"; String tokenEndpointAuthMethod = "client_secret_basic"; String audience = "trino_client_id"; String callbackUrl = format("https://localhost:%s/oidc/callback", ROUTER_PORT); - GenericContainer clientCreatingContainer = new GenericContainer(HYDRA_IMAGE) + GenericContainer clientCreatingContainer = new GenericContainer<>(HYDRA_IMAGE) .withNetwork(network) .dependsOn(hydra) .withCommand("clients", "create", @@ -146,7 +202,7 @@ void setup() "--callbacks", callbackUrl); clientCreatingContainer.start(); - PostgreSQLContainer gatewayBackendDatabase = createPostgreSqlContainer(); + PostgreSQLContainer gatewayBackendDatabase = createPostgreSqlContainer(); gatewayBackendDatabase.start(); URL resource = HaGatewayTestUtils.class.getClassLoader().getResource("auth/localhost.jks"); @@ -216,41 +272,10 @@ private Request.Builder uiCall() .post(RequestBody.create("", null)); } - public static void setupInsecureSsl(OkHttpClient.Builder clientBuilder) - throws Exception - { - X509TrustManager trustAllCerts = new X509TrustManager() - { - @Override - public void checkClientTrusted(X509Certificate[] chain, String authType) - { - throw new UnsupportedOperationException("checkClientTrusted should not be called"); - } - - @Override - public void checkServerTrusted(X509Certificate[] chain, String authType) - { - // skip validation of server certificate - } - - @Override - public X509Certificate[] getAcceptedIssuers() - { - return new X509Certificate[0]; - } - }; - - SSLContext sslContext = SSLContext.getInstance("SSL"); - sslContext.init(null, new TrustManager[] {trustAllCerts}, new SecureRandom()); - - clientBuilder.sslSocketFactory(sslContext.getSocketFactory(), trustAllCerts); - clientBuilder.hostnameVerifier((hostname, session) -> true); - } - public static class BadCookieJar implements CookieJar { - private JavaNetCookieJar cookieJar; + private final JavaNetCookieJar cookieJar; public BadCookieJar() { @@ -281,26 +306,4 @@ public List loadForRequest(HttpUrl url) } } } - - private static String extractRedirectURL(String body) - throws JsonProcessingException - { - ObjectMapper objectMapper = new ObjectMapper(); - JsonNode jsonNode = objectMapper.readTree(body); - return jsonNode.get("data").asText(); - } - - private static OkHttpClient createOkHttpClient(Optional cookieJar) - throws Exception - { - OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder() - .followRedirects(true) - .cookieJar(cookieJar.orElseGet(() -> { - CookieManager cookieManager = new CookieManager(); - cookieManager.setCookiePolicy(CookiePolicy.ACCEPT_ALL); - return new JavaNetCookieJar(cookieManager); - })); - setupInsecureSsl(httpClientBuilder); - return httpClientBuilder.build(); - } } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/util/QueryRequestMock.java b/gateway-ha/src/test/java/io/trino/gateway/ha/util/QueryRequestMock.java index e02a3cf7f..ea6f0537c 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/util/QueryRequestMock.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/util/QueryRequestMock.java @@ -116,6 +116,13 @@ public int read() { return byteArrayInputStream.read(); } + + @Override + public int read(byte[] b, int off, int len) + throws IOException + { + return byteArrayInputStream.read(b, off, len); + } }); when(mockRequest.getReader()).thenReturn(new BufferedReader(new StringReader(query)));