diff --git a/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java b/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java index 3c8d3d07571..67001085b02 100644 --- a/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java +++ b/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java @@ -218,7 +218,15 @@ private void assureReadyAndDrainedTurnsFalse() { private void runOrBuffer(ActionItem actionItem) throws IOException { WriteState curState = writeState.get(); if (curState.readyAndDrained) { // write to the outputStream directly - actionItem.run(); + try { + actionItem.run(); + } catch (IllegalStateException e) { + if (actionItem == flushAction || actionItem == completeAction) { + throw e; + } + buffer(actionItem, curState); + return; + } if (actionItem == completeAction) { return; } @@ -230,20 +238,26 @@ private void runOrBuffer(ActionItem actionItem) throws IOException { log.finest("the servlet output stream becomes not ready"); } } else { // buffer to the writeChain - writeChain.offer(actionItem); - if (!writeState.compareAndSet(curState, curState.withReadyAndDrained(false))) { - checkState( - writeState.get().readyAndDrained, - "Bug: onWritePossible() should have changed readyAndDrained to true, but not"); - ActionItem lastItem = writeChain.poll(); - if (lastItem != null) { - checkState(lastItem == actionItem, "Bug: lastItem != actionItem"); - runOrBuffer(lastItem); - } - } // state has not changed since + buffer(actionItem, curState); } } + private void buffer(ActionItem actionItem, WriteState curState) throws IOException { + writeChain.offer(actionItem); + if (writeState.compareAndSet(curState, curState.withReadyAndDrained(false))) { + LockSupport.unpark(parkingThread); + } else { + checkState( + writeState.get().readyAndDrained, + "Bug: onWritePossible() should have changed readyAndDrained to true, but not"); + ActionItem lastItem = writeChain.poll(); + if (lastItem != null) { + checkState(lastItem == actionItem, "Bug: lastItem != actionItem"); + runOrBuffer(lastItem); + } + } // state has not changed since + } + /** Write actions, e.g. writeBytes, flush, complete. */ @FunctionalInterface @VisibleForTesting diff --git a/servlet/src/test/java/io/grpc/servlet/AsyncServletOutputStreamWriterTest.java b/servlet/src/test/java/io/grpc/servlet/AsyncServletOutputStreamWriterTest.java new file mode 100644 index 00000000000..e7474078ec9 --- /dev/null +++ b/servlet/src/test/java/io/grpc/servlet/AsyncServletOutputStreamWriterTest.java @@ -0,0 +1,312 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.servlet; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import io.grpc.servlet.AsyncServletOutputStreamWriter.ActionItem; +import io.grpc.servlet.AsyncServletOutputStreamWriter.Log; +import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit test for {@link AsyncServletOutputStreamWriter} with a mock isReady supplier. */ +@RunWith(JUnit4.class) +public class AsyncServletOutputStreamWriterTest { + + @Test + public void writeBytes_notReadyException_buffersUntilOnWritePossible() throws IOException { + List actions = new ArrayList<>(); + AtomicBoolean rejectWrites = new AtomicBoolean(true); + + BiFunction writeAction = + (bytes, numBytes) -> () -> { + if (rejectWrites.get()) { + throw new IllegalStateException("not ready"); + } + actions.add("write"); + }; + ActionItem flushAction = () -> { }; + ActionItem completeAction = () -> { }; + + AsyncServletOutputStreamWriter writer = + new AsyncServletOutputStreamWriter( + writeAction, flushAction, completeAction, () -> true, new Log() {}); + + writer.onWritePossible(); + + writer.writeBytes(new byte[]{1}, 1); + + assertEquals("Write should be buffered until onWritePossible", 0, actions.size()); + + rejectWrites.set(false); + writer.onWritePossible(); + assertEquals("Buffered write should drain after onWritePossible", 1, actions.size()); + } + + @Test + public void writeBytes_consecutiveWithIsReadyTrue_allGoDirect() throws IOException { + List writtenData = new ArrayList<>(); + + BiFunction writeAction = + (bytes, numBytes) -> () -> writtenData.add(Arrays.copyOf(bytes, numBytes)); + ActionItem flushAction = () -> { }; + ActionItem completeAction = () -> { }; + + AsyncServletOutputStreamWriter writer = + new AsyncServletOutputStreamWriter( + writeAction, flushAction, completeAction, () -> true, new Log() {}); + + writer.onWritePossible(); + + for (int i = 0; i < 5; i++) { + writer.writeBytes(new byte[]{(byte) i}, 1); + } + + assertEquals("All writes should complete", 5, writtenData.size()); + } + + @Test + public void writeBytes_isReadyFalseAfterWrite_buffersNextWrite() throws IOException { + List writtenData = new ArrayList<>(); + AtomicBoolean isReady = new AtomicBoolean(true); + + BiFunction writeAction = + (bytes, numBytes) -> () -> { + writtenData.add(Arrays.copyOf(bytes, numBytes)); + isReady.set(false); + }; + ActionItem flushAction = () -> { }; + ActionItem completeAction = () -> { }; + + AsyncServletOutputStreamWriter writer = + new AsyncServletOutputStreamWriter( + writeAction, flushAction, completeAction, isReady::get, new Log() {}); + + writer.onWritePossible(); + + writer.writeBytes(new byte[]{1}, 1); + writer.writeBytes(new byte[]{2}, 1); + assertEquals("Second write should be buffered", 1, writtenData.size()); + + isReady.set(true); + writer.onWritePossible(); + assertEquals("Buffered write should drain", 2, writtenData.size()); + } + + @Test + public void flush_isReadyFalse_buffersUntilOnWritePossible() throws IOException { + List actions = new ArrayList<>(); + AtomicBoolean isReady = new AtomicBoolean(true); + + BiFunction writeAction = + (bytes, numBytes) -> () -> actions.add("write"); + ActionItem flushAction = () -> { + actions.add("flush"); + isReady.set(false); + }; + ActionItem completeAction = () -> { }; + + AsyncServletOutputStreamWriter writer = + new AsyncServletOutputStreamWriter( + writeAction, flushAction, completeAction, isReady::get, new Log() {}); + + writer.onWritePossible(); + + writer.flush(); + assertEquals("First flush should execute directly", 1, actions.size()); + + writer.flush(); + assertEquals("Second flush should be buffered", 1, actions.size()); + + isReady.set(true); + writer.onWritePossible(); + assertEquals("Both flushes should complete after onWritePossible", 2, actions.size()); + } + + @Test + public void flush_consecutiveWithIsReadyTrue_bothGoDirect() throws IOException { + List actions = new ArrayList<>(); + + BiFunction writeAction = + (bytes, numBytes) -> () -> actions.add("write"); + ActionItem flushAction = () -> actions.add("flush"); + ActionItem completeAction = () -> { }; + + AsyncServletOutputStreamWriter writer = + new AsyncServletOutputStreamWriter( + writeAction, flushAction, completeAction, () -> true, new Log() {}); + + writer.onWritePossible(); + + writer.flush(); + writer.flush(); + + assertEquals("Both flushes should execute directly", 2, actions.size()); + } + + @Test + public void complete_readyAndDrained_runsDirectly() throws IOException { + AtomicInteger completeCount = new AtomicInteger(); + + AsyncServletOutputStreamWriter writer = + new AsyncServletOutputStreamWriter( + (bytes, numBytes) -> () -> { }, + () -> { }, + completeCount::incrementAndGet, + () -> true, + new Log() {}); + + writer.onWritePossible(); + + writer.complete(); + + assertEquals(1, completeCount.get()); + } + + @Test + public void complete_notReadyAndDrained_buffersUntilOnWritePossible() throws IOException { + AtomicInteger completeCount = new AtomicInteger(); + + AsyncServletOutputStreamWriter writer = + new AsyncServletOutputStreamWriter( + (bytes, numBytes) -> () -> { }, + () -> { }, + completeCount::incrementAndGet, + () -> true, + new Log() {}); + + writer.complete(); + assertEquals(0, completeCount.get()); + + writer.onWritePossible(); + assertEquals(1, completeCount.get()); + } + + @Test + public void writeBytes_onWritePossibleWinsRace_drainsBufferedWrite() throws Exception { + List actions = new ArrayList<>(); + AsyncServletOutputStreamWriter writer = + new AsyncServletOutputStreamWriter( + (bytes, numBytes) -> () -> actions.add("write"), + () -> { + }, + () -> { + }, + () -> true, + new Log() { + }); + replaceWriteChain(writer, new ConcurrentLinkedQueue() { + @Override + public boolean offer(ActionItem actionItem) { + boolean offered = super.offer(actionItem); + try { + writer.onWritePossible(); + } catch (IOException e) { + throw new AssertionError(e); + } + return offered; + } + }); + + writer.writeBytes(new byte[]{1}, 1); + + assertEquals(1, actions.size()); + } + + @Test + public void writeBytes_readyStateWinsRace_retriesWrite() throws Exception { + List actions = new ArrayList<>(); + AsyncServletOutputStreamWriter writer = + new AsyncServletOutputStreamWriter( + (bytes, numBytes) -> () -> actions.add("write"), + () -> { + }, + () -> { + }, + () -> true, + new Log() { + }); + replaceWriteChain(writer, new ConcurrentLinkedQueue() { + @Override + public boolean offer(ActionItem actionItem) { + boolean offered = super.offer(actionItem); + try { + forceReadyAndDrained(writer); + } catch (ReflectiveOperationException e) { + throw new LinkageError(e.getMessage(), e); + } + return offered; + } + }); + + writer.writeBytes(new byte[]{1}, 1); + + assertEquals(1, actions.size()); + } + + private static void replaceWriteChain( + AsyncServletOutputStreamWriter writer, ConcurrentLinkedQueue writeChain) + throws ReflectiveOperationException { + Field writeChainField = AsyncServletOutputStreamWriter.class.getDeclaredField("writeChain"); + writeChainField.setAccessible(true); + writeChainField.set(writer, writeChain); + } + + private static void forceReadyAndDrained(AsyncServletOutputStreamWriter writer) + throws ReflectiveOperationException { + Field writeStateField = AsyncServletOutputStreamWriter.class.getDeclaredField("writeState"); + writeStateField.setAccessible(true); + @SuppressWarnings("unchecked") + AtomicReference writeState = + (AtomicReference) writeStateField.get(writer); + Object curState = writeState.get(); + Method withReadyAndDrained = curState.getClass().getDeclaredMethod( + "withReadyAndDrained", boolean.class); + withReadyAndDrained.setAccessible(true); + writeState.set(withReadyAndDrained.invoke(curState, true)); + } + + @Test + public void flush_notReadyException_isPropagated() throws IOException { + AsyncServletOutputStreamWriter writer = + new AsyncServletOutputStreamWriter( + (bytes, numBytes) -> () -> { }, + () -> { + throw new IllegalStateException("not ready"); + }, + () -> { }, + () -> true, + new Log() {}); + + writer.onWritePossible(); + + assertThrows(IllegalStateException.class, writer::flush); + } +}