...

Source file src/net/net_fake.go

Documentation: net

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Fake networking for js/wasm and wasip1/wasm.
     6  // It is intended to allow tests of other package to pass.
     7  
     8  //go:build js || wasip1
     9  
    10  package net
    11  
    12  import (
    13  	"context"
    14  	"errors"
    15  	"io"
    16  	"os"
    17  	"runtime"
    18  	"sync"
    19  	"sync/atomic"
    20  	"syscall"
    21  	"time"
    22  )
    23  
    24  var (
    25  	sockets         sync.Map // fakeSockAddr → *netFD
    26  	fakeSocketIDs   sync.Map // fakeNetFD.id → *netFD
    27  	fakePorts       sync.Map // int (port #) → *netFD
    28  	nextPortCounter atomic.Int32
    29  )
    30  
    31  const defaultBuffer = 65535
    32  
    33  type fakeSockAddr struct {
    34  	family  int
    35  	address string
    36  }
    37  
    38  func fakeAddr(sa sockaddr) fakeSockAddr {
    39  	return fakeSockAddr{
    40  		family:  sa.family(),
    41  		address: sa.String(),
    42  	}
    43  }
    44  
    45  // socket returns a network file descriptor that is ready for
    46  // I/O using the fake network.
    47  func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) {
    48  	if raddr != nil && ctrlCtxFn != nil {
    49  		return nil, os.NewSyscallError("socket", syscall.ENOTSUP)
    50  	}
    51  	switch sotype {
    52  	case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET, syscall.SOCK_DGRAM:
    53  	default:
    54  		return nil, os.NewSyscallError("socket", syscall.ENOTSUP)
    55  	}
    56  
    57  	fd := &netFD{
    58  		family: family,
    59  		sotype: sotype,
    60  		net:    net,
    61  	}
    62  	fd.fakeNetFD = newFakeNetFD(fd)
    63  
    64  	if raddr == nil {
    65  		if err := fakeListen(fd, laddr); err != nil {
    66  			fd.Close()
    67  			return nil, err
    68  		}
    69  		return fd, nil
    70  	}
    71  
    72  	if err := fakeConnect(ctx, fd, laddr, raddr); err != nil {
    73  		fd.Close()
    74  		return nil, err
    75  	}
    76  	return fd, nil
    77  }
    78  
    79  func validateResolvedAddr(net string, family int, sa sockaddr) error {
    80  	validateIP := func(ip IP) error {
    81  		switch family {
    82  		case syscall.AF_INET:
    83  			if len(ip) != 4 {
    84  				return &AddrError{
    85  					Err:  "non-IPv4 address",
    86  					Addr: ip.String(),
    87  				}
    88  			}
    89  		case syscall.AF_INET6:
    90  			if len(ip) != 16 {
    91  				return &AddrError{
    92  					Err:  "non-IPv6 address",
    93  					Addr: ip.String(),
    94  				}
    95  			}
    96  		default:
    97  			panic("net: unexpected address family in validateResolvedAddr")
    98  		}
    99  		return nil
   100  	}
   101  
   102  	switch net {
   103  	case "tcp", "tcp4", "tcp6":
   104  		sa, ok := sa.(*TCPAddr)
   105  		if !ok {
   106  			return &AddrError{
   107  				Err:  "non-TCP address for " + net + " network",
   108  				Addr: sa.String(),
   109  			}
   110  		}
   111  		if err := validateIP(sa.IP); err != nil {
   112  			return err
   113  		}
   114  		if sa.Port <= 0 || sa.Port >= 1<<16 {
   115  			return &AddrError{
   116  				Err:  "port out of range",
   117  				Addr: sa.String(),
   118  			}
   119  		}
   120  		return nil
   121  
   122  	case "udp", "udp4", "udp6":
   123  		sa, ok := sa.(*UDPAddr)
   124  		if !ok {
   125  			return &AddrError{
   126  				Err:  "non-UDP address for " + net + " network",
   127  				Addr: sa.String(),
   128  			}
   129  		}
   130  		if err := validateIP(sa.IP); err != nil {
   131  			return err
   132  		}
   133  		if sa.Port <= 0 || sa.Port >= 1<<16 {
   134  			return &AddrError{
   135  				Err:  "port out of range",
   136  				Addr: sa.String(),
   137  			}
   138  		}
   139  		return nil
   140  
   141  	case "unix", "unixgram", "unixpacket":
   142  		sa, ok := sa.(*UnixAddr)
   143  		if !ok {
   144  			return &AddrError{
   145  				Err:  "non-Unix address for " + net + " network",
   146  				Addr: sa.String(),
   147  			}
   148  		}
   149  		if sa.Name != "" {
   150  			i := len(sa.Name) - 1
   151  			for i > 0 && !os.IsPathSeparator(sa.Name[i]) {
   152  				i--
   153  			}
   154  			for i > 0 && os.IsPathSeparator(sa.Name[i]) {
   155  				i--
   156  			}
   157  			if i <= 0 {
   158  				return &AddrError{
   159  					Err:  "unix socket name missing path component",
   160  					Addr: sa.Name,
   161  				}
   162  			}
   163  			if _, err := os.Stat(sa.Name[:i+1]); err != nil {
   164  				return &AddrError{
   165  					Err:  err.Error(),
   166  					Addr: sa.Name,
   167  				}
   168  			}
   169  		}
   170  		return nil
   171  
   172  	default:
   173  		return &AddrError{
   174  			Err:  syscall.EAFNOSUPPORT.Error(),
   175  			Addr: sa.String(),
   176  		}
   177  	}
   178  }
   179  
   180  func matchIPFamily(family int, addr sockaddr) sockaddr {
   181  	convertIP := func(ip IP) IP {
   182  		switch family {
   183  		case syscall.AF_INET:
   184  			return ip.To4()
   185  		case syscall.AF_INET6:
   186  			return ip.To16()
   187  		default:
   188  			return ip
   189  		}
   190  	}
   191  
   192  	switch addr := addr.(type) {
   193  	case *TCPAddr:
   194  		ip := convertIP(addr.IP)
   195  		if ip == nil || len(ip) == len(addr.IP) {
   196  			return addr
   197  		}
   198  		return &TCPAddr{IP: ip, Port: addr.Port, Zone: addr.Zone}
   199  	case *UDPAddr:
   200  		ip := convertIP(addr.IP)
   201  		if ip == nil || len(ip) == len(addr.IP) {
   202  			return addr
   203  		}
   204  		return &UDPAddr{IP: ip, Port: addr.Port, Zone: addr.Zone}
   205  	default:
   206  		return addr
   207  	}
   208  }
   209  
   210  type fakeNetFD struct {
   211  	fd           *netFD
   212  	assignedPort int // 0 if no port has been assigned for this socket
   213  
   214  	queue         *packetQueue // incoming packets
   215  	peer          *netFD       // connected peer (for outgoing packets); nil for listeners and PacketConns
   216  	readDeadline  atomic.Pointer[deadlineTimer]
   217  	writeDeadline atomic.Pointer[deadlineTimer]
   218  
   219  	fakeAddr fakeSockAddr // cached fakeSockAddr equivalent of fd.laddr
   220  
   221  	// The incoming channels hold incoming connections that have not yet been accepted.
   222  	// All of these channels are 1-buffered.
   223  	incoming      chan []*netFD // holds the queue when it has >0 but <SOMAXCONN pending connections; closed when the Listener is closed
   224  	incomingFull  chan []*netFD // holds the queue when it has SOMAXCONN pending connections
   225  	incomingEmpty chan bool     // holds true when the incoming queue is empty
   226  }
   227  
   228  func newFakeNetFD(fd *netFD) *fakeNetFD {
   229  	ffd := &fakeNetFD{fd: fd}
   230  	ffd.readDeadline.Store(newDeadlineTimer(noDeadline))
   231  	ffd.writeDeadline.Store(newDeadlineTimer(noDeadline))
   232  	return ffd
   233  }
   234  
   235  func (ffd *fakeNetFD) Read(p []byte) (n int, err error) {
   236  	n, _, err = ffd.queue.recvfrom(ffd.readDeadline.Load(), p, false, nil)
   237  	return n, err
   238  }
   239  
   240  func (ffd *fakeNetFD) Write(p []byte) (nn int, err error) {
   241  	peer := ffd.peer
   242  	if peer == nil {
   243  		if ffd.fd.raddr == nil {
   244  			return 0, os.NewSyscallError("write", syscall.ENOTCONN)
   245  		}
   246  		peeri, _ := sockets.Load(fakeAddr(ffd.fd.raddr.(sockaddr)))
   247  		if peeri == nil {
   248  			return 0, os.NewSyscallError("write", syscall.ECONNRESET)
   249  		}
   250  		peer = peeri.(*netFD)
   251  		if peer.queue == nil {
   252  			return 0, os.NewSyscallError("write", syscall.ECONNRESET)
   253  		}
   254  	}
   255  
   256  	if peer.fakeNetFD == nil {
   257  		return 0, os.NewSyscallError("write", syscall.EINVAL)
   258  	}
   259  	return peer.queue.write(ffd.writeDeadline.Load(), p, ffd.fd.laddr.(sockaddr))
   260  }
   261  
   262  func (ffd *fakeNetFD) Close() (err error) {
   263  	if ffd.fakeAddr != (fakeSockAddr{}) {
   264  		sockets.CompareAndDelete(ffd.fakeAddr, ffd.fd)
   265  	}
   266  
   267  	if ffd.queue != nil {
   268  		if closeErr := ffd.queue.closeRead(); err == nil {
   269  			err = closeErr
   270  		}
   271  	}
   272  	if ffd.peer != nil {
   273  		if closeErr := ffd.peer.queue.closeWrite(); err == nil {
   274  			err = closeErr
   275  		}
   276  	}
   277  	ffd.readDeadline.Load().Reset(noDeadline)
   278  	ffd.writeDeadline.Load().Reset(noDeadline)
   279  
   280  	if ffd.incoming != nil {
   281  		var (
   282  			incoming []*netFD
   283  			ok       bool
   284  		)
   285  		select {
   286  		case _, ok = <-ffd.incomingEmpty:
   287  		case incoming, ok = <-ffd.incoming:
   288  		case incoming, ok = <-ffd.incomingFull:
   289  		}
   290  		if ok {
   291  			// Sends on ffd.incoming require a receive first.
   292  			// Since we successfully received, no other goroutine may
   293  			// send on it at this point, and we may safely close it.
   294  			close(ffd.incoming)
   295  
   296  			for _, c := range incoming {
   297  				c.Close()
   298  			}
   299  		}
   300  	}
   301  
   302  	if ffd.assignedPort != 0 {
   303  		fakePorts.CompareAndDelete(ffd.assignedPort, ffd.fd)
   304  	}
   305  
   306  	return err
   307  }
   308  
   309  func (ffd *fakeNetFD) closeRead() error {
   310  	return ffd.queue.closeRead()
   311  }
   312  
   313  func (ffd *fakeNetFD) closeWrite() error {
   314  	if ffd.peer == nil {
   315  		return os.NewSyscallError("closeWrite", syscall.ENOTCONN)
   316  	}
   317  	return ffd.peer.queue.closeWrite()
   318  }
   319  
   320  func (ffd *fakeNetFD) accept(laddr Addr) (*netFD, error) {
   321  	if ffd.incoming == nil {
   322  		return nil, os.NewSyscallError("accept", syscall.EINVAL)
   323  	}
   324  
   325  	var (
   326  		incoming []*netFD
   327  		ok       bool
   328  	)
   329  	select {
   330  	case <-ffd.readDeadline.Load().expired:
   331  		return nil, os.ErrDeadlineExceeded
   332  	case incoming, ok = <-ffd.incoming:
   333  		if !ok {
   334  			return nil, ErrClosed
   335  		}
   336  	case incoming, ok = <-ffd.incomingFull:
   337  	}
   338  
   339  	peer := incoming[0]
   340  	incoming = incoming[1:]
   341  	if len(incoming) == 0 {
   342  		ffd.incomingEmpty <- true
   343  	} else {
   344  		ffd.incoming <- incoming
   345  	}
   346  	return peer, nil
   347  }
   348  
   349  func (ffd *fakeNetFD) SetDeadline(t time.Time) error {
   350  	err1 := ffd.SetReadDeadline(t)
   351  	err2 := ffd.SetWriteDeadline(t)
   352  	if err1 != nil {
   353  		return err1
   354  	}
   355  	return err2
   356  }
   357  
   358  func (ffd *fakeNetFD) SetReadDeadline(t time.Time) error {
   359  	dt := ffd.readDeadline.Load()
   360  	if !dt.Reset(t) {
   361  		ffd.readDeadline.Store(newDeadlineTimer(t))
   362  	}
   363  	return nil
   364  }
   365  
   366  func (ffd *fakeNetFD) SetWriteDeadline(t time.Time) error {
   367  	dt := ffd.writeDeadline.Load()
   368  	if !dt.Reset(t) {
   369  		ffd.writeDeadline.Store(newDeadlineTimer(t))
   370  	}
   371  	return nil
   372  }
   373  
   374  const maxPacketSize = 65535
   375  
   376  type packet struct {
   377  	buf       []byte
   378  	bufOffset int
   379  	next      *packet
   380  	from      sockaddr
   381  }
   382  
   383  func (p *packet) clear() {
   384  	p.buf = p.buf[:0]
   385  	p.bufOffset = 0
   386  	p.next = nil
   387  	p.from = nil
   388  }
   389  
   390  var packetPool = sync.Pool{
   391  	New: func() any { return new(packet) },
   392  }
   393  
   394  type packetQueueState struct {
   395  	head, tail      *packet // unqueued packets
   396  	nBytes          int     // number of bytes enqueued in the packet buffers starting from head
   397  	readBufferBytes int     // soft limit on nbytes; no more packets may be enqueued when the limit is exceeded
   398  	readClosed      bool    // true if the reader of the queue has stopped reading
   399  	writeClosed     bool    // true if the writer of the queue has stopped writing; the reader sees either io.EOF or syscall.ECONNRESET when they have read all buffered packets
   400  	noLinger        bool    // if true, the reader sees ECONNRESET instead of EOF
   401  }
   402  
   403  // A packetQueue is a set of 1-buffered channels implementing a FIFO queue
   404  // of packets.
   405  type packetQueue struct {
   406  	empty chan packetQueueState // contains configuration parameters when the queue is empty and not closed
   407  	ready chan packetQueueState // contains the packets when non-empty or closed
   408  	full  chan packetQueueState // contains the packets when buffer is full and not closed
   409  }
   410  
   411  func newPacketQueue(readBufferBytes int) *packetQueue {
   412  	pq := &packetQueue{
   413  		empty: make(chan packetQueueState, 1),
   414  		ready: make(chan packetQueueState, 1),
   415  		full:  make(chan packetQueueState, 1),
   416  	}
   417  	pq.put(packetQueueState{
   418  		readBufferBytes: readBufferBytes,
   419  	})
   420  	return pq
   421  }
   422  
   423  func (pq *packetQueue) get() packetQueueState {
   424  	var q packetQueueState
   425  	select {
   426  	case q = <-pq.empty:
   427  	case q = <-pq.ready:
   428  	case q = <-pq.full:
   429  	}
   430  	return q
   431  }
   432  
   433  func (pq *packetQueue) put(q packetQueueState) {
   434  	switch {
   435  	case q.readClosed || q.writeClosed:
   436  		pq.ready <- q
   437  	case q.nBytes >= q.readBufferBytes:
   438  		pq.full <- q
   439  	case q.head == nil:
   440  		if q.nBytes > 0 {
   441  			defer panic("net: put with nil packet list and nonzero nBytes")
   442  		}
   443  		pq.empty <- q
   444  	default:
   445  		pq.ready <- q
   446  	}
   447  }
   448  
   449  func (pq *packetQueue) closeRead() error {
   450  	q := pq.get()
   451  
   452  	// Discard any unread packets.
   453  	for q.head != nil {
   454  		p := q.head
   455  		q.head = p.next
   456  		p.clear()
   457  		packetPool.Put(p)
   458  	}
   459  	q.nBytes = 0
   460  
   461  	q.readClosed = true
   462  	pq.put(q)
   463  	return nil
   464  }
   465  
   466  func (pq *packetQueue) closeWrite() error {
   467  	q := pq.get()
   468  	q.writeClosed = true
   469  	pq.put(q)
   470  	return nil
   471  }
   472  
   473  func (pq *packetQueue) setLinger(linger bool) error {
   474  	q := pq.get()
   475  	defer func() { pq.put(q) }()
   476  
   477  	if q.writeClosed {
   478  		return ErrClosed
   479  	}
   480  	q.noLinger = !linger
   481  	return nil
   482  }
   483  
   484  func (pq *packetQueue) write(dt *deadlineTimer, b []byte, from sockaddr) (n int, err error) {
   485  	for {
   486  		dn := len(b)
   487  		if dn > maxPacketSize {
   488  			dn = maxPacketSize
   489  		}
   490  
   491  		dn, err = pq.send(dt, b[:dn], from, true)
   492  		n += dn
   493  		if err != nil {
   494  			return n, err
   495  		}
   496  
   497  		b = b[dn:]
   498  		if len(b) == 0 {
   499  			return n, nil
   500  		}
   501  	}
   502  }
   503  
   504  func (pq *packetQueue) send(dt *deadlineTimer, b []byte, from sockaddr, block bool) (n int, err error) {
   505  	if from == nil {
   506  		return 0, os.NewSyscallError("send", syscall.EINVAL)
   507  	}
   508  	if len(b) > maxPacketSize {
   509  		return 0, os.NewSyscallError("send", syscall.EMSGSIZE)
   510  	}
   511  
   512  	var q packetQueueState
   513  	var full chan packetQueueState
   514  	if !block {
   515  		full = pq.full
   516  	}
   517  
   518  	// Before we check dt.expired, yield to other goroutines.
   519  	// This may help to prevent starvation of the goroutine that runs the
   520  	// deadlineTimer's time.After callback.
   521  	//
   522  	// TODO(#65178): Remove this when the runtime scheduler no longer starves
   523  	// runnable goroutines.
   524  	runtime.Gosched()
   525  
   526  	select {
   527  	case <-dt.expired:
   528  		return 0, os.ErrDeadlineExceeded
   529  
   530  	case q = <-full:
   531  		pq.put(q)
   532  		return 0, os.NewSyscallError("send", syscall.ENOBUFS)
   533  
   534  	case q = <-pq.empty:
   535  	case q = <-pq.ready:
   536  	}
   537  	defer func() { pq.put(q) }()
   538  
   539  	// Don't allow a packet to be sent if the deadline has expired,
   540  	// even if the select above chose a different branch.
   541  	select {
   542  	case <-dt.expired:
   543  		return 0, os.ErrDeadlineExceeded
   544  	default:
   545  	}
   546  	if q.writeClosed {
   547  		return 0, ErrClosed
   548  	} else if q.readClosed {
   549  		return 0, os.NewSyscallError("send", syscall.ECONNRESET)
   550  	}
   551  
   552  	p := packetPool.Get().(*packet)
   553  	p.buf = append(p.buf[:0], b...)
   554  	p.from = from
   555  
   556  	if q.head == nil {
   557  		q.head = p
   558  	} else {
   559  		q.tail.next = p
   560  	}
   561  	q.tail = p
   562  	q.nBytes += len(p.buf)
   563  
   564  	return len(b), nil
   565  }
   566  
   567  func (pq *packetQueue) recvfrom(dt *deadlineTimer, b []byte, wholePacket bool, checkFrom func(sockaddr) error) (n int, from sockaddr, err error) {
   568  	var q packetQueueState
   569  	var empty chan packetQueueState
   570  	if len(b) == 0 {
   571  		// For consistency with the implementation on Unix platforms,
   572  		// allow a zero-length Read to proceed if the queue is empty.
   573  		// (Without this, TestZeroByteRead deadlocks.)
   574  		empty = pq.empty
   575  	}
   576  
   577  	// Before we check dt.expired, yield to other goroutines.
   578  	// This may help to prevent starvation of the goroutine that runs the
   579  	// deadlineTimer's time.After callback.
   580  	//
   581  	// TODO(#65178): Remove this when the runtime scheduler no longer starves
   582  	// runnable goroutines.
   583  	runtime.Gosched()
   584  
   585  	select {
   586  	case <-dt.expired:
   587  		return 0, nil, os.ErrDeadlineExceeded
   588  	case q = <-empty:
   589  	case q = <-pq.ready:
   590  	case q = <-pq.full:
   591  	}
   592  	defer func() { pq.put(q) }()
   593  
   594  	p := q.head
   595  	if p == nil {
   596  		switch {
   597  		case q.readClosed:
   598  			return 0, nil, ErrClosed
   599  		case q.writeClosed:
   600  			if q.noLinger {
   601  				return 0, nil, os.NewSyscallError("recvfrom", syscall.ECONNRESET)
   602  			}
   603  			return 0, nil, io.EOF
   604  		case len(b) == 0:
   605  			return 0, nil, nil
   606  		default:
   607  			// This should be impossible: pq.full should only contain a non-empty list,
   608  			// pq.ready should either contain a non-empty list or indicate that the
   609  			// connection is closed, and we should only receive from pq.empty if
   610  			// len(b) == 0.
   611  			panic("net: nil packet list from non-closed packetQueue")
   612  		}
   613  	}
   614  
   615  	select {
   616  	case <-dt.expired:
   617  		return 0, nil, os.ErrDeadlineExceeded
   618  	default:
   619  	}
   620  
   621  	if checkFrom != nil {
   622  		if err := checkFrom(p.from); err != nil {
   623  			return 0, nil, err
   624  		}
   625  	}
   626  
   627  	n = copy(b, p.buf[p.bufOffset:])
   628  	from = p.from
   629  	if wholePacket || p.bufOffset+n == len(p.buf) {
   630  		q.head = p.next
   631  		q.nBytes -= len(p.buf)
   632  		p.clear()
   633  		packetPool.Put(p)
   634  	} else {
   635  		p.bufOffset += n
   636  	}
   637  
   638  	return n, from, nil
   639  }
   640  
   641  // setReadBuffer sets a soft limit on the number of bytes available to read
   642  // from the pipe.
   643  func (pq *packetQueue) setReadBuffer(bytes int) error {
   644  	if bytes <= 0 {
   645  		return os.NewSyscallError("setReadBuffer", syscall.EINVAL)
   646  	}
   647  	q := pq.get() // Use the queue as a lock.
   648  	q.readBufferBytes = bytes
   649  	pq.put(q)
   650  	return nil
   651  }
   652  
   653  type deadlineTimer struct {
   654  	timer   chan *time.Timer
   655  	expired chan struct{}
   656  }
   657  
   658  func newDeadlineTimer(deadline time.Time) *deadlineTimer {
   659  	dt := &deadlineTimer{
   660  		timer:   make(chan *time.Timer, 1),
   661  		expired: make(chan struct{}),
   662  	}
   663  	dt.timer <- nil
   664  	dt.Reset(deadline)
   665  	return dt
   666  }
   667  
   668  // Reset attempts to reset the timer.
   669  // If the timer has already expired, Reset returns false.
   670  func (dt *deadlineTimer) Reset(deadline time.Time) bool {
   671  	timer := <-dt.timer
   672  	defer func() { dt.timer <- timer }()
   673  
   674  	if deadline.Equal(noDeadline) {
   675  		if timer != nil && timer.Stop() {
   676  			timer = nil
   677  		}
   678  		return timer == nil
   679  	}
   680  
   681  	d := time.Until(deadline)
   682  	if d < 0 {
   683  		// Ensure that a deadline in the past takes effect immediately.
   684  		defer func() { <-dt.expired }()
   685  	}
   686  
   687  	if timer == nil {
   688  		timer = time.AfterFunc(d, func() { close(dt.expired) })
   689  		return true
   690  	}
   691  	if !timer.Stop() {
   692  		return false
   693  	}
   694  	timer.Reset(d)
   695  	return true
   696  }
   697  
   698  func sysSocket(family, sotype, proto int) (int, error) {
   699  	return 0, os.NewSyscallError("sysSocket", syscall.ENOSYS)
   700  }
   701  
   702  func fakeListen(fd *netFD, laddr sockaddr) (err error) {
   703  	wrapErr := func(err error) error {
   704  		if errno, ok := err.(syscall.Errno); ok {
   705  			err = os.NewSyscallError("listen", errno)
   706  		}
   707  		if errors.Is(err, syscall.EADDRINUSE) {
   708  			return err
   709  		}
   710  		if laddr != nil {
   711  			if _, ok := err.(*AddrError); !ok {
   712  				err = &AddrError{
   713  					Err:  err.Error(),
   714  					Addr: laddr.String(),
   715  				}
   716  			}
   717  		}
   718  		return err
   719  	}
   720  
   721  	ffd := newFakeNetFD(fd)
   722  	defer func() {
   723  		if fd.fakeNetFD != ffd {
   724  			// Failed to register listener; clean up.
   725  			ffd.Close()
   726  		}
   727  	}()
   728  
   729  	if err := ffd.assignFakeAddr(matchIPFamily(fd.family, laddr)); err != nil {
   730  		return wrapErr(err)
   731  	}
   732  
   733  	ffd.fakeAddr = fakeAddr(fd.laddr.(sockaddr))
   734  	switch fd.sotype {
   735  	case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
   736  		ffd.incoming = make(chan []*netFD, 1)
   737  		ffd.incomingFull = make(chan []*netFD, 1)
   738  		ffd.incomingEmpty = make(chan bool, 1)
   739  		ffd.incomingEmpty <- true
   740  	case syscall.SOCK_DGRAM:
   741  		ffd.queue = newPacketQueue(defaultBuffer)
   742  	default:
   743  		return wrapErr(syscall.EINVAL)
   744  	}
   745  
   746  	fd.fakeNetFD = ffd
   747  	if _, dup := sockets.LoadOrStore(ffd.fakeAddr, fd); dup {
   748  		fd.fakeNetFD = nil
   749  		return wrapErr(syscall.EADDRINUSE)
   750  	}
   751  
   752  	return nil
   753  }
   754  
   755  func fakeConnect(ctx context.Context, fd *netFD, laddr, raddr sockaddr) error {
   756  	wrapErr := func(err error) error {
   757  		if errno, ok := err.(syscall.Errno); ok {
   758  			err = os.NewSyscallError("connect", errno)
   759  		}
   760  		if errors.Is(err, syscall.EADDRINUSE) {
   761  			return err
   762  		}
   763  		if terr, ok := err.(interface{ Timeout() bool }); !ok || !terr.Timeout() {
   764  			// For consistency with the net implementation on other platforms,
   765  			// if we don't need to preserve the Timeout-ness of err we should
   766  			// wrap it in an AddrError. (Unfortunately we can't wrap errors
   767  			// that convey structured information, because AddrError reduces
   768  			// the wrapped Err to a flat string.)
   769  			if _, ok := err.(*AddrError); !ok {
   770  				err = &AddrError{
   771  					Err:  err.Error(),
   772  					Addr: raddr.String(),
   773  				}
   774  			}
   775  		}
   776  		return err
   777  	}
   778  
   779  	if fd.isConnected {
   780  		return wrapErr(syscall.EISCONN)
   781  	}
   782  	if ctx.Err() != nil {
   783  		return wrapErr(syscall.ETIMEDOUT)
   784  	}
   785  
   786  	fd.raddr = matchIPFamily(fd.family, raddr)
   787  	if err := validateResolvedAddr(fd.net, fd.family, fd.raddr.(sockaddr)); err != nil {
   788  		return wrapErr(err)
   789  	}
   790  
   791  	if err := fd.fakeNetFD.assignFakeAddr(laddr); err != nil {
   792  		return wrapErr(err)
   793  	}
   794  	fd.fakeNetFD.queue = newPacketQueue(defaultBuffer)
   795  
   796  	switch fd.sotype {
   797  	case syscall.SOCK_DGRAM:
   798  		if ua, ok := fd.laddr.(*UnixAddr); !ok || ua.Name != "" {
   799  			fd.fakeNetFD.fakeAddr = fakeAddr(fd.laddr.(sockaddr))
   800  			if _, dup := sockets.LoadOrStore(fd.fakeNetFD.fakeAddr, fd); dup {
   801  				return wrapErr(syscall.EADDRINUSE)
   802  			}
   803  		}
   804  		fd.isConnected = true
   805  		return nil
   806  
   807  	case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
   808  	default:
   809  		return wrapErr(syscall.EINVAL)
   810  	}
   811  
   812  	fa := fakeAddr(raddr)
   813  	lni, ok := sockets.Load(fa)
   814  	if !ok {
   815  		return wrapErr(syscall.ECONNREFUSED)
   816  	}
   817  	ln := lni.(*netFD)
   818  	if ln.sotype != fd.sotype {
   819  		return wrapErr(syscall.EPROTOTYPE)
   820  	}
   821  	if ln.incoming == nil {
   822  		return wrapErr(syscall.ECONNREFUSED)
   823  	}
   824  
   825  	peer := &netFD{
   826  		family:      ln.family,
   827  		sotype:      ln.sotype,
   828  		net:         ln.net,
   829  		laddr:       ln.laddr,
   830  		raddr:       fd.laddr,
   831  		isConnected: true,
   832  	}
   833  	peer.fakeNetFD = newFakeNetFD(fd)
   834  	peer.fakeNetFD.queue = newPacketQueue(defaultBuffer)
   835  	defer func() {
   836  		if fd.peer != peer {
   837  			// Failed to connect; clean up.
   838  			peer.Close()
   839  		}
   840  	}()
   841  
   842  	var incoming []*netFD
   843  	select {
   844  	case <-ctx.Done():
   845  		return wrapErr(syscall.ETIMEDOUT)
   846  	case ok = <-ln.incomingEmpty:
   847  	case incoming, ok = <-ln.incoming:
   848  	}
   849  	if !ok {
   850  		return wrapErr(syscall.ECONNREFUSED)
   851  	}
   852  
   853  	fd.isConnected = true
   854  	fd.peer = peer
   855  	peer.peer = fd
   856  
   857  	incoming = append(incoming, peer)
   858  	if len(incoming) >= listenerBacklog() {
   859  		ln.incomingFull <- incoming
   860  	} else {
   861  		ln.incoming <- incoming
   862  	}
   863  	return nil
   864  }
   865  
   866  func (ffd *fakeNetFD) assignFakeAddr(addr sockaddr) error {
   867  	validate := func(sa sockaddr) error {
   868  		if err := validateResolvedAddr(ffd.fd.net, ffd.fd.family, sa); err != nil {
   869  			return err
   870  		}
   871  		ffd.fd.laddr = sa
   872  		return nil
   873  	}
   874  
   875  	assignIP := func(addr sockaddr) error {
   876  		var (
   877  			ip   IP
   878  			port int
   879  			zone string
   880  		)
   881  		switch addr := addr.(type) {
   882  		case *TCPAddr:
   883  			if addr != nil {
   884  				ip = addr.IP
   885  				port = addr.Port
   886  				zone = addr.Zone
   887  			}
   888  		case *UDPAddr:
   889  			if addr != nil {
   890  				ip = addr.IP
   891  				port = addr.Port
   892  				zone = addr.Zone
   893  			}
   894  		default:
   895  			return validate(addr)
   896  		}
   897  
   898  		if ip == nil {
   899  			ip = IPv4(127, 0, 0, 1)
   900  		}
   901  		switch ffd.fd.family {
   902  		case syscall.AF_INET:
   903  			if ip4 := ip.To4(); ip4 != nil {
   904  				ip = ip4
   905  			}
   906  		case syscall.AF_INET6:
   907  			if ip16 := ip.To16(); ip16 != nil {
   908  				ip = ip16
   909  			}
   910  		}
   911  		if ip == nil {
   912  			return syscall.EINVAL
   913  		}
   914  
   915  		if port == 0 {
   916  			var prevPort int32
   917  			portWrapped := false
   918  			nextPort := func() (int, bool) {
   919  				for {
   920  					port := nextPortCounter.Add(1)
   921  					if port <= 0 || port >= 1<<16 {
   922  						// nextPortCounter ran off the end of the port space.
   923  						// Bump it back into range.
   924  						for {
   925  							if nextPortCounter.CompareAndSwap(port, 0) {
   926  								break
   927  							}
   928  							if port = nextPortCounter.Load(); port >= 0 && port < 1<<16 {
   929  								break
   930  							}
   931  						}
   932  						if portWrapped {
   933  							// This is the second wraparound, so we've scanned the whole port space
   934  							// at least once already and it's time to give up.
   935  							return 0, false
   936  						}
   937  						portWrapped = true
   938  						prevPort = 0
   939  						continue
   940  					}
   941  
   942  					if port <= prevPort {
   943  						// nextPortCounter has wrapped around since the last time we read it.
   944  						if portWrapped {
   945  							// This is the second wraparound, so we've scanned the whole port space
   946  							// at least once already and it's time to give up.
   947  							return 0, false
   948  						} else {
   949  							portWrapped = true
   950  						}
   951  					}
   952  
   953  					prevPort = port
   954  					return int(port), true
   955  				}
   956  			}
   957  
   958  			for {
   959  				var ok bool
   960  				port, ok = nextPort()
   961  				if !ok {
   962  					ffd.assignedPort = 0
   963  					return syscall.EADDRINUSE
   964  				}
   965  
   966  				ffd.assignedPort = int(port)
   967  				if _, dup := fakePorts.LoadOrStore(ffd.assignedPort, ffd.fd); !dup {
   968  					break
   969  				}
   970  			}
   971  		}
   972  
   973  		switch addr.(type) {
   974  		case *TCPAddr:
   975  			return validate(&TCPAddr{IP: ip, Port: port, Zone: zone})
   976  		case *UDPAddr:
   977  			return validate(&UDPAddr{IP: ip, Port: port, Zone: zone})
   978  		default:
   979  			panic("unreachable")
   980  		}
   981  	}
   982  
   983  	switch ffd.fd.net {
   984  	case "tcp", "tcp4", "tcp6":
   985  		if addr == nil {
   986  			return assignIP(new(TCPAddr))
   987  		}
   988  		return assignIP(addr)
   989  
   990  	case "udp", "udp4", "udp6":
   991  		if addr == nil {
   992  			return assignIP(new(UDPAddr))
   993  		}
   994  		return assignIP(addr)
   995  
   996  	case "unix", "unixgram", "unixpacket":
   997  		uaddr, ok := addr.(*UnixAddr)
   998  		if !ok && addr != nil {
   999  			return &AddrError{
  1000  				Err:  "non-Unix address for " + ffd.fd.net + " network",
  1001  				Addr: addr.String(),
  1002  			}
  1003  		}
  1004  		if uaddr == nil {
  1005  			return validate(&UnixAddr{Net: ffd.fd.net})
  1006  		}
  1007  		return validate(&UnixAddr{Net: ffd.fd.net, Name: uaddr.Name})
  1008  
  1009  	default:
  1010  		return &AddrError{
  1011  			Err:  syscall.EAFNOSUPPORT.Error(),
  1012  			Addr: addr.String(),
  1013  		}
  1014  	}
  1015  }
  1016  
  1017  func (ffd *fakeNetFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
  1018  	if ffd.queue == nil {
  1019  		return 0, nil, os.NewSyscallError("readFrom", syscall.EINVAL)
  1020  	}
  1021  
  1022  	n, from, err := ffd.queue.recvfrom(ffd.readDeadline.Load(), p, true, nil)
  1023  
  1024  	if from != nil {
  1025  		// Convert the net.sockaddr to a syscall.Sockaddr type.
  1026  		var saErr error
  1027  		sa, saErr = from.sockaddr(ffd.fd.family)
  1028  		if err == nil {
  1029  			err = saErr
  1030  		}
  1031  	}
  1032  
  1033  	return n, sa, err
  1034  }
  1035  
  1036  func (ffd *fakeNetFD) readFromInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
  1037  	n, _, err = ffd.queue.recvfrom(ffd.readDeadline.Load(), p, true, func(from sockaddr) error {
  1038  		fromSA, err := from.sockaddr(syscall.AF_INET)
  1039  		if err != nil {
  1040  			return err
  1041  		}
  1042  		if fromSA == nil {
  1043  			return os.NewSyscallError("readFromInet4", syscall.EINVAL)
  1044  		}
  1045  		*sa = *(fromSA.(*syscall.SockaddrInet4))
  1046  		return nil
  1047  	})
  1048  	return n, err
  1049  }
  1050  
  1051  func (ffd *fakeNetFD) readFromInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
  1052  	n, _, err = ffd.queue.recvfrom(ffd.readDeadline.Load(), p, true, func(from sockaddr) error {
  1053  		fromSA, err := from.sockaddr(syscall.AF_INET6)
  1054  		if err != nil {
  1055  			return err
  1056  		}
  1057  		if fromSA == nil {
  1058  			return os.NewSyscallError("readFromInet6", syscall.EINVAL)
  1059  		}
  1060  		*sa = *(fromSA.(*syscall.SockaddrInet6))
  1061  		return nil
  1062  	})
  1063  	return n, err
  1064  }
  1065  
  1066  func (ffd *fakeNetFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
  1067  	if flags != 0 {
  1068  		return 0, 0, 0, nil, os.NewSyscallError("readMsg", syscall.ENOTSUP)
  1069  	}
  1070  	n, sa, err = ffd.readFrom(p)
  1071  	return n, 0, 0, sa, err
  1072  }
  1073  
  1074  func (ffd *fakeNetFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
  1075  	if flags != 0 {
  1076  		return 0, 0, 0, os.NewSyscallError("readMsgInet4", syscall.ENOTSUP)
  1077  	}
  1078  	n, err = ffd.readFromInet4(p, sa)
  1079  	return n, 0, 0, err
  1080  }
  1081  
  1082  func (ffd *fakeNetFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
  1083  	if flags != 0 {
  1084  		return 0, 0, 0, os.NewSyscallError("readMsgInet6", syscall.ENOTSUP)
  1085  	}
  1086  	n, err = ffd.readFromInet6(p, sa)
  1087  	return n, 0, 0, err
  1088  }
  1089  
  1090  func (ffd *fakeNetFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
  1091  	if len(oob) > 0 {
  1092  		return 0, 0, os.NewSyscallError("writeMsg", syscall.ENOTSUP)
  1093  	}
  1094  	n, err = ffd.writeTo(p, sa)
  1095  	return n, 0, err
  1096  }
  1097  
  1098  func (ffd *fakeNetFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
  1099  	return ffd.writeMsg(p, oob, sa)
  1100  }
  1101  
  1102  func (ffd *fakeNetFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
  1103  	return ffd.writeMsg(p, oob, sa)
  1104  }
  1105  
  1106  func (ffd *fakeNetFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
  1107  	raddr := ffd.fd.raddr
  1108  	if sa != nil {
  1109  		if ffd.fd.isConnected {
  1110  			return 0, os.NewSyscallError("writeTo", syscall.EISCONN)
  1111  		}
  1112  		raddr = ffd.fd.addrFunc()(sa)
  1113  	}
  1114  	if raddr == nil {
  1115  		return 0, os.NewSyscallError("writeTo", syscall.EINVAL)
  1116  	}
  1117  
  1118  	peeri, _ := sockets.Load(fakeAddr(raddr.(sockaddr)))
  1119  	if peeri == nil {
  1120  		if len(ffd.fd.net) >= 3 && ffd.fd.net[:3] == "udp" {
  1121  			return len(p), nil
  1122  		}
  1123  		return 0, os.NewSyscallError("writeTo", syscall.ECONNRESET)
  1124  	}
  1125  	peer := peeri.(*netFD)
  1126  	if peer.queue == nil {
  1127  		if len(ffd.fd.net) >= 3 && ffd.fd.net[:3] == "udp" {
  1128  			return len(p), nil
  1129  		}
  1130  		return 0, os.NewSyscallError("writeTo", syscall.ECONNRESET)
  1131  	}
  1132  
  1133  	block := true
  1134  	if len(ffd.fd.net) >= 3 && ffd.fd.net[:3] == "udp" {
  1135  		block = false
  1136  	}
  1137  	return peer.queue.send(ffd.writeDeadline.Load(), p, ffd.fd.laddr.(sockaddr), block)
  1138  }
  1139  
  1140  func (ffd *fakeNetFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
  1141  	return ffd.writeTo(p, sa)
  1142  }
  1143  
  1144  func (ffd *fakeNetFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
  1145  	return ffd.writeTo(p, sa)
  1146  }
  1147  
  1148  func (ffd *fakeNetFD) dup() (f *os.File, err error) {
  1149  	return nil, os.NewSyscallError("dup", syscall.ENOSYS)
  1150  }
  1151  
  1152  func (ffd *fakeNetFD) setReadBuffer(bytes int) error {
  1153  	if ffd.queue == nil {
  1154  		return os.NewSyscallError("setReadBuffer", syscall.EINVAL)
  1155  	}
  1156  	ffd.queue.setReadBuffer(bytes)
  1157  	return nil
  1158  }
  1159  
  1160  func (ffd *fakeNetFD) setWriteBuffer(bytes int) error {
  1161  	return os.NewSyscallError("setWriteBuffer", syscall.ENOTSUP)
  1162  }
  1163  
  1164  func (ffd *fakeNetFD) setLinger(sec int) error {
  1165  	if sec < 0 || ffd.peer == nil {
  1166  		return os.NewSyscallError("setLinger", syscall.EINVAL)
  1167  	}
  1168  	ffd.peer.queue.setLinger(sec > 0)
  1169  	return nil
  1170  }
  1171  

View as plain text