...

Source file src/golang.org/x/crypto/ssh/mempipe_test.go

Documentation: golang.org/x/crypto/ssh

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"io"
     9  	"sync"
    10  	"testing"
    11  )
    12  
    13  // An in-memory packetConn. It is safe to call Close and writePacket
    14  // from different goroutines.
    15  type memTransport struct {
    16  	eof        bool
    17  	pending    [][]byte
    18  	write      *memTransport
    19  	writeCount uint64
    20  	sync.Mutex
    21  	*sync.Cond
    22  }
    23  
    24  func (t *memTransport) readPacket() ([]byte, error) {
    25  	t.Lock()
    26  	defer t.Unlock()
    27  	for {
    28  		if len(t.pending) > 0 {
    29  			r := t.pending[0]
    30  			t.pending = t.pending[1:]
    31  			return r, nil
    32  		}
    33  		if t.eof {
    34  			return nil, io.EOF
    35  		}
    36  		t.Cond.Wait()
    37  	}
    38  }
    39  
    40  func (t *memTransport) closeSelf() error {
    41  	t.Lock()
    42  	defer t.Unlock()
    43  	if t.eof {
    44  		return io.EOF
    45  	}
    46  	t.eof = true
    47  	t.Cond.Broadcast()
    48  	return nil
    49  }
    50  
    51  func (t *memTransport) Close() error {
    52  	err := t.write.closeSelf()
    53  	t.closeSelf()
    54  	return err
    55  }
    56  
    57  func (t *memTransport) writePacket(p []byte) error {
    58  	t.write.Lock()
    59  	defer t.write.Unlock()
    60  	if t.write.eof {
    61  		return io.EOF
    62  	}
    63  	c := make([]byte, len(p))
    64  	copy(c, p)
    65  	t.write.pending = append(t.write.pending, c)
    66  	t.write.Cond.Signal()
    67  	t.writeCount++
    68  	return nil
    69  }
    70  
    71  func (t *memTransport) getWriteCount() uint64 {
    72  	t.write.Lock()
    73  	defer t.write.Unlock()
    74  	return t.writeCount
    75  }
    76  
    77  func memPipe() (a, b packetConn) {
    78  	t1 := memTransport{}
    79  	t2 := memTransport{}
    80  	t1.write = &t2
    81  	t2.write = &t1
    82  	t1.Cond = sync.NewCond(&t1.Mutex)
    83  	t2.Cond = sync.NewCond(&t2.Mutex)
    84  	return &t1, &t2
    85  }
    86  
    87  func TestMemPipe(t *testing.T) {
    88  	a, b := memPipe()
    89  	if err := a.writePacket([]byte{42}); err != nil {
    90  		t.Fatalf("writePacket: %v", err)
    91  	}
    92  	if wc := a.(*memTransport).getWriteCount(); wc != 1 {
    93  		t.Fatalf("got %v, want 1", wc)
    94  	}
    95  	if err := a.Close(); err != nil {
    96  		t.Fatal("Close: ", err)
    97  	}
    98  	p, err := b.readPacket()
    99  	if err != nil {
   100  		t.Fatal("readPacket: ", err)
   101  	}
   102  	if len(p) != 1 || p[0] != 42 {
   103  		t.Fatalf("got %v, want {42}", p)
   104  	}
   105  	p, err = b.readPacket()
   106  	if err != io.EOF {
   107  		t.Fatalf("got %v, %v, want EOF", p, err)
   108  	}
   109  	if wc := b.(*memTransport).getWriteCount(); wc != 0 {
   110  		t.Fatalf("got %v, want 0", wc)
   111  	}
   112  }
   113  
   114  func TestDoubleClose(t *testing.T) {
   115  	a, _ := memPipe()
   116  	err := a.Close()
   117  	if err != nil {
   118  		t.Errorf("Close: %v", err)
   119  	}
   120  	err = a.Close()
   121  	if err != io.EOF {
   122  		t.Errorf("expect EOF on double close.")
   123  	}
   124  }
   125  

View as plain text