...

Source file src/golang.org/x/crypto/ssh/mux_test.go

Documentation: golang.org/x/crypto/ssh

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"sync"
    12  	"testing"
    13  )
    14  
    15  func muxPair() (*mux, *mux) {
    16  	a, b := memPipe()
    17  
    18  	s := newMux(a)
    19  	c := newMux(b)
    20  
    21  	return s, c
    22  }
    23  
    24  // Returns both ends of a channel, and the mux for the 2nd
    25  // channel.
    26  func channelPair(t *testing.T) (*channel, *channel, *mux) {
    27  	c, s := muxPair()
    28  
    29  	res := make(chan *channel, 1)
    30  	go func() {
    31  		newCh, ok := <-s.incomingChannels
    32  		if !ok {
    33  			t.Error("no incoming channel")
    34  			close(res)
    35  			return
    36  		}
    37  		if newCh.ChannelType() != "chan" {
    38  			t.Errorf("got type %q want chan", newCh.ChannelType())
    39  			newCh.Reject(Prohibited, fmt.Sprintf("got type %q want chan", newCh.ChannelType()))
    40  			close(res)
    41  			return
    42  		}
    43  		ch, _, err := newCh.Accept()
    44  		if err != nil {
    45  			t.Errorf("accept: %v", err)
    46  			close(res)
    47  			return
    48  		}
    49  		res <- ch.(*channel)
    50  	}()
    51  
    52  	ch, err := c.openChannel("chan", nil)
    53  	if err != nil {
    54  		t.Fatalf("OpenChannel: %v", err)
    55  	}
    56  	w := <-res
    57  	if w == nil {
    58  		t.Fatal("unable to get write channel")
    59  	}
    60  
    61  	return w, ch, c
    62  }
    63  
    64  // Test that stderr and stdout can be addressed from different
    65  // goroutines. This is intended for use with the race detector.
    66  func TestMuxChannelExtendedThreadSafety(t *testing.T) {
    67  	writer, reader, mux := channelPair(t)
    68  	defer writer.Close()
    69  	defer reader.Close()
    70  	defer mux.Close()
    71  
    72  	var wr, rd sync.WaitGroup
    73  	magic := "hello world"
    74  
    75  	wr.Add(2)
    76  	go func() {
    77  		io.WriteString(writer, magic)
    78  		wr.Done()
    79  	}()
    80  	go func() {
    81  		io.WriteString(writer.Stderr(), magic)
    82  		wr.Done()
    83  	}()
    84  
    85  	rd.Add(2)
    86  	go func() {
    87  		c, err := io.ReadAll(reader)
    88  		if string(c) != magic {
    89  			t.Errorf("stdout read got %q, want %q (error %s)", c, magic, err)
    90  		}
    91  		rd.Done()
    92  	}()
    93  	go func() {
    94  		c, err := io.ReadAll(reader.Stderr())
    95  		if string(c) != magic {
    96  			t.Errorf("stderr read got %q, want %q (error %s)", c, magic, err)
    97  		}
    98  		rd.Done()
    99  	}()
   100  
   101  	wr.Wait()
   102  	writer.CloseWrite()
   103  	rd.Wait()
   104  }
   105  
   106  func TestMuxReadWrite(t *testing.T) {
   107  	s, c, mux := channelPair(t)
   108  	defer s.Close()
   109  	defer c.Close()
   110  	defer mux.Close()
   111  
   112  	magic := "hello world"
   113  	magicExt := "hello stderr"
   114  	var wg sync.WaitGroup
   115  	t.Cleanup(wg.Wait)
   116  	wg.Add(1)
   117  	go func() {
   118  		defer wg.Done()
   119  		_, err := s.Write([]byte(magic))
   120  		if err != nil {
   121  			t.Errorf("Write: %v", err)
   122  			return
   123  		}
   124  		_, err = s.Extended(1).Write([]byte(magicExt))
   125  		if err != nil {
   126  			t.Errorf("Write: %v", err)
   127  			return
   128  		}
   129  	}()
   130  
   131  	var buf [1024]byte
   132  	n, err := c.Read(buf[:])
   133  	if err != nil {
   134  		t.Fatalf("server Read: %v", err)
   135  	}
   136  	got := string(buf[:n])
   137  	if got != magic {
   138  		t.Fatalf("server: got %q want %q", got, magic)
   139  	}
   140  
   141  	n, err = c.Extended(1).Read(buf[:])
   142  	if err != nil {
   143  		t.Fatalf("server Read: %v", err)
   144  	}
   145  
   146  	got = string(buf[:n])
   147  	if got != magicExt {
   148  		t.Fatalf("server: got %q want %q", got, magic)
   149  	}
   150  }
   151  
   152  func TestMuxChannelOverflow(t *testing.T) {
   153  	reader, writer, mux := channelPair(t)
   154  	defer reader.Close()
   155  	defer writer.Close()
   156  	defer mux.Close()
   157  
   158  	var wg sync.WaitGroup
   159  	t.Cleanup(wg.Wait)
   160  	wg.Add(1)
   161  	go func() {
   162  		defer wg.Done()
   163  		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
   164  			t.Errorf("could not fill window: %v", err)
   165  		}
   166  		writer.Write(make([]byte, 1))
   167  	}()
   168  	writer.remoteWin.waitWriterBlocked()
   169  
   170  	// Send 1 byte.
   171  	packet := make([]byte, 1+4+4+1)
   172  	packet[0] = msgChannelData
   173  	marshalUint32(packet[1:], writer.remoteId)
   174  	marshalUint32(packet[5:], uint32(1))
   175  	packet[9] = 42
   176  
   177  	if err := writer.mux.conn.writePacket(packet); err != nil {
   178  		t.Errorf("could not send packet")
   179  	}
   180  	if _, err := reader.SendRequest("hello", true, nil); err == nil {
   181  		t.Errorf("SendRequest succeeded.")
   182  	}
   183  }
   184  
   185  func TestMuxChannelReadUnblock(t *testing.T) {
   186  	reader, writer, mux := channelPair(t)
   187  	defer reader.Close()
   188  	defer writer.Close()
   189  	defer mux.Close()
   190  
   191  	var wg sync.WaitGroup
   192  	t.Cleanup(wg.Wait)
   193  	wg.Add(1)
   194  	go func() {
   195  		defer wg.Done()
   196  		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
   197  			t.Errorf("could not fill window: %v", err)
   198  		}
   199  		if _, err := writer.Write(make([]byte, 1)); err != nil {
   200  			t.Errorf("Write: %v", err)
   201  		}
   202  		writer.Close()
   203  	}()
   204  
   205  	writer.remoteWin.waitWriterBlocked()
   206  
   207  	buf := make([]byte, 32768)
   208  	for {
   209  		_, err := reader.Read(buf)
   210  		if err == io.EOF {
   211  			break
   212  		}
   213  		if err != nil {
   214  			t.Fatalf("Read: %v", err)
   215  		}
   216  	}
   217  }
   218  
   219  func TestMuxChannelCloseWriteUnblock(t *testing.T) {
   220  	reader, writer, mux := channelPair(t)
   221  	defer reader.Close()
   222  	defer writer.Close()
   223  	defer mux.Close()
   224  
   225  	var wg sync.WaitGroup
   226  	t.Cleanup(wg.Wait)
   227  	wg.Add(1)
   228  	go func() {
   229  		defer wg.Done()
   230  		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
   231  			t.Errorf("could not fill window: %v", err)
   232  		}
   233  		if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
   234  			t.Errorf("got %v, want EOF for unblock write", err)
   235  		}
   236  	}()
   237  
   238  	writer.remoteWin.waitWriterBlocked()
   239  	reader.Close()
   240  }
   241  
   242  func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
   243  	reader, writer, mux := channelPair(t)
   244  	defer reader.Close()
   245  	defer writer.Close()
   246  	defer mux.Close()
   247  
   248  	var wg sync.WaitGroup
   249  	t.Cleanup(wg.Wait)
   250  	wg.Add(1)
   251  	go func() {
   252  		defer wg.Done()
   253  		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
   254  			t.Errorf("could not fill window: %v", err)
   255  		}
   256  		if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
   257  			t.Errorf("got %v, want EOF for unblock write", err)
   258  		}
   259  	}()
   260  
   261  	writer.remoteWin.waitWriterBlocked()
   262  	mux.Close()
   263  }
   264  
   265  func TestMuxReject(t *testing.T) {
   266  	client, server := muxPair()
   267  	defer server.Close()
   268  	defer client.Close()
   269  
   270  	var wg sync.WaitGroup
   271  	t.Cleanup(wg.Wait)
   272  	wg.Add(1)
   273  	go func() {
   274  		defer wg.Done()
   275  
   276  		ch, ok := <-server.incomingChannels
   277  		if !ok {
   278  			t.Error("cannot accept channel")
   279  			return
   280  		}
   281  		if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
   282  			t.Errorf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
   283  			ch.Reject(RejectionReason(UnknownChannelType), UnknownChannelType.String())
   284  			return
   285  		}
   286  		ch.Reject(RejectionReason(42), "message")
   287  	}()
   288  
   289  	ch, err := client.openChannel("ch", []byte("extra"))
   290  	if ch != nil {
   291  		t.Fatal("openChannel not rejected")
   292  	}
   293  
   294  	ocf, ok := err.(*OpenChannelError)
   295  	if !ok {
   296  		t.Errorf("got %#v want *OpenChannelError", err)
   297  	} else if ocf.Reason != 42 || ocf.Message != "message" {
   298  		t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
   299  	}
   300  
   301  	want := "ssh: rejected: unknown reason 42 (message)"
   302  	if err.Error() != want {
   303  		t.Errorf("got %q, want %q", err.Error(), want)
   304  	}
   305  }
   306  
   307  func TestMuxChannelRequest(t *testing.T) {
   308  	client, server, mux := channelPair(t)
   309  	defer server.Close()
   310  	defer client.Close()
   311  	defer mux.Close()
   312  
   313  	var received int
   314  	var wg sync.WaitGroup
   315  	t.Cleanup(wg.Wait)
   316  	wg.Add(1)
   317  	go func() {
   318  		for r := range server.incomingRequests {
   319  			received++
   320  			r.Reply(r.Type == "yes", nil)
   321  		}
   322  		wg.Done()
   323  	}()
   324  	_, err := client.SendRequest("yes", false, nil)
   325  	if err != nil {
   326  		t.Fatalf("SendRequest: %v", err)
   327  	}
   328  	ok, err := client.SendRequest("yes", true, nil)
   329  	if err != nil {
   330  		t.Fatalf("SendRequest: %v", err)
   331  	}
   332  
   333  	if !ok {
   334  		t.Errorf("SendRequest(yes): %v", ok)
   335  
   336  	}
   337  
   338  	ok, err = client.SendRequest("no", true, nil)
   339  	if err != nil {
   340  		t.Fatalf("SendRequest: %v", err)
   341  	}
   342  	if ok {
   343  		t.Errorf("SendRequest(no): %v", ok)
   344  	}
   345  
   346  	client.Close()
   347  	wg.Wait()
   348  
   349  	if received != 3 {
   350  		t.Errorf("got %d requests, want %d", received, 3)
   351  	}
   352  }
   353  
   354  func TestMuxUnknownChannelRequests(t *testing.T) {
   355  	clientPipe, serverPipe := memPipe()
   356  	client := newMux(clientPipe)
   357  	defer serverPipe.Close()
   358  	defer client.Close()
   359  
   360  	kDone := make(chan error, 1)
   361  	go func() {
   362  		// Ignore unknown channel messages that don't want a reply.
   363  		err := serverPipe.writePacket(Marshal(channelRequestMsg{
   364  			PeersID:             1,
   365  			Request:             "keepalive@openssh.com",
   366  			WantReply:           false,
   367  			RequestSpecificData: []byte{},
   368  		}))
   369  		if err != nil {
   370  			kDone <- fmt.Errorf("send: %w", err)
   371  			return
   372  		}
   373  
   374  		// Send a keepalive, which should get a channel failure message
   375  		// in response.
   376  		err = serverPipe.writePacket(Marshal(channelRequestMsg{
   377  			PeersID:             2,
   378  			Request:             "keepalive@openssh.com",
   379  			WantReply:           true,
   380  			RequestSpecificData: []byte{},
   381  		}))
   382  		if err != nil {
   383  			kDone <- fmt.Errorf("send: %w", err)
   384  			return
   385  		}
   386  
   387  		packet, err := serverPipe.readPacket()
   388  		if err != nil {
   389  			kDone <- fmt.Errorf("read packet: %w", err)
   390  			return
   391  		}
   392  		decoded, err := decode(packet)
   393  		if err != nil {
   394  			kDone <- fmt.Errorf("decode failed: %w", err)
   395  			return
   396  		}
   397  
   398  		switch msg := decoded.(type) {
   399  		case *channelRequestFailureMsg:
   400  			if msg.PeersID != 2 {
   401  				kDone <- fmt.Errorf("received response to wrong message: %v", msg)
   402  				return
   403  
   404  			}
   405  		default:
   406  			kDone <- fmt.Errorf("unexpected channel message: %v", msg)
   407  			return
   408  		}
   409  
   410  		kDone <- nil
   411  
   412  		// Receive and respond to the keepalive to confirm the mux is
   413  		// still processing requests.
   414  		packet, err = serverPipe.readPacket()
   415  		if err != nil {
   416  			kDone <- fmt.Errorf("read packet: %w", err)
   417  			return
   418  		}
   419  		if packet[0] != msgGlobalRequest {
   420  			kDone <- errors.New("expected global request")
   421  			return
   422  		}
   423  
   424  		err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{
   425  			Data: []byte{},
   426  		}))
   427  		if err != nil {
   428  			kDone <- fmt.Errorf("failed to send failure msg: %w", err)
   429  			return
   430  		}
   431  
   432  		close(kDone)
   433  	}()
   434  
   435  	// Wait for the server to send the keepalive message and receive back a
   436  	// response.
   437  	if err := <-kDone; err != nil {
   438  		t.Fatal(err)
   439  	}
   440  
   441  	// Confirm client hasn't closed.
   442  	if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil {
   443  		t.Fatalf("failed to send keepalive: %v", err)
   444  	}
   445  
   446  	// Wait for the server to shut down.
   447  	if err := <-kDone; err != nil {
   448  		t.Fatal(err)
   449  	}
   450  }
   451  
   452  func TestMuxClosedChannel(t *testing.T) {
   453  	clientPipe, serverPipe := memPipe()
   454  	client := newMux(clientPipe)
   455  	defer serverPipe.Close()
   456  	defer client.Close()
   457  
   458  	kDone := make(chan error, 1)
   459  	go func() {
   460  		// Open the channel.
   461  		packet, err := serverPipe.readPacket()
   462  		if err != nil {
   463  			kDone <- fmt.Errorf("read packet: %w", err)
   464  			return
   465  		}
   466  		if packet[0] != msgChannelOpen {
   467  			kDone <- errors.New("expected chan open")
   468  			return
   469  		}
   470  
   471  		var openMsg channelOpenMsg
   472  		if err := Unmarshal(packet, &openMsg); err != nil {
   473  			kDone <- fmt.Errorf("unmarshal: %w", err)
   474  			return
   475  		}
   476  
   477  		// Send back the opened channel confirmation.
   478  		err = serverPipe.writePacket(Marshal(channelOpenConfirmMsg{
   479  			PeersID:       openMsg.PeersID,
   480  			MyID:          0,
   481  			MyWindow:      0,
   482  			MaxPacketSize: channelMaxPacket,
   483  		}))
   484  		if err != nil {
   485  			kDone <- fmt.Errorf("send: %w", err)
   486  			return
   487  		}
   488  
   489  		// Close the channel.
   490  		err = serverPipe.writePacket(Marshal(channelCloseMsg{
   491  			PeersID: openMsg.PeersID,
   492  		}))
   493  		if err != nil {
   494  			kDone <- fmt.Errorf("send: %w", err)
   495  			return
   496  		}
   497  
   498  		// Send a keepalive message on the channel we just closed.
   499  		err = serverPipe.writePacket(Marshal(channelRequestMsg{
   500  			PeersID:             openMsg.PeersID,
   501  			Request:             "keepalive@openssh.com",
   502  			WantReply:           true,
   503  			RequestSpecificData: []byte{},
   504  		}))
   505  		if err != nil {
   506  			kDone <- fmt.Errorf("send: %w", err)
   507  			return
   508  		}
   509  
   510  		// Receive the channel closed response.
   511  		packet, err = serverPipe.readPacket()
   512  		if err != nil {
   513  			kDone <- fmt.Errorf("read packet: %w", err)
   514  			return
   515  		}
   516  		if packet[0] != msgChannelClose {
   517  			kDone <- errors.New("expected channel close")
   518  			return
   519  		}
   520  
   521  		// Receive the keepalive response failure.
   522  		packet, err = serverPipe.readPacket()
   523  		if err != nil {
   524  			kDone <- fmt.Errorf("read packet: %w", err)
   525  			return
   526  		}
   527  		if packet[0] != msgChannelFailure {
   528  			kDone <- errors.New("expected channel failure")
   529  			return
   530  		}
   531  		kDone <- nil
   532  
   533  		// Receive and respond to the keepalive to confirm the mux is
   534  		// still processing requests.
   535  		packet, err = serverPipe.readPacket()
   536  		if err != nil {
   537  			kDone <- fmt.Errorf("read packet: %w", err)
   538  			return
   539  		}
   540  		if packet[0] != msgGlobalRequest {
   541  			kDone <- errors.New("expected global request")
   542  			return
   543  		}
   544  
   545  		err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{
   546  			Data: []byte{},
   547  		}))
   548  		if err != nil {
   549  			kDone <- fmt.Errorf("failed to send failure msg: %w", err)
   550  			return
   551  		}
   552  
   553  		close(kDone)
   554  	}()
   555  
   556  	// Open a channel.
   557  	ch, err := client.openChannel("chan", nil)
   558  	if err != nil {
   559  		t.Fatalf("OpenChannel: %v", err)
   560  	}
   561  	defer ch.Close()
   562  
   563  	// Wait for the server to close the channel and send the keepalive.
   564  	<-kDone
   565  
   566  	// Make sure the channel closed.
   567  	if _, ok := <-ch.incomingRequests; ok {
   568  		t.Fatalf("channel not closed")
   569  	}
   570  
   571  	// Confirm client hasn't closed
   572  	if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil {
   573  		t.Fatalf("failed to send keepalive: %v", err)
   574  	}
   575  
   576  	// Wait for the server to shut down.
   577  	<-kDone
   578  }
   579  
   580  func TestMuxGlobalRequest(t *testing.T) {
   581  	var sawPeek bool
   582  	var wg sync.WaitGroup
   583  	defer func() {
   584  		wg.Wait()
   585  		if !sawPeek {
   586  			t.Errorf("never saw 'peek' request")
   587  		}
   588  	}()
   589  
   590  	clientMux, serverMux := muxPair()
   591  	defer serverMux.Close()
   592  	defer clientMux.Close()
   593  
   594  	wg.Add(1)
   595  	go func() {
   596  		defer wg.Done()
   597  		for r := range serverMux.incomingRequests {
   598  			sawPeek = sawPeek || r.Type == "peek"
   599  			if r.WantReply {
   600  				err := r.Reply(r.Type == "yes",
   601  					append([]byte(r.Type), r.Payload...))
   602  				if err != nil {
   603  					t.Errorf("AckRequest: %v", err)
   604  				}
   605  			}
   606  		}
   607  	}()
   608  
   609  	_, _, err := clientMux.SendRequest("peek", false, nil)
   610  	if err != nil {
   611  		t.Errorf("SendRequest: %v", err)
   612  	}
   613  
   614  	ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
   615  	if !ok || string(data) != "yesa" || err != nil {
   616  		t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
   617  			ok, data, err)
   618  	}
   619  	if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
   620  		t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
   621  			ok, data, err)
   622  	}
   623  
   624  	if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
   625  		t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
   626  			ok, data, err)
   627  	}
   628  }
   629  
   630  func TestMuxGlobalRequestUnblock(t *testing.T) {
   631  	clientMux, serverMux := muxPair()
   632  	defer serverMux.Close()
   633  	defer clientMux.Close()
   634  
   635  	result := make(chan error, 1)
   636  	go func() {
   637  		_, _, err := clientMux.SendRequest("hello", true, nil)
   638  		result <- err
   639  	}()
   640  
   641  	<-serverMux.incomingRequests
   642  	serverMux.conn.Close()
   643  	err := <-result
   644  
   645  	if err != io.EOF {
   646  		t.Errorf("want EOF, got %v", io.EOF)
   647  	}
   648  }
   649  
   650  func TestMuxChannelRequestUnblock(t *testing.T) {
   651  	a, b, connB := channelPair(t)
   652  	defer a.Close()
   653  	defer b.Close()
   654  	defer connB.Close()
   655  
   656  	result := make(chan error, 1)
   657  	go func() {
   658  		_, err := a.SendRequest("hello", true, nil)
   659  		result <- err
   660  	}()
   661  
   662  	<-b.incomingRequests
   663  	connB.conn.Close()
   664  	err := <-result
   665  
   666  	if err != io.EOF {
   667  		t.Errorf("want EOF, got %v", err)
   668  	}
   669  }
   670  
   671  func TestMuxCloseChannel(t *testing.T) {
   672  	r, w, mux := channelPair(t)
   673  	defer mux.Close()
   674  	defer r.Close()
   675  	defer w.Close()
   676  
   677  	result := make(chan error, 1)
   678  	go func() {
   679  		var b [1024]byte
   680  		_, err := r.Read(b[:])
   681  		result <- err
   682  	}()
   683  	if err := w.Close(); err != nil {
   684  		t.Errorf("w.Close: %v", err)
   685  	}
   686  
   687  	if _, err := w.Write([]byte("hello")); err != io.EOF {
   688  		t.Errorf("got err %v, want io.EOF after Close", err)
   689  	}
   690  
   691  	if err := <-result; err != io.EOF {
   692  		t.Errorf("got %v (%T), want io.EOF", err, err)
   693  	}
   694  }
   695  
   696  func TestMuxCloseWriteChannel(t *testing.T) {
   697  	r, w, mux := channelPair(t)
   698  	defer mux.Close()
   699  
   700  	result := make(chan error, 1)
   701  	go func() {
   702  		var b [1024]byte
   703  		_, err := r.Read(b[:])
   704  		result <- err
   705  	}()
   706  	if err := w.CloseWrite(); err != nil {
   707  		t.Errorf("w.CloseWrite: %v", err)
   708  	}
   709  
   710  	if _, err := w.Write([]byte("hello")); err != io.EOF {
   711  		t.Errorf("got err %v, want io.EOF after CloseWrite", err)
   712  	}
   713  
   714  	if err := <-result; err != io.EOF {
   715  		t.Errorf("got %v (%T), want io.EOF", err, err)
   716  	}
   717  }
   718  
   719  func TestMuxInvalidRecord(t *testing.T) {
   720  	a, b := muxPair()
   721  	defer a.Close()
   722  	defer b.Close()
   723  
   724  	packet := make([]byte, 1+4+4+1)
   725  	packet[0] = msgChannelData
   726  	marshalUint32(packet[1:], 29348723 /* invalid channel id */)
   727  	marshalUint32(packet[5:], 1)
   728  	packet[9] = 42
   729  
   730  	a.conn.writePacket(packet)
   731  	go a.SendRequest("hello", false, nil)
   732  	// 'a' wrote an invalid packet, so 'b' has exited.
   733  	req, ok := <-b.incomingRequests
   734  	if ok {
   735  		t.Errorf("got request %#v after receiving invalid packet", req)
   736  	}
   737  }
   738  
   739  func TestZeroWindowAdjust(t *testing.T) {
   740  	a, b, mux := channelPair(t)
   741  	defer a.Close()
   742  	defer b.Close()
   743  	defer mux.Close()
   744  
   745  	go func() {
   746  		io.WriteString(a, "hello")
   747  		// bogus adjust.
   748  		a.sendMessage(windowAdjustMsg{})
   749  		io.WriteString(a, "world")
   750  		a.Close()
   751  	}()
   752  
   753  	want := "helloworld"
   754  	c, _ := io.ReadAll(b)
   755  	if string(c) != want {
   756  		t.Errorf("got %q want %q", c, want)
   757  	}
   758  }
   759  
   760  func TestMuxMaxPacketSize(t *testing.T) {
   761  	a, b, mux := channelPair(t)
   762  	defer a.Close()
   763  	defer b.Close()
   764  	defer mux.Close()
   765  
   766  	large := make([]byte, a.maxRemotePayload+1)
   767  	packet := make([]byte, 1+4+4+1+len(large))
   768  	packet[0] = msgChannelData
   769  	marshalUint32(packet[1:], a.remoteId)
   770  	marshalUint32(packet[5:], uint32(len(large)))
   771  	packet[9] = 42
   772  
   773  	if err := a.mux.conn.writePacket(packet); err != nil {
   774  		t.Errorf("could not send packet")
   775  	}
   776  
   777  	var wg sync.WaitGroup
   778  	t.Cleanup(wg.Wait)
   779  	wg.Add(1)
   780  	go func() {
   781  		a.SendRequest("hello", false, nil)
   782  		wg.Done()
   783  	}()
   784  
   785  	_, ok := <-b.incomingRequests
   786  	if ok {
   787  		t.Errorf("connection still alive after receiving large packet.")
   788  	}
   789  }
   790  
   791  func TestMuxChannelWindowDeferredUpdates(t *testing.T) {
   792  	s, c, mux := channelPair(t)
   793  	cTransport := mux.conn.(*memTransport)
   794  	defer s.Close()
   795  	defer c.Close()
   796  	defer mux.Close()
   797  
   798  	var wg sync.WaitGroup
   799  	t.Cleanup(wg.Wait)
   800  
   801  	data := make([]byte, 1024)
   802  
   803  	wg.Add(1)
   804  	go func() {
   805  		defer wg.Done()
   806  		_, err := s.Write(data)
   807  		if err != nil {
   808  			t.Errorf("Write: %v", err)
   809  			return
   810  		}
   811  	}()
   812  	cWritesInit := cTransport.getWriteCount()
   813  	buf := make([]byte, 1)
   814  	for i := 0; i < len(data); i++ {
   815  		n, err := c.Read(buf)
   816  		if n != len(buf) || err != nil {
   817  			t.Fatalf("Read: %v, %v", n, err)
   818  		}
   819  	}
   820  	cWrites := cTransport.getWriteCount() - cWritesInit
   821  	// reading 1 KiB should not cause any window updates to be sent, but allow
   822  	// for some unexpected writes
   823  	if cWrites > 30 {
   824  		t.Fatalf("reading 1 KiB from channel caused %v writes", cWrites)
   825  	}
   826  }
   827  
   828  // Don't ship code with debug=true.
   829  func TestDebug(t *testing.T) {
   830  	if debugMux {
   831  		t.Error("mux debug switched on")
   832  	}
   833  	if debugHandshake {
   834  		t.Error("handshake debug switched on")
   835  	}
   836  	if debugTransport {
   837  		t.Error("transport debug switched on")
   838  	}
   839  }
   840  

View as plain text