@@ -41,13 +41,12 @@ func newWebsocketPump(
4141 frontend : frontend ,
4242 backend : backend ,
4343 logger : logger ,
44- done : make (chan struct {}),
44+ done : make (chan struct {}, 1 ),
4545 }
4646}
4747
4848func (p * websocketPump ) run () error {
4949 failure := make (chan error , 2 )
50- done := make (chan struct {}, 2 )
5150
5251 p .frontend .SetPingHandler (p .pumpPings (p .frontend , p .backend , "f->b" ))
5352 p .backend .SetPongHandler (p .pumpPongs (p .backend , p .frontend , "b->f" ))
@@ -58,45 +57,58 @@ func (p *websocketPump) run() error {
5857 p .frontend .SetCloseHandler (p .pumpCloseMessages (p .frontend , p .backend , "f->b" ))
5958 p .backend .SetCloseHandler (p .pumpCloseMessages (p .backend , p .frontend , "b->f" ))
6059
61- go p .pumpMessages (p .backend , p .frontend , p .cfg .proxy .ForwardTimeout , "b->f" , done , failure )
62- go p .pumpMessages (p .frontend , p .backend , p .cfg .proxy .BackwardTimeout , "f->b" , done , failure )
60+ doneForward := make (chan struct {}, 1 )
61+ go p .pumpMessages (p .backend , p .frontend , p .cfg .proxy .ForwardTimeout , "b->f" , doneForward , failure )
62+
63+ doneReverse := make (chan struct {}, 1 )
64+ go p .pumpMessages (p .frontend , p .backend , p .cfg .proxy .BackwardTimeout , "f->b" , doneReverse , failure )
6365
6466 p .active .Store (true )
6567
68+ errs := make ([]error , 0 )
69+
70+ loop:
6671 for {
6772 select {
6873 case <- p .done :
69- done <- struct {}{}
7074 p .active .Store (false )
71- return nil
75+ doneForward <- struct {}{}
76+ doneReverse <- struct {}{}
77+ break loop
7278
7379 case err := <- failure :
74- errs := make ([]error , 0 , 2 )
7580 errs = append (errs , err )
81+ break loop
82+ }
83+ }
7684
77- exhaustErrs:
78- for {
79- select {
80- case err := <- failure :
81- errs = append (errs , err )
82- default :
83- break exhaustErrs
84- }
85- }
86- err = utils .FlattenErrors (errs )
87-
88- done <- struct {}{}
89- p .active .Store (false )
90- return err
85+ exhaustErrors:
86+ for {
87+ select {
88+ case err := <- failure :
89+ errs = append (errs , err )
90+ default :
91+ break exhaustErrors
9192 }
9293 }
94+
95+ return utils .FlattenErrors (errs )
9396}
9497
9598func (p * websocketPump ) stop () error {
96- p .done <- struct {}{}
99+ if ! p .active .Load () {
100+ return nil // already stopped
101+ }
97102
98103 errs := make ([]error , 0 )
99104
105+ select {
106+ case p .done <- struct {}{}:
107+ // no-op
108+ default :
109+ errs = append (errs , errors .New ("double-closing the pump" ))
110+ }
111+
100112 countdown := 60
101113 for p .active .Load () && countdown > 0 {
102114 time .Sleep (time .Second )
@@ -130,8 +142,19 @@ func (p *websocketPump) pumpMessages(
130142 )
131143
132144 messages := make (chan * websocketMessage , 16 )
133- doneReads := make (chan struct {})
134- doneWrites := make (chan struct {})
145+ doneReads := make (chan struct {}, 1 )
146+ doneWrites := make (chan struct {}, 1 )
147+
148+ notifyOnFailure := func (err error ) {
149+ select {
150+ case failure <- err :
151+ // no-op
152+ default :
153+ l .Warn ("Dropping websocket pump failure b/c the channel is full" ,
154+ zap .Error (err ),
155+ )
156+ }
157+ }
135158
136159 go func () { // read
137160 for {
@@ -141,13 +164,13 @@ func (p *websocketPump) pumpMessages(
141164
142165 default :
143166 if err := from .SetReadDeadline (utils .Deadline (timeout )); err != nil {
144- failure <- err
167+ notifyOnFailure ( err )
145168 continue
146169 }
147170
148171 msgType , bytes , err := from .ReadMessage ()
149172 if err != nil {
150- failure <- err
173+ notifyOnFailure ( err )
151174 continue
152175 }
153176
@@ -213,12 +236,12 @@ func (p *websocketPump) pumpMessages(
213236 }
214237
215238 if err := to .SetWriteDeadline (utils .Deadline (timeout )); err != nil {
216- failure <- err
239+ notifyOnFailure ( err )
217240 continue
218241 }
219242
220243 if err := to .WriteMessage (m .msgType , m .bytes ); err != nil {
221- failure <- err
244+ notifyOnFailure ( err )
222245 continue
223246 }
224247
0 commit comments