Skip to content
25 changes: 15 additions & 10 deletions sei-tendermint/internal/p2p/mux/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ var errFrameAfterClose = errors.New("received frame after CLOSE frame")
var errTooManyMsgs = errors.New("too many messages")
var errTooLargeMsg = errors.New("message too large")
var errUnknownKind = errors.New("unknown kind")
var errStreamKindMismatch = errors.New("stream kind mismatch")
var errAlreadyOpened = errors.New("stream already opened")
var errAlreadyClosed = errors.New("stream already closed")

type Config struct {
// Maximal number of bytes in a frame (excluding header).
Expand Down Expand Up @@ -121,20 +124,25 @@ func newRunner(mux *Mux) *runner {
}
}

// getOrAccept() gets the current state of the stream with the given id (kind is ignored).
// getOrAccept() gets the current state of the stream for the given header message.
// If the stream does not exist yet, it tries to create it as an accept (inbound) stream.
// In that case the inbound stream limit for the given kind is checked.
func (r *runner) getOrAccept(id streamID, kind StreamKind) (*streamState, error) {
func (r *runner) getOrAccept(h *pb.Header) (*streamState, error) {
id := streamIDFromRemote(h.Id)
kind := StreamKind(h.GetKind())
for inner := range r.inner.RLock() {
s, ok := inner.streams[id]
if ok {
if h.Kind != nil && s.kind != kind {
return nil, errStreamKindMismatch
}
return s, nil
}
}
if id.isConnect() || h.Kind == nil {
return nil, errUnknownStream
}
for inner := range r.inner.Lock() {
if id.isConnect() {
return nil, errUnknownStream
}
if inner.acceptsSem[kind] == 0 {
return nil, errTooManyAccepts
}
Expand Down Expand Up @@ -258,10 +266,7 @@ func (r *runner) runRecv(ctx context.Context, conn conn.Conn) error {
if err := proto.Unmarshal(headerRaw, &h); err != nil {
return err
}
id := streamIDFromRemote(h.Id)
kind := StreamKind(h.GetKind())

s, err := r.getOrAccept(id, kind)
s, err := r.getOrAccept(&h)
if err != nil {
return err
}
Expand All @@ -276,7 +281,7 @@ func (r *runner) runRecv(ctx context.Context, conn conn.Conn) error {
return err
}
if !s.id.isConnect() {
r.mux.kinds[kind].acceptsQueue <- s
r.mux.kinds[s.kind].acceptsQueue <- s
}
}
if we := h.GetWindowEnd(); we > 0 {
Expand Down
137 changes: 137 additions & 0 deletions sei-tendermint/internal/p2p/mux/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,140 @@ func TestProtocol_UnknownKind(t *testing.T) {
t.Fatal(err)
}
}

func TestProtocol_OpenWithoutKind(t *testing.T) {
err := scope.Run(t.Context(), func(ctx context.Context, s scope.Scope) error {
c1, c2 := conn.NewTestConn()
kind := StreamKind(0)

// Bad mux.
badMux := NewMux(makeConfig(kind))
s.SpawnBg(func() error { return badMux.Run(ctx, c1) })
s.SpawnBg(func() error {
// Send OPEN with no kind set - it should be rejected, instead of defaulting to 0.
for queue, ctrl := range badMux.queue.Lock() {
f := queue.Get(streamID(0))
f.Header.Kind = nil
ctrl.Updated()
}
return nil
})
mux := NewMux(makeConfig(kind))
err := mux.Run(ctx, c2)
t.Logf("mux terminated: %v", err)
if !errors.Is(err, errUnknownStream) {
return fmt.Errorf("err = %v, want %v", err, errUnknownStream)
}
return nil
})
if err != nil {
t.Fatal(err)
}
}

func TestProtocol_ConnectDoubleOpen(t *testing.T) {
err := scope.Run(t.Context(), func(ctx context.Context, s scope.Scope) error {
c1, c2 := conn.NewTestConn()
kind := StreamKind(0)
maxMsgSize := uint64(10)
window := uint64(3)

// Bad mux.
badMux := NewMux(makeConfig(kind))
s.SpawnBg(func() error { return badMux.Run(ctx, c1) })
s.SpawnBg(func() error {
t.Log("Connect stream.")
stream, err := badMux.Connect(ctx, kind, 0, 0)
if err != nil {
return fmt.Errorf("mux2.Connect(): %w", err)
}
defer stream.Close()
t.Log("Send another opening frame for the same stream.")
for queue, ctrl := range badMux.queue.Lock() {
f := queue.Get(stream.state.id)
f.Header.Kind = utils.Alloc(uint64(kind))
ctrl.Updated()
}
<-ctx.Done()
return nil
})

mux := NewMux(makeConfig(kind))
s.SpawnBg(func() error {
t.Log("Accept stream")
stream, err := mux.Accept(ctx, kind, maxMsgSize, window)
if err != nil {
return fmt.Errorf("mux.Accept(): %w", err)
}
defer stream.Close()
<-ctx.Done()
return nil
})
err := mux.Run(ctx, c2)
t.Logf("mux terminated: %v", err)
if !errors.Is(err, errAlreadyOpened) {
return fmt.Errorf("err = %v, want %v", err, errAlreadyOpened)
}
return nil
})
if err != nil {
t.Fatal(err)
}
}

func TestProtocol_AcceptDoubleOpen(t *testing.T) {
rng := utils.TestRng()
err := scope.Run(t.Context(), func(ctx context.Context, s scope.Scope) error {
c1, c2 := conn.NewTestConn()
kind := StreamKind(0)
maxMsgSize := uint64(10)
window := uint64(3)

// Bad mux.
badMux := NewMux(makeConfig(kind))
s.SpawnBg(func() error { return badMux.Run(ctx, c1) })
s.SpawnBg(func() error {
t.Log("Accept stream.")
stream, err := badMux.Accept(ctx, kind, 0, 0)
if err != nil {
return fmt.Errorf("mux2.Connect(): %w", err)
}
defer stream.Close()
t.Log("Send 2 messages to make sure that the OPEN frame is flushed")
for range 2 {
if err := stream.Send(ctx, utils.GenBytes(rng, int(maxMsgSize))); err != nil {
return utils.IgnoreCancel(err)
}
}
t.Log("Send another OPEN frame for the same stream.")
for queue, ctrl := range badMux.queue.Lock() {
f := queue.Get(stream.state.id)
f.Header.Kind = utils.Alloc(uint64(stream.state.kind))
ctrl.Updated()
}
<-ctx.Done()
return nil
})

mux := NewMux(makeConfig(kind))
s.SpawnBg(func() error {
t.Log("Connect stream")
stream, err := mux.Connect(ctx, kind, maxMsgSize, window)
if err != nil {
return fmt.Errorf("mux.Accept(): %w", err)
}
defer stream.Close()
<-ctx.Done()
return nil
})
err := mux.Run(ctx, c2)
t.Logf("mux terminated: %v", err)
if !errors.Is(err, errAlreadyOpened) {
return fmt.Errorf("err = %v, want %v", err, errAlreadyOpened)
}
return nil
})
if err != nil {
t.Fatal(err)
}
}
4 changes: 2 additions & 2 deletions sei-tendermint/internal/p2p/mux/stream_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func newStreamState(id streamID, kind StreamKind) *streamState {
func (s *streamState) RemoteOpen(maxMsgSize uint64) error {
for inner, ctrl := range s.inner.Lock() {
if inner.send.remoteOpened {
return fmt.Errorf("already opened")
return errAlreadyOpened
}
// Do not allow remote open before we connect.
if s.id.isConnect() && !inner.recv.opened {
Expand All @@ -77,7 +77,7 @@ func (s *streamState) RemoteOpen(maxMsgSize uint64) error {
func (s *streamState) RemoteClose() error {
for inner, ctrl := range s.inner.Lock() {
if inner.closed.remote {
return fmt.Errorf("already closed")
return errAlreadyClosed
}
inner.closed.remote = true
ctrl.Updated()
Expand Down
Loading