diff --git a/MODULE.bazel b/MODULE.bazel index 803eacba297..42c3ed23e03 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -1,6 +1,6 @@ module( name = "grpc-java", - version = "1.81.0-SNAPSHOT", # CURRENT_GRPC_VERSION + version = "1.82.0-SNAPSHOT", # CURRENT_GRPC_VERSION compatibility_level = 0, repo_name = "io_grpc_grpc_java", ) @@ -47,6 +47,7 @@ IO_GRPC_GRPC_JAVA_ARTIFACTS = [ ] # GRPC_DEPS_END +bazel_dep(name = "abseil-cpp", version = "20250512.1") bazel_dep(name = "bazel_jar_jar", version = "0.1.11.bcr.1") bazel_dep(name = "bazel_skylib", version = "1.7.1") bazel_dep(name = "googleapis", version = "0.0.0-20240326-1c8d509c5", repo_name = "com_google_googleapis") diff --git a/README.md b/README.md index b0f7a6a14af..8e6620c927e 100644 --- a/README.md +++ b/README.md @@ -44,8 +44,8 @@ For a guided tour, take a look at the [quick start guide](https://grpc.io/docs/languages/java/quickstart) or the more explanatory [gRPC basics](https://grpc.io/docs/languages/java/basics). -The [examples](https://github.com/grpc/grpc-java/tree/v1.80.0/examples) and the -[Android example](https://github.com/grpc/grpc-java/tree/v1.80.0/examples/android) +The [examples](https://github.com/grpc/grpc-java/tree/v1.81.0/examples) and the +[Android example](https://github.com/grpc/grpc-java/tree/v1.81.0/examples/android) are standalone projects that showcase the usage of gRPC. Download @@ -56,34 +56,34 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: io.grpc grpc-netty-shaded - 1.80.0 + 1.81.0 runtime io.grpc grpc-protobuf - 1.80.0 + 1.81.0 io.grpc grpc-stub - 1.80.0 + 1.81.0 ``` Or for Gradle with non-Android, add to your dependencies: ```gradle -runtimeOnly 'io.grpc:grpc-netty-shaded:1.80.0' -implementation 'io.grpc:grpc-protobuf:1.80.0' -implementation 'io.grpc:grpc-stub:1.80.0' +runtimeOnly 'io.grpc:grpc-netty-shaded:1.81.0' +implementation 'io.grpc:grpc-protobuf:1.81.0' +implementation 'io.grpc:grpc-stub:1.81.0' ``` For Android client, use `grpc-okhttp` instead of `grpc-netty-shaded` and `grpc-protobuf-lite` instead of `grpc-protobuf`: ```gradle -implementation 'io.grpc:grpc-okhttp:1.80.0' -implementation 'io.grpc:grpc-protobuf-lite:1.80.0' -implementation 'io.grpc:grpc-stub:1.80.0' +implementation 'io.grpc:grpc-okhttp:1.81.0' +implementation 'io.grpc:grpc-protobuf-lite:1.81.0' +implementation 'io.grpc:grpc-stub:1.81.0' ``` For [Bazel](https://bazel.build), you can either @@ -91,7 +91,7 @@ For [Bazel](https://bazel.build), you can either (with the GAVs from above), or use `@io_grpc_grpc_java//api` et al (see below). [the JARs]: -https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.80.0 +https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.81.0 Development snapshots are available in [Sonatypes's snapshot repository](https://central.sonatype.com/repository/maven-snapshots/). @@ -123,7 +123,7 @@ For protobuf-based codegen integrated with the Maven build system, you can use com.google.protobuf:protoc:3.25.8:exe:${os.detected.classifier} grpc-java - io.grpc:protoc-gen-grpc-java:1.80.0:exe:${os.detected.classifier} + io.grpc:protoc-gen-grpc-java:1.81.0:exe:${os.detected.classifier} @@ -153,7 +153,7 @@ protobuf { } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.80.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.81.0' } } generateProtoTasks { @@ -186,7 +186,7 @@ protobuf { } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.80.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.81.0' } } generateProtoTasks { diff --git a/api/src/main/java/io/grpc/EquivalentAddressGroup.java b/api/src/main/java/io/grpc/EquivalentAddressGroup.java index 18151e88aba..2dd52fe7f21 100644 --- a/api/src/main/java/io/grpc/EquivalentAddressGroup.java +++ b/api/src/main/java/io/grpc/EquivalentAddressGroup.java @@ -55,6 +55,12 @@ public final class EquivalentAddressGroup { */ public static final Attributes.Key ATTR_LOCALITY_NAME = Attributes.Key.create("io.grpc.EquivalentAddressGroup.LOCALITY"); + /** + * The backend service associated with this EquivalentAddressGroup. + */ + @Attr + static final Attributes.Key ATTR_BACKEND_SERVICE = + Attributes.Key.create("io.grpc.EquivalentAddressGroup.BACKEND_SERVICE"); /** * Endpoint weight for load balancing purposes. While the type is Long, it must be a valid uint32. * Must not be zero. The weight is proportional to the other endpoints; if an endpoint's weight is diff --git a/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java b/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java index d4bed4d81bc..cd171208af7 100644 --- a/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java +++ b/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java @@ -26,4 +26,10 @@ private InternalEquivalentAddressGroup() {} * twice that of another endpoint, it is intended to receive twice the load. */ public static final Attributes.Key ATTR_WEIGHT = EquivalentAddressGroup.ATTR_WEIGHT; + + /** + * The backend service associated with this EquivalentAddressGroup. + */ + public static final Attributes.Key ATTR_BACKEND_SERVICE = + EquivalentAddressGroup.ATTR_BACKEND_SERVICE; } diff --git a/api/src/main/java/io/grpc/QueryParams.java b/api/src/main/java/io/grpc/QueryParams.java new file mode 100644 index 00000000000..31bc2e0e6da --- /dev/null +++ b/api/src/main/java/io/grpc/QueryParams.java @@ -0,0 +1,289 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.base.Splitter; +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.net.URLEncoder; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import javax.annotation.Nullable; + +/** + * A parser and mutable container class for {@code application/x-www-form-urlencoded}-style URL + * parameters as conceived by + * RFC 1866 Section 8.2.1. + * + *

For example, a URI like {@code "http://who?name=John+Doe&role=admin&role=user&active"} has: + * + *

    + *
  • A key {@code name} with value {@code John Doe} + *
  • A key {@code role} with value {@code admin} + *
  • A second key named {@code role} with value {@code user} + *
  • "Lone" key {@code active} without a value. + *
+ * + *

This class is meant to be used with {@link io.grpc.Uri}. For example: + * + *

{@code
+ * Uri uri = Uri.parse("http://who?name=John+Doe&role=admin&role=user&active");
+ * QueryParams params = QueryParams.fromRawQuery(uri.getRawQuery());
+ * params.asList().removeIf(e -> "role".equals(e.getKey()) && "admin".equals(e.getValue()));
+ *
+ * Uri modifiedUri = uri.toBuilder().setRawQuery(params.toRawQuery()).build();
+ * }
+ * + *

Note that the empty collection is encoded as a null raw query string, which means "absent" to + * {@link io.grpc.Uri.Builder#setRawQuery}. An empty string query component (""), on the other hand, + * is modeled as an instance of QueryParams containing a single lone (empty) key. It must be this + * way if we are to simultaneously 1) support lone keys, 2) have parse/toRawQuery round-trip + * transparency, and 3) never fail to parse a valid RFC 3986 query component. + * + *

This container and its {@link Entry} take the same position as {@link io.grpc.Uri} on + * equality: raw keys and values must match exactly to be equal. Most callers won't care about how + * keys and values are encoded on the wire and will work with the getters for cooked keys and values + * instead. + * + *

Instances are not safe for concurrent access by multiple threads, including by way of the + * {@link #asList()} view method. + */ +@Internal +public final class QueryParams { + + private static final String UTF_8 = "UTF-8"; + private final List entries = new ArrayList<>(); + + /** Creates a new, empty {@code QueryParams} instance. */ + public QueryParams() {} + + /** + * Parses a raw query string into a {@code QueryParams} instance. + * + *

The input is split on {@code '&'} and each parameter is parsed as either a key/value pair + * (if it contains an equals sign) or a "lone" key (if it does not). + * + *

No valid RFC 3986 query component will fail to parse. For example, {@code ===} is parsed as + * a single parameter with "" as the key and "==" as the value. {@code &&&} is parsed as three + * lone keys named "". And so on. If {@code rawQuery} is not a valid RFC 3986 query component, the + * behavior is undefined. But if you are starting with a {@link io.grpc.Uri}, passing the value + * returned by {@link io.grpc.Uri#getRawQuery()} is always well-defined and will never fail. + * + *

Calling {@link #toRawQuery()} on the returned object is guaranteed to return exactly {@code + * rawQuery}. + * + * @param rawQuery the raw query component to parse, or null to return an empty container + * @return a new instance of {@code QueryParams} representing the input + */ + public static QueryParams fromRawQuery(@Nullable String rawQuery) { + QueryParams params = new QueryParams(); + if (rawQuery != null) { + for (String part : Splitter.on('&').split(rawQuery)) { + int equalsIndex = part.indexOf('='); + if (equalsIndex == -1) { + params.entries.add(Entry.forRawLoneKey(part)); + } else { + String rawKey = part.substring(0, equalsIndex); + String rawValue = part.substring(equalsIndex + 1); + params.entries.add(Entry.forRawKeyValue(rawKey, rawValue)); + } + } + } + return params; + } + + /** + * Returns a mutable list view of the query parameters. + * + * @return the mutable list of entries + */ + public List asList() { + return entries; + } + + /** + * Returns the "raw" query string representation of these parameters, suitable for passing to the + * {@link io.grpc.Uri.Builder#setRawQuery} method. + * + * @return the raw query string + */ + @Nullable + public String toRawQuery() { + if (entries.isEmpty()) { + return null; + } + StringBuilder resultBuilder = new StringBuilder(); + boolean first = true; + for (Entry entry : entries) { + if (!first) { + resultBuilder.append('&'); + } + entry.appendToRawQueryStringBuilder(resultBuilder); + first = false; + } + return resultBuilder.toString(); + } + + @Override + public String toString() { + return entries.toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof QueryParams)) { + return false; + } + QueryParams other = (QueryParams) o; + return entries.equals(other.entries); + } + + @Override + public int hashCode() { + return entries.hashCode(); + } + + /** A single query parameter entry. */ + public static final class Entry { + private final String rawKey; + @Nullable private final String rawValue; + private final String key; + @Nullable private final String value; + + private Entry(String rawKey, @Nullable String rawValue, String key, @Nullable String value) { + this.rawKey = checkNotNull(rawKey, "rawKey"); + this.rawValue = rawValue; + this.key = checkNotNull(key, "key"); + this.value = value; + } + + /** + * Returns the key. + * + *

Any characters that needed URL encoding have already been decoded. + */ + public String getKey() { + return key; + } + + /** + * Returns the value, or {@code null} if this is a "lone" key. + * + *

Any characters that needed URL encoding have already been decoded. + */ + @Nullable + public String getValue() { + return value; + } + + /** Returns {@code true} if this entry has a value, {@code false} if it is a "lone" key. */ + public boolean hasValue() { + return value != null; + } + + /** + * Creates a new key/value pair entry. + * + *

Both key and value can contain any character. They will be URL encoded for you if + * necessary. + */ + public static Entry forKeyValue(String key, String value) { + checkNotNull(key, "key"); + checkNotNull(value, "value"); + return new Entry(encode(key), encode(value), key, value); + } + + /** + * Creates a new query parameter with a "lone" key. + * + *

'key' can contain any character. It will be URL encoded for you later, as necessary. + * + * @param key the decoded key, must not be null + * @return a new {@code Entry} + */ + public static Entry forLoneKey(String key) { + checkNotNull(key, "key"); + return new Entry(encode(key), null, key, null); + } + + static Entry forRawKeyValue(String rawKey, String rawValue) { + checkNotNull(rawKey, "rawKey"); + checkNotNull(rawValue, "rawValue"); + return new Entry(rawKey, rawValue, decode(rawKey), decode(rawValue)); + } + + static Entry forRawLoneKey(String rawKey) { + checkNotNull(rawKey, "rawKey"); + return new Entry(rawKey, null, decode(rawKey), null); + } + + void appendToRawQueryStringBuilder(StringBuilder sb) { + sb.append(rawKey); + if (rawValue != null) { + sb.append('=').append(rawValue); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Entry)) { + return false; + } + Entry entry = (Entry) o; + return Objects.equals(rawKey, entry.rawKey) && Objects.equals(rawValue, entry.rawValue); + } + + @Override + public int hashCode() { + return Objects.hash(rawKey, rawValue); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + appendToRawQueryStringBuilder(sb); + return sb.toString(); + } + } + + private static String decode(String s) { + try { + // TODO: Use URLDecoder.decode(String, Charset) when available + return URLDecoder.decode(s, UTF_8); + } catch (UnsupportedEncodingException impossible) { + throw new AssertionError("UTF-8 is not supported", impossible); + } + } + + private static String encode(String s) { + try { + // TODO: Use URLEncoder.encode(String, Charset) when available + return URLEncoder.encode(s, UTF_8); + } catch (UnsupportedEncodingException impossible) { + throw new AssertionError("UTF-8 is not supported", impossible); + } + } +} diff --git a/api/src/main/java/io/grpc/Uri.java b/api/src/main/java/io/grpc/Uri.java index 9f8a5a87848..42cc48044e9 100644 --- a/api/src/main/java/io/grpc/Uri.java +++ b/api/src/main/java/io/grpc/Uri.java @@ -546,24 +546,18 @@ public String getRawPath() { return path; } - /** - * Returns the percent-decoded "query" component of this URI, or null if not present. - * - *

NB: This method assumes the query was encoded as UTF-8, although RFC 3986 doesn't specify an - * encoding. - * - *

Decoding errors are indicated by a {@code '\u005CuFFFD'} unicode replacement character in - * the output. Callers who want to detect and handle errors in some other way should call {@link - * #getRawQuery()}, {@link #percentDecode(CharSequence)}, then decode the bytes for themselves. - */ - @Nullable - public String getQuery() { - return percentDecodeAssumedUtf8(query); - } - /** * Returns the query component of this URI in its originally parsed, possibly percent-encoded - * form, without any leading '?' character. + * form, without any leading '?' character, or null if not present. + * + *

The query component can only be read in its raw form. That’s because virtually everyone uses + * query as a container for structured data, with some additional layer of encoding not present in + * RFC-3986. Like 'application/x-www-form-urlencoded', which encodes key/value pairs like so: + * ?k1=v1&k2=v+2. The encoding of these containers always has characters that take on + * a special delimiter meaning when not percent-encoded and a literal meaning when they are (like + * '&', '=' and '+' above). Since it matters whether a character was percent encoded or not, + * offering a '#getQuery()' method that percent-decodes everything like we do for other components + * would be error-prone. */ @Nullable public String getRawQuery() { @@ -776,10 +770,20 @@ public Builder setRawPath(String path) { } /** - * Specifies the query component of the new URI (not including the leading '?'). + * Specifies the query component of the new URI, possibly percent-encoded, exactly as it will + * appear in the string form of the built URI. * - *

Query can contain any string of codepoints. Codepoints that can't be encoded literally - * will be percent-encoded for you as UTF-8. + *

'query' must only contain codepoints from RFC 3986's "query" character class. Any other + * characters must be percent-encoded using UTF-8. Do not include the leading '?' delimiter. + * + *

The query component can only be provided in its raw form. That’s because virtually + * everyone uses query as a container for structured data, with some additional layer of + * encoding not present in RFC-3986. Like 'application/x-www-form-urlencoded', which encodes + * key/value pairs like so: ?k1=v1&k2=v+2. The encoding of these containers always + * has characters that take on a special delimiter meaning when not percent-encoded and a + * literal meaning when they are (like '&', '=' and '+' above). Since 'query' must have already + * been carefully percent-encoded externally, a '#setQuery(String)' method that percent-encodes + * an assumed-cooked string would be error-prone. * *

This field is optional. * @@ -787,14 +791,10 @@ public Builder setRawPath(String path) { * @return this, for fluent building */ @CanIgnoreReturnValue - public Builder setQuery(@Nullable String query) { - this.query = percentEncode(query, queryChars); - return this; - } - - @CanIgnoreReturnValue - Builder setRawQuery(String query) { - checkPercentEncodedArg(query, "query", queryChars); + public Builder setRawQuery(@Nullable String query) { + if (query != null) { + checkPercentEncodedArg(query, "query", queryChars); + } this.query = query; return this; } diff --git a/api/src/test/java/io/grpc/QueryParamsTest.java b/api/src/test/java/io/grpc/QueryParamsTest.java new file mode 100644 index 00000000000..2def165a170 --- /dev/null +++ b/api/src/test/java/io/grpc/QueryParamsTest.java @@ -0,0 +1,274 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import io.grpc.QueryParams.Entry; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link QueryParams}. */ +@RunWith(JUnit4.class) +public class QueryParamsTest { + + @Test + public void emptyInstance() { + QueryParams params = new QueryParams(); + assertThat(params.asList()).isEmpty(); + assertThat(params.toRawQuery()).isNull(); + } + + @Test + public void parseNull_yieldsEmptyInstance() { + QueryParams params = QueryParams.fromRawQuery(null); + assertThat(params.asList()).isEmpty(); + assertThat(params.toRawQuery()).isNull(); + } + + @Test + public void parseEmptyString_yieldsSingleLoneKey() { + QueryParams params = QueryParams.fromRawQuery(""); + assertThat(params.toRawQuery()).isEmpty(); + assertThat(params.asList()).isNotEmpty(); + Entry entry = params.asList().get(0); + assertThat(entry).isNotNull(); + assertThat(entry.getKey()).isEmpty(); + assertThat(entry.hasValue()).isFalse(); + assertThat(entry.getValue()).isNull(); + } + + @Test + public void parseNormalPairs() { + QueryParams params = QueryParams.fromRawQuery("a=b&c=d"); + assertThat(params.toRawQuery()).isEqualTo("a=b&c=d"); + + QueryParams.Entry a = params.asList().get(0); + assertThat(a.getKey()).isEqualTo("a"); + assertThat(a.hasValue()).isTrue(); + assertThat(a.getValue()).isEqualTo("b"); + + QueryParams.Entry c = params.asList().get(1); + assertThat(c.getKey()).isEqualTo("c"); + assertThat(c.getValue()).isEqualTo("d"); + } + + @Test + public void parseLoneKey() { + QueryParams params = QueryParams.fromRawQuery("a&b"); + assertThat(params.toRawQuery()).isEqualTo("a&b"); + + QueryParams.Entry a = params.asList().get(0); + assertThat(a.getKey()).isEqualTo("a"); + assertThat(a.hasValue()).isFalse(); + + QueryParams.Entry b = params.asList().get(1); + assertThat(b.getKey()).isEqualTo("b"); + assertThat(b.hasValue()).isFalse(); + } + + @Test + public void parseEmptyKeysAndValues() { + QueryParams params = QueryParams.fromRawQuery("=&="); + assertThat(params.toRawQuery()).isEqualTo("=&="); + + assertThat(params.asList()).hasSize(2); + assertThat(params.asList().get(0).getKey()).isEmpty(); + assertThat(params.asList().get(0).hasValue()).isTrue(); + assertThat(params.asList().get(0).getValue()).isEmpty(); + assertThat(params.asList().get(1).getKey()).isEmpty(); + assertThat(params.asList().get(1).hasValue()).isTrue(); + assertThat(params.asList().get(1).getValue()).isEmpty(); + } + + @Test + public void roundTripPreservesEncodingOfSpaces() { + // Spaces can be encoded as + or %20. + QueryParams params = QueryParams.fromRawQuery("a+b=c%20d"); + assertThat(params.asList().get(0).getKey()).isEqualTo("a b"); + assertThat(params.asList().get(0).getValue()).isEqualTo("c d"); + assertThat(params.toRawQuery()).isEqualTo("a+b=c%20d"); + } + + @Test + public void roundTripPreservesCaseOfHexDigits() { + // Percent encoding can use upper or lower case. + QueryParams params = QueryParams.fromRawQuery("%4A%4a=%4B%4b"); + assertThat(params.asList().get(0).getKey()).isEqualTo("JJ"); + assertThat(params.asList().get(0).getValue()).isEqualTo("KK"); + assertThat(params.toRawQuery()).isEqualTo("%4A%4a=%4B%4b"); + } + + @Test + public void asListMethod() { + QueryParams params = new QueryParams(); + params.asList().add(QueryParams.Entry.forKeyValue("a b", "c d")); + params.asList().add(QueryParams.Entry.forLoneKey("e f")); + + // URLEncoder encodes spaces as + + assertThat(params.toRawQuery()).isEqualTo("a+b=c+d&e+f"); + } + + @Test + public void parseInvalidPercentEncodingThrows() { + assertThrows(IllegalArgumentException.class, () -> QueryParams.fromRawQuery("a=%GH")); + } + + @Test + public void parseInvalidKeyValueEncodingSucceeds() { + QueryParams params = QueryParams.fromRawQuery("===="); + assertThat(params.asList()) + .containsExactly(Entry.forRawKeyValue("", "===")) + .inOrder(); + assertThat(params.toRawQuery()).isEqualTo("===="); + } + + @Test + public void uriIntegration_canBuild() { + QueryParams params = new QueryParams(); + params.asList().add(Entry.forKeyValue("a", "b")); + params.asList().add(Entry.forKeyValue("c", "d")); + + Uri uri = + Uri.newBuilder() + .setScheme("http") + .setHost("example.com") + .setRawQuery(params.toRawQuery()) + .build(); + + assertThat(uri.toString()).isEqualTo("http://example.com?a=b&c=d"); + assertThat(uri.getRawQuery()).isEqualTo("a=b&c=d"); + } + + @Test + public void uriIntegration_canBuildEmpty() { + QueryParams params = new QueryParams(); + Uri uri = + Uri.newBuilder() + .setScheme("http") + .setHost("example.com") + .setRawQuery(params.toRawQuery()) + .build(); + + assertThat(uri.toString()).isEqualTo("http://example.com"); + assertThat(uri.getRawQuery()).isNull(); + } + + @Test + public void uriIntegration_canParse() { + Uri uri = Uri.create("http://example.com?a=b&c=d&e"); + QueryParams params = QueryParams.fromRawQuery(uri.getRawQuery()); + + assertThat(params.asList()) + .containsExactly( + Entry.forKeyValue("a", "b"), Entry.forKeyValue("c", "d"), Entry.forLoneKey("e")) + .inOrder(); + } + + @Test + public void keysAndValuesWithCharactersNeedingUrlEncoding() { + QueryParams params = new QueryParams(); + params.asList().add(Entry.forKeyValue("a=b", "c&d")); + params.asList().add(Entry.forKeyValue("e+f", "g h")); + + assertThat(params.toRawQuery()).isEqualTo("a%3Db=c%26d&e%2Bf=g+h"); + + QueryParams roundTripped = QueryParams.fromRawQuery(params.toRawQuery()); + assertThat(roundTripped).isEqualTo(params); + } + + @Test + public void keysAndValuesWithCodePointsOutsideAsciiRange() { + QueryParams params = new QueryParams(); + params.asList().add(Entry.forKeyValue("€", "𐐷")); + + assertThat(params.toRawQuery()).isEqualTo("%E2%82%AC=%F0%90%90%B7"); + + QueryParams roundTripped = QueryParams.fromRawQuery(params.toRawQuery()); + assertThat(roundTripped).isEqualTo(params); + } + + @Test + public void toStringMethod() { + QueryParams params = new QueryParams(); + assertThat(params.toString()).isEqualTo("[]"); + + params.asList().add(Entry.forKeyValue("a", "b")); + assertThat(params.toString()).isEqualTo("[a=b]"); + + params.asList().add(Entry.forLoneKey("c")); + assertThat(params.toString()).isEqualTo("[a=b, c]"); + + params.asList().add(Entry.forKeyValue("d=e", "f&g")); + assertThat(params.toString()).isEqualTo("[a=b, c, d%3De=f%26g]"); + } + + @Test + public void entryProperties() { + Entry keyValue = Entry.forKeyValue("key", "val"); + assertThat(keyValue.getKey()).isEqualTo("key"); + assertThat(keyValue.getValue()).isEqualTo("val"); + assertThat(keyValue.hasValue()).isTrue(); + + Entry loneKey = Entry.forLoneKey("key"); + assertThat(loneKey.getKey()).isEqualTo("key"); + assertThat(loneKey.getValue()).isNull(); + assertThat(loneKey.hasValue()).isFalse(); + } + + @Test + public void equalsAndHashCode_container() { + QueryParams params1 = new QueryParams(); + QueryParams params2 = new QueryParams(); + + // Empty instances are equal + assertThat(params1).isEqualTo(params2); + assertThat(params1.hashCode()).isEqualTo(params2.hashCode()); + + params1.asList().add(Entry.forKeyValue("a", "b")); + params1.asList().add(Entry.forLoneKey("c")); + + params2.asList().add(Entry.forKeyValue("a", "b")); + params2.asList().add(Entry.forLoneKey("c")); + + // Identical parameters in identical order are equal + assertThat(params1).isEqualTo(params2); + assertThat(params1.hashCode()).isEqualTo(params2.hashCode()); + + // Order matters. + QueryParams params3 = new QueryParams(); + params3.asList().add(Entry.forLoneKey("c")); + params3.asList().add(Entry.forKeyValue("a", "b")); + assertThat(params1).isNotEqualTo(params3); + } + + @Test + public void equalsAndHashCode_entry() { + // Raw matches are equal. + assertThat(Entry.forRawKeyValue("a+b", "c")).isEqualTo(Entry.forRawKeyValue("a+b", "c")); + assertThat(Entry.forRawKeyValue("a+b", "c").hashCode()) + .isEqualTo(Entry.forRawKeyValue("a+b", "c").hashCode()); + + // Spaces encoding matters. + and %20 are not equal. + assertThat(Entry.forRawKeyValue("a+b", "c")).isNotEqualTo(Entry.forRawKeyValue("a%20b", "c")); + + // Case of hex digits matter: %4A vs %4a are not equal raw keys. + assertThat(Entry.forRawKeyValue("a", "%4A")).isNotEqualTo(Entry.forRawKeyValue("a", "%4a")); + } +} diff --git a/api/src/test/java/io/grpc/UriTest.java b/api/src/test/java/io/grpc/UriTest.java index a1bd550696f..71ec1749b7d 100644 --- a/api/src/test/java/io/grpc/UriTest.java +++ b/api/src/test/java/io/grpc/UriTest.java @@ -42,7 +42,7 @@ public void parse_allComponents() throws URISyntaxException { assertThat(uri.getPort()).isEqualTo(443); assertThat(uri.getRawPort()).isEqualTo("0443"); assertThat(uri.getPath()).isEqualTo("/path"); - assertThat(uri.getQuery()).isEqualTo("query"); + assertThat(uri.getRawQuery()).isEqualTo("query"); assertThat(uri.getFragment()).isEqualTo("fragment"); assertThat(uri.toString()).isEqualTo("scheme://user@host:0443/path?query#fragment"); assertThat(uri.isAbsolute()).isFalse(); // Has a fragment. @@ -56,7 +56,7 @@ public void parse_noAuthority() throws URISyntaxException { assertThat(uri.getScheme()).isEqualTo("scheme"); assertThat(uri.getAuthority()).isNull(); assertThat(uri.getPath()).isEqualTo("/path"); - assertThat(uri.getQuery()).isEqualTo("query"); + assertThat(uri.getRawQuery()).isEqualTo("query"); assertThat(uri.getFragment()).isEqualTo("fragment"); assertThat(uri.toString()).isEqualTo("scheme:/path?query#fragment"); assertThat(uri.isAbsolute()).isFalse(); // Has a fragment. @@ -102,7 +102,7 @@ public void parse_noQuery() throws URISyntaxException { assertThat(uri.getScheme()).isEqualTo("scheme"); assertThat(uri.getAuthority()).isEqualTo("authority"); assertThat(uri.getPath()).isEqualTo("/path"); - assertThat(uri.getQuery()).isNull(); + assertThat(uri.getRawQuery()).isNull(); assertThat(uri.getFragment()).isEqualTo("fragment"); assertThat(uri.toString()).isEqualTo("scheme://authority/path#fragment"); } @@ -113,7 +113,7 @@ public void parse_noFragment() throws URISyntaxException { assertThat(uri.getScheme()).isEqualTo("scheme"); assertThat(uri.getAuthority()).isEqualTo("authority"); assertThat(uri.getPath()).isEqualTo("/path"); - assertThat(uri.getQuery()).isEqualTo("query"); + assertThat(uri.getRawQuery()).isEqualTo("query"); assertThat(uri.getFragment()).isNull(); assertThat(uri.toString()).isEqualTo("scheme://authority/path?query"); assertThat(uri.isAbsolute()).isTrue(); @@ -125,7 +125,7 @@ public void parse_emptyPathWithAuthority() throws URISyntaxException { assertThat(uri.getScheme()).isEqualTo("scheme"); assertThat(uri.getAuthority()).isEqualTo("authority"); assertThat(uri.getPath()).isEmpty(); - assertThat(uri.getQuery()).isNull(); + assertThat(uri.getRawQuery()).isNull(); assertThat(uri.getFragment()).isNull(); assertThat(uri.toString()).isEqualTo("scheme://authority"); assertThat(uri.isAbsolute()).isTrue(); @@ -139,7 +139,7 @@ public void parse_rootless() throws URISyntaxException { assertThat(uri.getScheme()).isEqualTo("mailto"); assertThat(uri.getAuthority()).isNull(); assertThat(uri.getPath()).isEqualTo("ceo@company.com"); - assertThat(uri.getQuery()).isEqualTo("subject=raise"); + assertThat(uri.getRawQuery()).isEqualTo("subject=raise"); assertThat(uri.getFragment()).isNull(); assertThat(uri.toString()).isEqualTo("mailto:ceo@company.com?subject=raise"); assertThat(uri.isAbsolute()).isTrue(); @@ -153,7 +153,7 @@ public void parse_emptyPath() throws URISyntaxException { assertThat(uri.getScheme()).isEqualTo("scheme"); assertThat(uri.getAuthority()).isNull(); assertThat(uri.getPath()).isEmpty(); - assertThat(uri.getQuery()).isNull(); + assertThat(uri.getRawQuery()).isNull(); assertThat(uri.getFragment()).isNull(); assertThat(uri.toString()).isEqualTo("scheme:"); assertThat(uri.isAbsolute()).isTrue(); @@ -165,7 +165,7 @@ public void parse_emptyPath() throws URISyntaxException { public void parse_emptyQuery() { Uri uri = Uri.create("scheme:?"); assertThat(uri.getScheme()).isEqualTo("scheme"); - assertThat(uri.getQuery()).isEmpty(); + assertThat(uri.getRawQuery()).isEmpty(); } @Test @@ -322,7 +322,6 @@ public void parse_decoding() throws URISyntaxException { assertThat(uri.getPort()).isEqualTo(1234); assertThat(uri.getPath()).isEqualTo("/p ath"); assertThat(uri.getRawPath()).isEqualTo("/p%20ath"); - assertThat(uri.getQuery()).isEqualTo("q uery"); assertThat(uri.getRawQuery()).isEqualTo("q%20uery"); assertThat(uri.getFragment()).isEqualTo("f ragment"); assertThat(uri.getRawFragment()).isEqualTo("f%20ragment"); @@ -336,9 +335,8 @@ public void parse_decodingNonAscii() throws URISyntaxException { @Test public void parse_decodingPercent() throws URISyntaxException { - Uri uri = Uri.parse("s://a/p%2520ath?q%25uery#f%25ragment"); + Uri uri = Uri.parse("s://a/p%2520ath#f%25ragment"); assertThat(uri.getPath()).isEqualTo("/p%20ath"); - assertThat(uri.getQuery()).isEqualTo("q%uery"); assertThat(uri.getFragment()).isEqualTo("f%ragment"); } @@ -420,7 +418,7 @@ public void toString_percentEncoding() throws URISyntaxException { .setScheme("s") .setHost("a b") .setPath("/p ath") - .setQuery("q uery") + .setRawQuery("q%20uery") .setFragment("f ragment") .build(); assertThat(uri.toString()).isEqualTo("s://a%20b/p%20ath?q%20uery#f%20ragment"); @@ -440,7 +438,6 @@ public void parse_transparentRoundTrip_ipLiteral() { assertThat(uri.getRawPath()).isEqualTo("/%4a%4B%2f%2F"); assertThat(uri.getPathSegments()).containsExactly("JK//"); assertThat(uri.getRawQuery()).isEqualTo("%4c%4D"); - assertThat(uri.getQuery()).isEqualTo("LM"); assertThat(uri.getRawFragment()).isEqualTo("%4e%4F"); assertThat(uri.getFragment()).isEqualTo("NO"); } @@ -459,7 +456,6 @@ public void parse_transparentRoundTrip_regName() { assertThat(uri.getRawPath()).isEqualTo("/%4a%4B%2f%2F"); assertThat(uri.getPathSegments()).containsExactly("JK//"); assertThat(uri.getRawQuery()).isEqualTo("%4c%4D"); - assertThat(uri.getQuery()).isEqualTo("LM"); assertThat(uri.getRawFragment()).isEqualTo("%4e%4F"); assertThat(uri.getFragment()).isEqualTo("NO"); } @@ -529,7 +525,7 @@ public void builder_encodingWithAllowedReservedChars() throws URISyntaxException .setUserInfo("u@") .setHost("a[]") .setPath("/p:/@") - .setQuery("q/?") + .setRawQuery("q/?") .setFragment("f/?") .build(); assertThat(uri.toString()).isEqualTo("s://u%40@a%5B%5D/p:/@?q/?#f/?"); @@ -600,7 +596,7 @@ public void builder_normalizesCaseWhereAppropriate() { .setScheme("hTtP") // #section-3.1 says producers (Builder) should normalize to lower. .setHost("aBc") // #section-3.2.2 says producers (Builder) should normalize to lower. .setPath("/CdE") // #section-6.2.2.1 says the rest are assumed to be case-sensitive - .setQuery("fGh") + .setRawQuery("fGh") .setFragment("IjK") .build(); assertThat(uri.toString()).isEqualTo("http://abc/CdE?fGh#IjK"); @@ -621,12 +617,32 @@ public void builder_canClearAllOptionalFields() { .setPath("") .setUserInfo(null) .setPort(-1) - .setQuery(null) + .setRawQuery(null) .setFragment(null) .build(); assertThat(uri.toString()).isEqualTo("http:"); } + @Test + public void builder_setRawQuery() { + Uri uri = Uri.newBuilder().setScheme("http").setHost("host").setRawQuery("%61=b&c=%64").build(); + assertThat(uri.getRawQuery()).isEqualTo("%61=b&c=%64"); + assertThat(uri.toString()).isEqualTo("http://host?%61=b&c=%64"); + } + + @Test + public void builder_setRawQuery_null() { + Uri uri = + Uri.newBuilder() + .setScheme("http") + .setHost("host") + .setRawQuery("a=b") + .setRawQuery(null) + .build(); + assertThat(uri.getRawQuery()).isNull(); + assertThat(uri.toString()).isEqualTo("http://host"); + } + @Test public void builder_canClearAuthorityComponents() { Uri uri = Uri.create("s://user@host:80/path").toBuilder().setRawAuthority(null).build(); @@ -692,7 +708,7 @@ public void toString_percentEncodingLiteralPercent() throws URISyntaxException { .setScheme("s") .setHost("a") .setPath("/p%20ath") - .setQuery("q%uery") + .setRawQuery("q%25uery") .setFragment("f%ragment") .build(); assertThat(uri.toString()).isEqualTo("s://a/p%2520ath?q%25uery#f%25ragment"); diff --git a/api/src/testFixtures/java/io/grpc/StatusMatcher.java b/api/src/testFixtures/java/io/grpc/StatusMatcher.java index f464b2d709d..08e9fffb013 100644 --- a/api/src/testFixtures/java/io/grpc/StatusMatcher.java +++ b/api/src/testFixtures/java/io/grpc/StatusMatcher.java @@ -26,7 +26,7 @@ */ public final class StatusMatcher implements ArgumentMatcher { public static StatusMatcher statusHasCode(ArgumentMatcher codeMatcher) { - return new StatusMatcher(codeMatcher, null); + return new StatusMatcher(codeMatcher, null, null); } public static StatusMatcher statusHasCode(Status.Code code) { @@ -35,17 +35,20 @@ public static StatusMatcher statusHasCode(Status.Code code) { private final ArgumentMatcher codeMatcher; private final ArgumentMatcher descriptionMatcher; + private final ArgumentMatcher causeMatcher; private StatusMatcher( ArgumentMatcher codeMatcher, - ArgumentMatcher descriptionMatcher) { + ArgumentMatcher descriptionMatcher, + ArgumentMatcher causeMatcher) { this.codeMatcher = checkNotNull(codeMatcher, "codeMatcher"); this.descriptionMatcher = descriptionMatcher; + this.causeMatcher = causeMatcher; } public StatusMatcher andDescription(ArgumentMatcher descriptionMatcher) { checkState(this.descriptionMatcher == null, "Already has a description matcher"); - return new StatusMatcher(codeMatcher, descriptionMatcher); + return new StatusMatcher(codeMatcher, descriptionMatcher, causeMatcher); } public StatusMatcher andDescription(String description) { @@ -56,11 +59,21 @@ public StatusMatcher andDescriptionContains(String substring) { return andDescription(new StringContainsMatcher(substring)); } + public StatusMatcher andCause(ArgumentMatcher causeMatcher) { + checkState(this.causeMatcher == null, "Already has a cause matcher"); + return new StatusMatcher(codeMatcher, descriptionMatcher, causeMatcher); + } + + public StatusMatcher andCause(Throwable cause) { + return andCause(new EqualsMatcher<>(cause)); + } + @Override public boolean matches(Status status) { return status != null && codeMatcher.matches(status.getCode()) - && (descriptionMatcher == null || descriptionMatcher.matches(status.getDescription())); + && (descriptionMatcher == null || descriptionMatcher.matches(status.getDescription())) + && (causeMatcher == null || causeMatcher.matches(status.getCause())); } @Override @@ -72,6 +85,10 @@ public String toString() { sb.append(", description="); sb.append(descriptionMatcher); } + if (causeMatcher != null) { + sb.append(", cause="); + sb.append(causeMatcher); + } sb.append("}"); return sb.toString(); } diff --git a/api/src/testFixtures/java/io/grpc/StatusSubject.java b/api/src/testFixtures/java/io/grpc/StatusSubject.java new file mode 100644 index 00000000000..0b00df96140 --- /dev/null +++ b/api/src/testFixtures/java/io/grpc/StatusSubject.java @@ -0,0 +1,68 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Fact.fact; + +import com.google.common.truth.FailureMetadata; +import com.google.common.truth.Subject; +import javax.annotation.Nullable; + +/** Propositions for {@link Status} subjects. */ +public final class StatusSubject extends Subject { + + private static final Subject.Factory statusFactory = new Factory(); + + public static Subject.Factory status() { + return statusFactory; + } + + private final Status actual; + + private StatusSubject(FailureMetadata metadata, @Nullable Status subject) { + super(metadata, subject); + this.actual = subject; + } + + /** Fails if the subject is not OK. */ + public void isOk() { + if (actual == null) { + failWithActual("expected to be OK but was", "null"); + } else if (!actual.isOk()) { + failWithoutActual( + fact("expected to be OK but was", actual.getCode()), + fact("description", actual.getDescription()), + fact("cause", actual.getCause())); + } + } + + /** Fails if the subject does not have the given code. */ + public void hasCode(Status.Code expectedCode) { + if (actual == null) { + failWithActual("expected to have code " + expectedCode + " but was", "null"); + } else { + check("getCode()").that(actual.getCode()).isEqualTo(expectedCode); + } + } + + private static final class Factory implements Subject.Factory { + @Override + public StatusSubject createSubject(FailureMetadata metadata, @Nullable Status that) { + return new StatusSubject(metadata, that); + } + } +} diff --git a/binder/build.gradle b/binder/build.gradle index 0da3f97ceee..7e7d4810e98 100644 --- a/binder/build.gradle +++ b/binder/build.gradle @@ -20,6 +20,13 @@ android { testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" } lintOptions { abortOnError = false } + buildTypes { + debug { + testCoverageEnabled true // For robolectric unit tests. + enableUnitTestCoverage true // For tests that run on an emulator. + } + } + publishing { singleVariant('release') { withSourcesJar() @@ -54,6 +61,7 @@ dependencies { testImplementation project(':grpc-testing') testImplementation project(':grpc-inprocess') testImplementation testFixtures(project(':grpc-core')) + testImplementation testFixtures(project(':grpc-api')) androidTestAnnotationProcessor libraries.auto.value androidTestImplementation project(':grpc-testing') @@ -133,3 +141,31 @@ afterEvaluate { components.release.withVariantsFromConfiguration(configurations.releaseTestFixturesVariantReleaseApiPublication) { skip() } components.release.withVariantsFromConfiguration(configurations.releaseTestFixturesVariantReleaseRuntimePublication) { skip() } } + +tasks.withType(Test) { + // Robolectric modifies classes in memory at runtime, so they lack a java.security.CodeSource + // URL to their on-disk location. By default, JaCoCo ignores classes without this property. + // Overriding this allows Robolectric tests to be instrumented. + jacoco.includeNoLocationClasses = true + // Don't instrument certain JDK internals protected from modification by JEP 403's "strong + // encapsulation." Avoids IllegalAccessError, InvalidClassException and similar at runtime. + jacoco.excludes = ["jdk.internal.**"] +} + +// Android projects don't automatically get a coverage report task. We must +// register one manually here and wire it up to AGP's test tasks. +tasks.register("jacocoTestReport", JacocoReport) { + dependsOn "testDebugUnitTest" + + reports { + // For codecov.io and coveralls. + xml.required = true + // Use the same output location as the other subprojects. + html.outputLocation = layout.buildDirectory.dir("reports/jacoco/test/html") + } + + sourceDirectories.from = android.sourceSets.main.java.srcDirs + classDirectories.from = fileTree(dir: layout.buildDirectory.dir("intermediates/javac/debug/classes"), + excludes: ['**/R.class', '**/R$*.class', '**/BuildConfig.class', '**/Manifest*.*', '**/*Test*.*', 'android/**/*.*']) + executionData.from = tasks.named("testDebugUnitTest").map { it.jacoco.destinationFile } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderClientTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderClientTransport.java index bef1eefd43e..58e7d7e2b31 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderClientTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderClientTransport.java @@ -279,7 +279,7 @@ public synchronized ClientStream newStream( } @Override - protected void unregisterInbound(Inbound inbound) { + protected void unregisterInbound(Inbound inbound) { if (inbound.countsForInUse() && numInUseStreams.decrementAndGet() == 0) { clientTransportListener.transportInUse(false); } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderServer.java b/binder/src/main/java/io/grpc/binder/internal/BinderServer.java index 96685a2f8bd..f913775fcbe 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderServer.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderServer.java @@ -70,6 +70,7 @@ public final class BinderServer implements InternalServer, LeakSafeOneWayBinder. private final LeakSafeOneWayBinder hostServiceBinder; private final BinderTransportSecurity.ServerPolicyChecker serverPolicyChecker; private final InboundParcelablePolicy inboundParcelablePolicy; + private final OneWayBinderProxy.Decorator clientBinderDecorator; @GuardedBy("this") private ServerListener listener; @@ -92,6 +93,7 @@ private BinderServer(Builder builder) { ImmutableList.copyOf(checkNotNull(builder.streamTracerFactories, "streamTracerFactories")); this.serverPolicyChecker = BinderInternal.createPolicyChecker(builder.serverSecurityPolicy); this.inboundParcelablePolicy = builder.inboundParcelablePolicy; + this.clientBinderDecorator = builder.clientBinderDecorator; hostServiceBinder = new LeakSafeOneWayBinder(this); } @@ -183,7 +185,7 @@ public synchronized boolean handleTransaction(int code, Parcel parcel) { executorServicePool, attrsBuilder.build(), streamTracerFactories, - OneWayBinderProxy.IDENTITY_DECORATOR, + clientBinderDecorator, callbackBinder); transport.start(listener.transportCreated(transport)); return true; @@ -225,6 +227,7 @@ public static class Builder { SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); ServerSecurityPolicy serverSecurityPolicy = SecurityPolicies.serverInternalOnly(); InboundParcelablePolicy inboundParcelablePolicy = InboundParcelablePolicy.DEFAULT; + OneWayBinderProxy.Decorator clientBinderDecorator = OneWayBinderProxy.IDENTITY_DECORATOR; public BinderServer build() { return new BinderServer(this); @@ -295,5 +298,19 @@ public Builder setInboundParcelablePolicy(InboundParcelablePolicy inboundParcela checkNotNull(inboundParcelablePolicy, "inboundParcelablePolicy"); return this; } + + /** + * Sets the {@link OneWayBinderProxy.Decorator} to be applied to this server's "client Binders". + * + *

Tests can use this to capture post-setup transactions from server to client. The specified + * decorator will be applied every time a client connects. The decorated result will be used for + * all subsequent transactions to this client from the new ServerTransport. + * + *

Optional, {@link OneWayBinderProxy#IDENTITY_DECORATOR} is the default. + */ + public Builder setClientBinderDecorator(OneWayBinderProxy.Decorator clientBinderDecorator) { + this.clientBinderDecorator = checkNotNull(clientBinderDecorator); + return this; + } } } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderServerTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderServerTransport.java index b8ab5e9f843..784d833bdf5 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderServerTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderServerTransport.java @@ -146,7 +146,7 @@ public synchronized void shutdownNow(Status reason) { @Override @Nullable @GuardedBy("this") - protected Inbound createInbound(int callId) { + protected Inbound createInbound(int callId) { return new Inbound.ServerInbound(this, attributes, callId); } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java index 1592f6977df..30b8735ac68 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java @@ -163,7 +163,7 @@ protected enum TransportState { @GuardedBy("this") private final LeakSafeOneWayBinder incomingBinder; - protected final ConcurrentHashMap> ongoingCalls; + protected final ConcurrentHashMap> ongoingCalls; protected final OneWayBinderProxy.Decorator binderDecorator; @GuardedBy("this") @@ -318,13 +318,13 @@ final void shutdownInternal(Status shutdownStatus, boolean forceTerminate) { incomingBinder.detach(); setState(TransportState.SHUTDOWN_TERMINATED); sendShutdownTransaction(); - ArrayList> calls = new ArrayList<>(ongoingCalls.values()); + ArrayList> calls = new ArrayList<>(ongoingCalls.values()); ongoingCalls.clear(); ArrayList> futuresToCancel = new ArrayList<>(ownedFutures); ownedFutures.clear(); scheduledExecutorService.execute( () -> { - for (Inbound inbound : calls) { + for (Inbound inbound : calls) { synchronized (inbound) { inbound.closeAbnormal(shutdownStatus); } @@ -392,7 +392,7 @@ protected synchronized void sendPing(int id) throws StatusException { } } - protected void unregisterInbound(Inbound inbound) { + protected void unregisterInbound(Inbound inbound) { unregisterCall(inbound.callId); } @@ -481,13 +481,13 @@ private boolean handleTransactionInternal(int code, Parcel parcel) { } } else { int size = parcel.dataSize(); - Inbound inbound = ongoingCalls.get(code); + Inbound inbound = ongoingCalls.get(code); if (inbound == null) { synchronized (this) { if (!isShutdown()) { inbound = createInbound(code); if (inbound != null) { - Inbound existing = ongoingCalls.put(code, inbound); + Inbound existing = ongoingCalls.put(code, inbound); // Can't happen as only one invocation of handleTransaction() is running at a time. Verify.verify(existing == null, "impossible appearance of %s", existing); } @@ -519,7 +519,7 @@ protected void restrictIncomingBinderToCallsFrom(int allowedCallingUid) { @Nullable @GuardedBy("this") - protected Inbound createInbound(int callId) { + protected Inbound createInbound(int callId) { return null; } @@ -566,7 +566,7 @@ final void handleAcknowledgedBytes(long numBytes) { Iterator i = callIdsToNotifyWhenReady.iterator(); while (isReady() && i.hasNext()) { - Inbound inbound = ongoingCalls.get(i.next()); + Inbound inbound = ongoingCalls.get(i.next()); i.remove(); if (inbound != null) { // Calls can be removed out from under us. inbound.onTransportReady(); @@ -598,7 +598,7 @@ private static void checkTransition(TransportState current, TransportState next) } @VisibleForTesting - Map> getOngoingCalls() { + Map> getOngoingCalls() { return ongoingCalls; } diff --git a/binder/src/main/java/io/grpc/binder/internal/BlockPool.java b/binder/src/main/java/io/grpc/binder/internal/BlockPool.java index 3c58abdd80b..985e465ab4b 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BlockPool.java +++ b/binder/src/main/java/io/grpc/binder/internal/BlockPool.java @@ -40,7 +40,7 @@ final class BlockPool { * The size of each standard block. (Currently 16k) The block size must be at least as large as * the maximum header list size. */ - private static final int BLOCK_SIZE = Math.max(16 * 1024, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); + static final int BLOCK_SIZE = Math.max(16 * 1024, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); /** * Maximum number of blocks to keep around. (Max 128k). This limit is a judgement call. 128k is diff --git a/binder/src/main/java/io/grpc/binder/internal/Inbound.java b/binder/src/main/java/io/grpc/binder/internal/Inbound.java index 9b9dfeef5ce..83fc8273d6f 100644 --- a/binder/src/main/java/io/grpc/binder/internal/Inbound.java +++ b/binder/src/main/java/io/grpc/binder/internal/Inbound.java @@ -42,9 +42,10 @@ * *

Out-of-order messages are reassembled into their correct order. */ -abstract class Inbound implements StreamListener.MessageProducer { +abstract class Inbound + implements StreamListener.MessageProducer { - protected final BinderTransport transport; + protected final T transport; protected final Attributes attributes; final int callId; @@ -145,7 +146,7 @@ enum State { @GuardedBy("this") private boolean producingMessages; - private Inbound(BinderTransport transport, Attributes attributes, int callId) { + private Inbound(T transport, Attributes attributes, int callId) { this.transport = transport; this.attributes = attributes; this.callId = callId; @@ -399,6 +400,13 @@ private void handleMessageData(int flags, int index, Parcel parcel) throws Statu numBytes = parcel.dataPosition() - startPos; } else { numBytes = parcel.readInt(); + if (numBytes > parcel.dataAvail()) { + throw Status.INTERNAL + .withDescription( + "Message size is larger than remaining parcel size: " + + numBytes + " > " + parcel.dataAvail()) + .asException(); + } block = BlockPool.acquireBlock(numBytes); if (numBytes > 0) { parcel.readByteArray(block); @@ -551,7 +559,7 @@ public synchronized String toString() { // ====================================== // Client-side inbound transactions. - static final class ClientInbound extends Inbound { + static final class ClientInbound extends Inbound { private final boolean countsForInUse; @@ -564,7 +572,10 @@ static final class ClientInbound extends Inbound { private Metadata trailers; ClientInbound( - BinderTransport transport, Attributes attributes, int callId, boolean countsForInUse) { + BinderClientTransport transport, + Attributes attributes, + int callId, + boolean countsForInUse) { super(transport, attributes, callId); this.countsForInUse = countsForInUse; } @@ -608,13 +619,9 @@ protected void deliverCloseAbnormal(Status status) { // ====================================== // Server-side inbound transactions. - static final class ServerInbound extends Inbound { - - private final BinderServerTransport serverTransport; - + static final class ServerInbound extends Inbound { ServerInbound(BinderServerTransport transport, Attributes attributes, int callId) { super(transport, attributes, callId); - this.serverTransport = transport; } @GuardedBy("this") @@ -623,17 +630,16 @@ protected void handlePrefix(int flags, Parcel parcel) throws StatusException { String methodName = parcel.readString(); Metadata headers = MetadataHelper.readMetadata(parcel, attributes); - StatsTraceContext statsTraceContext = - serverTransport.createStatsTraceContext(methodName, headers); + StatsTraceContext statsTraceContext = transport.createStatsTraceContext(methodName, headers); Outbound.ServerOutbound outbound = - new Outbound.ServerOutbound(serverTransport, callId, statsTraceContext); + new Outbound.ServerOutbound(transport, callId, statsTraceContext); ServerStream stream; if ((flags & TransactionUtils.FLAG_EXPECT_SINGLE_MESSAGE) != 0) { stream = new SingleMessageServerStream(this, outbound, attributes); } else { stream = new MultiMessageServerStream(this, outbound, attributes); } - Status status = serverTransport.startStream(stream, methodName, headers); + Status status = transport.startStream(stream, methodName, headers); if (status.isOk()) { checkNotNull(listener); // Is it ok to assume this will happen synchronously? if (transport.isReady()) { diff --git a/binder/src/test/java/io/grpc/binder/internal/RobolectricBinderTransportTest.java b/binder/src/test/java/io/grpc/binder/internal/RobolectricBinderTransportTest.java index 8282f5e1025..63c47bf4f19 100644 --- a/binder/src/test/java/io/grpc/binder/internal/RobolectricBinderTransportTest.java +++ b/binder/src/test/java/io/grpc/binder/internal/RobolectricBinderTransportTest.java @@ -18,8 +18,10 @@ import static android.os.IBinder.FLAG_ONEWAY; import static android.os.Process.myUid; +import static com.google.common.truth.Truth.assertAbout; import static com.google.common.truth.Truth.assertThat; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.grpc.StatusSubject.status; import static io.grpc.binder.internal.BinderTransport.REMOTE_UID; import static io.grpc.binder.internal.BinderTransport.SETUP_TRANSPORT; import static io.grpc.binder.internal.BinderTransport.SHUTDOWN_TRANSPORT; @@ -47,15 +49,20 @@ import com.google.common.collect.ImmutableList; import com.google.common.truth.TruthJUnit; import io.grpc.Attributes; +import io.grpc.CallOptions; import io.grpc.InternalChannelz.SocketStats; +import io.grpc.Metadata; import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.binder.AndroidComponentAddress; import io.grpc.binder.ApiConstants; import io.grpc.binder.AsyncSecurityPolicy; import io.grpc.binder.SecurityPolicies; +import io.grpc.binder.internal.OneWayBinderProxies.*; import io.grpc.binder.internal.SettableAsyncSecurityPolicy.AuthRequest; import io.grpc.internal.AbstractTransportTest; +import io.grpc.internal.ClientStream; +import io.grpc.internal.ClientStreamListenerBase; import io.grpc.internal.ClientTransport; import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; import io.grpc.internal.ConnectionClientTransport; @@ -66,7 +73,9 @@ import io.grpc.internal.MockServerTransportListener; import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourcePool; +import java.io.InputStream; import java.util.List; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import org.junit.Before; @@ -124,6 +133,8 @@ public final class RobolectricBinderTransportTest extends AbstractTransportTest ServiceInfo serviceInfo; private int nextServerAddress; + private BlockingBinderDecorator blockingDecorator = + new BlockingBinderDecorator<>(); @Parameter(value = 0) public boolean preAuthServersParam; @@ -167,27 +178,34 @@ public void requestRealisticBindServiceBehavior() { shadowOf(application).setUnbindServiceCallsOnServiceDisconnected(false); } - @Override - protected InternalServer newServer(List streamTracerFactories) { + BinderServer.Builder newServerBuilder() { AndroidComponentAddress listenAddr = AndroidComponentAddress.forBindIntent( new Intent() .setClassName(serviceInfo.packageName, serviceInfo.name) .setAction("io.grpc.action.BIND." + nextServerAddress++)); - BinderServer binderServer = - new BinderServer.Builder() - .setListenAddress(listenAddr) - .setExecutorPool(serverExecutorPool) - .setExecutorServicePool(executorServicePool) - .setStreamTracerFactories(streamTracerFactories) - .build(); + return new BinderServer.Builder() + .setListenAddress(listenAddr) + .setExecutorPool(serverExecutorPool) + .setExecutorServicePool(executorServicePool) + .setStreamTracerFactories(List.of()); + } + void registerServerWithRobolectric(BinderServer server) { + AndroidComponentAddress listenAddr = (AndroidComponentAddress) server.getListenSocketAddress(); shadowOf(application.getPackageManager()).addServiceIfNotPresent(listenAddr.getComponent()); shadowOf(application) .setComponentNameAndServiceForBindServiceForIntent( - listenAddr.asBindIntent(), listenAddr.getComponent(), binderServer.getHostBinder()); - return binderServer; + listenAddr.asBindIntent(), listenAddr.getComponent(), server.getHostBinder()); + } + + @Override + protected InternalServer newServer(List streamTracerFactories) { + BinderServer server = + newServerBuilder().setStreamTracerFactories(streamTracerFactories).build(); + registerServerWithRobolectric(server); + return server; } @Override @@ -433,4 +451,248 @@ public void flowControlPushBack() {} @Ignore("See BinderTransportTest#serverAlreadyListening") @Override public void serverAlreadyListening() {} + + @Test + public void singleTxnMsgsDeliveredToServerOutOfOrder() throws Exception { + server.start(serverListener); + client = + newClientTransportBuilder() + .setFactory( + newClientTransportFactoryBuilder() + .setBinderDecorator(blockingDecorator) + .buildClientTransportFactory()) + .build(); + runIfNotNull(client.start(mockClientTransportListener)); + blockingDecorator.putNextResult(takeNextBinder(blockingDecorator)); // Endpoint binder. + QueueingOneWayBinderProxy queueingServerProxy = + new QueueingOneWayBinderProxy(takeNextBinder(blockingDecorator)); // Server binder. + blockingDecorator.putNextResult(queueingServerProxy); + + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportReady(); + + ClientStream stream = + client.newStream(methodDescriptor, new Metadata(), CallOptions.DEFAULT, noopTracers); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + stream.start(clientStreamListener); + stream.writeMessage(methodDescriptor.streamRequest("one")); + stream.writeMessage(methodDescriptor.streamRequest("two")); + stream.halfClose(); + + // Expect one transaction for headers, one for each message, and one for half-close. + QueueingOneWayBinderProxy.Transaction txHeaders = takeNextTransaction(queueingServerProxy); + QueueingOneWayBinderProxy.Transaction tx1 = takeNextTransaction(queueingServerProxy); + QueueingOneWayBinderProxy.Transaction tx2 = takeNextTransaction(queueingServerProxy); + QueueingOneWayBinderProxy.Transaction txHalfClose = takeNextTransaction(queueingServerProxy); + + // Deliver messages out of order! + queueingServerProxy.deliver(txHeaders); + queueingServerProxy.deliver(tx2); + queueingServerProxy.deliver(tx1); + queueingServerProxy.deliver(txHalfClose); + + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, MILLISECONDS); + MockServerTransportListener.StreamCreation serverStreamCreation = + serverTransportListener.takeStreamOrFail(TIMEOUT_MS, MILLISECONDS); + serverStreamCreation.stream.request(2); + + // Expect the server to deliver the messages in the order they were originally sent. + InputStream msg1 = takeNextMessage(serverStreamCreation.listener.messageQueue); + assertThat(methodDescriptor.parseResponse(msg1)).isEqualTo("one"); + + InputStream msg2 = takeNextMessage(serverStreamCreation.listener.messageQueue); + assertThat(methodDescriptor.parseResponse(msg2)).isEqualTo("two"); + + assertThat(serverStreamCreation.listener.awaitHalfClosed(TIMEOUT_MS, MILLISECONDS)).isTrue(); + serverStreamCreation.stream.close(Status.OK, new Metadata()); + + assertAbout(status()).that(clientStreamListener.awaitClose(TIMEOUT_MS, MILLISECONDS)).isOk(); + assertAbout(status()) + .that(serverStreamCreation.listener.awaitClose(TIMEOUT_MS, MILLISECONDS)) + .isOk(); + } + + @Test + public void msgFragmentsDeliveredToServerOutOfOrder() throws Exception { + server.start(serverListener); + client = + newClientTransportBuilder() + .setFactory( + newClientTransportFactoryBuilder() + .setBinderDecorator(blockingDecorator) + .buildClientTransportFactory()) + .build(); + runIfNotNull(client.start(mockClientTransportListener)); + blockingDecorator.putNextResult(takeNextBinder(blockingDecorator)); // Endpoint binder. + QueueingOneWayBinderProxy queueingServerProxy = + new QueueingOneWayBinderProxy(takeNextBinder(blockingDecorator)); // Server binder. + blockingDecorator.putNextResult(queueingServerProxy); + + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportReady(); + + ClientStream stream = + client.newStream(methodDescriptor, new Metadata(), CallOptions.DEFAULT, noopTracers); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + stream.start(clientStreamListener); + + String largeMessage = newStringOfLength(BlockPool.BLOCK_SIZE + 1); + stream.writeMessage(methodDescriptor.streamRequest(largeMessage)); + stream.halfClose(); + + // Expect the client to split largeMessage into two transactions, plus headers and half-close. + QueueingOneWayBinderProxy.Transaction txHeaders = takeNextTransaction(queueingServerProxy); + QueueingOneWayBinderProxy.Transaction tx1 = takeNextTransaction(queueingServerProxy); + QueueingOneWayBinderProxy.Transaction tx2 = takeNextTransaction(queueingServerProxy); + QueueingOneWayBinderProxy.Transaction txHalfClose = takeNextTransaction(queueingServerProxy); + + // Deliver fragments out of order! + queueingServerProxy.deliver(txHeaders); + queueingServerProxy.deliver(tx2); + queueingServerProxy.deliver(tx1); + queueingServerProxy.deliver(txHalfClose); + + // Verify that the server reassembles the transactions correctly. + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, MILLISECONDS); + MockServerTransportListener.StreamCreation serverStreamCreation = + serverTransportListener.takeStreamOrFail(TIMEOUT_MS, MILLISECONDS); + serverStreamCreation.stream.request(1); + InputStream msg = takeNextMessage(serverStreamCreation.listener.messageQueue); + assertThat(methodDescriptor.parseResponse(msg)).isEqualTo(largeMessage); + + assertThat(serverStreamCreation.listener.awaitHalfClosed(TIMEOUT_MS, MILLISECONDS)).isTrue(); + serverStreamCreation.stream.close(Status.OK, new Metadata()); + + assertAbout(status()).that(clientStreamListener.awaitClose(TIMEOUT_MS, MILLISECONDS)).isOk(); + assertAbout(status()) + .that(serverStreamCreation.listener.awaitClose(TIMEOUT_MS, MILLISECONDS)) + .isOk(); + } + + @Test + public void singleTxnMsgsDeliveredToClientOutOfOrder() throws Exception { + server = newServerBuilder().setClientBinderDecorator(blockingDecorator).build(); + registerServerWithRobolectric((BinderServer) server); + server.start(serverListener); + + client = newClientTransport(server); + runIfNotNull(client.start(mockClientTransportListener)); + + QueueingOneWayBinderProxy queueingClientProxy = + new QueueingOneWayBinderProxy(takeNextBinder(blockingDecorator)); + blockingDecorator.putNextResult(queueingClientProxy); + + // Deliver the setup transaction without interference. + queueingClientProxy.deliver(takeNextTransaction(queueingClientProxy)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportReady(); + + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + ClientStream stream = + client.newStream(methodDescriptor, new Metadata(), CallOptions.DEFAULT, noopTracers); + stream.start(clientStreamListener); + stream.halfClose(); + stream.request(2); + + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, MILLISECONDS); + MockServerTransportListener.StreamCreation serverStreamCreation = + serverTransportListener.takeStreamOrFail(TIMEOUT_MS, MILLISECONDS); + + serverStreamCreation.stream.writeMessage(methodDescriptor.streamResponse("one")); + serverStreamCreation.stream.writeMessage(methodDescriptor.streamResponse("two")); + serverStreamCreation.stream.close(Status.OK, new Metadata()); + + // Expect one transaction from the server for each message. + QueueingOneWayBinderProxy.Transaction tx1 = takeNextTransaction(queueingClientProxy); + QueueingOneWayBinderProxy.Transaction tx2 = takeNextTransaction(queueingClientProxy); + QueueingOneWayBinderProxy.Transaction txClose = takeNextTransaction(queueingClientProxy); + + // Deliver messages to the client out of order! + queueingClientProxy.deliver(tx2); + queueingClientProxy.deliver(tx1); + queueingClientProxy.deliver(txClose); + + // Client should deliver messages to the application in the order sent. + InputStream msg1 = takeNextMessage(clientStreamListener.messageQueue); + assertThat(methodDescriptor.parseResponse(msg1)).isEqualTo("one"); + InputStream msg2 = takeNextMessage(clientStreamListener.messageQueue); + assertThat(methodDescriptor.parseResponse(msg2)).isEqualTo("two"); + + assertAbout(status()).that(clientStreamListener.awaitClose(TIMEOUT_MS, MILLISECONDS)).isOk(); + assertAbout(status()) + .that(serverStreamCreation.listener.awaitClose(TIMEOUT_MS, MILLISECONDS)) + .isOk(); + } + + @Test + public void msgFragmentsDeliveredToClientOutOfOrder() throws Exception { + server = newServerBuilder().setClientBinderDecorator(blockingDecorator).build(); + registerServerWithRobolectric((BinderServer) server); + server.start(serverListener); + + client = newClientTransport(server); + runIfNotNull(client.start(mockClientTransportListener)); + + QueueingOneWayBinderProxy queueingClientProxy = + new QueueingOneWayBinderProxy(takeNextBinder(blockingDecorator)); + blockingDecorator.putNextResult(queueingClientProxy); + + // Deliver the setup transaction without interference. + queueingClientProxy.deliver(takeNextTransaction(queueingClientProxy)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportReady(); + + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + ClientStream stream = + client.newStream(methodDescriptor, new Metadata(), CallOptions.DEFAULT, noopTracers); + stream.start(clientStreamListener); + stream.request(1); + + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, MILLISECONDS); + MockServerTransportListener.StreamCreation serverStreamCreation = + serverTransportListener.takeStreamOrFail(TIMEOUT_MS, MILLISECONDS); + + String largeMessage = newStringOfLength(BlockPool.BLOCK_SIZE + 1); + serverStreamCreation.stream.writeMessage(methodDescriptor.streamResponse(largeMessage)); + serverStreamCreation.stream.flush(); + + // Expect the client to split largeMessage into two transactions. + QueueingOneWayBinderProxy.Transaction tx1 = takeNextTransaction(queueingClientProxy); + QueueingOneWayBinderProxy.Transaction tx2 = takeNextTransaction(queueingClientProxy); + + // Deliver them to the client out of order! + queueingClientProxy.deliver(tx2); + queueingClientProxy.deliver(tx1); + + // Client should reassemble the message correctly. + InputStream msg = takeNextMessage(clientStreamListener.messageQueue); + assertThat(methodDescriptor.parseResponse(msg)).isEqualTo(largeMessage); + } + + private static OneWayBinderProxy takeNextBinder( + BlockingBinderDecorator decorator) throws InterruptedException { + OneWayBinderProxy proxy = decorator.takeNextRequest(TIMEOUT_MS, MILLISECONDS); + assertThat(proxy).isNotNull(); + return proxy; + } + + private static QueueingOneWayBinderProxy.Transaction takeNextTransaction( + QueueingOneWayBinderProxy proxy) throws InterruptedException { + QueueingOneWayBinderProxy.Transaction tx = proxy.pollNextTransaction(TIMEOUT_MS, MILLISECONDS); + assertThat(tx).isNotNull(); + return tx; + } + + private static InputStream takeNextMessage(BlockingQueue messageQueue) + throws InterruptedException { + InputStream msg = messageQueue.poll(TIMEOUT_MS, MILLISECONDS); + assertThat(msg).isNotNull(); + return msg; + } + + private static String newStringOfLength(int numChars) { + char[] chars = new char[numChars]; + java.util.Arrays.fill(chars, 'x'); + return new String(chars); + } } diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/OneWayBinderProxies.java b/binder/src/testFixtures/java/io/grpc/binder/internal/OneWayBinderProxies.java similarity index 67% rename from binder/src/androidTest/java/io/grpc/binder/internal/OneWayBinderProxies.java rename to binder/src/testFixtures/java/io/grpc/binder/internal/OneWayBinderProxies.java index 4abdb2c03dd..c7eee06e73a 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/OneWayBinderProxies.java +++ b/binder/src/testFixtures/java/io/grpc/binder/internal/OneWayBinderProxies.java @@ -18,6 +18,7 @@ import android.os.RemoteException; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; /** A collection of {@link OneWayBinderProxy}-related test helpers. */ @@ -42,6 +43,18 @@ public OneWayBinderProxy takeNextRequest() throws InterruptedException { return requests.take(); } + /** + * Returns the next {@link OneWayBinderProxy} that needs decorating, blocking for up to the + * specified timeout if it hasn't yet been provided to {@link #decorate}. + * + *

Follow this with a call to {@link #putNextResult(OneWayBinderProxy)} to provide the result + * of {@link #decorate} and unblock the waiting caller. + */ + public OneWayBinderProxy takeNextRequest(long timeout, TimeUnit unit) + throws InterruptedException { + return requests.poll(timeout, unit); + } + /** Provides the next value to return from {@link #decorate}. */ public void putNextResult(T next) throws InterruptedException { results.put(next); @@ -119,6 +132,49 @@ public void transact(int code, ParcelHolder data) throws RemoteException { } } + /** A {@link OneWayBinderProxy} that queues transactions for a test to deliver manually later. */ + public static final class QueueingOneWayBinderProxy extends OneWayBinderProxy { + public static final class Transaction { + public final int code; + private final ParcelHolder parcel; + + public Transaction(int code, ParcelHolder parcel) { + this.code = code; + this.parcel = parcel; + } + } + + private final BlockingQueue queue = new LinkedBlockingQueue<>(); + private final OneWayBinderProxy wrapped; + + public QueueingOneWayBinderProxy(OneWayBinderProxy wrapped) { + super(wrapped.getDelegate()); + this.wrapped = wrapped; + } + + @Override + public void transact(int code, ParcelHolder data) throws RemoteException { + queue.add(new Transaction(code, new ParcelHolder(data.release()))); + } + + /** + * Returns the next transaction that was queued in order, waiting up to the specified timeout. + */ + public Transaction pollNextTransaction(long timeout, TimeUnit unit) + throws InterruptedException { + return queue.poll(timeout, unit); + } + + /** + * Delivers a previously queued transaction to its original destination. + * + * @throws IllegalStateException if transaction was already delivered once before + */ + public void deliver(Transaction transaction) throws RemoteException { + wrapped.transact(transaction.code, transaction.parcel); + } + } + // Cannot be instantiated. private OneWayBinderProxies() {} ; diff --git a/build.gradle b/build.gradle index 2cf3439ea76..e65261b0cc4 100644 --- a/build.gradle +++ b/build.gradle @@ -21,7 +21,7 @@ subprojects { apply plugin: "net.ltgt.errorprone" group = "io.grpc" - version = "1.81.0-SNAPSHOT" // CURRENT_GRPC_VERSION + version = "1.82.0-SNAPSHOT" // CURRENT_GRPC_VERSION repositories { maven { // The google mirror is less flaky than mavenCentral() diff --git a/buildscripts/grpc-java-artifacts/Dockerfile.multiarch.base b/buildscripts/grpc-java-artifacts/Dockerfile.multiarch.base index da2c46904ca..6b670994677 100644 --- a/buildscripts/grpc-java-artifacts/Dockerfile.multiarch.base +++ b/buildscripts/grpc-java-artifacts/Dockerfile.multiarch.base @@ -1,4 +1,7 @@ -FROM ubuntu:18.04 +FROM ubuntu:24.04 + +# Redirect to the internal mirror to bypass the Kokoro network block +RUN sed -i 's|http://archive.ubuntu.com/ubuntu/|http://mirror.bazel.build/archive.ubuntu.com/ubuntu/|g' /etc/apt/sources.list RUN export DEBIAN_FRONTEND=noninteractive && \ apt-get update && \ @@ -9,8 +12,8 @@ RUN export DEBIAN_FRONTEND=noninteractive && \ curl \ g++-aarch64-linux-gnu \ g++-powerpc64le-linux-gnu \ - openjdk-8-jdk \ - pkg-config \ + openjdk-11-jdk \ + pkgconf \ && \ rm -rf /var/lib/apt/lists/* diff --git a/compiler/BUILD.bazel b/compiler/BUILD.bazel index e8a0571e134..a9ffe77a55a 100644 --- a/compiler/BUILD.bazel +++ b/compiler/BUILD.bazel @@ -13,6 +13,7 @@ cc_binary( ], visibility = ["//visibility:public"], deps = [ + "@abseil-cpp//absl/strings", "@com_google_protobuf//:protoc_lib", ], ) diff --git a/compiler/src/java_plugin/cpp/java_generator.cpp b/compiler/src/java_plugin/cpp/java_generator.cpp index a81d54791b4..d0f8cdd13d5 100644 --- a/compiler/src/java_plugin/cpp/java_generator.cpp +++ b/compiler/src/java_plugin/cpp/java_generator.cpp @@ -46,6 +46,7 @@ #include #include #include +#include "absl/strings/escaping.h" #include #include #include @@ -1206,7 +1207,7 @@ static void PrintService(const ServiceDescriptor* service, bool disable_version, GeneratedAnnotation generated_annotation) { (*vars)["service_name"] = service->name(); - (*vars)["file_name"] = service->file()->name(); + (*vars)["file_name"] = absl::Utf8SafeCEscape(service->file()->name()); (*vars)["service_class_name"] = ServiceClassName(service); (*vars)["grpc_version"] = ""; #ifdef GRPC_VERSION diff --git a/compiler/src/test/golden/TestDeprecatedService.java.txt b/compiler/src/test/golden/TestDeprecatedService.java.txt index 1c37c9a8af9..0b4924f3e6a 100644 --- a/compiler/src/test/golden/TestDeprecatedService.java.txt +++ b/compiler/src/test/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.81.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.82.0-SNAPSHOT)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated diff --git a/compiler/src/test/golden/TestService.java.txt b/compiler/src/test/golden/TestService.java.txt index 08eb2fb6ac3..5c65890273c 100644 --- a/compiler/src/test/golden/TestService.java.txt +++ b/compiler/src/test/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.81.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.82.0-SNAPSHOT)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { diff --git a/core/src/main/java/io/grpc/internal/DelayedClientCall.java b/core/src/main/java/io/grpc/internal/DelayedClientCall.java index b568bb12c46..e0c05ca637e 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientCall.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientCall.java @@ -206,7 +206,7 @@ public final void start(Listener listener, final Metadata headers) { savedError = error; savedPassThrough = passThrough; if (!savedPassThrough) { - listener = delayedListener = new DelayedListener<>(listener); + listener = delayedListener = new DelayedListener<>(this, listener); startHeaders = headers; } } @@ -445,15 +445,33 @@ public void runInContext() { } private static final class DelayedListener extends Listener { + private final DelayedClientCall call; private final Listener realListener; private volatile boolean passThrough; + private volatile Status exceptionStatus; @GuardedBy("this") private List pendingCallbacks = new ArrayList<>(); - public DelayedListener(Listener listener) { + public DelayedListener(DelayedClientCall call, Listener listener) { + this.call = call; this.realListener = listener; } + /** + * Cancels call and schedules onClose() notification. May only be called from within a + * DelayedListener callback dispatch (either queued drain or passThrough). Visibility of the + * write to {@code exceptionStatus} does not rely on a single callback executor; it is a + * {@code volatile} field, and callback queuing/pass-through transitions are coordinated by + * this listener's synchronization so subsequent callbacks observe the updated status. + */ + private void exceptionThrown(Throwable t, String description) { + // onClose() must be delivered exactly once and last. Other callbacks may already be queued + // ahead of realCall's eventual onClose, so we can't call onClose() here. We set the status + // and overwrite the onClose() details when it arrives. + exceptionStatus = Status.CANCELLED.withCause(t).withDescription(description); + call.cancel(description, t); + } + private void delayOrExecute(Runnable runnable) { synchronized (this) { if (!passThrough) { @@ -467,37 +485,75 @@ private void delayOrExecute(Runnable runnable) { @Override public void onHeaders(final Metadata headers) { if (passThrough) { - realListener.onHeaders(headers); + deliverHeaders(headers); } else { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onHeaders(headers); + deliverHeaders(headers); } }); } } + private void deliverHeaders(Metadata headers) { + if (exceptionStatus != null) { + return; + } + try { + realListener.onHeaders(headers); + } catch (Throwable t) { + exceptionThrown(t, "Failed to read headers"); + } + } + @Override public void onMessage(final RespT message) { if (passThrough) { - realListener.onMessage(message); + deliverMessage(message); } else { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onMessage(message); + deliverMessage(message); } }); } } + private void deliverMessage(RespT message) { + if (exceptionStatus != null) { + return; + } + try { + realListener.onMessage(message); + } catch (Throwable t) { + exceptionThrown(t, "Failed to read message."); + } + } + @Override public void onClose(final Status status, final Metadata trailers) { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onClose(status, trailers); + Status effectiveStatus = status; + Metadata effectiveTrailers = trailers; + if (exceptionStatus != null) { + // Ideally status matches exceptionStatus, since exceptionStatus was used to cancel + // the call. However, cancel() may reconstruct a new Status instance, and the cancel + // is racy so this onClose may have already been queued when the cancellation + // occurred. Since other callbacks throw away data if exceptionStatus != null, it is + // semantically essential that we _not_ use a status provided by the server. + effectiveStatus = exceptionStatus; + // Replace trailers to prevent mixing sources of status and trailers. + effectiveTrailers = new Metadata(); + } + try { + realListener.onClose(effectiveStatus, effectiveTrailers); + } catch (RuntimeException ex) { + logger.log(Level.WARNING, "Exception thrown by onClose() in ClientCall", ex); + } } }); } @@ -505,17 +561,28 @@ public void run() { @Override public void onReady() { if (passThrough) { - realListener.onReady(); + deliverOnReady(); } else { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onReady(); + deliverOnReady(); } }); } } + private void deliverOnReady() { + if (exceptionStatus != null) { + return; + } + try { + realListener.onReady(); + } catch (Throwable t) { + exceptionThrown(t, "Failed to call onReady."); + } + } + void drainPendingCallbacks() { assert !passThrough; List toRun = new ArrayList<>(); @@ -535,7 +602,6 @@ void drainPendingCallbacks() { } for (Runnable runnable : toRun) { // Avoid calling listener while lock is held to prevent deadlocks. - // TODO(ejona): exception handling runnable.run(); } toRun.clear(); diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index deae5d831b8..c419f028f58 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -219,7 +219,7 @@ public byte[] parseAsciiString(byte[] serialized) { public static final Splitter ACCEPT_ENCODING_SPLITTER = Splitter.on(',').trimResults(); - public static final String IMPLEMENTATION_VERSION = "1.81.0-SNAPSHOT"; // CURRENT_GRPC_VERSION + public static final String IMPLEMENTATION_VERSION = "1.82.0-SNAPSHOT"; // CURRENT_GRPC_VERSION /** * The default timeout in nanos for a keepalive ping request. @@ -241,6 +241,12 @@ public byte[] parseAsciiString(byte[] serialized) { */ public static final long DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS = TimeUnit.SECONDS.toNanos(20L); + /** + * The default minimum time between client keepalive pings permitted by server. + */ + public static final long DEFAULT_SERVER_PERMIT_KEEPALIVE_TIME_NANOS + = TimeUnit.MINUTES.toNanos(5L); + /** * The magic keepalive time value that disables keepalive. */ diff --git a/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java b/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java index 5560a1abb6d..7124f2fc88a 100644 --- a/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java +++ b/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java @@ -223,8 +223,11 @@ private Status validateInitialMetadata(Metadata headers) { } String contentType = headers.get(GrpcUtil.CONTENT_TYPE_KEY); if (!GrpcUtil.isGrpcContentType(contentType)) { - return GrpcUtil.httpStatusToGrpcStatus(httpStatus) - .augmentDescription("invalid content-type: " + contentType); + Status status = GrpcUtil.httpStatusToGrpcStatus(httpStatus); + if (contentType == null) { + return status.augmentDescription("missing content-type in response headers"); + } + return status.augmentDescription("invalid content-type: " + contentType); } return null; } diff --git a/core/src/main/java/io/grpc/internal/InternalSubchannel.java b/core/src/main/java/io/grpc/internal/InternalSubchannel.java index ce31921e316..00a66b1c1df 100644 --- a/core/src/main/java/io/grpc/internal/InternalSubchannel.java +++ b/core/src/main/java/io/grpc/internal/InternalSubchannel.java @@ -42,6 +42,7 @@ import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.InternalChannelz; import io.grpc.InternalChannelz.ChannelStats; +import io.grpc.InternalEquivalentAddressGroup; import io.grpc.InternalInstrumented; import io.grpc.InternalLogId; import io.grpc.InternalWithLogId; @@ -606,8 +607,8 @@ public void run() { connectedAddressAttributes = addressIndex.getCurrentEagAttributes(); gotoNonErrorState(READY); subchannelMetrics.recordConnectionAttemptSucceeded(/* target= */ target, - /* backendService= */ getAttributeOrDefault( - addressIndex.getCurrentEagAttributes(), NameResolver.ATTR_BACKEND_SERVICE), + /* backendService= */ getBackendServiceOrDefault( + addressIndex.getCurrentEagAttributes()), /* locality= */ getAttributeOrDefault(addressIndex.getCurrentEagAttributes(), EquivalentAddressGroup.ATTR_LOCALITY_NAME), /* securityLevel= */ extractSecurityLevel(addressIndex.getCurrentEagAttributes() @@ -638,8 +639,8 @@ public void run() { addressIndex.reset(); gotoNonErrorState(IDLE); subchannelMetrics.recordDisconnection(/* target= */ target, - /* backendService= */ getAttributeOrDefault(addressIndex.getCurrentEagAttributes(), - NameResolver.ATTR_BACKEND_SERVICE), + /* backendService= */ getBackendServiceOrDefault( + addressIndex.getCurrentEagAttributes()), /* locality= */ getAttributeOrDefault(addressIndex.getCurrentEagAttributes(), EquivalentAddressGroup.ATTR_LOCALITY_NAME), /* disconnectError= */ disconnectError.toErrorString(), @@ -647,8 +648,8 @@ public void run() { .get(GrpcAttributes.ATTR_SECURITY_LEVEL))); } else if (pendingTransport == transport) { subchannelMetrics.recordConnectionAttemptFailed(/* target= */ target, - /* backendService= */getAttributeOrDefault(addressIndex.getCurrentEagAttributes(), - NameResolver.ATTR_BACKEND_SERVICE), + /* backendService= */ getBackendServiceOrDefault( + addressIndex.getCurrentEagAttributes()), /* locality= */ getAttributeOrDefault(addressIndex.getCurrentEagAttributes(), EquivalentAddressGroup.ATTR_LOCALITY_NAME)); Preconditions.checkState(state.getState() == CONNECTING, @@ -711,6 +712,14 @@ private String getAttributeOrDefault(Attributes attributes, Attributes.Key 0 && fullStreamDecompressor == null + && unprocessed.readableBytes() == 0) { + return false; + } + if (nextFrame == null) { nextFrame = new CompositeReadableBuffer(); } diff --git a/core/src/main/java/io/grpc/internal/NameResolverFactoryToProviderFacade.java b/core/src/main/java/io/grpc/internal/NameResolverFactoryToProviderFacade.java index 31c20f6e499..e52eb5e38d4 100644 --- a/core/src/main/java/io/grpc/internal/NameResolverFactoryToProviderFacade.java +++ b/core/src/main/java/io/grpc/internal/NameResolverFactoryToProviderFacade.java @@ -19,6 +19,7 @@ import io.grpc.NameResolver; import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; +import io.grpc.Uri; import java.net.URI; public class NameResolverFactoryToProviderFacade extends NameResolverProvider { @@ -34,6 +35,11 @@ public NameResolver newNameResolver(URI targetUri, Args args) { return factory.newNameResolver(targetUri, args); } + @Override + public NameResolver newNameResolver(Uri targetUri, Args args) { + return factory.newNameResolver(targetUri, args); + } + @Override public String getDefaultScheme() { return factory.getDefaultScheme(); diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java index dc0709e1fb8..d469fdb33dc 100644 --- a/core/src/main/java/io/grpc/internal/ServerImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerImpl.java @@ -151,7 +151,9 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume InternalLogId.allocate("Server", String.valueOf(getListenSocketsIgnoringLifecycle())); // Fork from the passed in context so that it does not propagate cancellation, it only // inherits values. - this.rootContext = Preconditions.checkNotNull(rootContext, "rootContext").fork(); + this.rootContext = Preconditions.checkNotNull(rootContext, "rootContext") + .fork() + .withValue(io.grpc.InternalServer.SERVER_CONTEXT_KEY, ServerImpl.this); this.decompressorRegistry = builder.decompressorRegistry; this.compressorRegistry = builder.compressorRegistry; this.transportFilters = Collections.unmodifiableList( @@ -622,19 +624,7 @@ private void runInternal() { // An extremely short deadline may expire before stream.setListener(jumpListener). // This causes NPE as in issue: https://github.com/grpc/grpc-java/issues/6300 // Delay of setting cancellationListener to context will fix the issue. - final class ServerStreamCancellationListener implements Context.CancellationListener { - @Override - public void cancelled(Context context) { - Status status = statusFromCancelled(context); - if (DEADLINE_EXCEEDED.getCode().equals(status.getCode())) { - // This should rarely get run, since the client will likely cancel the stream - // before the timeout is reached. - stream.cancel(status); - } - } - } - - context.addListener(new ServerStreamCancellationListener(), directExecutor()); + context.addListener(new ServerStreamCancellationListener(stream), directExecutor()); } } @@ -648,8 +638,7 @@ private Context.CancellableContext createContext( Context baseContext = statsTraceCtx - .serverFilterContext(rootContext) - .withValue(io.grpc.InternalServer.SERVER_CONTEXT_KEY, ServerImpl.this); + .serverFilterContext(rootContext); if (timeoutNanos == null) { return baseContext.withCancellation(); @@ -707,6 +696,31 @@ private ServerStreamListener startWrappedCall( } } + /** + * Propagates context cancellation to the ServerStream. + * + *

This is outside of HandleServerCall because that class holds Metadata and other state needed + * only when starting the RPC. The cancellation listener will live for the life of the call, so we + * avoid that useless state being retained. + */ + static final class ServerStreamCancellationListener implements Context.CancellationListener { + private final ServerStream stream; + + ServerStreamCancellationListener(ServerStream stream) { + this.stream = checkNotNull(stream, "stream"); + } + + @Override + public void cancelled(Context context) { + Status status = statusFromCancelled(context); + if (DEADLINE_EXCEEDED.getCode().equals(status.getCode())) { + // This should rarely get run, since the client will likely cancel the stream + // before the timeout is reached. + stream.cancel(status); + } + } + } + @Override public InternalLogId getLogId() { return logId; diff --git a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java index ff131d29975..0d30e947b0c 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java @@ -229,6 +229,232 @@ public void delayedCallsRunUnderContext() throws Exception { assertThat(contextKey.get(readyContext.get())).isEqualTo(goldenValue); } + @Test + public void listenerThrowsInPendingCallback_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onMessage(Integer msg) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + // Deliver onMessage while the wrapping DelayedListener is still buffering, by firing + // it from within realCall.start() — drainPendingCalls has not yet flipped the listener + // to pass-through. The queued onMessage is then drained and throws; the fix must catch + // the throwable and cancel the real call rather than let it escape. + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onMessage(42); + } + }); + assertThat(r).isNotNull(); + r.run(); // Must not propagate `boom`. + verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom)); + } + + @Test + public void listenerThrowsInPendingOnHeaders_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onHeaders(Metadata headers) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onHeaders(new Metadata()); + } + }); + assertThat(r).isNotNull(); + r.run(); + verify(mockRealCall).cancel(eq("Failed to read headers"), eq(boom)); + } + + @Test + public void listenerThrowsInPendingOnReady_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onReady() { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onReady(); + } + }); + assertThat(r).isNotNull(); + r.run(); + verify(mockRealCall).cancel(eq("Failed to call onReady."), eq(boom)); + } + + @Test + public void onCloseExceptionCaughtAndLogged() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + final AtomicReference observed = new AtomicReference<>(); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onClose(Status status, Metadata trailers) { + observed.set(status); + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onClose(Status.DATA_LOSS, new Metadata()); + } + }); + assertThat(r).isNotNull(); + r.run(); // Must not propagate `boom`. + assertThat(observed.get().getCode()).isEqualTo(Status.Code.DATA_LOSS); + verify(mockRealCall, never()).cancel(any(), any()); + } + + @Test + public void listenerThrowsInPassThroughOnMessage_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onMessage(Integer msg) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); // drain completes, listener transitions to passThrough + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + realCallListener.onMessage(42); // dispatched on passThrough fast path + verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom)); + } + + @Test + public void listenerThrowsInPassThroughOnHeaders_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onHeaders(Metadata headers) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + realCallListener.onHeaders(new Metadata()); + verify(mockRealCall).cancel(eq("Failed to read headers"), eq(boom)); + } + + @Test + public void listenerThrowsInPassThroughOnReady_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onReady() { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + realCallListener.onReady(); + verify(mockRealCall).cancel(eq("Failed to call onReady."), eq(boom)); + } + + @Test + public void listenerThrowsInPassThrough_subsequentCallbacksSwallowedAndOnCloseOverridden() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + final AtomicReference lastMessage = new AtomicReference<>(); + final AtomicReference closeStatus = new AtomicReference<>(); + final AtomicReference closeTrailers = new AtomicReference<>(); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onMessage(Integer msg) { + lastMessage.set(msg); + if (msg == 1) { + throw boom; + } + } + + @Override + public void onClose(Status status, Metadata trailers) { + closeStatus.set(status); + closeTrailers.set(trailers); + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + + realCallListener.onMessage(1); // throws -> exceptionStatus captured + assertThat(lastMessage.get()).isEqualTo(1); + verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom)); + + // Later callbacks are swallowed — the listener must not see message 2. + realCallListener.onMessage(2); + assertThat(lastMessage.get()).isEqualTo(1); + + // Transport onClose with OK must be overridden by the captured CANCELLED status. + Metadata serverTrailers = new Metadata(); + serverTrailers.put(Metadata.Key.of("k", Metadata.ASCII_STRING_MARSHALLER), "v"); + realCallListener.onClose(Status.OK, serverTrailers); + assertThat(closeStatus.get().getCode()).isEqualTo(Status.Code.CANCELLED); + assertThat(closeStatus.get().getCause()).isEqualTo(boom); + // Trailers replaced to avoid mixing sources. + assertThat(closeTrailers.get()).isNotSameInstanceAs(serverTrailers); + } + private void callMeMaybe(Runnable r) { if (r != null) { r.run(); diff --git a/core/src/test/java/io/grpc/internal/Http2ClientStreamTransportStateTest.java b/core/src/test/java/io/grpc/internal/Http2ClientStreamTransportStateTest.java index 9d32bf1af7d..66df062a3e0 100644 --- a/core/src/test/java/io/grpc/internal/Http2ClientStreamTransportStateTest.java +++ b/core/src/test/java/io/grpc/internal/Http2ClientStreamTransportStateTest.java @@ -301,6 +301,24 @@ public void transportTrailersReceived_missingStatusUsesHttpStatus() { assertTrue(statusCaptor.getValue().getDescription().contains("401")); } + @Test + public void transportTrailersReceived_missingContentTypeUsesHttpStatus() { + BaseTransportState state = new BaseTransportState(transportTracer); + state.setListener(mockListener); + Metadata trailers = new Metadata(); + trailers.put(testStatusMashaller, "431"); + + state.transportTrailersReceived(trailers); + + verify(mockListener, never()).headersRead(any(Metadata.class)); + verify(mockListener).closed(statusCaptor.capture(), same(PROCESSED), same(trailers)); + assertEquals(Code.INTERNAL, statusCaptor.getValue().getCode()); + assertTrue(statusCaptor.getValue().getDescription().contains("HTTP status code 431")); + assertTrue( + statusCaptor.getValue().getDescription().contains( + "missing content-type in response headers")); + } + @Test public void transportTrailersReceived_missingHttpStatus() { BaseTransportState state = new BaseTransportState(transportTracer); diff --git a/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java b/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java index 811344da307..4236c091d9c 100644 --- a/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java +++ b/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java @@ -48,6 +48,7 @@ import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.InternalChannelz; +import io.grpc.InternalEquivalentAddressGroup; import io.grpc.InternalLogId; import io.grpc.InternalWithLogId; import io.grpc.LoadBalancer; @@ -1510,7 +1511,7 @@ public void subchannelStateChanges_triggersAttemptFailedMetric() { when(mockBackoffPolicyProvider.get()).thenReturn(mockBackoffPolicy); SocketAddress addr = mock(SocketAddress.class); Attributes eagAttributes = Attributes.newBuilder() - .set(NameResolver.ATTR_BACKEND_SERVICE, BACKEND_SERVICE) + .set(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE, BACKEND_SERVICE) .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, LOCALITY) .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SECURITY_LEVEL) .build(); @@ -1564,7 +1565,7 @@ public void subchannelStateChanges_triggersSuccessAndDisconnectMetrics() { // 2. Setup Subchannel with attributes SocketAddress addr = mock(SocketAddress.class); Attributes eagAttributes = Attributes.newBuilder() - .set(NameResolver.ATTR_BACKEND_SERVICE, BACKEND_SERVICE) + .set(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE, BACKEND_SERVICE) .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, LOCALITY) .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SECURITY_LEVEL) .build(); @@ -1631,6 +1632,45 @@ public void subchannelStateChanges_triggersSuccessAndDisconnectMetrics() { inOrder.verifyNoMoreInteractions(); } + @Test + public void subchannelStateChanges_backendServiceFallsBackToResolutionResultAttr() { + when(mockBackoffPolicyProvider.get()).thenReturn(mockBackoffPolicy); + SocketAddress addr = mock(SocketAddress.class); + Attributes eagAttributes = Attributes.newBuilder() + .set(NameResolver.ATTR_BACKEND_SERVICE, BACKEND_SERVICE) + .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, LOCALITY) + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SECURITY_LEVEL) + .build(); + List addressGroups = + Arrays.asList(new EquivalentAddressGroup(Arrays.asList(addr), eagAttributes)); + InternalLogId logId = InternalLogId.allocate("Subchannel", /*details=*/ AUTHORITY); + ChannelTracer subchannelTracer = new ChannelTracer(logId, 10, + fakeClock.getTimeProvider().currentTimeNanos(), "Subchannel"); + LoadBalancer.CreateSubchannelArgs createSubchannelArgs = + LoadBalancer.CreateSubchannelArgs.newBuilder().setAddresses(addressGroups).build(); + internalSubchannel = new InternalSubchannel( + createSubchannelArgs, AUTHORITY, USER_AGENT, mockBackoffPolicyProvider, + mockTransportFactory, fakeClock.getScheduledExecutorService(), + fakeClock.getStopwatchSupplier(), syncContext, mockInternalSubchannelCallback, channelz, + CallTracer.getDefaultFactory().create(), subchannelTracer, logId, + new ChannelLoggerImpl(subchannelTracer, fakeClock.getTimeProvider()), + Collections.emptyList(), AUTHORITY, mockMetricRecorder + ); + + internalSubchannel.obtainActiveTransport(); + MockClientTransportInfo transportInfo = transports.poll(); + assertNotNull(transportInfo); + transportInfo.listener.transportReady(); + fakeClock.runDueTasks(); + + verify(mockMetricRecorder).addLongCounter( + eqMetricInstrumentName("grpc.subchannel.connection_attempts_succeeded"), + eq(1L), + eq(Arrays.asList(AUTHORITY)), + eq(Arrays.asList(BACKEND_SERVICE, LOCALITY)) + ); + } + private void assertNoCallbackInvoke() { while (fakeExecutor.runDueTasks() > 0) {} assertEquals(0, callbackInvokes.size()); diff --git a/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java b/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java index 5d6b88a1392..5d07de32df9 100644 --- a/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java +++ b/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java @@ -185,7 +185,7 @@ public void log(ChannelLogLevel level, String messageFormat, Object... args) {} protected final ClientStreamTracer[] tracers = new ClientStreamTracer[] { clientStreamTracer1, clientStreamTracer2 }; - private final ClientStreamTracer[] noopTracers = new ClientStreamTracer[] { + protected final ClientStreamTracer[] noopTracers = new ClientStreamTracer[] { new ClientStreamTracer() {} }; diff --git a/examples/MODULE.bazel b/examples/MODULE.bazel index 2e90a63c219..105fcecaafe 100644 --- a/examples/MODULE.bazel +++ b/examples/MODULE.bazel @@ -1,4 +1,4 @@ -bazel_dep(name = "grpc-java", version = "1.81.0-SNAPSHOT", repo_name = "io_grpc_grpc_java") # CURRENT_GRPC_VERSION +bazel_dep(name = "grpc-java", version = "1.82.0-SNAPSHOT", repo_name = "io_grpc_grpc_java") # CURRENT_GRPC_VERSION bazel_dep(name = "rules_java", version = "9.3.0") bazel_dep(name = "grpc-proto", version = "0.0.0-20240627-ec30f58", repo_name = "io_grpc_grpc_proto") bazel_dep(name = "protobuf", version = "33.1", repo_name = "com_google_protobuf") diff --git a/examples/android/clientcache/app/build.gradle b/examples/android/clientcache/app/build.gradle index 67110a78c43..0219e73ff89 100644 --- a/examples/android/clientcache/app/build.gradle +++ b/examples/android/clientcache/app/build.gradle @@ -33,7 +33,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -53,11 +53,11 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION testImplementation 'junit:junit:4.13.2' testImplementation 'com.google.truth:truth:1.4.5' - testImplementation 'io.grpc:grpc-testing:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + testImplementation 'io.grpc:grpc-testing:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/android/helloworld/app/build.gradle b/examples/android/helloworld/app/build.gradle index d20bd03d1fc..1e81415e483 100644 --- a/examples/android/helloworld/app/build.gradle +++ b/examples/android/helloworld/app/build.gradle @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,7 +52,7 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/android/routeguide/app/build.gradle b/examples/android/routeguide/app/build.gradle index 377cb417100..7152add7858 100644 --- a/examples/android/routeguide/app/build.gradle +++ b/examples/android/routeguide/app/build.gradle @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,7 +52,7 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/android/strictmode/app/build.gradle b/examples/android/strictmode/app/build.gradle index b752bc4ffd3..cc54d274a29 100644 --- a/examples/android/strictmode/app/build.gradle +++ b/examples/android/strictmode/app/build.gradle @@ -33,7 +33,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -53,7 +53,7 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/build.gradle b/examples/build.gradle index 688121c677e..0ad62bb9ef0 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.8' def protocVersion = protobufVersion diff --git a/examples/example-alts/build.gradle b/examples/example-alts/build.gradle index 47268ab6510..3fea622b923 100644 --- a/examples/example-alts/build.gradle +++ b/examples/example-alts/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.8' dependencies { diff --git a/examples/example-debug/build.gradle b/examples/example-debug/build.gradle index 940543a3681..e4edc0704d0 100644 --- a/examples/example-debug/build.gradle +++ b/examples/example-debug/build.gradle @@ -23,7 +23,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.8' dependencies { diff --git a/examples/example-debug/pom.xml b/examples/example-debug/pom.xml index 10734935ee6..ccb9977f679 100644 --- a/examples/example-debug/pom.xml +++ b/examples/example-debug/pom.xml @@ -6,13 +6,13 @@ jar - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT example-debug https://github.com/grpc/grpc-java UTF-8 - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT 3.25.8 1.8 diff --git a/examples/example-dualstack/build.gradle b/examples/example-dualstack/build.gradle index f2947c641cf..f79888831dc 100644 --- a/examples/example-dualstack/build.gradle +++ b/examples/example-dualstack/build.gradle @@ -23,7 +23,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.8' dependencies { diff --git a/examples/example-dualstack/pom.xml b/examples/example-dualstack/pom.xml index f5e720a9128..99c0da77a22 100644 --- a/examples/example-dualstack/pom.xml +++ b/examples/example-dualstack/pom.xml @@ -6,13 +6,13 @@ jar - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT example-dualstack https://github.com/grpc/grpc-java UTF-8 - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT 3.25.8 1.8 diff --git a/examples/example-gauth/build.gradle b/examples/example-gauth/build.gradle index 489197e5f20..5ab563479a4 100644 --- a/examples/example-gauth/build.gradle +++ b/examples/example-gauth/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.8' def protocVersion = protobufVersion diff --git a/examples/example-gauth/pom.xml b/examples/example-gauth/pom.xml index 9fb854629b4..66e0f3be563 100644 --- a/examples/example-gauth/pom.xml +++ b/examples/example-gauth/pom.xml @@ -6,13 +6,13 @@ jar - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT example-gauth https://github.com/grpc/grpc-java UTF-8 - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT 3.25.8 1.8 diff --git a/examples/example-gcp-csm-observability/build.gradle b/examples/example-gcp-csm-observability/build.gradle index 63c6d20125d..2ddfd995cd3 100644 --- a/examples/example-gcp-csm-observability/build.gradle +++ b/examples/example-gcp-csm-observability/build.gradle @@ -22,7 +22,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.8' def openTelemetryVersion = '1.56.0' def openTelemetryPrometheusVersion = '1.56.0-alpha' diff --git a/examples/example-gcp-observability/build.gradle b/examples/example-gcp-observability/build.gradle index a41e7cdd629..531a5c2f9de 100644 --- a/examples/example-gcp-observability/build.gradle +++ b/examples/example-gcp-observability/build.gradle @@ -22,7 +22,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.8' dependencies { diff --git a/examples/example-hostname/build.gradle b/examples/example-hostname/build.gradle index 6117b8c32a1..f776de41511 100644 --- a/examples/example-hostname/build.gradle +++ b/examples/example-hostname/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.8' dependencies { diff --git a/examples/example-hostname/pom.xml b/examples/example-hostname/pom.xml index ed90d481587..8a3c231e3eb 100644 --- a/examples/example-hostname/pom.xml +++ b/examples/example-hostname/pom.xml @@ -6,13 +6,13 @@ jar - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT example-hostname https://github.com/grpc/grpc-java UTF-8 - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT 3.25.8 1.8 diff --git a/examples/example-jwt-auth/build.gradle b/examples/example-jwt-auth/build.gradle index 5614a72742c..36e6f08b3cc 100644 --- a/examples/example-jwt-auth/build.gradle +++ b/examples/example-jwt-auth/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.8' def protocVersion = protobufVersion diff --git a/examples/example-jwt-auth/pom.xml b/examples/example-jwt-auth/pom.xml index 7befaf500c5..2989f61d4a0 100644 --- a/examples/example-jwt-auth/pom.xml +++ b/examples/example-jwt-auth/pom.xml @@ -7,13 +7,13 @@ jar - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT example-jwt-auth https://github.com/grpc/grpc-java UTF-8 - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT 3.25.8 3.25.8 diff --git a/examples/example-oauth/build.gradle b/examples/example-oauth/build.gradle index 07e51217622..3ad99a51d5d 100644 --- a/examples/example-oauth/build.gradle +++ b/examples/example-oauth/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.8' def protocVersion = protobufVersion diff --git a/examples/example-oauth/pom.xml b/examples/example-oauth/pom.xml index 9ce20f2f684..3d88e732829 100644 --- a/examples/example-oauth/pom.xml +++ b/examples/example-oauth/pom.xml @@ -7,13 +7,13 @@ jar - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT example-oauth https://github.com/grpc/grpc-java UTF-8 - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT 3.25.8 3.25.8 diff --git a/examples/example-opentelemetry/build.gradle b/examples/example-opentelemetry/build.gradle index a24900c0fe5..8515f015c92 100644 --- a/examples/example-opentelemetry/build.gradle +++ b/examples/example-opentelemetry/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.8' def openTelemetryVersion = '1.56.0' def openTelemetryPrometheusVersion = '1.56.0-alpha' diff --git a/examples/example-orca/build.gradle b/examples/example-orca/build.gradle index 674c4bdf2f7..65627159c9c 100644 --- a/examples/example-orca/build.gradle +++ b/examples/example-orca/build.gradle @@ -16,7 +16,7 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.8' dependencies { diff --git a/examples/example-reflection/build.gradle b/examples/example-reflection/build.gradle index aa870967135..7c54ea281d5 100644 --- a/examples/example-reflection/build.gradle +++ b/examples/example-reflection/build.gradle @@ -16,7 +16,7 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.8' dependencies { diff --git a/examples/example-servlet/build.gradle b/examples/example-servlet/build.gradle index 7f23c83e0d9..b83d38be5b5 100644 --- a/examples/example-servlet/build.gradle +++ b/examples/example-servlet/build.gradle @@ -15,7 +15,7 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.8' dependencies { diff --git a/examples/example-tls/build.gradle b/examples/example-tls/build.gradle index 456cb8b4f73..4fe0794d62b 100644 --- a/examples/example-tls/build.gradle +++ b/examples/example-tls/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.8' dependencies { diff --git a/examples/example-tls/pom.xml b/examples/example-tls/pom.xml index ff9d01253f5..dfe611e4fe7 100644 --- a/examples/example-tls/pom.xml +++ b/examples/example-tls/pom.xml @@ -6,13 +6,13 @@ jar - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT example-tls https://github.com/grpc/grpc-java UTF-8 - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT 3.25.8 1.8 diff --git a/examples/example-xds/build.gradle b/examples/example-xds/build.gradle index e8b3f3dd395..1974c86798e 100644 --- a/examples/example-xds/build.gradle +++ b/examples/example-xds/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.81.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.8' dependencies { diff --git a/examples/pom.xml b/examples/pom.xml index 5375b930b3b..943182b60fe 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -6,13 +6,13 @@ jar - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT examples https://github.com/grpc/grpc-java UTF-8 - 1.81.0-SNAPSHOT + 1.82.0-SNAPSHOT 3.25.8 3.25.8 diff --git a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java index 427c0658531..10ba586ab47 100644 --- a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java +++ b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java @@ -27,6 +27,7 @@ import io.grpc.MetricRecorder; import io.grpc.NameResolver; import io.grpc.NameResolverRegistry; +import io.grpc.QueryParams; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.Uri; @@ -47,7 +48,6 @@ import java.io.Reader; import java.net.HttpURLConnection; import java.net.URI; -import java.net.URISyntaxException; import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.List; @@ -81,18 +81,26 @@ final class GoogleCloudToProdNameResolver extends NameResolver { private static HttpConnectionProvider httpConnectionProvider = HttpConnectionFactory.INSTANCE; private static int c2pId = new Random().nextInt(); - private static synchronized BootstrapInfo getBootstrapInfo() + private static synchronized BootstrapInfo getBootstrapInfo(boolean isForcedXds) throws XdsInitializationException, IOException { if (bootstrapInfo != null) { return bootstrapInfo; } - BootstrapInfo bootstrapInfoTmp = - InternalGrpcBootstrapperImpl.parseBootstrap(generateBootstrap()); + BootstrapInfo newInfo; + if (isForcedXds) { + newInfo = InternalGrpcBootstrapperImpl.parseBootstrap( + generateBootstrap("", true)); + } else { + newInfo = InternalGrpcBootstrapperImpl.parseBootstrap( + generateBootstrap( + queryZoneMetadata(METADATA_URL_ZONE), + queryIpv6SupportMetadata(METADATA_URL_SUPPORT_IPV6))); + } // Avoid setting global when testing if (httpConnectionProvider == HttpConnectionFactory.INSTANCE) { - bootstrapInfo = bootstrapInfoTmp; + bootstrapInfo = newInfo; } - return bootstrapInfoTmp; + return newInfo; } private final String authority; @@ -102,7 +110,8 @@ private static synchronized BootstrapInfo getBootstrapInfo() private final MetricRecorder metricRecorder; private final NameResolver delegate; private final boolean usingExecutorResource; - private final String schemeOverride = !isOnGcp ? "dns" : "xds"; + private final boolean forceXds; + private final String schemeOverride; private XdsClientResult xdsClientPool; private XdsClient xdsClient; private Executor executor; @@ -122,6 +131,13 @@ private static synchronized BootstrapInfo getBootstrapInfo() NameResolver.Factory nameResolverFactory) { this.executorResource = checkNotNull(executorResource, "executorResource"); String targetPath = checkNotNull(checkNotNull(targetUri, "targetUri").getPath(), "targetPath"); + Uri grpcUri = Uri.create(targetUri.toString()); + QueryParams queryParams = QueryParams.fromRawQuery(grpcUri.getRawQuery()); + this.forceXds = checkForceXds(queryParams); + this.schemeOverride = (forceXds || isOnGcp) ? "xds" : "dns"; + stripForceXds(queryParams); + String newQuery = queryParams.toRawQuery(); + Preconditions.checkArgument( targetPath.startsWith("/"), "the path component (%s) of the target (%s) must start with '/'", @@ -129,9 +145,15 @@ private static synchronized BootstrapInfo getBootstrapInfo() targetUri); authority = GrpcUtil.checkAuthority(targetPath.substring(1)); syncContext = checkNotNull(args, "args").getSynchronizationContext(); - targetUri = overrideUriScheme(targetUri, schemeOverride); + + Uri.Builder modifiedTargetBuilder = grpcUri.toBuilder().setScheme(schemeOverride); + modifiedTargetBuilder.setRawQuery(newQuery); + if (schemeOverride.equals("xds")) { + modifiedTargetBuilder.setRawAuthority(C2P_AUTHORITY); + } + targetUri = URI.create(modifiedTargetBuilder.build().toString()); + if (schemeOverride.equals("xds")) { - targetUri = overrideUriAuthority(targetUri, C2P_AUTHORITY); args = args.toBuilder() .setArg(XdsNameResolverProvider.XDS_CLIENT_SUPPLIER, () -> xdsClient) .build(); @@ -155,6 +177,12 @@ private static synchronized BootstrapInfo getBootstrapInfo() Resource executorResource, NameResolver.Factory nameResolverFactory) { this.executorResource = checkNotNull(executorResource, "executorResource"); + QueryParams queryParams = QueryParams.fromRawQuery(targetUri.getRawQuery()); + this.forceXds = checkForceXds(queryParams); + this.schemeOverride = (forceXds || isOnGcp) ? "xds" : "dns"; + stripForceXds(queryParams); + String newQuery = queryParams.toRawQuery(); + Preconditions.checkArgument( targetUri.isPathAbsolute(), "the path component of the target (%s) must start with '/'", @@ -167,6 +195,12 @@ private static synchronized BootstrapInfo getBootstrapInfo() authority = GrpcUtil.checkAuthority(pathSegments.get(0)); syncContext = checkNotNull(args, "args").getSynchronizationContext(); Uri.Builder modifiedTargetBuilder = targetUri.toBuilder().setScheme(schemeOverride); + if (newQuery != null) { + modifiedTargetBuilder.setRawQuery(newQuery); + } else { + modifiedTargetBuilder.setRawQuery(null); + } + if (schemeOverride.equals("xds")) { modifiedTargetBuilder.setRawAuthority(C2P_AUTHORITY); args = @@ -226,7 +260,7 @@ class Resolve implements Runnable { public void run() { BootstrapInfo bootstrapInfo = null; try { - bootstrapInfo = getBootstrapInfo(); + bootstrapInfo = getBootstrapInfo(forceXds); } catch (IOException e) { listener.onError( Status.INTERNAL.withDescription("Unable to get metadata").withCause(e)); @@ -259,16 +293,11 @@ public void run() { executor.execute(new Resolve()); } - @VisibleForTesting - static ImmutableMap generateBootstrap() throws IOException { - return generateBootstrap( - queryZoneMetadata(METADATA_URL_ZONE), - queryIpv6SupportMetadata(METADATA_URL_SUPPORT_IPV6)); - } - - private static ImmutableMap generateBootstrap(String zone, boolean supportIpv6) { + private static ImmutableMap generateBootstrap( + String zone, boolean supportIpv6) { ImmutableMap.Builder nodeBuilder = ImmutableMap.builder(); - nodeBuilder.put("id", "C2P-" + (c2pId & Integer.MAX_VALUE)); + String nodeIdPrefix = isOnGcp ? "C2P-" : "C2P-non-gcp-"; + nodeBuilder.put("id", nodeIdPrefix + (c2pId & Integer.MAX_VALUE)); if (!zone.isEmpty()) { nodeBuilder.put("locality", ImmutableMap.of("zone", zone)); } @@ -373,24 +402,17 @@ static void setC2pId(int c2pId) { GoogleCloudToProdNameResolver.c2pId = c2pId; } - private static URI overrideUriScheme(URI uri, String scheme) { - URI res; - try { - res = new URI(scheme, uri.getAuthority(), uri.getPath(), uri.getQuery(), uri.getFragment()); - } catch (URISyntaxException ex) { - throw new IllegalArgumentException("Invalid scheme: " + scheme, ex); + private static boolean checkForceXds(QueryParams params) { + for (QueryParams.Entry entry : params.asList()) { + if ("force-xds".equals(entry.getKey())) { + return true; + } } - return res; + return false; } - private static URI overrideUriAuthority(URI uri, String authority) { - URI res; - try { - res = new URI(uri.getScheme(), authority, uri.getPath(), uri.getQuery(), uri.getFragment()); - } catch (URISyntaxException ex) { - throw new IllegalArgumentException("Invalid authority: " + authority, ex); - } - return res; + private static void stripForceXds(QueryParams params) { + params.asList().removeIf(entry -> "force-xds".equals(entry.getKey())); } private enum HttpConnectionFactory implements HttpConnectionProvider { diff --git a/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java index d3d3cfc4bff..bbd3ba3ef05 100644 --- a/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java +++ b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java @@ -21,8 +21,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import io.grpc.ChannelLogger; import io.grpc.MetricRecorder; @@ -46,7 +44,6 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Random; import java.util.concurrent.Executor; @@ -103,6 +100,8 @@ public void close(Executor instance) {} private final NameResolverRegistry nsRegistry = new NameResolverRegistry(); private final Map delegatedResolver = new HashMap<>(); + private final Map delegatedUri = new HashMap<>(); + private final Map delegatedRfcUri = new HashMap<>(); @Mock private NameResolver.Listener2 mockListener; @@ -187,57 +186,125 @@ public void onGcpAndNoProvidedBootstrap_DelegateToXds() { verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); } - @SuppressWarnings("unchecked") @Test - public void generateBootstrap_ipv6() throws IOException { - Map bootstrap = GoogleCloudToProdNameResolver.generateBootstrap(); - Map node = (Map) bootstrap.get("node"); - assertThat(node).containsExactly( - "id", "C2P-991614323", - "locality", ImmutableMap.of("zone", ZONE), - "metadata", ImmutableMap.of("TRAFFICDIRECTOR_DIRECTPATH_C2P_IPV6_CAPABLE", true)); - Map server = Iterables.getOnlyElement( - (List>) bootstrap.get("xds_servers")); - assertThat(server).containsExactly( - "server_uri", "directpath-pa.googleapis.com", - "channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default")), - "server_features", ImmutableList.of("xds_v3", "ignore_resource_deletion")); - Map authorities = (Map) bootstrap.get("authorities"); - assertThat(authorities).containsExactly( - "traffic-director-c2p.xds.googleapis.com", - ImmutableMap.of("xds_servers", ImmutableList.of(server))); + public void notOnGcpButForceXds_DelegateToXds() { + GoogleCloudToProdNameResolver.isOnGcp = false; + String target = TARGET_URI + "?force-xds"; + resolver = + enableRfc3986UrisParam + ? new GoogleCloudToProdNameResolver( + Uri.create(target), args, fakeExecutorResource, nsRegistry.asFactory()) + : new GoogleCloudToProdNameResolver( + URI.create(target), args, fakeExecutorResource, nsRegistry.asFactory()); + resolver.start(mockListener); + fakeExecutor.runDueTasks(); + assertThat(delegatedResolver.keySet()).containsExactly("xds"); + + if (enableRfc3986UrisParam) { + Uri delegatedRfcUriValue = delegatedRfcUri.get("xds"); + assertThat(delegatedRfcUriValue).isNotNull(); + assertThat(delegatedRfcUriValue.getRawQuery()).isNull(); + } else { + URI delegatedUriValue = delegatedUri.get("xds"); + assertThat(delegatedUriValue).isNotNull(); + assertThat(delegatedUriValue.getQuery()).isNull(); + } + } + + @Test + public void notOnGcpButForceXds_KeyValueTrue_DelegateToXds() { + GoogleCloudToProdNameResolver.isOnGcp = false; + String target = TARGET_URI + "?force-xds=true"; + resolver = enableRfc3986UrisParam + ? new GoogleCloudToProdNameResolver( + Uri.create(target), args, fakeExecutorResource, nsRegistry.asFactory()) + : new GoogleCloudToProdNameResolver( + URI.create(target), args, fakeExecutorResource, nsRegistry.asFactory()); + resolver.start(mockListener); + fakeExecutor.runDueTasks(); + assertThat(delegatedResolver.keySet()).containsExactly("xds"); + + if (enableRfc3986UrisParam) { + Uri delegatedRfcUriValue = delegatedRfcUri.get("xds"); + assertThat(delegatedRfcUriValue).isNotNull(); + assertThat(delegatedRfcUriValue.getRawQuery()).isNull(); + } else { + URI delegatedUriValue = delegatedUri.get("xds"); + assertThat(delegatedUriValue).isNotNull(); + assertThat(delegatedUriValue.getQuery()).isNull(); + } + } + + + @Test + public void notOnGcpButForceXds_WithMultipleParams_DelegateToXds() { + GoogleCloudToProdNameResolver.isOnGcp = false; + String target = TARGET_URI + "?foo=bar&force-xds&baz=qux"; + resolver = enableRfc3986UrisParam + ? new GoogleCloudToProdNameResolver( + Uri.create(target), args, fakeExecutorResource, nsRegistry.asFactory()) + : new GoogleCloudToProdNameResolver( + URI.create(target), args, fakeExecutorResource, nsRegistry.asFactory()); + resolver.start(mockListener); + fakeExecutor.runDueTasks(); + assertThat(delegatedResolver.keySet()).containsExactly("xds"); + + if (enableRfc3986UrisParam) { + Uri delegatedRfcUriValue = delegatedRfcUri.get("xds"); + assertThat(delegatedRfcUriValue).isNotNull(); + assertThat(delegatedRfcUriValue.getRawQuery()).isEqualTo("foo=bar&baz=qux"); + } else { + URI delegatedUriValue = delegatedUri.get("xds"); + assertThat(delegatedUriValue).isNotNull(); + assertThat(delegatedUriValue.getQuery()).isEqualTo("foo=bar&baz=qux"); + } } - @SuppressWarnings("unchecked") @Test - public void generateBootstrap_noIpV6() throws IOException { - responseToIpV6 = null; - Map bootstrap = GoogleCloudToProdNameResolver.generateBootstrap(); - Map node = (Map) bootstrap.get("node"); - assertThat(node).containsExactly( - "id", "C2P-991614323", - "locality", ImmutableMap.of("zone", ZONE)); - Map server = Iterables.getOnlyElement( - (List>) bootstrap.get("xds_servers")); - assertThat(server).containsExactly( - "server_uri", "directpath-pa.googleapis.com", - "channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default")), - "server_features", ImmutableList.of("xds_v3", "ignore_resource_deletion")); - Map authorities = (Map) bootstrap.get("authorities"); - assertThat(authorities).containsExactly( - "traffic-director-c2p.xds.googleapis.com", - ImmutableMap.of("xds_servers", ImmutableList.of(server))); + public void notOnGcpButForceXds_WithEncodedAmpersand_DelegateToXds() { + GoogleCloudToProdNameResolver.isOnGcp = false; + String target = TARGET_URI + "?force-xds&foo=bar%26baz"; + resolver = enableRfc3986UrisParam + ? new GoogleCloudToProdNameResolver( + Uri.create(target), args, fakeExecutorResource, nsRegistry.asFactory()) + : new GoogleCloudToProdNameResolver( + URI.create(target), args, fakeExecutorResource, nsRegistry.asFactory()); + resolver.start(mockListener); + fakeExecutor.runDueTasks(); + assertThat(delegatedResolver.keySet()).containsExactly("xds"); + + if (enableRfc3986UrisParam) { + Uri delegatedRfcUriValue = delegatedRfcUri.get("xds"); + assertThat(delegatedRfcUriValue).isNotNull(); + assertThat(delegatedRfcUriValue.getRawQuery()).isEqualTo("foo=bar%26baz"); + } else { + URI delegatedUriValue = delegatedUri.get("xds"); + assertThat(delegatedUriValue).isNotNull(); + assertThat(delegatedUriValue.getRawQuery()).isEqualTo("foo=bar%26baz"); + } } - @SuppressWarnings("unchecked") @Test - public void emptyResolverMeetadataValue() throws IOException { - responseToIpV6 = ""; - Map bootstrap = GoogleCloudToProdNameResolver.generateBootstrap(); - Map node = (Map) bootstrap.get("node"); - assertThat(node).containsExactly( - "id", "C2P-991614323", - "locality", ImmutableMap.of("zone", ZONE)); + public void notOnGcpButForceXds_CaseSensitive_DelegateToDns() { + GoogleCloudToProdNameResolver.isOnGcp = false; + String target = TARGET_URI + "?FORCE-XDS"; + resolver = enableRfc3986UrisParam + ? new GoogleCloudToProdNameResolver( + Uri.create(target), args, fakeExecutorResource, nsRegistry.asFactory()) + : new GoogleCloudToProdNameResolver( + URI.create(target), args, fakeExecutorResource, nsRegistry.asFactory()); + resolver.start(mockListener); + assertThat(delegatedResolver.keySet()).containsExactly("dns"); + + if (enableRfc3986UrisParam) { + Uri delegatedRfcUriValue = delegatedRfcUri.get("dns"); + assertThat(delegatedRfcUriValue).isNotNull(); + assertThat(delegatedRfcUriValue.getRawQuery()).isEqualTo("FORCE-XDS"); + } else { + URI delegatedUriValue = delegatedUri.get("dns"); + assertThat(delegatedUriValue).isNotNull(); + assertThat(delegatedUriValue.getQuery()).isEqualTo("FORCE-XDS"); + } } @Test @@ -270,6 +337,18 @@ private FakeNsProvider(String scheme) { @Override public NameResolver newNameResolver(URI targetUri, Args args) { if (scheme.equals(targetUri.getScheme())) { + delegatedUri.put(scheme, targetUri); + NameResolver resolver = mock(NameResolver.class); + delegatedResolver.put(scheme, resolver); + return resolver; + } + return null; + } + + @Override + public NameResolver newNameResolver(Uri targetUri, Args args) { + if (scheme.equals(targetUri.getScheme())) { + delegatedRfcUri.put(scheme, targetUri); NameResolver resolver = mock(NameResolver.class); delegatedResolver.put(scheme, resolver); return resolver; diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java index a9ee9382495..4742675416b 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java @@ -45,12 +45,10 @@ import java.util.ArrayDeque; import java.util.Arrays; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Queue; import java.util.Random; -import java.util.Set; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.Semaphore; @@ -511,27 +509,30 @@ public static List interceptors() { } /** - * Echo the request headers from a client into response headers and trailers. Useful for + * Echo a request header from a client into response headers and trailers. Useful for * testing end-to-end metadata propagation. */ - private static ServerInterceptor echoRequestHeadersInterceptor(final Metadata.Key... keys) { - final Set> keySet = new HashSet<>(Arrays.asList(keys)); + private static ServerInterceptor echoRequestHeadersInterceptor(final Metadata.Key key) { return new ServerInterceptor() { @Override public ServerCall.Listener interceptCall( ServerCall call, - final Metadata requestHeaders, + Metadata requestHeaders, ServerCallHandler next) { + if (!requestHeaders.containsKey(key)) { + return next.startCall(call, requestHeaders); + } + T value = requestHeaders.get(key); return next.startCall(new SimpleForwardingServerCall(call) { @Override public void sendHeaders(Metadata responseHeaders) { - responseHeaders.merge(requestHeaders, keySet); + responseHeaders.put(key, value); super.sendHeaders(responseHeaders); } @Override public void close(Status status, Metadata trailers) { - trailers.merge(requestHeaders, keySet); + trailers.put(key, value); super.close(status, trailers); } }, requestHeaders); @@ -540,52 +541,48 @@ public void close(Status status, Metadata trailers) { } /** - * Echoes request headers with the specified key(s) from a client into response headers only. + * Echoes request headers with the specified key from a client into response headers only. */ - private static ServerInterceptor echoRequestMetadataInHeaders(final Metadata.Key... keys) { - final Set> keySet = new HashSet<>(Arrays.asList(keys)); + private static ServerInterceptor echoRequestMetadataInHeaders(final Metadata.Key key) { return new ServerInterceptor() { @Override public ServerCall.Listener interceptCall( ServerCall call, final Metadata requestHeaders, ServerCallHandler next) { + if (!requestHeaders.containsKey(key)) { + return next.startCall(call, requestHeaders); + } + T value = requestHeaders.get(key); return next.startCall(new SimpleForwardingServerCall(call) { @Override public void sendHeaders(Metadata responseHeaders) { - responseHeaders.merge(requestHeaders, keySet); + responseHeaders.put(key, value); super.sendHeaders(responseHeaders); } - - @Override - public void close(Status status, Metadata trailers) { - super.close(status, trailers); - } }, requestHeaders); } }; } /** - * Echoes request headers with the specified key(s) from a client into response trailers only. + * Echoes request headers with the specified key from a client into response trailers only. */ - private static ServerInterceptor echoRequestMetadataInTrailers(final Metadata.Key... keys) { - final Set> keySet = new HashSet<>(Arrays.asList(keys)); + private static ServerInterceptor echoRequestMetadataInTrailers(final Metadata.Key key) { return new ServerInterceptor() { @Override public ServerCall.Listener interceptCall( ServerCall call, final Metadata requestHeaders, ServerCallHandler next) { + if (!requestHeaders.containsKey(key)) { + return next.startCall(call, requestHeaders); + } + T value = requestHeaders.get(key); return next.startCall(new SimpleForwardingServerCall(call) { - @Override - public void sendHeaders(Metadata responseHeaders) { - super.sendHeaders(responseHeaders); - } - @Override public void close(Status status, Metadata trailers) { - trailers.merge(requestHeaders, keySet); + trailers.put(key, value); super.close(status, trailers); } }, requestHeaders); diff --git a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java index 3c9d2bbe184..4ef14b0e933 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java @@ -22,6 +22,7 @@ import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS; import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIME_NANOS; +import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_PERMIT_KEEPALIVE_TIME_NANOS; import static io.grpc.internal.GrpcUtil.SERVER_KEEPALIVE_TIME_NANOS_DISABLED; import com.google.common.annotations.VisibleForTesting; @@ -113,7 +114,7 @@ public final class NettyServerBuilder extends ForwardingServerBuilder SET_USE_SESSION_TICKETS = diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java index 163d2023b1c..50097e1922e 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.grpc.internal.CertificateUtils.createTrustManager; +import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_PERMIT_KEEPALIVE_TIME_NANOS; import com.google.common.base.Preconditions; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -137,7 +138,7 @@ public InternalServer buildClientTransportServers( int maxInboundMessageSize = GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; long maxConnectionIdleInNanos = MAX_CONNECTION_IDLE_NANOS_DISABLED; boolean permitKeepAliveWithoutCalls; - long permitKeepAliveTimeInNanos = TimeUnit.MINUTES.toNanos(5); + long permitKeepAliveTimeInNanos = DEFAULT_SERVER_PERMIT_KEEPALIVE_TIME_NANOS; long maxConnectionAgeInNanos = MAX_CONNECTION_AGE_NANOS_DISABLED; long maxConnectionAgeGraceInNanos = MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE; int maxConcurrentCallsPerConnection = MAX_CONCURRENT_STREAMS; diff --git a/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java index 63c6f33ff79..ad9af056afc 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java +++ b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java @@ -41,7 +41,7 @@ static final class ServerCredentials extends io.grpc.ServerCredentials { private final ConnectionSpec connectionSpec; ServerCredentials(SSLSocketFactory factory) { - this(factory, OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC); + this(factory, OkHttpChannelBuilder.INTERNAL_LEGACY_CONNECTION_SPEC); } ServerCredentials(SSLSocketFactory factory, ConnectionSpec connectionSpec) { diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Hpack.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Hpack.java index 484cc5673dc..fa9458e2024 100644 --- a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Hpack.java +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Hpack.java @@ -354,6 +354,13 @@ int readInt(int firstByte, int prefixMask) throws IOException { if ((b & 0x80) != 0) { // Equivalent to (b >= 128) since b is in [0..255]. result += (b & 0x7f) << shift; shift += 7; + // We can safely store 31 bits, and then next byte will have 7 more bits. While the next + // byte may not have high bits set to cause an overflow, that's only useful for 256+ MiB + // values, which is excessive. This also gives us at least one bit of spare, which is + // necessary to store the carry from the addition. + if (shift >= 28) { + throw new IOException("Varint overflowed"); + } } else { result += b << shift; // Last byte. break; @@ -483,9 +490,14 @@ void writeHeaders(List headerBlock) throw writeByteString(name); writeByteString(value); insertIntoDynamicTable(header); - } else if (name.startsWith(PSEUDO_PREFIX) && !io.grpc.okhttp.internal.framed.Header.TARGET_AUTHORITY.equals(name)) { - // Follow Chromes lead - only include the :authority pseudo header, but exclude all other - // pseudo headers. Literal Header Field without Indexing - Indexed Name. + } else if (name.startsWith(PSEUDO_PREFIX) + && !io.grpc.okhttp.internal.framed.Header.TARGET_AUTHORITY.equals(name) + && !io.grpc.okhttp.internal.framed.Header.TARGET_PATH.equals(name)) { + // Allow :authority and :path pseudo headers to be indexed. Other pseudo headers are not + // indexed. + // This is a departure from the original Chrome-inspired behavior, as gRPC paths + // (ServiceName/MethodName) + // are stable and benefit from indexing. writeInt(headerNameIndex, PREFIX_4_BITS, 0); writeByteString(value); } else { @@ -508,6 +520,9 @@ void writeInt(int value, int prefixMask, int bits) throws IOException { // Write the mask to start a multibyte value. out.writeByte(bits | prefixMask); value -= prefixMask; + if (value > 0xfffffff) { + throw new IOException("Varint would overflow reader"); + } // Write 7 bits at a time 'til we're done. while (value >= 0x80) { diff --git a/okhttp/third_party/okhttp/test/java/io/grpc/okhttp/internal/framed/HpackTest.java b/okhttp/third_party/okhttp/test/java/io/grpc/okhttp/internal/framed/HpackTest.java index 26580f85e54..8635151c6b9 100644 --- a/okhttp/third_party/okhttp/test/java/io/grpc/okhttp/internal/framed/HpackTest.java +++ b/okhttp/third_party/okhttp/test/java/io/grpc/okhttp/internal/framed/HpackTest.java @@ -272,14 +272,18 @@ public void readerEviction() throws IOException { /** * http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-12#appendix-C.2.2 + * + *

This test mimics the draft example which uses ":path", but since gRPC-Java now indexes + * ":path" for performance, we use ":method" (with a non-static value like "PUT") to verify the + * "Literal Header Field without Indexing - Indexed Name" representation. */ @Test public void literalHeaderFieldWithoutIndexingIndexedName() throws IOException { - List

headerBlock = headerEntries(":path", "/sample/path"); + List
headerBlock = headerEntries(":method", "PUT"); - bytesIn.writeByte(0x04); // == Literal not indexed == - // Indexed name (idx = 4) -> :path - bytesIn.writeByte(0x0c); // Literal value (len = 12) - bytesIn.writeUtf8("/sample/path"); + bytesIn.writeByte(0x02); // == Literal not indexed == + // Indexed name (idx = 2) -> :method + bytesIn.writeByte(0x03); // Literal value (len = 3) + bytesIn.writeUtf8("PUT"); hpackWriter.writeHeaders(headerBlock); assertEquals(bytesIn, bytesOut); @@ -455,7 +459,7 @@ public void theSameHeaderAfterOneIncrementalIndexed() throws IOException { hpackReader.readHeaders(); fail(); } catch (IOException e) { - assertEquals("Header index too large -2147483521", e.getMessage()); + assertEquals("Varint overflowed", e.getMessage()); } } @@ -497,7 +501,7 @@ public void theSameHeaderAfterOneIncrementalIndexed() throws IOException { hpackReader.readHeaders(); fail(); } catch (IOException e) { - assertEquals("Invalid dynamic table size update -2147483648", e.getMessage()); + assertEquals("Varint overflowed", e.getMessage()); } } @@ -856,11 +860,53 @@ private void checkReadThirdRequestWithHuffman() { assertBytes(0xe0 | 31, 154, 10); } - @Test public void max31BitValue() throws IOException { - hpackWriter.writeInt(0x7fffffff, 31, 0); - assertBytes(31, 224, 255, 255, 255, 7); - assertEquals(0x7fffffff, - newReader(byteStream(224, 255, 255, 255, 7)).readInt(31, 31)); + @Test public void max29BitValue() throws IOException { + hpackWriter.writeInt(0x100000fe, 0xff, 0xff); + assertBytes(0xff, 0xff, 0xff, 0xff, 0x7f); + assertEquals(0x100000fe, + newReader(byteStream(0xff, 0xff, 0xff, 0x7f)).readInt(0xff, 0xff)); + } + + @Test public void beyondMax29BitValue() throws IOException { + try { + hpackWriter.writeInt(0x100000ff, 0xff, 0xff); + fail(); + } catch (IOException e) { + assertEquals("Varint would overflow reader", e.getMessage()); + } + try { + newReader(byteStream(0xff, 0xff, 0xff, 0xff, 0x80)).readInt(0xff, 0xff); + fail(); + } catch (IOException e) { + assertEquals("Varint overflowed", e.getMessage()); + } + } + + @Test public void beyondMax29BitValue_smallPrefix() throws IOException { + try { + hpackWriter.writeInt(0x10000001, 1, 1); + fail(); + } catch (IOException e) { + assertEquals("Varint would overflow reader", e.getMessage()); + } + try { + newReader(byteStream(0xff, 0xff, 0xff, 0xff, 0x80)).readInt(1, 1); + fail(); + } catch (IOException e) { + assertEquals("Varint overflowed", e.getMessage()); + } + } + + @Test public void readerAbortsLongVarintsWithZeros() throws IOException { + try { + // The reader should fail before getting to the end, because it will overflow as soon as there + // is a 1 bit, and the only reason to use this many continuations is to eventually have a 1 + // bit. + newReader(byteStream(0x80, 0x80, 0x80, 0x80, 0x80)).readInt(31, 31); + fail(); + } catch (IOException e) { + assertEquals("Varint overflowed", e.getMessage()); + } } @Test public void prefixMask() throws IOException { @@ -1062,14 +1108,29 @@ public void dynamicTableIndexedHeader() throws IOException { } @Test - public void doNotIndexPseudoHeaders() throws IOException { + public void pseudoHeaderIndexing() throws IOException { + // :method is not indexed (unless it's GET or POST, which are in the static table) hpackWriter.writeHeaders(headerEntries(":method", "PUT")); assertBytes(0x02, 3, 'P', 'U', 'T'); assertEquals(0, hpackWriter.dynamicTableHeaderCount); + // :path should now be indexed hpackWriter.writeHeaders(headerEntries(":path", "/okhttp")); - assertBytes(0x04, 7, '/', 'o', 'k', 'h', 't', 't', 'p'); - assertEquals(0, hpackWriter.dynamicTableHeaderCount); + assertBytes(0x44, 7, '/', 'o', 'k', 'h', 't', 't', 'p'); + assertEquals(1, hpackWriter.dynamicTableHeaderCount); + // Second time should be an index + hpackWriter.writeHeaders(headerEntries(":path", "/okhttp")); + assertBytes(0xbe); + assertEquals(1, hpackWriter.dynamicTableHeaderCount); + + // :authority should be indexed + hpackWriter.writeHeaders(headerEntries(":authority", "test.com")); + assertBytes(0x41, 8, 't', 'e', 's', 't', '.', 'c', 'o', 'm'); + assertEquals(2, hpackWriter.dynamicTableHeaderCount); + // Second time should be an index + hpackWriter.writeHeaders(headerEntries(":authority", "test.com")); + assertBytes(0xbe); + assertEquals(2, hpackWriter.dynamicTableHeaderCount); } @Test diff --git a/servlet/jakarta/build.gradle b/servlet/jakarta/build.gradle index bcd904ccaee..5cd213949f4 100644 --- a/servlet/jakarta/build.gradle +++ b/servlet/jakarta/build.gradle @@ -122,6 +122,11 @@ if (JavaVersion.current().isCompatibleWith(JavaVersion.VERSION_17)) { tasks.named("check").configure { dependsOn jetty11Test } + tasks.named("jacocoTestReport").configure { + // Must use executionData(Task...) override. The executionData(Object...) override doesn't + // find execution data correctly for tasks. + executionData jetty11Test.get() + } } if (JavaVersion.current().isJava11Compatible()) { def tomcat10Test = tasks.register('tomcat10Test', Test) { @@ -150,4 +155,9 @@ if (JavaVersion.current().isJava11Compatible()) { tasks.named("check").configure { dependsOn tomcat10Test, undertowTest } + tasks.named("jacocoTestReport").configure { + // Must use executionData(Task...) override. The executionData(Object...) override doesn't + // find execution data correctly for tasks. + executionData tomcat10Test.get(), undertowTest.get() + } } diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index 5a59b47c529..f6ee60ab1ef 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -27,6 +27,7 @@ import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.HttpConnectProxiedSocketAddress; +import io.grpc.InternalEquivalentAddressGroup; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; import io.grpc.LoadBalancerProvider; @@ -369,6 +370,7 @@ StatusOr edsUpdateToResult( String localityName = localityName(locality); Attributes attr = endpoint.eag().getAttributes().toBuilder() + .set(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE, clusterName) .set(io.grpc.xds.XdsAttributes.ATTR_LOCALITY, locality) .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, localityName) .set(io.grpc.xds.XdsAttributes.ATTR_LOCALITY_WEIGHT, diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index 01ef3d97b57..3cf28d23578 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -79,10 +79,17 @@ public static final class UpstreamTlsContext extends BaseTlsContext { @VisibleForTesting public UpstreamTlsContext(CommonTlsContext commonTlsContext) { + this(commonTlsContext, "", false, false); + } + + @VisibleForTesting + public UpstreamTlsContext( + CommonTlsContext commonTlsContext, String sni, boolean autoHostSni, + boolean autoSniSanValidation) { super(commonTlsContext); - this.sni = null; - this.autoHostSni = false; - this.autoSniSanValidation = false; + this.sni = sni == null ? "" : sni; + this.autoHostSni = autoHostSni; + this.autoSniSanValidation = autoSniSanValidation; } @VisibleForTesting @@ -122,6 +129,26 @@ public String toString() { + "\nauto_sni_san_validation=" + autoSniSanValidation + "}"; } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + UpstreamTlsContext that = (UpstreamTlsContext) o; + return autoHostSni == that.autoHostSni + && autoSniSanValidation == that.autoSniSanValidation + && Objects.equals(commonTlsContext, that.commonTlsContext) + && Objects.equals(sni, that.sni); + } + + @Override + public int hashCode() { + return Objects.hash(commonTlsContext, sni, autoHostSni, autoSniSanValidation); + } } public static final class DownstreamTlsContext extends BaseTlsContext { diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index a8b7e120cca..6744903de35 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -41,6 +41,7 @@ import io.grpc.util.ForwardingSubchannel; import io.grpc.util.MultiChildLoadBalancer; import io.grpc.xds.internal.MetricReportUtils; +import io.grpc.xds.internal.MetricReportUtils.ParsedMetricName; import io.grpc.xds.orca.OrcaOobUtil; import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener; import io.grpc.xds.orca.OrcaPerRequestUtil; @@ -239,7 +240,7 @@ protected void updateOverallBalancingState() { private SubchannelPicker createReadyPicker(Collection activeList) { WeightedRoundRobinPicker picker = new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList), config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, - config.metricNamesForComputingUtilization); + config.parsedMetricNamesForComputingUtilization); updateWeight(picker); return picker; } @@ -329,15 +330,15 @@ public void addSubchannel(WrrSubchannel wrrSubchannel) { } public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty, - ImmutableList metricNamesForComputingUtilization) { + ImmutableList parsedMetricNamesForComputingUtilization) { if (orcaReportListener != null && orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty - && orcaReportListener.metricNamesForComputingUtilization - .equals(metricNamesForComputingUtilization)) { + && orcaReportListener.parsedMetricNamesForComputingUtilization + .equals(parsedMetricNamesForComputingUtilization)) { return orcaReportListener; } orcaReportListener = - new OrcaReportListener(errorUtilizationPenalty, metricNamesForComputingUtilization); + new OrcaReportListener(errorUtilizationPenalty, parsedMetricNamesForComputingUtilization); return orcaReportListener; } @@ -362,17 +363,17 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener { private final float errorUtilizationPenalty; - private final ImmutableList metricNamesForComputingUtilization; + private final ImmutableList parsedMetricNamesForComputingUtilization; OrcaReportListener(float errorUtilizationPenalty, - ImmutableList metricNamesForComputingUtilization) { + ImmutableList parsedMetricNamesForComputingUtilization) { this.errorUtilizationPenalty = errorUtilizationPenalty; - this.metricNamesForComputingUtilization = metricNamesForComputingUtilization; + this.parsedMetricNamesForComputingUtilization = parsedMetricNamesForComputingUtilization; } @Override public void onLoadReport(MetricReport report) { - double utilization = getUtilization(report, metricNamesForComputingUtilization); + double utilization = getUtilization(report); double newWeight = 0; if (utilization > 0 && report.getQps() > 0) { @@ -398,8 +399,8 @@ public void onLoadReport(MetricReport report) { * if application utilization is > 0, it is returned. If neither are present, the CPU * utilization is returned. */ - private double getUtilization(MetricReport report, ImmutableList metricNames) { - OptionalDouble customUtil = getCustomMetricUtilization(report, metricNames); + private double getUtilization(MetricReport report) { + OptionalDouble customUtil = getCustomMetricUtilization(report); if (customUtil.isPresent()) { return customUtil.getAsDouble(); } @@ -411,19 +412,23 @@ private double getUtilization(MetricReport report, ImmutableList metricN } /** - * Returns the maximum utilization value among the specified metric names. + * Returns the maximum utilization value among the parsed metric names. * Returns OptionalDouble.empty() if NONE of the specified metrics are present in the report, - * or if all present metrics are NaN. - * Returns OptionalDouble.of(maxUtil) if at least one non-NaN metric is present. + * or if all present metrics are NaN or non positive. */ - private OptionalDouble getCustomMetricUtilization(MetricReport report, - ImmutableList metricNames) { - return metricNames.stream() - .map(name -> MetricReportUtils.getMetric(report, name)) - .filter(OptionalDouble::isPresent) - .mapToDouble(OptionalDouble::getAsDouble) - .filter(d -> !Double.isNaN(d) && d > 0) - .max(); + private OptionalDouble getCustomMetricUtilization(MetricReport report) { + OptionalDouble max = OptionalDouble.empty(); + for (int i = 0; i < parsedMetricNamesForComputingUtilization.size(); i++) { + OptionalDouble opt = MetricReportUtils.getMetricValue(report, + parsedMetricNamesForComputingUtilization.get(i)); + if (opt.isPresent()) { + double d = opt.getAsDouble(); + if (!Double.isNaN(d) && d > 0 && (!max.isPresent() || d > max.getAsDouble())) { + max = opt; + } + } + } + return max; } } } @@ -446,7 +451,7 @@ private void createAndApplyOrcaListeners() { if (config.enableOobLoadReport) { OrcaOobUtil.setListener(weightedSubchannel, wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty, - config.metricNamesForComputingUtilization), + config.parsedMetricNamesForComputingUtilization), OrcaOobUtil.OrcaReportingConfig.newBuilder() .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS).build()); } else { @@ -516,7 +521,7 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker { WeightedRoundRobinPicker(List children, boolean enableOobLoadReport, float errorUtilizationPenalty, AtomicInteger sequence, - ImmutableList metricNamesForComputingUtilization) { + ImmutableList parsedMetricNamesForComputingUtilization) { checkNotNull(children, "children"); Preconditions.checkArgument(!children.isEmpty(), "empty child list"); this.children = children; @@ -526,7 +531,7 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker { WeightedChildLbState wChild = (WeightedChildLbState) child; pickers.add(wChild.getCurrentPicker()); reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty, - metricNamesForComputingUtilization)); + parsedMetricNamesForComputingUtilization)); } this.pickers = pickers; this.reportListeners = reportListeners; @@ -767,7 +772,7 @@ static final class WeightedRoundRobinLoadBalancerConfig { final long oobReportingPeriodNanos; final long weightUpdatePeriodNanos; final float errorUtilizationPenalty; - final ImmutableList metricNamesForComputingUtilization; + final ImmutableList parsedMetricNamesForComputingUtilization; public static Builder newBuilder() { return new Builder(); @@ -783,7 +788,20 @@ private WeightedRoundRobinLoadBalancerConfig(long blackoutPeriodNanos, this.oobReportingPeriodNanos = oobReportingPeriodNanos; this.weightUpdatePeriodNanos = weightUpdatePeriodNanos; this.errorUtilizationPenalty = errorUtilizationPenalty; - this.metricNamesForComputingUtilization = metricNamesForComputingUtilization; + + ImmutableList.Builder builder = ImmutableList.builder(); + if (metricNamesForComputingUtilization != null) { + for (int i = 0; i < metricNamesForComputingUtilization.size(); i++) { + String metricName = metricNamesForComputingUtilization.get(i); + ParsedMetricName parsed = MetricReportUtils.ParsedMetricName.parse(metricName); + if (parsed.getMetricType() != MetricReportUtils.MetricType.INVALID) { + builder.add(parsed); + } else { + log.log(Level.FINE, "Invalid custom metric name configured and ignored: " + metricName); + } + } + } + this.parsedMetricNamesForComputingUtilization = builder.build(); } @Override @@ -799,15 +817,15 @@ public boolean equals(Object o) { && this.weightUpdatePeriodNanos == that.weightUpdatePeriodNanos // Float.compare considers NaNs equal && Float.compare(this.errorUtilizationPenalty, that.errorUtilizationPenalty) == 0 - && Objects.equals(this.metricNamesForComputingUtilization, - that.metricNamesForComputingUtilization); + && Objects.equals(this.parsedMetricNamesForComputingUtilization, + that.parsedMetricNamesForComputingUtilization); } @Override public int hashCode() { return Objects.hash(blackoutPeriodNanos, weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos, weightUpdatePeriodNanos, errorUtilizationPenalty, - metricNamesForComputingUtilization); + parsedMetricNamesForComputingUtilization); } static final class Builder { diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java index e17b8764a6c..0f9fcf07c9a 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java @@ -81,7 +81,7 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal(Map rawC Long weightUpdatePeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "weightUpdatePeriod"); Float errorUtilizationPenalty = JsonUtil.getNumberAsFloat(rawConfig, "errorUtilizationPenalty"); List metricNamesForComputingUtilization = JsonUtil.getListOfStrings(rawConfig, - LoadBalancerConfigFactory.METRIC_NAMES_FOR_COMPUTING_UTILIZATION); + "metricNamesForComputingUtilization"); WeightedRoundRobinLoadBalancerConfig.Builder configBuilder = WeightedRoundRobinLoadBalancerConfig.newBuilder(); diff --git a/xds/src/main/java/io/grpc/xds/XdsDependencyManager.java b/xds/src/main/java/io/grpc/xds/XdsDependencyManager.java index 919836ddd9c..a0af5974175 100644 --- a/xds/src/main/java/io/grpc/xds/XdsDependencyManager.java +++ b/xds/src/main/java/io/grpc/xds/XdsDependencyManager.java @@ -30,6 +30,7 @@ import io.grpc.Status; import io.grpc.StatusOr; import io.grpc.SynchronizationContext; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.RetryingNameResolver; import io.grpc.xds.Endpoints.LocalityLbEndpoints; import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight; @@ -652,13 +653,10 @@ public void onResourceChanged(StatusOr update) { data = update; subscribeToChildren(update.getValue()); } else { - Status status = update.getStatus(); - Status translatedStatus = Status.UNAVAILABLE.withDescription( - String.format("Error retrieving %s: %s. Details: %s%s", - toContextString(), - status.getCode(), - status.getDescription() != null ? status.getDescription() : "", - nodeInfo())); + Status translatedStatus = GrpcUtil.statusWithDetails( + Status.Code.UNAVAILABLE, + "Error retrieving " + toContextString() + nodeInfo(), + update.getStatus()); data = StatusOr.fromStatus(translatedStatus); } diff --git a/xds/src/main/java/io/grpc/xds/client/ControlPlaneClient.java b/xds/src/main/java/io/grpc/xds/client/ControlPlaneClient.java index 59f439d3687..981db516e5b 100644 --- a/xds/src/main/java/io/grpc/xds/client/ControlPlaneClient.java +++ b/xds/src/main/java/io/grpc/xds/client/ControlPlaneClient.java @@ -160,6 +160,31 @@ void adjustResourceSubscription(XdsResourceType resourceType) { } Collection resources = resourceStore.getSubscribedResources(serverInfo, resourceType); + if (resources == null && !adsStream.sentTypes.contains(resourceType)) { + // No subscription for this type on this server, and we have never sent a DiscoveryRequest + // of this type on the current stream — the server has no subscription state to clear. + // + // Per the ResourceStore contract in XdsClient.java, a null return means "no subscription"; + // an empty collection means wildcard subscription, which is a real subscription and must + // not be skipped here. + // + // We track sent types per-stream rather than gating on `versions` because `versions` is + // only populated on ACK. If a watch is canceled after the initial DiscoveryRequest goes + // out but before any response is ACKed, `versions` would still have no entry for the + // type, and gating on it would suppress the empty unsubscribe — leaving the server with + // a stale subscription until the stream resets. + // + // Without this skip, sendDiscoveryRequests() iterates over every globally-subscribed + // resource type when a stream becomes ready and emits an empty DiscoveryRequest for types + // that have no subscription on this server. Per A47 (xDS Federation) servers may be + // authority-specific (e.g. an EDS-only control plane) and reject DiscoveryRequests for + // types they do not handle, tearing down the stream. + // + // Mirrors grpc-go's behavior in + // internal/xds/clients/xdsclient/ads_stream.go:sendExisting, which skips types with no + // subscription. + return; + } if (resources == null) { resources = Collections.emptyList(); } @@ -319,6 +344,11 @@ private class AdsStream implements XdsTransportFactory.EventHandler respNonces = new HashMap<>(); + // Resource types for which a DiscoveryRequest has been sent on this stream. Used by + // adjustResourceSubscription() to decide whether an empty unsubscribe must be sent on the + // wire: the server only has subscription state to clear for types we have actually sent a + // request for on this stream. Cleared implicitly when the stream is replaced. + private final Set> sentTypes = new HashSet<>(); private final StreamingCall call; private final MethodDescriptor methodDescriptor = AggregatedDiscoveryServiceGrpc.getStreamAggregatedResourcesMethod(); @@ -358,6 +388,7 @@ void sendDiscoveryRequest(XdsResourceType type, String versionInfo, } DiscoveryRequest request = builder.build(); call.sendMessage(request); + sentTypes.add(type); if (logger.isLoggable(XdsLogLevel.DEBUG)) { logger.log(XdsLogLevel.DEBUG, "Sent DiscoveryRequest\n{0}", messagePrinter.print(request)); } diff --git a/xds/src/main/java/io/grpc/xds/internal/MetricReportUtils.java b/xds/src/main/java/io/grpc/xds/internal/MetricReportUtils.java index 7da9a3ab6d9..4194cab76d3 100644 --- a/xds/src/main/java/io/grpc/xds/internal/MetricReportUtils.java +++ b/xds/src/main/java/io/grpc/xds/internal/MetricReportUtils.java @@ -16,10 +16,12 @@ package io.grpc.xds.internal; +import com.google.auto.value.AutoValue; import io.grpc.services.MetricReport; -import java.util.Map; +import java.util.Optional; import java.util.OptionalDouble; + /** * Utilities for parsing and resolving metrics from {@link MetricReport}. */ @@ -27,41 +29,91 @@ public final class MetricReportUtils { private MetricReportUtils() {} + public enum MetricType { + CPU_UTILIZATION, + APPLICATION_UTILIZATION, + MEMORY_UTILIZATION, + UTILIZATION, + NAMED_METRICS, + INVALID + } + + @AutoValue + public abstract static class ParsedMetricName { + public abstract MetricType getMetricType(); + + public abstract Optional getKey(); + + public static ParsedMetricName create(MetricType metricType, Optional key) { + return new AutoValue_MetricReportUtils_ParsedMetricName(metricType, key); + } + + /** + * Pre-parses a custom metric name into a {@link ParsedMetricName}. + * + * @param name The custom metric name to parse. + * @return The parsed metric name. + */ + public static ParsedMetricName parse(String name) { + if (name.equals("cpu_utilization")) { + return create(MetricType.CPU_UTILIZATION, Optional.empty()); + } + if (name.equals("application_utilization")) { + return create(MetricType.APPLICATION_UTILIZATION, Optional.empty()); + } + if (name.equals("mem_utilization")) { + return create(MetricType.MEMORY_UTILIZATION, Optional.empty()); + } + if (name.startsWith("utilization.")) { + return create(MetricType.UTILIZATION, Optional.of(name.substring("utilization.".length()))); + } + if (name.startsWith("named_metrics.")) { + return create(MetricType.NAMED_METRICS, + Optional.of(name.substring("named_metrics.".length()))); + } + return create(MetricType.INVALID, Optional.empty()); + } + + } + /** - * Resolves a metric value from the report based on the given metric name. - * The logic checks for specific prefixes to determine where to look up the metric: - *
    - *
  • "cpu_utilization" -> getCpuUtilization()
  • - *
  • "application_utilization" -> getApplicationUtilization()
  • - *
  • "mem_utilization" -> getMemoryUtilization()
  • - *
  • "utilization." -> lookup in utilizationMetrics
  • - *
  • "named_metrics." -> lookup in namedMetrics
  • - *
+ * Resolves a custom metric value for `parsedMetric` + * Returns OptionalDouble.empty() if the metric is absent or invalid. * * @param report The metric report to query. - * @param metricName The name of the custom metric to look up. - * @return The value of the metric if found, or empty if not found. + * @param parsedMetric The parsed metric to lookup. + * @return The metric value wrapped in an OptionalDouble, or empty if absent. */ - public static OptionalDouble getMetric(MetricReport report, String metricName) { - if (metricName.equals("cpu_utilization")) { - return OptionalDouble.of(report.getCpuUtilization()); - } else if (metricName.equals("application_utilization")) { - return OptionalDouble.of(report.getApplicationUtilization()); - } else if (metricName.equals("mem_utilization")) { - return OptionalDouble.of(report.getMemoryUtilization()); - } else if (metricName.startsWith("utilization.")) { - Map map = report.getUtilizationMetrics(); - Double val = map.get(metricName.substring("utilization.".length())); - if (val != null) { - return OptionalDouble.of(val); - } - } else if (metricName.startsWith("named_metrics.")) { - Map map = report.getNamedMetrics(); - Double val = map.get(metricName.substring("named_metrics.".length())); - if (val != null) { - return OptionalDouble.of(val); - } + + public static OptionalDouble getMetricValue(MetricReport report, ParsedMetricName parsedMetric) { + switch (parsedMetric.getMetricType()) { + case CPU_UTILIZATION: + return OptionalDouble.of(report.getCpuUtilization()); + case APPLICATION_UTILIZATION: + return OptionalDouble.of(report.getApplicationUtilization()); + case MEMORY_UTILIZATION: + return OptionalDouble.of(report.getMemoryUtilization()); + case UTILIZATION: + if (parsedMetric.getKey().isPresent()) { + String key = parsedMetric.getKey().get(); + Double val = report.getUtilizationMetrics().get(key); + if (val != null) { + return OptionalDouble.of(val); + } + } + return OptionalDouble.empty(); + case NAMED_METRICS: + if (parsedMetric.getKey().isPresent()) { + String key = parsedMetric.getKey().get(); + Double val = report.getNamedMetrics().get(key); + if (val != null) { + return OptionalDouble.of(val); + } + } + return OptionalDouble.empty(); + case INVALID: + default: + return OptionalDouble.empty(); } - return OptionalDouble.empty(); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationDisallowedException.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationDisallowedException.java new file mode 100644 index 00000000000..b8d4eb582fb --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationDisallowedException.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import io.grpc.Status; +import io.grpc.StatusException; + +/** + * Exception thrown when a header mutation is disallowed. + */ +public final class HeaderMutationDisallowedException extends StatusException { + + private static final long serialVersionUID = 1L; + + public HeaderMutationDisallowedException(String message) { + super(Status.INTERNAL.withDescription(message)); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationFilter.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationFilter.java new file mode 100644 index 00000000000..35cab17d928 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationFilter.java @@ -0,0 +1,114 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.common.collect.ImmutableList; +import io.grpc.xds.internal.grpcservice.HeaderValueValidationUtils; +import java.util.Collection; +import java.util.Optional; +import java.util.function.Predicate; + +/** + * The HeaderMutationFilter class is responsible for filtering header mutations based on a given set + * of rules. + */ +public class HeaderMutationFilter { + private final Optional mutationRules; + + + + public HeaderMutationFilter(Optional mutationRules) { + this.mutationRules = mutationRules; + } + + /** + * Filters the given header mutations based on the configured rules and returns the allowed + * mutations. + * + * @param mutations The header mutations to filter + * @return The allowed header mutations. + * @throws HeaderMutationDisallowedException if a disallowed mutation is encountered and the rules + * specify that this should be an error. + */ + public HeaderMutations filter(HeaderMutations mutations) + throws HeaderMutationDisallowedException { + ImmutableList allowedHeaders = + filterCollection(mutations.headers(), this::isDisallowed, this::isHeaderMutationAllowed); + ImmutableList allowedHeadersToRemove = + filterCollection(mutations.headersToRemove(), this::isDisallowed, + this::isHeaderMutationAllowed); + return HeaderMutations.create(allowedHeaders, allowedHeadersToRemove); + } + + /** + * A generic helper to filter a collection based on a predicate. + */ + private ImmutableList filterCollection(Collection items, + Predicate isIgnoredPredicate, Predicate isAllowedPredicate) + throws HeaderMutationDisallowedException { + ImmutableList.Builder allowed = ImmutableList.builder(); + for (T item : items) { + boolean isIgnored = isIgnoredPredicate.test(item); + boolean isAllowed = isAllowedPredicate.test(item); + + // TODO(sauravzg): The specification is ambiguous regarding whether system headers + // should be silently ignored or trigger an error when disallowIsError is enabled. + // We default to triggering errors matching Envoy's implementation. + // Ref: https://github.com/grpc/proposal/pull/481#discussion_r3124453674 + if (!isIgnored && isAllowed) { + allowed.add(item); + } else if (disallowIsError()) { + throw new HeaderMutationDisallowedException("Header mutation disallowed"); + } + } + return allowed.build(); + } + + private boolean isDisallowed(String key) { + return HeaderValueValidationUtils.isDisallowed(key); + } + + private boolean isDisallowed(HeaderValueOption option) { + return HeaderValueValidationUtils.isDisallowed(option.header()); + } + + private boolean isHeaderMutationAllowed(HeaderValueOption option) { + return isHeaderMutationAllowed(option.header().key()); + } + + private boolean isHeaderMutationAllowed(String headerName) { + return mutationRules.map(rules -> isHeaderMutationAllowed(headerName, rules)) + .orElse(true); + } + + private boolean isHeaderMutationAllowed(String headerName, + HeaderMutationRulesConfig rules) { + if (rules.disallowExpression().isPresent() + && rules.disallowExpression().get().matcher(headerName).matches()) { + return false; + } + if (rules.allowExpression().isPresent() + && rules.allowExpression().get().matcher(headerName).matches()) { + return true; + } + return !rules.disallowAll(); + } + + private boolean disallowIsError() { + return mutationRules.map(HeaderMutationRulesConfig::disallowIsError).orElse(false); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutations.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutations.java new file mode 100644 index 00000000000..a456413c899 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutations.java @@ -0,0 +1,34 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; + +/** A collection of header mutations. */ +@AutoValue +public abstract class HeaderMutations { + + public static HeaderMutations create(ImmutableList headers, + ImmutableList headersToRemove) { + return new AutoValue_HeaderMutations(headers, headersToRemove); + } + + public abstract ImmutableList headers(); + + public abstract ImmutableList headersToRemove(); +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutator.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutator.java new file mode 100644 index 00000000000..e6cdc126f22 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutator.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + + +import io.grpc.Metadata; +import io.grpc.xds.internal.grpcservice.HeaderValue; +import io.grpc.xds.internal.headermutations.HeaderValueOption.HeaderAppendAction; +import java.util.logging.Logger; + +/** + * The HeaderMutator provides methods to apply header mutations to a given set of headers based on a + * given set of rules. + */ +public class HeaderMutator { + + private static final Logger logger = Logger.getLogger(HeaderMutator.class.getName()); + + /** + * Creates a new instance of {@code HeaderMutator}. + */ + public static HeaderMutator create() { + return new HeaderMutator(); + } + + HeaderMutator() {} + + /** + * Applies the given header mutations to the provided metadata headers. + * + * @param mutations The header mutations to apply. + * @param headers The metadata headers to which the mutations will be applied. + */ + public void applyMutations(final HeaderMutations mutations, Metadata headers) { + // TODO(sauravzg): The specification is not clear on order of header removals and additions. + // in case of conflicts. Copying the order from Envoy here, which does removals at the end. + applyHeaderUpdates(mutations.headers(), headers); + for (String headerToRemove : mutations.headersToRemove()) { + Metadata.Key key = headerToRemove.endsWith(Metadata.BINARY_HEADER_SUFFIX) + ? Metadata.Key.of(headerToRemove, Metadata.BINARY_BYTE_MARSHALLER) + : Metadata.Key.of(headerToRemove, Metadata.ASCII_STRING_MARSHALLER); + headers.discardAll(key); + } + } + + private void applyHeaderUpdates(final Iterable headerOptions, + Metadata headers) { + for (HeaderValueOption headerOption : headerOptions) { + updateHeader(headerOption, headers); + } + } + + private void updateHeader(final HeaderValueOption option, Metadata mutableHeaders) { + HeaderValue header = option.header(); + HeaderAppendAction action = option.appendAction(); + boolean keepEmptyValue = option.keepEmptyValue(); + + if (header.key().endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + if (header.rawValue().isPresent()) { + byte[] value = header.rawValue().get().toByteArray(); + if (value.length > 0 || keepEmptyValue) { + updateHeader(action, Metadata.Key.of(header.key(), Metadata.BINARY_BYTE_MARSHALLER), + value, mutableHeaders); + } + } else { + logger.fine("Missing binary rawValue for header: " + header.key()); + } + } else { + if (header.value().isPresent()) { + String value = header.value().get(); + if (!value.isEmpty() || keepEmptyValue) { + updateHeader(action, Metadata.Key.of(header.key(), Metadata.ASCII_STRING_MARSHALLER), + value, mutableHeaders); + } + } else { + logger.fine("Missing value for header: " + header.key()); + } + } + } + + private void updateHeader(final HeaderAppendAction action, final Metadata.Key key, + final T value, Metadata mutableHeaders) { + switch (action) { + case APPEND_IF_EXISTS_OR_ADD: + mutableHeaders.put(key, value); + break; + case ADD_IF_ABSENT: + if (!mutableHeaders.containsKey(key)) { + mutableHeaders.put(key, value); + } + break; + case OVERWRITE_IF_EXISTS_OR_ADD: + mutableHeaders.discardAll(key); + mutableHeaders.put(key, value); + break; + case OVERWRITE_IF_EXISTS: + if (mutableHeaders.containsKey(key)) { + mutableHeaders.discardAll(key); + mutableHeaders.put(key, value); + } + break; + + default: + // Should be unreachable unless there's a proto schema mismatch. + logger.fine("Unknown HeaderAppendAction: " + action); + } + } +} + diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderValueOption.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderValueOption.java new file mode 100644 index 00000000000..6cb96da864d --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderValueOption.java @@ -0,0 +1,50 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.auto.value.AutoValue; +import io.grpc.xds.internal.grpcservice.HeaderValue; + +/** + * Represents a header option to be appended or mutated as part of xDS configuration. + * Avoids direct dependency on Envoy's proto objects. + */ +@AutoValue +public abstract class HeaderValueOption { + + public static HeaderValueOption create( + HeaderValue header, HeaderAppendAction appendAction, boolean keepEmptyValue) { + return new AutoValue_HeaderValueOption(header, appendAction, keepEmptyValue); + } + + public abstract HeaderValue header(); + + public abstract HeaderAppendAction appendAction(); + + public abstract boolean keepEmptyValue(); + + /** + * Defines the action to take when appending headers. + * Mirrors io.envoyproxy.envoy.config.core.v3.HeaderValueOption.HeaderAppendAction. + */ + public enum HeaderAppendAction { + APPEND_IF_EXISTS_OR_ADD, + ADD_IF_ABSENT, + OVERWRITE_IF_EXISTS_OR_ADD, + OVERWRITE_IF_EXISTS + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java index 59e114a89ff..e7b27cd644a 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java @@ -44,7 +44,6 @@ public abstract class DynamicSslContextProvider extends SslContextProvider { @Nullable protected final CertificateValidationContext staticCertificateValidationContext; @Nullable protected AbstractMap.SimpleImmutableEntry sslContextAndTrustManager; - protected boolean autoSniSanValidationDoesNotApply; protected DynamicSslContextProvider( BaseTlsContext tlsContext, CertificateValidationContext staticCertValidationContext) { @@ -60,10 +59,6 @@ protected DynamicSslContextProvider( protected abstract CertificateValidationContext generateCertificateValidationContext(); - public void setAutoSniSanValidationDoesNotApply() { - autoSniSanValidationDoesNotApply = true; - } - /** Gets a server or client side SslContextBuilder. */ protected abstract AbstractMap.SimpleImmutableEntry getSslContextBuilderAndTrustManager( diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index e5960dd95e8..94fc423c202 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -60,14 +60,11 @@ public synchronized void updateSslContext( try { if (!shutdown) { if (sslContextProvider == null) { - sslContextProvider = getSslContextProvider(); - if (tlsContext instanceof UpstreamTlsContext && autoSniSanValidationDoesNotApply) { - ((DynamicSslContextProvider) sslContextProvider).setAutoSniSanValidationDoesNotApply(); - } + sslContextProvider = getSslContextProvider(autoSniSanValidationDoesNotApply); } } // we want to increment the ref-count so call findOrCreate again... - final SslContextProvider toRelease = getSslContextProvider(); + final SslContextProvider toRelease = getSslContextProvider(autoSniSanValidationDoesNotApply); toRelease.addCallback( new SslContextProvider.Callback(callback.getExecutor()) { @@ -102,11 +99,20 @@ private void releaseSslContextProvider(SslContextProvider toRelease) { } } - private SslContextProvider getSslContextProvider() { - return tlsContext instanceof UpstreamTlsContext - ? tlsContextManager.findOrCreateClientSslContextProvider((UpstreamTlsContext) tlsContext) - : tlsContextManager.findOrCreateServerSslContextProvider( - (DownstreamTlsContext) tlsContext); + private SslContextProvider getSslContextProvider(boolean autoSniSanValidationDoesNotApply) { + if (tlsContext instanceof UpstreamTlsContext) { + UpstreamTlsContext upstreamTlsContext = (UpstreamTlsContext) tlsContext; + if (autoSniSanValidationDoesNotApply && upstreamTlsContext.getAutoSniSanValidation()) { + upstreamTlsContext = new UpstreamTlsContext( + upstreamTlsContext.getCommonTlsContext(), + upstreamTlsContext.getSni(), + upstreamTlsContext.getAutoHostSni(), + false); + } + return tlsContextManager.findOrCreateClientSslContextProvider(upstreamTlsContext); + } + return tlsContextManager.findOrCreateServerSslContextProvider( + (DownstreamTlsContext) tlsContext); } @VisibleForTesting public boolean isShutdown() { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index b4b72ae11c6..8984efc9435 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -58,39 +58,25 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP getSslContextBuilderAndTrustManager( CertificateValidationContext certificateValidationContext) throws CertStoreException { - SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); + UpstreamTlsContext upstreamTlsContext = (UpstreamTlsContext) tlsContext; + XdsTrustManagerFactory trustManagerFactory; if (savedSpiffeTrustMap != null) { - sslContextBuilder = sslContextBuilder.trustManager( - new XdsTrustManagerFactory( - savedSpiffeTrustMap, - certificateValidationContext, - autoSniSanValidationDoesNotApply - ? false : ((UpstreamTlsContext) tlsContext).getAutoSniSanValidation())); + trustManagerFactory = new XdsTrustManagerFactory( + savedSpiffeTrustMap, + certificateValidationContext, + upstreamTlsContext.getAutoSniSanValidation()); } else if (savedTrustedRoots != null) { - sslContextBuilder = sslContextBuilder.trustManager( - new XdsTrustManagerFactory( + trustManagerFactory = new XdsTrustManagerFactory( savedTrustedRoots.toArray(new X509Certificate[0]), certificateValidationContext, - autoSniSanValidationDoesNotApply - ? false : ((UpstreamTlsContext) tlsContext).getAutoSniSanValidation())); + upstreamTlsContext.getAutoSniSanValidation()); } else { // Should be impossible because of the check in CertProviderClientSslContextProviderFactory throw new IllegalStateException("There must be trusted roots or a SPIFFE trust map"); } - XdsTrustManagerFactory trustManagerFactory; - if (savedSpiffeTrustMap != null) { - trustManagerFactory = new XdsTrustManagerFactory( - savedSpiffeTrustMap, - certificateValidationContext, - ((UpstreamTlsContext) tlsContext).getAutoSniSanValidation()); - sslContextBuilder = sslContextBuilder.trustManager(trustManagerFactory); - } else { - trustManagerFactory = new XdsTrustManagerFactory( - savedTrustedRoots.toArray(new X509Certificate[0]), - certificateValidationContext, - ((UpstreamTlsContext) tlsContext).getAutoSniSanValidation()); - sslContextBuilder = sslContextBuilder.trustManager(trustManagerFactory); - } + + SslContextBuilder sslContextBuilder = + GrpcSslContexts.forClient().trustManager(trustManagerFactory); if (isMtls()) { sslContextBuilder.keyManager(savedKey, savedCertChain); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java index 304124cc7f2..9cb9a867118 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java @@ -116,7 +116,7 @@ void checkAndReloadCertificates() { FileTime currentCertTime = Files.getLastModifiedTime(certFile); FileTime currentKeyTime = Files.getLastModifiedTime(keyFile); if (!currentCertTime.equals(lastModifiedTimeCert) - && !currentKeyTime.equals(lastModifiedTimeKey)) { + || !currentKeyTime.equals(lastModifiedTimeKey)) { byte[] certFileContents = Files.readAllBytes(certFile); byte[] keyFileContents = Files.readAllBytes(keyFile); FileTime currentCertTime2 = Files.getLastModifiedTime(certFile); diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java index 928520aded7..ff4813fe6a8 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java @@ -255,9 +255,9 @@ public void nonAggregateCluster_resourceNotExist_returnErrorPicker() { startXdsDepManager(); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - String expectedDescription = "Error retrieving CDS resource " + CLUSTER + ": NOT_FOUND. " - + "Details: Timed out waiting for resource " + CLUSTER - + " from xDS server nodeID: " + NODE_ID; + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + + " nodeID: " + NODE_ID + + ": NOT_FOUND: Timed out waiting for resource " + CLUSTER + " from xDS server"; Status unavailable = Status.UNAVAILABLE.withDescription(expectedDescription); assertPickerStatus(pickerCaptor.getValue(), unavailable); assertThat(childBalancers).isEmpty(); @@ -311,8 +311,9 @@ public void nonAggregateCluster_resourceRevoked() { controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of()); assertThat(childBalancer.shutdown).isTrue(); - String expectedDescription = "Error retrieving CDS resource " + CLUSTER + ": NOT_FOUND. " - + "Details: Resource " + CLUSTER + " does not exist nodeID: " + NODE_ID; + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + + " nodeID: " + NODE_ID + + ": NOT_FOUND: Resource " + CLUSTER + " does not exist"; Status unavailable = Status.UNAVAILABLE.withDescription(expectedDescription); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); @@ -515,9 +516,9 @@ public void aggregateCluster_noNonAggregateClusterExits_returnErrorPicker() { verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - String expectedDescription = "Error retrieving CDS resource " + cluster1 + ": NOT_FOUND. " - + "Details: Timed out waiting for resource " + cluster1 + " from xDS server nodeID: " - + NODE_ID; + String expectedDescription = "Error retrieving CDS resource " + cluster1 + + " nodeID: " + NODE_ID + + ": NOT_FOUND: Timed out waiting for resource " + cluster1 + " from xDS server"; Status status = Status.UNAVAILABLE.withDescription(expectedDescription); assertPickerStatus(pickerCaptor.getValue(), status); assertThat(childBalancers).isEmpty(); diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index c6e5db08526..a508da34f88 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -60,6 +60,7 @@ import io.grpc.ConnectivityState; import io.grpc.EquivalentAddressGroup; import io.grpc.HttpConnectProxiedSocketAddress; +import io.grpc.InternalEquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; @@ -317,14 +318,20 @@ public void edsClustersWithRingHashEndpointLbPolicy() throws Exception { // LOCALITY1 are equally weighted. assertThat(addr1.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.1", 8080))); + assertThat(addr1.getAttributes().get(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE)) + .isEqualTo(CLUSTER); assertThat(addr1.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x0AAAAAAA /* 1/12 */ : 10); assertThat(addr2.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.2", 8080))); + assertThat(addr2.getAttributes().get(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE)) + .isEqualTo(CLUSTER); assertThat(addr2.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x0AAAAAAA /* 1/12 */ : 10); assertThat(addr3.getAddresses()) .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.1.1", 8080))); + assertThat(addr3.getAttributes().get(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE)) + .isEqualTo(CLUSTER); assertThat(addr3.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x6AAAAAAA /* 5/6 */ : 50 * 60); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); @@ -690,8 +697,8 @@ public void onlyEdsClusters_resourceNeverExist_returnErrorPicker() { verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - String expectedDescription = "Error retrieving CDS resource " + CLUSTER + ": NOT_FOUND. " - + "Details: Timed out waiting for resource " + CLUSTER + " from xDS server nodeID: node-id"; + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + " nodeID: node-id: " + + "NOT_FOUND: Timed out waiting for resource " + CLUSTER + " from xDS server"; Status expectedError = Status.UNAVAILABLE.withDescription(expectedDescription); assertPicker(pickerCaptor.getValue(), expectedError, null); } @@ -713,8 +720,8 @@ public void cdsMissing_handledDirectly() { assertThat(childBalancers).hasSize(0); // no child LB policy created verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - String expectedDescription = "Error retrieving CDS resource " + CLUSTER + ": NOT_FOUND. " - + "Details: Timed out waiting for resource " + CLUSTER + " from xDS server nodeID: node-id"; + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + " nodeID: node-id: " + + "NOT_FOUND: Timed out waiting for resource " + CLUSTER + " from xDS server"; Status expectedError = Status.UNAVAILABLE.withDescription(expectedDescription); assertPicker(pickerCaptor.getValue(), expectedError, null); assertPicker(pickerCaptor.getValue(), expectedError, null); @@ -744,8 +751,8 @@ public void cdsRevoked_handledDirectly() { controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of()); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - String expectedDescription = "Error retrieving CDS resource " + CLUSTER + ": NOT_FOUND. " - + "Details: Resource " + CLUSTER + " does not exist nodeID: node-id"; + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + " nodeID: node-id: " + + "NOT_FOUND: Resource " + CLUSTER + " does not exist"; Status expectedError = Status.UNAVAILABLE.withDescription(expectedDescription); assertPicker(pickerCaptor.getValue(), expectedError, null); assertThat(childBalancer.shutdown).isTrue(); @@ -760,8 +767,8 @@ public void edsMissing_failsRpcs() { verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); String expectedDescription = "Error retrieving EDS resource " + EDS_SERVICE_NAME - + ": NOT_FOUND. Details: Timed out waiting for resource " + EDS_SERVICE_NAME - + " from xDS server nodeID: node-id"; + + " nodeID: node-id: " + + "NOT_FOUND: Timed out waiting for resource " + EDS_SERVICE_NAME + " from xDS server"; Status expectedError = Status.UNAVAILABLE.withDescription(expectedDescription); assertPicker(pickerCaptor.getValue(), expectedError, null); } @@ -920,6 +927,8 @@ public void onlyLogicalDnsCluster_endpointsResolved() { Arrays.asList(new EquivalentAddressGroup(Arrays.asList( newInetSocketAddress("127.0.2.1", 9000), newInetSocketAddress("127.0.2.2", 9000)))), childBalancer.addresses); + assertThat(childBalancer.addresses.get(0).getAttributes() + .get(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE)).isEqualTo(CLUSTER); assertThat(childBalancer.addresses.get(0).getAttributes() .get(XdsInternalAttributes.ATTR_ADDRESS_NAME)).isEqualTo(DNS_HOST_NAME + ":9000"); } diff --git a/xds/src/test/java/io/grpc/xds/FailingClientInterceptor.java b/xds/src/test/java/io/grpc/xds/FailingClientInterceptor.java new file mode 100644 index 00000000000..c8b32f376ee --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/FailingClientInterceptor.java @@ -0,0 +1,50 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static java.util.Objects.requireNonNull; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.NoopClientCall; +import io.grpc.Status; + +/** + * An interceptor that fails all RPCs with the provided status. + */ +final class FailingClientInterceptor implements ClientInterceptor { + private final Status status; + + public FailingClientInterceptor(Status status) { + this.status = requireNonNull(status, "status"); + } + + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return new NoopClientCall() { + @Override + public void start(Listener responseListener, Metadata headers) { + responseListener.onClose(status, new Metadata()); + } + }; + } +} diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java index 7bd1590885e..0bd3283cb79 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java @@ -29,6 +29,7 @@ import io.grpc.internal.FakeClock; import io.grpc.internal.JsonParser; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; +import io.grpc.xds.internal.MetricReportUtils.ParsedMetricName; import java.io.IOException; import java.util.Map; import org.junit.Test; @@ -112,16 +113,19 @@ public void parseLoadBalancingConfigDefaultValues() throws IOException { } @Test - public void parseLoadBalancingConfigCustomMetrics() throws IOException { + public void parseLoadBalancingConfigCustomMetricsIgnoresInvalid() throws IOException { System.setProperty("GRPC_EXPERIMENTAL_WRR_CUSTOM_METRICS", "true"); try { - String lbConfig = "{\"metricNamesForComputingUtilization\" : [\"foo\", \"bar\"]}"; + String lbConfig = + "{\"metricNamesForComputingUtilization\" : " + + "[\"utilization.foo\", \"invalid_name\", \"named_metrics.bar\"]}"; ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig( parseJsonObject(lbConfig)); assertThat(configOrError.getConfig()).isNotNull(); WeightedRoundRobinLoadBalancerConfig config = (WeightedRoundRobinLoadBalancerConfig) configOrError.getConfig(); - assertThat(config.metricNamesForComputingUtilization).containsExactly("foo", "bar"); + assertThat(config.parsedMetricNamesForComputingUtilization).containsExactly( + ParsedMetricName.parse("utilization.foo"), ParsedMetricName.parse("named_metrics.bar")); } finally { System.clearProperty("GRPC_EXPERIMENTAL_WRR_CUSTOM_METRICS"); } diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index d495521123a..bac62d1a103 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -291,11 +291,11 @@ public void wrrLifeCycle() { WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; @@ -348,11 +348,11 @@ public void enableOobLoadReportConfig() { WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.9, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; @@ -409,11 +409,11 @@ private void pickByWeight(MetricReport r1, MetricReport r2, MetricReport r3, WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport(r1); + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport(r1); weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport(r2); + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport(r2); weightedChild3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport(r3); + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport(r3); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); Map pickCount = new HashMap<>(); @@ -611,11 +611,11 @@ public void blackoutPeriod() { WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedCount = isEnabledHappyEyeballs() ? 2 : 1; @@ -676,11 +676,11 @@ public void updateWeightTimer() { WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; @@ -695,11 +695,11 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, .setAttributes(affinity).build())); assertThat(getNumFilteredPendingTasks()).isEqualTo(1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); //timer fires, new weight updated @@ -732,11 +732,11 @@ public void weightExpired() { WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; @@ -840,11 +840,11 @@ public void unknownWeightIsAvgWeight() { WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); @@ -883,11 +883,11 @@ public void pickFromOtherThread() throws Exception { WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, - weightedConfig.metricNamesForComputingUtilization).onLoadReport( + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); CyclicBarrier barrier = new CyclicBarrier(2); @@ -1224,7 +1224,8 @@ public void metrics() { // can be calculated, but it's still essentially round_robin Iterator childLbStates = wrr.getChildLbStates().iterator(); ((WeightedChildLbState) childLbStates.next()).new OrcaReportListener( - weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization) + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization) .onLoadReport(InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); @@ -1232,11 +1233,13 @@ public void metrics() { // Now send a second child LB state an ORCA update, so there's real weights ((WeightedChildLbState) childLbStates.next()).new OrcaReportListener( - weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization) + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization) .onLoadReport(InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); ((WeightedChildLbState) childLbStates.next()).new OrcaReportListener( - weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization) + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization) .onLoadReport(InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); @@ -1355,7 +1358,8 @@ public void customMetric_priority_overAppUtil() { WeightedChildLbState weightedChild = (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( - weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization); Map namedMetrics = new HashMap<>(); namedMetrics.put("cost", 0.5); @@ -1389,7 +1393,8 @@ public void customMetric_invalid_fallbackToAppUtil() { WeightedChildLbState weightedChild = (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( - weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization); // custom metric is NaN, App util = 0.8 Map namedMetrics = new HashMap<>(); @@ -1424,7 +1429,8 @@ public void customMetric_mapLookup_used() { WeightedChildLbState weightedChild = (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( - weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization); Map namedMetrics = new HashMap<>(); namedMetrics.put("cost", 0.5); @@ -1456,7 +1462,8 @@ public void customMetric_shouldFilterOutAndFallbackToCpu() { WeightedChildLbState weightedChild = (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( - weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization); // custom metric is NaN, but CPU is 0.1 Map namedMetrics = new HashMap<>(); @@ -1493,7 +1500,8 @@ public void customMetric_multipleMetrics_maxUsed() { WeightedChildLbState weightedChild = (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( - weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization); Map namedMetrics = new HashMap<>(); namedMetrics.put("cost", 0.5); @@ -1528,7 +1536,8 @@ public void customMetric_allInvalid_fallbackToCpu() { WeightedChildLbState weightedChild = (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( - weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization); Map namedMetrics = new HashMap<>(); namedMetrics.put("cost", Double.NaN); @@ -1564,7 +1573,8 @@ public void customMetric_mixInvalidAndValid_validUsed() { WeightedChildLbState weightedChild = (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( - weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization); Map namedMetrics = new HashMap<>(); namedMetrics.put("cost", Double.NaN); diff --git a/xds/src/test/java/io/grpc/xds/XdsDependencyManagerTest.java b/xds/src/test/java/io/grpc/xds/XdsDependencyManagerTest.java index 7bae7000eaf..522eb29c001 100644 --- a/xds/src/test/java/io/grpc/xds/XdsDependencyManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsDependencyManagerTest.java @@ -409,6 +409,32 @@ public void testTcpListenerErrors() { testWatcher.verifyStats(0, 1); } + @Test + public void testControlPlaneError() { + Status forcedStatus = Status.NOT_FOUND + .withDescription("expected") + .withCause(new IllegalArgumentException("a random exception")); + xdsClient.shutdown(); + xdsClient = XdsTestUtils.createXdsClient( + Collections.singletonList("control-plane"), + serverInfo -> new GrpcXdsTransportFactory.GrpcXdsTransport( + InProcessChannelBuilder.forName(serverInfo.target()) + .directExecutor() + .intercept(new FailingClientInterceptor(forcedStatus)) + .build()), + fakeClock); + xdsDependencyManager = new XdsDependencyManager( + xdsClient, syncContext, serverName, serverName, nameResolverArgs); + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate( + argThat(StatusOrMatcher.hasStatus( + statusHasCode(Status.Code.UNAVAILABLE) + .andDescriptionContains(forcedStatus.getDescription()) + .andCause(forcedStatus.getCause())))); + testWatcher.verifyStats(0, 1); + } + @Test public void testMissingRds() { String rdsName = "badRdsName"; diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index c8ad9f1c670..6b39106f18c 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -77,6 +77,7 @@ import io.grpc.xds.internal.security.SslContextProviderSupplier; import io.grpc.xds.internal.security.TlsContextManagerImpl; import io.grpc.xds.internal.security.certprovider.FileWatcherCertificateProviderProvider; +import io.grpc.xds.internal.security.trust.CertificateUtils; import io.netty.handler.ssl.NotSslRecordException; import java.io.File; import java.io.FileOutputStream; @@ -378,7 +379,11 @@ public void tlsClientServer_autoSniValidation_sniFromHostname() public void tlsClientServer_autoSniValidation_noSniApplicable_usesMatcherFromCmnVdnCtx() throws Exception { Path trustStoreFilePath = getCacertFilePathForTestCa(); + boolean originalUseChannelAuthorityIfNoSniApplicable = + CertificateUtils.useChannelAuthorityIfNoSniApplicable; try { + CertificateUtils.useChannelAuthorityIfNoSniApplicable = + true; setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); DownstreamTlsContext downstreamTlsContext = setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, @@ -398,6 +403,8 @@ public void tlsClientServer_autoSniValidation_noSniApplicable_usesMatcherFromCmn getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); unaryRpc(/* requestMessage= */ "buddy", blockingStub); } finally { + CertificateUtils.useChannelAuthorityIfNoSniApplicable = + originalUseChannelAuthorityIfNoSniApplicable; Files.deleteIfExists(trustStoreFilePath); clearTrustStoreSystemProperties(); } diff --git a/xds/src/test/java/io/grpc/xds/client/ControlPlaneClientTest.java b/xds/src/test/java/io/grpc/xds/client/ControlPlaneClientTest.java new file mode 100644 index 00000000000..64786c4fb3b --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/client/ControlPlaneClientTest.java @@ -0,0 +1,279 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.client; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.base.Stopwatch; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.envoyproxy.envoy.service.discovery.v3.DiscoveryRequest; +import io.envoyproxy.envoy.service.discovery.v3.DiscoveryResponse; +import io.grpc.InsecureChannelCredentials; +import io.grpc.MethodDescriptor; +import io.grpc.SynchronizationContext; +import io.grpc.internal.BackoffPolicy; +import io.grpc.internal.FakeClock; +import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.EnvoyProtoData.Node; +import io.grpc.xds.client.XdsClient.ResourceStore; +import io.grpc.xds.client.XdsClient.XdsResponseHandler; +import io.grpc.xds.client.XdsTransportFactory.EventHandler; +import io.grpc.xds.client.XdsTransportFactory.StreamingCall; +import io.grpc.xds.client.XdsTransportFactory.XdsTransport; +import java.util.Collections; +import java.util.Map; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Unit tests for {@link ControlPlaneClient}. */ +@RunWith(JUnit4.class) +public class ControlPlaneClientTest { + + private static final String CDS_TYPE_URL = "type.googleapis.com/envoy.config.cluster.v3.Cluster"; + private static final String EDS_TYPE_URL = + "type.googleapis.com/envoy.config.endpoint.v3.ClusterLoadAssignment"; + + private final SynchronizationContext syncContext = + new SynchronizationContext((t, e) -> { + throw new AssertionError("Uncaught exception in sync context", e); + }); + private final FakeClock fakeClock = new FakeClock(); + private final ServerInfo serverInfo = + ServerInfo.create("eds-control-plane:8443", InsecureChannelCredentials.create()); + private final Node bootstrapNode = Node.newBuilder().setId("test-node").build(); + + @Mock private XdsTransport xdsTransport; + @Mock private StreamingCall streamingCall; + @Mock private XdsResponseHandler responseHandler; + @Mock private ResourceStore resourceStore; + @Mock private BackoffPolicy.Provider backoffPolicyProvider; + @Mock private MessagePrettyPrinter messagePrinter; + @Mock private XdsResourceType cdsType; + @Mock private XdsResourceType edsType; + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + private ControlPlaneClient cpc; + private ArgumentCaptor> handlerCaptor; + + @Before + @SuppressWarnings("unchecked") + public void setUp() { + when(cdsType.typeUrl()).thenReturn(CDS_TYPE_URL); + when(cdsType.typeName()).thenReturn("CDS"); + when(edsType.typeUrl()).thenReturn(EDS_TYPE_URL); + when(edsType.typeName()).thenReturn("EDS"); + + when(xdsTransport.createStreamingCall( + anyString(), + any(MethodDescriptor.Marshaller.class), + any(MethodDescriptor.Marshaller.class))) + .thenReturn(streamingCall); + when(streamingCall.isReady()).thenReturn(true); + + handlerCaptor = ArgumentCaptor.forClass(EventHandler.class); + + cpc = new ControlPlaneClient( + xdsTransport, + serverInfo, + bootstrapNode, + responseHandler, + resourceStore, + fakeClock.getScheduledExecutorService(), + syncContext, + backoffPolicyProvider, + () -> Stopwatch.createUnstarted(fakeClock.getTicker()), + messagePrinter); + } + + /** + * Reproduces the bug where, when an ADS stream is opened to an authority-specific server (e.g. + * an EDS-only control plane), {@code sendDiscoveryRequests} previously emitted an empty + * DiscoveryRequest for every globally-subscribed resource type — including types this server + * does not handle. Authority-specific servers may reject those requests with UNIMPLEMENTED and + * tear down the stream, blocking the legitimate request that follows. + * + *

Asserts that the empty CDS request is suppressed and only the EDS request (which has + * resources for this server) goes on the wire. + */ + @Test + public void streamReady_skipsEmptyDiscoveryRequestForUnsubscribedType() { + // CDS is globally subscribed (e.g. against a different authority) but has no resources on + // this server. EDS has one resource on this server. + Map> subscribedTypes = + ImmutableMap.of(CDS_TYPE_URL, cdsType, EDS_TYPE_URL, edsType); + when(resourceStore.getSubscribedResourceTypesWithTypeUrl()).thenReturn(subscribedTypes); + when(resourceStore.getSubscribedResources(serverInfo, cdsType)).thenReturn(null); + when(resourceStore.getSubscribedResources(serverInfo, edsType)) + .thenReturn(ImmutableList.of("foo-endpoint")); + + // Triggers stream creation and registers the EventHandler. + syncContext.execute(cpc::sendDiscoveryRequests); + verify(streamingCall).start(handlerCaptor.capture()); + + // Drive the stream into the connected state. onReady() flips sentInitialRequest=true and + // re-invokes sendDiscoveryRequests, which iterates the globally-subscribed types. + handlerCaptor.getValue().onReady(); + + // EDS request was sent with the one resource for this server. + ArgumentCaptor sent = ArgumentCaptor.forClass(DiscoveryRequest.class); + verify(streamingCall, atLeastOnce()).sendMessage(sent.capture()); + ImmutableSet sentTypes = sent.getAllValues().stream() + .map(DiscoveryRequest::getTypeUrl) + .collect(ImmutableSet.toImmutableSet()); + assertThat(sentTypes).contains(EDS_TYPE_URL); + assertThat(sentTypes).doesNotContain(CDS_TYPE_URL); + + // Confirm the EDS request actually carried the resource name. + DiscoveryRequest edsReq = sent.getAllValues().stream() + .filter(r -> r.getTypeUrl().equals(EDS_TYPE_URL)) + .findFirst() + .orElseThrow(() -> new AssertionError("EDS request not sent")); + assertThat(edsReq.getResourceNamesList()).containsExactly("foo-endpoint"); + } + + /** + * If a server has resources for every globally-subscribed type, the empty-skip guard is a + * no-op: a DiscoveryRequest is sent for every type. This guards against the skip becoming + * over-eager and dropping legitimate subscriptions. + */ + @Test + public void streamReady_sendsRequestForAllTypesWhenAllHaveResources() { + Map> subscribedTypes = + ImmutableMap.of(CDS_TYPE_URL, cdsType, EDS_TYPE_URL, edsType); + when(resourceStore.getSubscribedResourceTypesWithTypeUrl()).thenReturn(subscribedTypes); + when(resourceStore.getSubscribedResources(serverInfo, cdsType)) + .thenReturn(ImmutableList.of("foo-cluster")); + when(resourceStore.getSubscribedResources(serverInfo, edsType)) + .thenReturn(ImmutableList.of("foo-endpoint")); + + syncContext.execute(cpc::sendDiscoveryRequests); + verify(streamingCall).start(handlerCaptor.capture()); + handlerCaptor.getValue().onReady(); + + ArgumentCaptor sent = ArgumentCaptor.forClass(DiscoveryRequest.class); + verify(streamingCall, times(2)).sendMessage(sent.capture()); + ImmutableSet sentTypes = sent.getAllValues().stream() + .map(DiscoveryRequest::getTypeUrl) + .collect(ImmutableSet.toImmutableSet()); + assertThat(sentTypes).containsExactly(CDS_TYPE_URL, EDS_TYPE_URL); + } + + /** + * If only one type has a subscription on this server, no request is sent for the unsubscribed + * type. This is the canonical multi-authority federation case (e.g. fabric authority owns CDS, + * eds-control-plane owns EDS — the eds-control-plane stream should only see EDS). + */ + @Test + public void streamReady_skipsTypeWithNoSubscription() { + Map> subscribedTypes = + ImmutableMap.of(CDS_TYPE_URL, cdsType, EDS_TYPE_URL, edsType); + when(resourceStore.getSubscribedResourceTypesWithTypeUrl()).thenReturn(subscribedTypes); + when(resourceStore.getSubscribedResources(serverInfo, cdsType)).thenReturn(null); + when(resourceStore.getSubscribedResources(serverInfo, edsType)) + .thenReturn(ImmutableList.of("foo-endpoint")); + + syncContext.execute(cpc::sendDiscoveryRequests); + verify(streamingCall).start(handlerCaptor.capture()); + handlerCaptor.getValue().onReady(); + + verify(streamingCall, never()).sendMessage( + argThatTypeUrlIs(CDS_TYPE_URL)); + verify(streamingCall).sendMessage(argThatTypeUrlIs(EDS_TYPE_URL)); + } + + /** + * Per the ResourceStore contract in XdsClient.java, an empty collection from + * getSubscribedResources indicates a wildcard subscription. The skip-on-empty guard must not + * suppress wildcard requests on initial stream ready — the server needs the empty-resource-list + * DiscoveryRequest to start streaming, and the watcher's missing-resource timers must start. + */ + @Test + public void streamReady_sendsWildcardRequestAndStartsTimers() { + Map> subscribedTypes = ImmutableMap.of(CDS_TYPE_URL, cdsType); + when(resourceStore.getSubscribedResourceTypesWithTypeUrl()).thenReturn(subscribedTypes); + // Empty collection == wildcard subscription per the ResourceStore contract. + when(resourceStore.getSubscribedResources(serverInfo, cdsType)) + .thenReturn(Collections.emptyList()); + + syncContext.execute(cpc::sendDiscoveryRequests); + verify(streamingCall).start(handlerCaptor.capture()); + handlerCaptor.getValue().onReady(); + + ArgumentCaptor sent = ArgumentCaptor.forClass(DiscoveryRequest.class); + verify(streamingCall, atLeastOnce()).sendMessage(sent.capture()); + DiscoveryRequest cdsReq = sent.getAllValues().stream() + .filter(r -> r.getTypeUrl().equals(CDS_TYPE_URL)) + .findFirst() + .orElseThrow(() -> new AssertionError("CDS wildcard request not sent")); + assertThat(cdsReq.getResourceNamesList()).isEmpty(); + + verify(resourceStore).startMissingResourceTimers(Collections.emptyList(), cdsType); + } + + /** + * If a watch is canceled after the initial DiscoveryRequest goes out but before any response + * is ACKed, the empty unsubscribe must still be sent — otherwise the server keeps the stale + * subscription until the stream resets. The skip guard must gate on per-stream send history, + * not on the {@code versions} map (which is only populated on ACK). + */ + @Test + public void cancelBeforeAck_sendsEmptyUnsubscribe() { + Map> subscribedTypes = ImmutableMap.of(CDS_TYPE_URL, cdsType); + when(resourceStore.getSubscribedResourceTypesWithTypeUrl()).thenReturn(subscribedTypes); + when(resourceStore.getSubscribedResources(serverInfo, cdsType)) + .thenReturn(ImmutableList.of("foo-cluster")); + + syncContext.execute(cpc::sendDiscoveryRequests); + verify(streamingCall).start(handlerCaptor.capture()); + handlerCaptor.getValue().onReady(); + + // Initial DiscoveryRequest with the resource went out. No DiscoveryResponse has been ACKed. + verify(streamingCall).sendMessage(argThatTypeUrlIs(CDS_TYPE_URL)); + + // Cancel the watch before any response arrives: store now reports no subscription. + when(resourceStore.getSubscribedResources(serverInfo, cdsType)).thenReturn(null); + syncContext.execute(() -> cpc.adjustResourceSubscription(cdsType)); + + ArgumentCaptor sent = ArgumentCaptor.forClass(DiscoveryRequest.class); + verify(streamingCall, times(2)).sendMessage(sent.capture()); + DiscoveryRequest unsub = sent.getAllValues().get(1); + assertThat(unsub.getTypeUrl()).isEqualTo(CDS_TYPE_URL); + assertThat(unsub.getResourceNamesList()).isEmpty(); + } + + private static DiscoveryRequest argThatTypeUrlIs(String typeUrl) { + return argThat(req -> req != null && typeUrl.equals(req.getTypeUrl())); + } +} \ No newline at end of file diff --git a/xds/src/test/java/io/grpc/xds/internal/MetricReportUtilsTest.java b/xds/src/test/java/io/grpc/xds/internal/MetricReportUtilsTest.java index bf5e0ae9ede..9d7a3910216 100644 --- a/xds/src/test/java/io/grpc/xds/internal/MetricReportUtilsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/MetricReportUtilsTest.java @@ -35,59 +35,76 @@ public class MetricReportUtilsTest { @Test - public void getMetric_cpuUtilization() { + public void getMetricValue_cpuUtilization() { MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); - OptionalDouble result = MetricReportUtils.getMetric(report, "cpu_utilization"); + MetricReportUtils.ParsedMetricName parsed = + MetricReportUtils.ParsedMetricName.parse("cpu_utilization"); + OptionalDouble result = MetricReportUtils.getMetricValue(report, parsed); assertTrue(result.isPresent()); assertEquals(0.5, result.getAsDouble(), 0.0001); } @Test - public void getMetric_applicationUtilization() { + public void getMetricValue_applicationUtilization() { MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); - OptionalDouble result = MetricReportUtils.getMetric(report, "application_utilization"); + MetricReportUtils.ParsedMetricName parsed = + MetricReportUtils.ParsedMetricName.parse("application_utilization"); + OptionalDouble result = MetricReportUtils.getMetricValue(report, parsed); assertTrue(result.isPresent()); assertEquals(0.1, result.getAsDouble(), 0.0001); } @Test - public void getMetric_memUtilization() { + public void getMetricValue_memUtilization() { MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); - OptionalDouble result = MetricReportUtils.getMetric(report, "mem_utilization"); + MetricReportUtils.ParsedMetricName parsed = + MetricReportUtils.ParsedMetricName.parse("mem_utilization"); + OptionalDouble result = MetricReportUtils.getMetricValue(report, parsed); assertTrue(result.isPresent()); assertEquals(0.2, result.getAsDouble(), 0.0001); } @Test - public void getMetric_utilizationMetric() { + public void getMetricValue_utilizationMetric() { Map utilizationMetrics = new HashMap<>(); utilizationMetrics.put("foo", 1.23); MetricReport report = InternalCallMetricRecorder.createMetricReport( - 0, 0, 0, 0, 0, Collections.emptyMap(), utilizationMetrics, Collections.emptyMap()); + 0, 0, 0, 0, 0, Collections.emptyMap(), utilizationMetrics, Collections.emptyMap()); - OptionalDouble result = MetricReportUtils.getMetric(report, "utilization.foo"); + MetricReportUtils.ParsedMetricName parsed = + MetricReportUtils.ParsedMetricName.parse("utilization.foo"); + OptionalDouble result = MetricReportUtils.getMetricValue(report, parsed); assertTrue(result.isPresent()); assertEquals(1.23, result.getAsDouble(), 0.0001); - assertFalse(MetricReportUtils.getMetric(report, "utilization.bar").isPresent()); + + MetricReportUtils.ParsedMetricName bad = + MetricReportUtils.ParsedMetricName.parse("utilization.bar"); + assertFalse(MetricReportUtils.getMetricValue(report, bad).isPresent()); } @Test - public void getMetric_namedMetric() { + public void getMetricValue_namedMetric() { Map namedMetrics = new HashMap<>(); namedMetrics.put("foo", 7.89); MetricReport report = createMetricReport(0, 0, 0, 0, 0, namedMetrics); - OptionalDouble result = MetricReportUtils.getMetric(report, "named_metrics.foo"); + + MetricReportUtils.ParsedMetricName parsed = + MetricReportUtils.ParsedMetricName.parse("named_metrics.foo"); + OptionalDouble result = MetricReportUtils.getMetricValue(report, parsed); assertTrue(result.isPresent()); assertEquals(7.89, result.getAsDouble(), 0.0001); - assertFalse(MetricReportUtils.getMetric(report, "named_metrics.bar").isPresent()); + MetricReportUtils.ParsedMetricName bad = + MetricReportUtils.ParsedMetricName.parse("named_metrics.bar"); + assertFalse(MetricReportUtils.getMetricValue(report, bad).isPresent()); } @Test - public void getMetric_unknownPrefix() { - MetricReport report = createMetricReport(0, 0, 0, 0, 0, Collections.emptyMap()); - assertFalse(MetricReportUtils.getMetric(report, "unknown.foo").isPresent()); - assertFalse(MetricReportUtils.getMetric(report, "foo").isPresent()); + public void getMetricValue_invalidMetric() { + MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); + MetricReportUtils.ParsedMetricName invalid = + MetricReportUtils.ParsedMetricName.parse("invalid_metric"); + assertFalse(MetricReportUtils.getMetricValue(report, invalid).isPresent()); } private MetricReport createMetricReport(double cpu, double app, double mem, double qps, diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationFilterTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationFilterTest.java new file mode 100644 index 00000000000..c0598590ebc --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationFilterTest.java @@ -0,0 +1,242 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import com.google.re2j.Pattern; +import io.grpc.xds.internal.headermutations.HeaderValueOption.HeaderAppendAction; +import java.util.Optional; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutationFilterTest { + + private static final int MAX_HEADER_LENGTH = 16384; + + private static HeaderValueOption header(String key, ByteString value) { + return HeaderValueOption.create(io.grpc.xds.internal.grpcservice.HeaderValue.create(key, value), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, false); + } + + private static HeaderValueOption header(String key, String value) { + return HeaderValueOption.create(io.grpc.xds.internal.grpcservice.HeaderValue.create(key, value), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, false); + } + + @Test + public void filter_validationRules_dropsInvalidHeaders() throws Exception { + HeaderMutationFilter filter = new HeaderMutationFilter(Optional.empty()); + String longString = Strings.repeat("a", MAX_HEADER_LENGTH + 1); + ByteString longBytes = ByteString.copyFrom(new byte[MAX_HEADER_LENGTH + 1]); + + HeaderMutations mutations = HeaderMutations.create( + ImmutableList.of( + header("add-key", "add-value"), header(":authority", "new-authority"), + header("host", "new-host"), header(":scheme", "https"), header(":method", "PUT"), + header("resp-add-key", "resp-add-value"), header(":scheme", "https"), + header(":path", "/new-path"), header(":grpc-trace-bin", "binary-value"), + header(":alt-svc", "h3=:443"), header("user-agent", "new-agent"), + header("Valid-Key", "value"), header("", "value"), header(longString, "value"), + header("long-value-key", longString), header("long-bin-key-bin", longBytes), + header("grpc-timeout", "10S"), header("valid-key-lower", "value")), + ImmutableList.of("remove-key", "host", ":authority", ":scheme", ":method", ":foo", ":bar", + "Valid-Key", "", longString, "grpc-timeout", "UPPER-REMOVE", "lower-remove")); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.headersToRemove()).containsExactly("remove-key", "lower-remove"); + assertThat(filtered.headers()).containsExactly( + header("add-key", "add-value"), header("resp-add-key", "resp-add-value"), + header("user-agent", "new-agent"), header("valid-key-lower", "value")); + } + + @Test + public void filter_validationRules_throwsOnInvalidHeaders() throws Exception { + HeaderMutationRulesConfig rules = + HeaderMutationRulesConfig.builder().disallowIsError(true).build(); + HeaderMutationFilter filter = new HeaderMutationFilter(Optional.of(rules)); + String longString = Strings.repeat("a", MAX_HEADER_LENGTH + 1); + + // Test system headers modification + assertThrows(HeaderMutationDisallowedException.class, () -> filter.filter(HeaderMutations + .create( + ImmutableList.of(header(":path", "/new-path")), ImmutableList.of()))); + + // Test system headers removal + assertThrows(HeaderMutationDisallowedException.class, + () -> filter.filter(HeaderMutations.create( + ImmutableList.of(), ImmutableList.of(":path")))); + + // Test uppercase header modification + assertThrows(HeaderMutationDisallowedException.class, () -> filter.filter(HeaderMutations + .create( + ImmutableList.of(header("Valid-Key", "value")), ImmutableList.of()))); + + // Test uppercase header removal + assertThrows(HeaderMutationDisallowedException.class, () -> filter + .filter(HeaderMutations.create( + ImmutableList.of(), ImmutableList.of("UPPER-REMOVE")))); + + // Test empty header + assertThrows(HeaderMutationDisallowedException.class, () -> filter + .filter(HeaderMutations.create( + ImmutableList.of(header("", "value")), ImmutableList.of()))); + + // Test long header key + assertThrows(HeaderMutationDisallowedException.class, () -> filter + .filter(HeaderMutations.create( + ImmutableList.of(), ImmutableList.of(longString)))); + } + + + @Test + public void filter_mutationRules_disallowAll_dropsAll() throws Exception { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder().disallowAll(true).build(); + HeaderMutationFilter filter = new HeaderMutationFilter(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create( + ImmutableList.of(header("add-key", "add-value"), header("resp-add-key", "resp-add-value")), + ImmutableList.of("remove-key")); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.headers()).isEmpty(); + assertThat(filtered.headersToRemove()).isEmpty(); + } + + @Test + public void filter_mutationRules_disallowAll_throws() throws Exception { + HeaderMutationRulesConfig rules = + HeaderMutationRulesConfig.builder().disallowAll(true).disallowIsError(true).build(); + HeaderMutationFilter filter = new HeaderMutationFilter(Optional.of(rules)); + + // Test add header + assertThrows(HeaderMutationDisallowedException.class, () -> filter.filter(HeaderMutations + .create( + ImmutableList.of(header("add-key", "add-value")), ImmutableList.of()))); + + // Test remove header + assertThrows(HeaderMutationDisallowedException.class, () -> filter + .filter(HeaderMutations.create( + ImmutableList.of(), ImmutableList.of("remove-key")))); + + // Test response header + assertThrows(HeaderMutationDisallowedException.class, () -> filter.filter(HeaderMutations + .create( + ImmutableList.of(header("resp-add-key", "resp-add-value")), ImmutableList.of()))); + } + + + @Test + public void filter_mutationRules_disallowExpression_dropsMatching() throws Exception { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder() + .disallowExpression(Pattern.compile("^x-private-.*")).build(); + HeaderMutationFilter filter = new HeaderMutationFilter(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create( + ImmutableList.of(header("x-public", "value"), header("x-private-key", "value"), + header("x-public-resp", "value"), header("x-private-resp", "value")), + ImmutableList.of("x-public-remove", "x-private-remove")); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.headersToRemove()).containsExactly("x-public-remove"); + assertThat(filtered.headers()).containsExactly(header("x-public", "value"), + header("x-public-resp", "value")); + } + + @Test + public void filter_mutationRules_disallowExpression_throws() throws Exception { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder() + .disallowExpression(Pattern.compile("^x-private-.*")).disallowIsError(true).build(); + HeaderMutationFilter filter = new HeaderMutationFilter(Optional.of(rules)); + + // Test disallowed key modification + assertThrows(HeaderMutationDisallowedException.class, () -> filter.filter(HeaderMutations + .create( + ImmutableList.of(header("x-private-key", "value")), ImmutableList.of()))); + + // Test disallowed key removal + assertThrows(HeaderMutationDisallowedException.class, () -> filter + .filter(HeaderMutations.create( + ImmutableList.of(), ImmutableList.of("x-private-remove")))); + } + + + @Test + public void filter_mutationRules_precedence() throws Exception { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder() + .disallowAll(true) + .allowExpression(Pattern.compile("^x-allowed-.*")) + .disallowExpression(Pattern.compile("^x-allowed-but-disallowed-.*")) + .build(); + HeaderMutationFilter filter = new HeaderMutationFilter(Optional.of(rules)); + + // Case 1: allowExpression overrides disallowAll + HeaderMutations mutations1 = HeaderMutations.create( + ImmutableList.of(header("x-allowed-key", "value"), header("not-allowed", "value")), + ImmutableList.of("x-allowed-remove", "not-allowed-remove")); + HeaderMutations filtered1 = filter.filter(mutations1); + assertThat(filtered1.headersToRemove()).containsExactly("x-allowed-remove"); + assertThat(filtered1.headers()).containsExactly(header("x-allowed-key", "value")); + + // Case 2: disallowExpression overrides allowExpression + HeaderMutations mutations2 = HeaderMutations.create( + ImmutableList.of(header("x-allowed-but-disallowed-key", "value")), + ImmutableList.of("x-allowed-but-disallowed-remove")); + HeaderMutations filtered2 = filter.filter(mutations2); + assertThat(filtered2.headers()).isEmpty(); + assertThat(filtered2.headersToRemove()).isEmpty(); + } + + @Test + public void filter_mutationRules_precedence_throws() throws Exception { + // Case 1: allowExpression overrides disallowAll (does not throw) + HeaderMutationRulesConfig rules1 = HeaderMutationRulesConfig.builder() + .disallowAll(true) + .allowExpression(Pattern.compile("^x-allowed-.*")) + .disallowIsError(true) + .build(); + HeaderMutationFilter filter1 = new HeaderMutationFilter(Optional.of(rules1)); + HeaderMutations mutations1 = HeaderMutations.create( + ImmutableList.of(header("x-allowed-key", "value")), ImmutableList.of("x-allowed-remove")); + HeaderMutations filtered1 = filter1.filter(mutations1); + assertThat(filtered1.headersToRemove()).containsExactly("x-allowed-remove"); + assertThat(filtered1.headers()).containsExactly(header("x-allowed-key", "value")); + + // Case 2: disallowExpression overrides allowExpression (throws) + HeaderMutationRulesConfig rules2 = HeaderMutationRulesConfig.builder() + .allowExpression(Pattern.compile("^x-allowed-.*")) + .disallowExpression(Pattern.compile("^x-allowed-but-disallowed-.*")) + .disallowIsError(true) + .build(); + HeaderMutationFilter filter2 = new HeaderMutationFilter(Optional.of(rules2)); + assertThrows(HeaderMutationDisallowedException.class, + () -> filter2.filter(HeaderMutations.create( + ImmutableList.of(header("x-allowed-but-disallowed-key", "value")), + ImmutableList.of()))); + + assertThrows(HeaderMutationDisallowedException.class, () -> filter2.filter(HeaderMutations + .create( + ImmutableList.of(), ImmutableList.of("x-allowed-but-disallowed-remove")))); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationsTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationsTest.java new file mode 100644 index 00000000000..ef7f22b7ac8 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationsTest.java @@ -0,0 +1,40 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import io.grpc.xds.internal.grpcservice.HeaderValue; +import io.grpc.xds.internal.headermutations.HeaderValueOption.HeaderAppendAction; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutationsTest { + @Test + public void testCreate() { + HeaderValueOption header = HeaderValueOption.create( + HeaderValue.create("key", "value"), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, false); + HeaderMutations mutations = HeaderMutations.create( + ImmutableList.of(header), ImmutableList.of("remove-key")); + assertThat(mutations.headers()).containsExactly(header); + assertThat(mutations.headersToRemove()).containsExactly("remove-key"); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutatorTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutatorTest.java new file mode 100644 index 00000000000..b6806760f9b --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutatorTest.java @@ -0,0 +1,315 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import com.google.common.testing.TestLogHandler; +import com.google.protobuf.ByteString; +import io.grpc.Metadata; +import io.grpc.xds.internal.grpcservice.HeaderValue; +import io.grpc.xds.internal.headermutations.HeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderValueOption.HeaderAppendAction; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutatorTest { + + private static final Metadata.Key BINARY_KEY = + Metadata.Key.of("some-key-bin", Metadata.BINARY_BYTE_MARSHALLER); + private static final Metadata.Key APPEND_KEY = + Metadata.Key.of("append-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key ADD_KEY = + Metadata.Key.of("add-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key OVERWRITE_KEY = + Metadata.Key.of("overwrite-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key REMOVE_KEY = + Metadata.Key.of("remove-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key NEW_ADD_KEY = + Metadata.Key.of("new-add-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key NEW_OVERWRITE_KEY = + Metadata.Key.of("new-overwrite-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key OVERWRITE_IF_EXISTS_KEY = + Metadata.Key.of("overwrite-if-exists-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key OVERWRITE_IF_EXISTS_ABSENT_KEY = + Metadata.Key.of("overwrite-if-exists-absent-key", Metadata.ASCII_STRING_MARSHALLER); + + private final HeaderMutator headerMutator = HeaderMutator.create(); + + private static final TestLogHandler logHandler = new TestLogHandler(); + private static final Logger logger = Logger.getLogger(HeaderMutator.class.getName()); + + @Before + public void setUp() { + logHandler.clear(); + logger.addHandler(logHandler); + logger.setLevel(Level.WARNING); + } + + @After + public void tearDown() { + logger.removeHandler(logHandler); + } + + private static HeaderValueOption header(String key, String value, HeaderAppendAction action) { + return HeaderValueOption.create(HeaderValue.create(key, value), action, false); + } + + @Test + public void applyMutations_asciiHeaders() { + Metadata headers = new Metadata(); + headers.put(APPEND_KEY, "append-value-1"); + headers.put(ADD_KEY, "add-value-original"); + headers.put(OVERWRITE_KEY, "overwrite-value-original"); + headers.put(REMOVE_KEY, "remove-value-original"); + headers.put(OVERWRITE_IF_EXISTS_KEY, "original-value"); + + HeaderMutations mutations = + HeaderMutations.create( + ImmutableList.of( + header( + APPEND_KEY.name(), + "append-value-2", + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header(ADD_KEY.name(), "add-value-new", HeaderAppendAction.ADD_IF_ABSENT), + header(NEW_ADD_KEY.name(), "new-add-value", HeaderAppendAction.ADD_IF_ABSENT), + header( + OVERWRITE_KEY.name(), + "overwrite-value-new", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + header( + NEW_OVERWRITE_KEY.name(), + "new-overwrite-value", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + header( + OVERWRITE_IF_EXISTS_KEY.name(), + "new-value", + HeaderAppendAction.OVERWRITE_IF_EXISTS), + header( + OVERWRITE_IF_EXISTS_ABSENT_KEY.name(), + "new-value", + HeaderAppendAction.OVERWRITE_IF_EXISTS)), + ImmutableList.of(REMOVE_KEY.name())); + + headerMutator.applyMutations(mutations, headers); + + assertThat(headers.getAll(APPEND_KEY)).containsExactly("append-value-1", "append-value-2"); + assertThat(headers.get(ADD_KEY)).isEqualTo("add-value-original"); + assertThat(headers.get(NEW_ADD_KEY)).isEqualTo("new-add-value"); + assertThat(headers.get(OVERWRITE_KEY)).isEqualTo("overwrite-value-new"); + assertThat(headers.get(NEW_OVERWRITE_KEY)).isEqualTo("new-overwrite-value"); + assertThat(headers.containsKey(REMOVE_KEY)).isFalse(); + assertThat(headers.get(OVERWRITE_IF_EXISTS_KEY)).isEqualTo("new-value"); + assertThat(headers.containsKey(OVERWRITE_IF_EXISTS_ABSENT_KEY)).isFalse(); + } + + @Test + public void applyMutations_removalHasPriority() { + Metadata headers = new Metadata(); + headers.put(REMOVE_KEY, "value"); + HeaderMutations mutations = + HeaderMutations.create( + ImmutableList.of( + header( + REMOVE_KEY.name(), "new-value", HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD)), + ImmutableList.of(REMOVE_KEY.name())); + + headerMutator.applyMutations(mutations, headers); + + assertThat(headers.containsKey(REMOVE_KEY)).isFalse(); + } + + @Test + public void applyMutations_binary() { + Metadata headers = new Metadata(); + byte[] value = new byte[] {1, 2, 3}; + HeaderValueOption option = + HeaderValueOption.create( + HeaderValue.create(BINARY_KEY.name(), ByteString.copyFrom(value)), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, + false); + headerMutator.applyMutations( + HeaderMutations.create(ImmutableList.of(option), ImmutableList.of()), headers); + assertThat(headers.get(BINARY_KEY)).isEqualTo(value); + } + + @Test + public void applyResponseMutations_asciiHeaders() { + Metadata headers = new Metadata(); + headers.put(APPEND_KEY, "append-value-1"); + headers.put(ADD_KEY, "add-value-original"); + headers.put(OVERWRITE_KEY, "overwrite-value-original"); + + HeaderMutations mutations = + HeaderMutations.create( + ImmutableList.of( + header( + APPEND_KEY.name(), + "append-value-2", + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header(ADD_KEY.name(), "add-value-new", HeaderAppendAction.ADD_IF_ABSENT), + header(NEW_ADD_KEY.name(), "new-add-value", HeaderAppendAction.ADD_IF_ABSENT), + header( + OVERWRITE_KEY.name(), + "overwrite-value-new", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + header( + NEW_OVERWRITE_KEY.name(), + "new-overwrite-value", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD)), ImmutableList.of()); + + headerMutator.applyMutations(mutations, headers); + + assertThat(headers.getAll(APPEND_KEY)).containsExactly("append-value-1", "append-value-2"); + assertThat(headers.get(ADD_KEY)).isEqualTo("add-value-original"); + assertThat(headers.get(NEW_ADD_KEY)).isEqualTo("new-add-value"); + assertThat(headers.get(OVERWRITE_KEY)).isEqualTo("overwrite-value-new"); + assertThat(headers.get(NEW_OVERWRITE_KEY)).isEqualTo("new-overwrite-value"); + } + + @Test + public void applyResponseMutations_binary() { + Metadata headers = new Metadata(); + byte[] value = new byte[] {1, 2, 3}; + HeaderValueOption option = + HeaderValueOption.create( + HeaderValue.create(BINARY_KEY.name(), ByteString.copyFrom(value)), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, + false); + headerMutator.applyMutations( + HeaderMutations.create(ImmutableList.of(option), ImmutableList.of()), headers); + assertThat(headers.get(BINARY_KEY)).isEqualTo(value); + } + + @Test + public void applyMutations_keepEmptyValue() { + Metadata headers = new Metadata(); + headers.put(APPEND_KEY, "existing-value"); + headers.put(OVERWRITE_KEY, "existing-value"); + headers.put(OVERWRITE_IF_EXISTS_KEY, "existing-value"); + + HeaderMutations mutations = + HeaderMutations.create( + ImmutableList.of( + header(NEW_ADD_KEY.name(), "", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header(APPEND_KEY.name(), "", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header(OVERWRITE_KEY.name(), "", HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + header(ADD_KEY.name(), "", HeaderAppendAction.ADD_IF_ABSENT), + header(OVERWRITE_IF_EXISTS_KEY.name(), "", HeaderAppendAction.OVERWRITE_IF_EXISTS), + HeaderValueOption.create( + HeaderValue.create("keep-empty-key", ""), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, + true), + HeaderValueOption.create( + HeaderValue.create("keep-empty-overwrite-key", ""), + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD, + true), + HeaderValueOption.create( + HeaderValue.create("keep-empty-bin-key-bin", ByteString.EMPTY), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, true), + HeaderValueOption.create( + HeaderValue.create("ignore-empty-bin-key-bin", ByteString.EMPTY), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, false), + HeaderValueOption.create( + HeaderValue.create("overwrite-empty-bin-key-bin", ByteString.EMPTY), + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD, false)), + ImmutableList.of()); + + headers.put( + Metadata.Key.of("keep-empty-overwrite-key", Metadata.ASCII_STRING_MARSHALLER), "old"); + + Metadata.Key overwriteEmptyBinKey = + Metadata.Key.of("overwrite-empty-bin-key-bin", Metadata.BINARY_BYTE_MARSHALLER); + byte[] originalBinValue = new byte[] {1, 2, 3}; + headers.put(overwriteEmptyBinKey, originalBinValue); + + headerMutator.applyMutations(mutations, headers); + + assertThat(headers.containsKey(NEW_ADD_KEY)).isFalse(); + assertThat(headers.getAll(APPEND_KEY)).containsExactly("existing-value"); + assertThat(headers.get(OVERWRITE_KEY)).isEqualTo("existing-value"); + assertThat(headers.containsKey(ADD_KEY)).isFalse(); + assertThat(headers.get(OVERWRITE_IF_EXISTS_KEY)).isEqualTo("existing-value"); + + Metadata.Key keepEmptyKey = + Metadata.Key.of("keep-empty-key", Metadata.ASCII_STRING_MARSHALLER); + Metadata.Key keepEmptyOverwriteKey = + Metadata.Key.of("keep-empty-overwrite-key", Metadata.ASCII_STRING_MARSHALLER); + + assertThat(headers.containsKey(keepEmptyKey)).isTrue(); + assertThat(headers.get(keepEmptyKey)).isEqualTo(""); + assertThat(headers.containsKey(keepEmptyOverwriteKey)).isTrue(); + assertThat(headers.get(keepEmptyOverwriteKey)).isEqualTo(""); + + Metadata.Key keepEmptyBinKey = + Metadata.Key.of("keep-empty-bin-key-bin", Metadata.BINARY_BYTE_MARSHALLER); + Metadata.Key ignoreEmptyBinKey = + Metadata.Key.of("ignore-empty-bin-key-bin", Metadata.BINARY_BYTE_MARSHALLER); + + assertThat(headers.containsKey(keepEmptyBinKey)).isTrue(); + assertThat(headers.get(keepEmptyBinKey)).isEqualTo(new byte[0]); + assertThat(headers.containsKey(ignoreEmptyBinKey)).isFalse(); + assertThat(headers.get(overwriteEmptyBinKey)).isEqualTo(originalBinValue); + } + + @Test + public void applyMutations_binaryRemoval() { + Metadata headers = new Metadata(); + byte[] value = new byte[] {1, 2, 3}; + headers.put(BINARY_KEY, value); + HeaderMutations mutations = + HeaderMutations.create(ImmutableList.of(), ImmutableList.of(BINARY_KEY.name())); + + headerMutator.applyMutations(mutations, headers); + + assertThat(headers.containsKey(BINARY_KEY)).isFalse(); + } + + @Test + public void applyMutations_stringValueWithBinaryKey_ignored() { + Metadata headers = new Metadata(); + HeaderValueOption option = HeaderValueOption.create(HeaderValue.create("some-key-bin", "value"), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, false); + + headerMutator.applyMutations( + HeaderMutations.create(ImmutableList.of(option), ImmutableList.of()), headers); + + Metadata.Key key = Metadata.Key.of("some-key-bin", Metadata.BINARY_BYTE_MARSHALLER); + assertThat(headers.containsKey(key)).isFalse(); + } + + @Test + public void applyMutations_binaryValueWithAsciiKey_ignored() { + Metadata headers = new Metadata(); + HeaderValueOption option = HeaderValueOption.create( + HeaderValue.create("some-key", ByteString.copyFrom(new byte[] {1})), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, false); + + headerMutator.applyMutations( + HeaderMutations.create(ImmutableList.of(option), ImmutableList.of()), headers); + + Metadata.Key key = Metadata.Key.of("some-key", Metadata.ASCII_STRING_MARSHALLER); + assertThat(headers.containsKey(key)).isFalse(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderValueOptionTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderValueOptionTest.java new file mode 100644 index 00000000000..49c43749135 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderValueOptionTest.java @@ -0,0 +1,40 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; + +import io.grpc.xds.internal.grpcservice.HeaderValue; +import io.grpc.xds.internal.headermutations.HeaderValueOption.HeaderAppendAction; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderValueOptionTest { + + @Test + public void create_withAllFields_success() { + HeaderValue header = HeaderValue.create("key1", "value1"); + HeaderValueOption option = HeaderValueOption.create( + header, HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, true); + + assertThat(option.header()).isEqualTo(header); + assertThat(option.appendAction()).isEqualTo(HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD); + assertThat(option.keepEmptyValue()).isTrue(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index f11c661e211..125b7e65aa6 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -209,7 +209,7 @@ public void updateSslContextAndExtendedX509TrustManager( protected void onException(Throwable throwable) { future.set(throwable); } - }, false); + }, true); assertThat(executor.runDueTasks()).isEqualTo(1); channel.runPendingTasks(); Object fromFuture = future.get(2, TimeUnit.SECONDS); @@ -356,7 +356,7 @@ public void updateSslContextAndExtendedX509TrustManager( protected void onException(Throwable throwable) { future.set(throwable); } - }, false); + }, true); channel.runPendingTasks(); // need this for tasks to execute on eventLoop assertThat(executor.runDueTasks()).isEqualTo(1); Object fromFuture = future.get(2, TimeUnit.SECONDS); @@ -493,7 +493,7 @@ public void updateSslContextAndExtendedX509TrustManager( protected void onException(Throwable throwable) { future.set(throwable); } - }, false); + }, true); executor.runDueTasks(); channel.runPendingTasks(); // need this for tasks to execute on eventLoop Object fromFuture = future.get(5, TimeUnit.SECONDS); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java index 620ee0a7ff7..f6fdc51dece 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java @@ -261,6 +261,59 @@ public void certAndKeyFileUpdateOnly() verifyTimeServiceAndScheduledFuture(); } + @Test + public void certFileUpdateOnly() + throws IOException, CertificateException, InterruptedException { + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + // Ideally we'd use a matching cert/key pair here, but we don't actually have any ready-made. + // The test doesn't notice they don't match though. + populateTarget( + CLIENT_PEM_FILE, SERVER_0_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); + provider.checkAndReloadCertificates(); + + reset(mockWatcher, timeService); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + timeProvider.forwardTime(1, TimeUnit.SECONDS); + // It's normal to get a newer cert while continuing to use the same private key + populateTarget(SERVER_0_PEM_FILE, null, null, null, false, false, false, false); + provider.checkAndReloadCertificates(); + verifyWatcherUpdates(SERVER_0_PEM_FILE, null, null); + verifyTimeServiceAndScheduledFuture(); + } + + @Test + public void keyFileUpdateOnly() + throws IOException, CertificateException, InterruptedException { + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + // Assume the key/cert is not updated atomically and we see a tear between them. Or maybe this + // was just a bug. + populateTarget( + SERVER_0_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); + provider.checkAndReloadCertificates(); + + reset(mockWatcher, timeService); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + timeProvider.forwardTime(1, TimeUnit.SECONDS); + // Even though it is strange the key updated without a cert update, we do still want to use the + // new files, as this recovers from the earlier tear. + populateTarget(null, SERVER_0_KEY_FILE, null, null, false, false, false, false); + provider.checkAndReloadCertificates(); + verifyWatcherUpdates(SERVER_0_PEM_FILE, null, null); + verifyTimeServiceAndScheduledFuture(); + } + @Test public void spiffeTrustMapFileUpdateOnly() throws Exception { provider = new FileWatcherCertificateProvider(watcher, true, certFile, keyFile, null,