Skip to content

Commit 647c1c2

Browse files
authored
Fix WsSession.send race (27.x) (helidon-io#11787)
* Fix websocket send race * Publish websocket close intent before locking
1 parent 96ac8fb commit 647c1c2

5 files changed

Lines changed: 690 additions & 31 deletions

File tree

webclient/websocket/src/main/java/io/helidon/webclient/websocket/ClientWsConnection.java

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023, 2025 Oracle and/or its affiliates.
2+
* Copyright (c) 2023, 2026 Oracle and/or its affiliates.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,6 +18,9 @@
1818

1919
import java.nio.charset.StandardCharsets;
2020
import java.util.Optional;
21+
import java.util.concurrent.atomic.AtomicBoolean;
22+
import java.util.concurrent.locks.Lock;
23+
import java.util.concurrent.locks.ReentrantLock;
2124

2225
import io.helidon.common.buffers.BufferData;
2326
import io.helidon.common.buffers.DataReader;
@@ -44,10 +47,11 @@ public class ClientWsConnection implements WsSession, Runnable {
4447
private final BufferData sendBuffer = BufferData.growing(1024);
4548
private final ClientConnection connection;
4649
private final HelidonSocket helidonSocket;
50+
private final Lock sendLock = new ReentrantLock();
4751

4852
private ContinuationType recvContinuation = ContinuationType.NONE;
4953
private boolean sendContinuation;
50-
private boolean closeSent;
54+
private final AtomicBoolean closeSent = new AtomicBoolean();
5155
private boolean terminated;
5256

5357
ClientWsConnection(ClientConnection connection,
@@ -113,22 +117,42 @@ public void run() {
113117

114118
@Override
115119
public WsSession send(String text, boolean last) {
116-
return send(ClientWsFrame.data(text, last));
120+
sendLock.lock();
121+
try {
122+
return sendLocked(ClientWsFrame.data(text, last));
123+
} finally {
124+
sendLock.unlock();
125+
}
117126
}
118127

119128
@Override
120129
public WsSession send(BufferData bufferData, boolean last) {
121-
return send(ClientWsFrame.data(bufferData, last));
130+
sendLock.lock();
131+
try {
132+
return sendLocked(ClientWsFrame.data(bufferData, last));
133+
} finally {
134+
sendLock.unlock();
135+
}
122136
}
123137

124138
@Override
125139
public WsSession ping(BufferData bufferData) {
126-
return send(ClientWsFrame.control(WsOpCode.PING, bufferData));
140+
sendLock.lock();
141+
try {
142+
return sendLocked(ClientWsFrame.control(WsOpCode.PING, bufferData));
143+
} finally {
144+
sendLock.unlock();
145+
}
127146
}
128147

129148
@Override
130149
public WsSession pong(BufferData bufferData) {
131-
return send(ClientWsFrame.control(WsOpCode.PONG, bufferData));
150+
sendLock.lock();
151+
try {
152+
return sendLocked(ClientWsFrame.control(WsOpCode.PONG, bufferData));
153+
} finally {
154+
sendLock.unlock();
155+
}
132156
}
133157

134158
/**
@@ -142,17 +166,24 @@ public WsSession pong(BufferData bufferData) {
142166
*/
143167
@Override
144168
public WsSession close(int code, String reason) {
145-
closeSent = true;
169+
if (!closeSent.compareAndSet(false, true)) {
170+
return this;
171+
}
146172

147-
// send empty close (no code or reason) if code is negative
148-
if (code < 0) {
149-
send(ClientWsFrame.control(WsOpCode.CLOSE, BufferData.empty()));
150-
} else {
151-
byte[] reasonBytes = reason.getBytes(StandardCharsets.UTF_8);
152-
BufferData bufferData = BufferData.create(2 + reasonBytes.length);
153-
bufferData.writeInt16(code);
154-
bufferData.write(reasonBytes);
155-
send(ClientWsFrame.control(WsOpCode.CLOSE, bufferData));
173+
sendLock.lock();
174+
try {
175+
// send empty close (no code or reason) if code is negative
176+
if (code < 0) {
177+
sendLocked(ClientWsFrame.control(WsOpCode.CLOSE, BufferData.empty()));
178+
} else {
179+
byte[] reasonBytes = reason.getBytes(StandardCharsets.UTF_8);
180+
BufferData bufferData = BufferData.create(2 + reasonBytes.length);
181+
bufferData.writeInt16(code);
182+
bufferData.write(reasonBytes);
183+
sendLocked(ClientWsFrame.control(WsOpCode.CLOSE, bufferData));
184+
}
185+
} finally {
186+
sendLock.unlock();
156187
}
157188
return this;
158189
}
@@ -175,7 +206,7 @@ public SocketContext socketContext() {
175206
return helidonSocket;
176207
}
177208

178-
private ClientWsConnection send(ClientWsFrame frame) {
209+
private ClientWsConnection sendLocked(ClientWsFrame frame) {
179210
WsOpCode opCode = frame.opCode();
180211
if (opCode == WsOpCode.TEXT || opCode == WsOpCode.BINARY) {
181212
if (sendContinuation) {
@@ -230,7 +261,7 @@ private void doRun() {
230261
} catch (DataReader.InsufficientDataAvailableException e) {
231262
return;
232263
} catch (WsCloseException e) {
233-
if (!closeSent) {
264+
if (!closeSent.get()) {
234265
try {
235266
close(e.closeCode(), e.getMessage());
236267
} catch (Exception ex) {
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
/*
2+
* Copyright (c) 2026 Oracle and/or its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.helidon.webclient.websocket;
18+
19+
import java.lang.reflect.Field;
20+
import java.time.Duration;
21+
import java.util.concurrent.CountDownLatch;
22+
import java.util.concurrent.TimeUnit;
23+
import java.util.concurrent.atomic.AtomicBoolean;
24+
import java.util.concurrent.atomic.AtomicReference;
25+
26+
import io.helidon.common.buffers.BufferData;
27+
import io.helidon.common.buffers.DataReader;
28+
import io.helidon.common.buffers.DataWriter;
29+
import io.helidon.common.socket.HelidonSocket;
30+
import io.helidon.common.socket.PeerInfo;
31+
import io.helidon.webclient.api.ClientConnection;
32+
import io.helidon.websocket.WsCloseCodes;
33+
import io.helidon.websocket.WsListener;
34+
import io.helidon.websocket.WsSession;
35+
36+
import org.junit.jupiter.api.Test;
37+
38+
import static org.hamcrest.MatcherAssert.assertThat;
39+
import static org.hamcrest.Matchers.is;
40+
import static org.hamcrest.Matchers.nullValue;
41+
import static org.junit.jupiter.api.Assertions.assertAll;
42+
import static org.junit.jupiter.api.Assertions.fail;
43+
44+
class ClientWsConnectionTest {
45+
private static final long TEST_TIMEOUT_SECONDS = 5;
46+
47+
@Test
48+
void closePublishesStateBeforeWaitingForSendLock() throws Exception {
49+
BlockingDataWriter dataWriter = new BlockingDataWriter();
50+
ClientWsConnection connection = ClientWsConnection.create(new TestClientConnection(dataWriter),
51+
new WsListener() {
52+
@Override
53+
public void onMessage(WsSession session,
54+
String text,
55+
boolean last) {
56+
}
57+
});
58+
59+
AtomicReference<Throwable> sendFailure = new AtomicReference<>();
60+
Thread sendThread = new Thread(() -> invoke(() -> connection.send("hello", true), sendFailure), "client-ws-send");
61+
sendThread.start();
62+
dataWriter.awaitFirstWrite();
63+
64+
AtomicReference<Throwable> closeFailure = new AtomicReference<>();
65+
Thread closeThread = new Thread(() -> invoke(() -> connection.close(WsCloseCodes.NORMAL_CLOSE, "done"),
66+
closeFailure),
67+
"client-ws-close");
68+
closeThread.start();
69+
70+
assertThat("close flag was not published while send lock was held",
71+
awaitCloseSent(connection),
72+
is(true));
73+
74+
dataWriter.releaseFirstWrite();
75+
sendThread.join(TimeUnit.SECONDS.toMillis(TEST_TIMEOUT_SECONDS));
76+
closeThread.join(TimeUnit.SECONDS.toMillis(TEST_TIMEOUT_SECONDS));
77+
78+
assertAll(
79+
() -> assertThat("send thread failed", sendFailure.get(), is(nullValue())),
80+
() -> assertThat("close thread failed", closeFailure.get(), is(nullValue())),
81+
() -> assertThat("send thread did not finish", sendThread.isAlive(), is(false)),
82+
() -> assertThat("close thread did not finish", closeThread.isAlive(), is(false))
83+
);
84+
}
85+
86+
private static boolean awaitCloseSent(ClientWsConnection connection)
87+
throws ReflectiveOperationException, InterruptedException {
88+
Field field = ClientWsConnection.class.getDeclaredField("closeSent");
89+
field.setAccessible(true);
90+
AtomicBoolean closeSent = (AtomicBoolean) field.get(connection);
91+
92+
long timeoutNanos = TimeUnit.SECONDS.toNanos(TEST_TIMEOUT_SECONDS);
93+
long deadline = System.nanoTime() + timeoutNanos;
94+
while (System.nanoTime() < deadline) {
95+
if (closeSent.get()) {
96+
return true;
97+
}
98+
TimeUnit.MILLISECONDS.sleep(10);
99+
}
100+
return closeSent.get();
101+
}
102+
103+
private static void invoke(ThrowingRunnable action, AtomicReference<Throwable> failure) {
104+
try {
105+
action.run();
106+
} catch (Throwable t) {
107+
failure.set(t);
108+
}
109+
}
110+
111+
@FunctionalInterface
112+
private interface ThrowingRunnable {
113+
void run() throws Exception;
114+
}
115+
116+
private static final class TestClientConnection implements ClientConnection {
117+
private final DataWriter dataWriter;
118+
private final HelidonSocket socket = new TestHelidonSocket();
119+
120+
private TestClientConnection(DataWriter dataWriter) {
121+
this.dataWriter = dataWriter;
122+
}
123+
124+
@Override
125+
public DataReader reader() {
126+
return null;
127+
}
128+
129+
@Override
130+
public DataWriter writer() {
131+
return dataWriter;
132+
}
133+
134+
@Override
135+
public String channelId() {
136+
return "test";
137+
}
138+
139+
@Override
140+
public HelidonSocket helidonSocket() {
141+
return socket;
142+
}
143+
144+
@Override
145+
public void readTimeout(Duration readTimeout) {
146+
}
147+
148+
@Override
149+
public void closeResource() {
150+
}
151+
}
152+
153+
private static final class TestHelidonSocket implements HelidonSocket {
154+
@Override
155+
public void close() {
156+
}
157+
158+
@Override
159+
public void idle() {
160+
}
161+
162+
@Override
163+
public boolean isConnected() {
164+
return true;
165+
}
166+
167+
@Override
168+
public void write(BufferData buffer) {
169+
}
170+
171+
@Override
172+
public PeerInfo remotePeer() {
173+
return null;
174+
}
175+
176+
@Override
177+
public PeerInfo localPeer() {
178+
return null;
179+
}
180+
181+
@Override
182+
public boolean isSecure() {
183+
return false;
184+
}
185+
186+
@Override
187+
public String socketId() {
188+
return "test";
189+
}
190+
191+
@Override
192+
public String childSocketId() {
193+
return "test";
194+
}
195+
196+
@Override
197+
public byte[] get() {
198+
return new byte[0];
199+
}
200+
}
201+
202+
private static final class BlockingDataWriter implements DataWriter {
203+
private final CountDownLatch firstWriteStarted = new CountDownLatch(1);
204+
private final CountDownLatch releaseFirstWrite = new CountDownLatch(1);
205+
private final AtomicBoolean firstWrite = new AtomicBoolean(true);
206+
207+
@Override
208+
public void write(BufferData... buffers) {
209+
throw new UnsupportedOperationException();
210+
}
211+
212+
@Override
213+
public void write(BufferData buffer) {
214+
throw new UnsupportedOperationException();
215+
}
216+
217+
@Override
218+
public void writeNow(BufferData... buffers) {
219+
throw new UnsupportedOperationException();
220+
}
221+
222+
@Override
223+
public void writeNow(BufferData buffer) {
224+
if (firstWrite.compareAndSet(true, false)) {
225+
firstWriteStarted.countDown();
226+
try {
227+
if (!releaseFirstWrite.await(TEST_TIMEOUT_SECONDS, TimeUnit.SECONDS)) {
228+
fail("Timed out waiting to release first write");
229+
}
230+
} catch (InterruptedException e) {
231+
Thread.currentThread().interrupt();
232+
fail("Interrupted while waiting to release first write", e);
233+
}
234+
}
235+
}
236+
237+
void awaitFirstWrite() throws InterruptedException {
238+
if (!firstWriteStarted.await(TEST_TIMEOUT_SECONDS, TimeUnit.SECONDS)) {
239+
fail("Timed out waiting for first websocket send");
240+
}
241+
}
242+
243+
void releaseFirstWrite() {
244+
releaseFirstWrite.countDown();
245+
}
246+
}
247+
}

0 commit comments

Comments
 (0)