...

Source file src/os/readfrom_linux_test.go

Documentation: os

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package os_test
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"internal/poll"
    11  	"internal/testpty"
    12  	"io"
    13  	"math/rand"
    14  	"net"
    15  	. "os"
    16  	"path/filepath"
    17  	"runtime"
    18  	"strconv"
    19  	"strings"
    20  	"sync"
    21  	"syscall"
    22  	"testing"
    23  	"time"
    24  
    25  	"golang.org/x/net/nettest"
    26  )
    27  
    28  func TestCopyFileRange(t *testing.T) {
    29  	sizes := []int{
    30  		1,
    31  		42,
    32  		1025,
    33  		syscall.Getpagesize() + 1,
    34  		32769,
    35  	}
    36  	t.Run("Basic", func(t *testing.T) {
    37  		for _, size := range sizes {
    38  			t.Run(strconv.Itoa(size), func(t *testing.T) {
    39  				testCopyFileRange(t, int64(size), -1)
    40  			})
    41  		}
    42  	})
    43  	t.Run("Limited", func(t *testing.T) {
    44  		t.Run("OneLess", func(t *testing.T) {
    45  			for _, size := range sizes {
    46  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    47  					testCopyFileRange(t, int64(size), int64(size)-1)
    48  				})
    49  			}
    50  		})
    51  		t.Run("Half", func(t *testing.T) {
    52  			for _, size := range sizes {
    53  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    54  					testCopyFileRange(t, int64(size), int64(size)/2)
    55  				})
    56  			}
    57  		})
    58  		t.Run("More", func(t *testing.T) {
    59  			for _, size := range sizes {
    60  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    61  					testCopyFileRange(t, int64(size), int64(size)+7)
    62  				})
    63  			}
    64  		})
    65  	})
    66  	t.Run("DoesntTryInAppendMode", func(t *testing.T) {
    67  		dst, src, data, hook := newCopyFileRangeTest(t, 42)
    68  
    69  		dst2, err := OpenFile(dst.Name(), O_RDWR|O_APPEND, 0755)
    70  		if err != nil {
    71  			t.Fatal(err)
    72  		}
    73  		defer dst2.Close()
    74  
    75  		if _, err := io.Copy(dst2, src); err != nil {
    76  			t.Fatal(err)
    77  		}
    78  		if hook.called {
    79  			t.Fatal("called poll.CopyFileRange for destination in O_APPEND mode")
    80  		}
    81  		mustSeekStart(t, dst2)
    82  		mustContainData(t, dst2, data) // through traditional means
    83  	})
    84  	t.Run("CopyFileItself", func(t *testing.T) {
    85  		hook := hookCopyFileRange(t)
    86  
    87  		f, err := CreateTemp("", "file-readfrom-itself-test")
    88  		if err != nil {
    89  			t.Fatalf("failed to create tmp file: %v", err)
    90  		}
    91  		t.Cleanup(func() {
    92  			f.Close()
    93  			Remove(f.Name())
    94  		})
    95  
    96  		data := []byte("hello world!")
    97  		if _, err := f.Write(data); err != nil {
    98  			t.Fatalf("failed to create and feed the file: %v", err)
    99  		}
   100  
   101  		if err := f.Sync(); err != nil {
   102  			t.Fatalf("failed to save the file: %v", err)
   103  		}
   104  
   105  		// Rewind it.
   106  		if _, err := f.Seek(0, io.SeekStart); err != nil {
   107  			t.Fatalf("failed to rewind the file: %v", err)
   108  		}
   109  
   110  		// Read data from the file itself.
   111  		if _, err := io.Copy(f, f); err != nil {
   112  			t.Fatalf("failed to read from the file: %v", err)
   113  		}
   114  
   115  		if !hook.called || hook.written != 0 || hook.handled || hook.err != nil {
   116  			t.Fatalf("poll.CopyFileRange should be called and return the EINVAL error, but got hook.called=%t, hook.err=%v", hook.called, hook.err)
   117  		}
   118  
   119  		// Rewind it.
   120  		if _, err := f.Seek(0, io.SeekStart); err != nil {
   121  			t.Fatalf("failed to rewind the file: %v", err)
   122  		}
   123  
   124  		data2, err := io.ReadAll(f)
   125  		if err != nil {
   126  			t.Fatalf("failed to read from the file: %v", err)
   127  		}
   128  
   129  		// It should wind up a double of the original data.
   130  		if strings.Repeat(string(data), 2) != string(data2) {
   131  			t.Fatalf("data mismatch: %s != %s", string(data), string(data2))
   132  		}
   133  	})
   134  	t.Run("NotRegular", func(t *testing.T) {
   135  		t.Run("BothPipes", func(t *testing.T) {
   136  			hook := hookCopyFileRange(t)
   137  
   138  			pr1, pw1, err := Pipe()
   139  			if err != nil {
   140  				t.Fatal(err)
   141  			}
   142  			defer pr1.Close()
   143  			defer pw1.Close()
   144  
   145  			pr2, pw2, err := Pipe()
   146  			if err != nil {
   147  				t.Fatal(err)
   148  			}
   149  			defer pr2.Close()
   150  			defer pw2.Close()
   151  
   152  			// The pipe is empty, and PIPE_BUF is large enough
   153  			// for this, by (POSIX) definition, so there is no
   154  			// need for an additional goroutine.
   155  			data := []byte("hello")
   156  			if _, err := pw1.Write(data); err != nil {
   157  				t.Fatal(err)
   158  			}
   159  			pw1.Close()
   160  
   161  			n, err := io.Copy(pw2, pr1)
   162  			if err != nil {
   163  				t.Fatal(err)
   164  			}
   165  			if n != int64(len(data)) {
   166  				t.Fatalf("transferred %d, want %d", n, len(data))
   167  			}
   168  			if !hook.called {
   169  				t.Fatalf("should have called poll.CopyFileRange")
   170  			}
   171  			pw2.Close()
   172  			mustContainData(t, pr2, data)
   173  		})
   174  		t.Run("DstPipe", func(t *testing.T) {
   175  			dst, src, data, hook := newCopyFileRangeTest(t, 255)
   176  			dst.Close()
   177  
   178  			pr, pw, err := Pipe()
   179  			if err != nil {
   180  				t.Fatal(err)
   181  			}
   182  			defer pr.Close()
   183  			defer pw.Close()
   184  
   185  			n, err := io.Copy(pw, src)
   186  			if err != nil {
   187  				t.Fatal(err)
   188  			}
   189  			if n != int64(len(data)) {
   190  				t.Fatalf("transferred %d, want %d", n, len(data))
   191  			}
   192  			if !hook.called {
   193  				t.Fatalf("should have called poll.CopyFileRange")
   194  			}
   195  			pw.Close()
   196  			mustContainData(t, pr, data)
   197  		})
   198  		t.Run("SrcPipe", func(t *testing.T) {
   199  			dst, src, data, hook := newCopyFileRangeTest(t, 255)
   200  			src.Close()
   201  
   202  			pr, pw, err := Pipe()
   203  			if err != nil {
   204  				t.Fatal(err)
   205  			}
   206  			defer pr.Close()
   207  			defer pw.Close()
   208  
   209  			// The pipe is empty, and PIPE_BUF is large enough
   210  			// for this, by (POSIX) definition, so there is no
   211  			// need for an additional goroutine.
   212  			if _, err := pw.Write(data); err != nil {
   213  				t.Fatal(err)
   214  			}
   215  			pw.Close()
   216  
   217  			n, err := io.Copy(dst, pr)
   218  			if err != nil {
   219  				t.Fatal(err)
   220  			}
   221  			if n != int64(len(data)) {
   222  				t.Fatalf("transferred %d, want %d", n, len(data))
   223  			}
   224  			if !hook.called {
   225  				t.Fatalf("should have called poll.CopyFileRange")
   226  			}
   227  			mustSeekStart(t, dst)
   228  			mustContainData(t, dst, data)
   229  		})
   230  	})
   231  	t.Run("Nil", func(t *testing.T) {
   232  		var nilFile *File
   233  		anyFile, err := CreateTemp("", "")
   234  		if err != nil {
   235  			t.Fatal(err)
   236  		}
   237  		defer Remove(anyFile.Name())
   238  		defer anyFile.Close()
   239  
   240  		if _, err := io.Copy(nilFile, nilFile); err != ErrInvalid {
   241  			t.Errorf("io.Copy(nilFile, nilFile) = %v, want %v", err, ErrInvalid)
   242  		}
   243  		if _, err := io.Copy(anyFile, nilFile); err != ErrInvalid {
   244  			t.Errorf("io.Copy(anyFile, nilFile) = %v, want %v", err, ErrInvalid)
   245  		}
   246  		if _, err := io.Copy(nilFile, anyFile); err != ErrInvalid {
   247  			t.Errorf("io.Copy(nilFile, anyFile) = %v, want %v", err, ErrInvalid)
   248  		}
   249  
   250  		if _, err := nilFile.ReadFrom(nilFile); err != ErrInvalid {
   251  			t.Errorf("nilFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
   252  		}
   253  		if _, err := anyFile.ReadFrom(nilFile); err != ErrInvalid {
   254  			t.Errorf("anyFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
   255  		}
   256  		if _, err := nilFile.ReadFrom(anyFile); err != ErrInvalid {
   257  			t.Errorf("nilFile.ReadFrom(anyFile) = %v, want %v", err, ErrInvalid)
   258  		}
   259  	})
   260  }
   261  
   262  func TestSpliceFile(t *testing.T) {
   263  	sizes := []int{
   264  		1,
   265  		42,
   266  		1025,
   267  		syscall.Getpagesize() + 1,
   268  		32769,
   269  	}
   270  	t.Run("Basic-TCP", func(t *testing.T) {
   271  		for _, size := range sizes {
   272  			t.Run(strconv.Itoa(size), func(t *testing.T) {
   273  				testSpliceFile(t, "tcp", int64(size), -1)
   274  			})
   275  		}
   276  	})
   277  	t.Run("Basic-Unix", func(t *testing.T) {
   278  		for _, size := range sizes {
   279  			t.Run(strconv.Itoa(size), func(t *testing.T) {
   280  				testSpliceFile(t, "unix", int64(size), -1)
   281  			})
   282  		}
   283  	})
   284  	t.Run("TCP-To-TTY", func(t *testing.T) {
   285  		testSpliceToTTY(t, "tcp", 32768)
   286  	})
   287  	t.Run("Unix-To-TTY", func(t *testing.T) {
   288  		testSpliceToTTY(t, "unix", 32768)
   289  	})
   290  	t.Run("Limited", func(t *testing.T) {
   291  		t.Run("OneLess-TCP", func(t *testing.T) {
   292  			for _, size := range sizes {
   293  				t.Run(strconv.Itoa(size), func(t *testing.T) {
   294  					testSpliceFile(t, "tcp", int64(size), int64(size)-1)
   295  				})
   296  			}
   297  		})
   298  		t.Run("OneLess-Unix", func(t *testing.T) {
   299  			for _, size := range sizes {
   300  				t.Run(strconv.Itoa(size), func(t *testing.T) {
   301  					testSpliceFile(t, "unix", int64(size), int64(size)-1)
   302  				})
   303  			}
   304  		})
   305  		t.Run("Half-TCP", func(t *testing.T) {
   306  			for _, size := range sizes {
   307  				t.Run(strconv.Itoa(size), func(t *testing.T) {
   308  					testSpliceFile(t, "tcp", int64(size), int64(size)/2)
   309  				})
   310  			}
   311  		})
   312  		t.Run("Half-Unix", func(t *testing.T) {
   313  			for _, size := range sizes {
   314  				t.Run(strconv.Itoa(size), func(t *testing.T) {
   315  					testSpliceFile(t, "unix", int64(size), int64(size)/2)
   316  				})
   317  			}
   318  		})
   319  		t.Run("More-TCP", func(t *testing.T) {
   320  			for _, size := range sizes {
   321  				t.Run(strconv.Itoa(size), func(t *testing.T) {
   322  					testSpliceFile(t, "tcp", int64(size), int64(size)+1)
   323  				})
   324  			}
   325  		})
   326  		t.Run("More-Unix", func(t *testing.T) {
   327  			for _, size := range sizes {
   328  				t.Run(strconv.Itoa(size), func(t *testing.T) {
   329  					testSpliceFile(t, "unix", int64(size), int64(size)+1)
   330  				})
   331  			}
   332  		})
   333  	})
   334  }
   335  
   336  func testSpliceFile(t *testing.T, proto string, size, limit int64) {
   337  	dst, src, data, hook, cleanup := newSpliceFileTest(t, proto, size)
   338  	defer cleanup()
   339  
   340  	// If we have a limit, wrap the reader.
   341  	var (
   342  		r  io.Reader
   343  		lr *io.LimitedReader
   344  	)
   345  	if limit >= 0 {
   346  		lr = &io.LimitedReader{N: limit, R: src}
   347  		r = lr
   348  		if limit < int64(len(data)) {
   349  			data = data[:limit]
   350  		}
   351  	} else {
   352  		r = src
   353  	}
   354  	// Now call ReadFrom (through io.Copy), which will hopefully call poll.Splice
   355  	n, err := io.Copy(dst, r)
   356  	if err != nil {
   357  		t.Fatal(err)
   358  	}
   359  
   360  	// We should have called poll.Splice with the right file descriptor arguments.
   361  	if n > 0 && !hook.called {
   362  		t.Fatal("expected to called poll.Splice")
   363  	}
   364  	if hook.called && hook.dstfd != int(dst.Fd()) {
   365  		t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
   366  	}
   367  	sc, ok := src.(syscall.Conn)
   368  	if !ok {
   369  		t.Fatalf("server Conn is not a syscall.Conn")
   370  	}
   371  	rc, err := sc.SyscallConn()
   372  	if err != nil {
   373  		t.Fatalf("server Conn SyscallConn error: %v", err)
   374  	}
   375  	if err = rc.Control(func(fd uintptr) {
   376  		if hook.called && hook.srcfd != int(fd) {
   377  			t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, int(fd))
   378  		}
   379  	}); err != nil {
   380  		t.Fatalf("server Conn Control error: %v", err)
   381  	}
   382  
   383  	// Check that the offsets after the transfer make sense, that the size
   384  	// of the transfer was reported correctly, and that the destination
   385  	// file contains exactly the bytes we expect it to contain.
   386  	dstoff, err := dst.Seek(0, io.SeekCurrent)
   387  	if err != nil {
   388  		t.Fatal(err)
   389  	}
   390  	if dstoff != int64(len(data)) {
   391  		t.Errorf("dstoff = %d, want %d", dstoff, len(data))
   392  	}
   393  	if n != int64(len(data)) {
   394  		t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
   395  	}
   396  	mustSeekStart(t, dst)
   397  	mustContainData(t, dst, data)
   398  
   399  	// If we had a limit, check that it was updated.
   400  	if lr != nil {
   401  		if want := limit - n; lr.N != want {
   402  			t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
   403  		}
   404  	}
   405  }
   406  
   407  // Issue #59041.
   408  func testSpliceToTTY(t *testing.T, proto string, size int64) {
   409  	var wg sync.WaitGroup
   410  
   411  	// Call wg.Wait as the final deferred function,
   412  	// because the goroutines may block until some of
   413  	// the deferred Close calls.
   414  	defer wg.Wait()
   415  
   416  	pty, ttyName, err := testpty.Open()
   417  	if err != nil {
   418  		t.Skipf("skipping test because pty open failed: %v", err)
   419  	}
   420  	defer pty.Close()
   421  
   422  	// Open the tty directly, rather than via OpenFile.
   423  	// This bypasses the non-blocking support and is required
   424  	// to recreate the problem in the issue (#59041).
   425  	ttyFD, err := syscall.Open(ttyName, syscall.O_RDWR, 0)
   426  	if err != nil {
   427  		t.Skipf("skipping test becaused failed to open tty: %v", err)
   428  	}
   429  	defer syscall.Close(ttyFD)
   430  
   431  	tty := NewFile(uintptr(ttyFD), "tty")
   432  	defer tty.Close()
   433  
   434  	client, server := createSocketPair(t, proto)
   435  
   436  	data := bytes.Repeat([]byte{'a'}, int(size))
   437  
   438  	wg.Add(1)
   439  	go func() {
   440  		defer wg.Done()
   441  		// The problem (issue #59041) occurs when writing
   442  		// a series of blocks of data. It does not occur
   443  		// when all the data is written at once.
   444  		for i := 0; i < len(data); i += 1024 {
   445  			if _, err := client.Write(data[i : i+1024]); err != nil {
   446  				// If we get here because the client was
   447  				// closed, skip the error.
   448  				if !errors.Is(err, net.ErrClosed) {
   449  					t.Errorf("error writing to socket: %v", err)
   450  				}
   451  				return
   452  			}
   453  		}
   454  		client.Close()
   455  	}()
   456  
   457  	wg.Add(1)
   458  	go func() {
   459  		defer wg.Done()
   460  		buf := make([]byte, 32)
   461  		for {
   462  			if _, err := pty.Read(buf); err != nil {
   463  				if err != io.EOF && !errors.Is(err, ErrClosed) {
   464  					// An error here doesn't matter for
   465  					// our test.
   466  					t.Logf("error reading from pty: %v", err)
   467  				}
   468  				return
   469  			}
   470  		}
   471  	}()
   472  
   473  	// Close Client to wake up the writing goroutine if necessary.
   474  	defer client.Close()
   475  
   476  	_, err = io.Copy(tty, server)
   477  	if err != nil {
   478  		t.Fatal(err)
   479  	}
   480  }
   481  
   482  func testCopyFileRange(t *testing.T, size int64, limit int64) {
   483  	dst, src, data, hook := newCopyFileRangeTest(t, size)
   484  
   485  	// If we have a limit, wrap the reader.
   486  	var (
   487  		realsrc io.Reader
   488  		lr      *io.LimitedReader
   489  	)
   490  	if limit >= 0 {
   491  		lr = &io.LimitedReader{N: limit, R: src}
   492  		realsrc = lr
   493  		if limit < int64(len(data)) {
   494  			data = data[:limit]
   495  		}
   496  	} else {
   497  		realsrc = src
   498  	}
   499  
   500  	// Now call ReadFrom (through io.Copy), which will hopefully call
   501  	// poll.CopyFileRange.
   502  	n, err := io.Copy(dst, realsrc)
   503  	if err != nil {
   504  		t.Fatal(err)
   505  	}
   506  
   507  	// If we didn't have a limit, we should have called poll.CopyFileRange
   508  	// with the right file descriptor arguments.
   509  	if limit > 0 && !hook.called {
   510  		t.Fatal("never called poll.CopyFileRange")
   511  	}
   512  	if hook.called && hook.dstfd != int(dst.Fd()) {
   513  		t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
   514  	}
   515  	if hook.called && hook.srcfd != int(src.Fd()) {
   516  		t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
   517  	}
   518  
   519  	// Check that the offsets after the transfer make sense, that the size
   520  	// of the transfer was reported correctly, and that the destination
   521  	// file contains exactly the bytes we expect it to contain.
   522  	dstoff, err := dst.Seek(0, io.SeekCurrent)
   523  	if err != nil {
   524  		t.Fatal(err)
   525  	}
   526  	srcoff, err := src.Seek(0, io.SeekCurrent)
   527  	if err != nil {
   528  		t.Fatal(err)
   529  	}
   530  	if dstoff != srcoff {
   531  		t.Errorf("offsets differ: dstoff = %d, srcoff = %d", dstoff, srcoff)
   532  	}
   533  	if dstoff != int64(len(data)) {
   534  		t.Errorf("dstoff = %d, want %d", dstoff, len(data))
   535  	}
   536  	if n != int64(len(data)) {
   537  		t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
   538  	}
   539  	mustSeekStart(t, dst)
   540  	mustContainData(t, dst, data)
   541  
   542  	// If we had a limit, check that it was updated.
   543  	if lr != nil {
   544  		if want := limit - n; lr.N != want {
   545  			t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
   546  		}
   547  	}
   548  }
   549  
   550  // newCopyFileRangeTest initializes a new test for copy_file_range.
   551  //
   552  // It creates source and destination files, and populates the source file
   553  // with random data of the specified size. It also hooks package os' call
   554  // to poll.CopyFileRange and returns the hook so it can be inspected.
   555  func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileRangeHook) {
   556  	t.Helper()
   557  
   558  	hook = hookCopyFileRange(t)
   559  	tmp := t.TempDir()
   560  
   561  	src, err := Create(filepath.Join(tmp, "src"))
   562  	if err != nil {
   563  		t.Fatal(err)
   564  	}
   565  	t.Cleanup(func() { src.Close() })
   566  
   567  	dst, err = Create(filepath.Join(tmp, "dst"))
   568  	if err != nil {
   569  		t.Fatal(err)
   570  	}
   571  	t.Cleanup(func() { dst.Close() })
   572  
   573  	// Populate the source file with data, then rewind it, so it can be
   574  	// consumed by copy_file_range(2).
   575  	prng := rand.New(rand.NewSource(time.Now().Unix()))
   576  	data = make([]byte, size)
   577  	prng.Read(data)
   578  	if _, err := src.Write(data); err != nil {
   579  		t.Fatal(err)
   580  	}
   581  	if _, err := src.Seek(0, io.SeekStart); err != nil {
   582  		t.Fatal(err)
   583  	}
   584  
   585  	return dst, src, data, hook
   586  }
   587  
   588  // newSpliceFileTest initializes a new test for splice.
   589  //
   590  // It creates source sockets and destination file, and populates the source sockets
   591  // with random data of the specified size. It also hooks package os' call
   592  // to poll.Splice and returns the hook so it can be inspected.
   593  func newSpliceFileTest(t *testing.T, proto string, size int64) (*File, net.Conn, []byte, *spliceFileHook, func()) {
   594  	t.Helper()
   595  
   596  	hook := hookSpliceFile(t)
   597  
   598  	client, server := createSocketPair(t, proto)
   599  
   600  	dst, err := CreateTemp(t.TempDir(), "dst-splice-file-test")
   601  	if err != nil {
   602  		t.Fatal(err)
   603  	}
   604  	t.Cleanup(func() { dst.Close() })
   605  
   606  	randSeed := time.Now().Unix()
   607  	t.Logf("random data seed: %d\n", randSeed)
   608  	prng := rand.New(rand.NewSource(randSeed))
   609  	data := make([]byte, size)
   610  	prng.Read(data)
   611  
   612  	done := make(chan struct{})
   613  	go func() {
   614  		client.Write(data)
   615  		client.Close()
   616  		close(done)
   617  	}()
   618  
   619  	return dst, server, data, hook, func() { <-done }
   620  }
   621  
   622  // mustContainData ensures that the specified file contains exactly the
   623  // specified data.
   624  func mustContainData(t *testing.T, f *File, data []byte) {
   625  	t.Helper()
   626  
   627  	got := make([]byte, len(data))
   628  	if _, err := io.ReadFull(f, got); err != nil {
   629  		t.Fatal(err)
   630  	}
   631  	if !bytes.Equal(got, data) {
   632  		t.Fatalf("didn't get the same data back from %s", f.Name())
   633  	}
   634  	if _, err := f.Read(make([]byte, 1)); err != io.EOF {
   635  		t.Fatalf("not at EOF")
   636  	}
   637  }
   638  
   639  func mustSeekStart(t *testing.T, f *File) {
   640  	if _, err := f.Seek(0, io.SeekStart); err != nil {
   641  		t.Fatal(err)
   642  	}
   643  }
   644  
   645  func hookCopyFileRange(t *testing.T) *copyFileRangeHook {
   646  	h := new(copyFileRangeHook)
   647  	h.install()
   648  	t.Cleanup(h.uninstall)
   649  	return h
   650  }
   651  
   652  type copyFileRangeHook struct {
   653  	called bool
   654  	dstfd  int
   655  	srcfd  int
   656  	remain int64
   657  
   658  	written int64
   659  	handled bool
   660  	err     error
   661  
   662  	original func(dst, src *poll.FD, remain int64) (int64, bool, error)
   663  }
   664  
   665  func (h *copyFileRangeHook) install() {
   666  	h.original = *PollCopyFileRangeP
   667  	*PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
   668  		h.called = true
   669  		h.dstfd = dst.Sysfd
   670  		h.srcfd = src.Sysfd
   671  		h.remain = remain
   672  		h.written, h.handled, h.err = h.original(dst, src, remain)
   673  		return h.written, h.handled, h.err
   674  	}
   675  }
   676  
   677  func (h *copyFileRangeHook) uninstall() {
   678  	*PollCopyFileRangeP = h.original
   679  }
   680  
   681  func hookSpliceFile(t *testing.T) *spliceFileHook {
   682  	h := new(spliceFileHook)
   683  	h.install()
   684  	t.Cleanup(h.uninstall)
   685  	return h
   686  }
   687  
   688  type spliceFileHook struct {
   689  	called bool
   690  	dstfd  int
   691  	srcfd  int
   692  	remain int64
   693  
   694  	written int64
   695  	handled bool
   696  	sc      string
   697  	err     error
   698  
   699  	original func(dst, src *poll.FD, remain int64) (int64, bool, string, error)
   700  }
   701  
   702  func (h *spliceFileHook) install() {
   703  	h.original = *PollSpliceFile
   704  	*PollSpliceFile = func(dst, src *poll.FD, remain int64) (int64, bool, string, error) {
   705  		h.called = true
   706  		h.dstfd = dst.Sysfd
   707  		h.srcfd = src.Sysfd
   708  		h.remain = remain
   709  		h.written, h.handled, h.sc, h.err = h.original(dst, src, remain)
   710  		return h.written, h.handled, h.sc, h.err
   711  	}
   712  }
   713  
   714  func (h *spliceFileHook) uninstall() {
   715  	*PollSpliceFile = h.original
   716  }
   717  
   718  // On some kernels copy_file_range fails on files in /proc.
   719  func TestProcCopy(t *testing.T) {
   720  	t.Parallel()
   721  
   722  	const cmdlineFile = "/proc/self/cmdline"
   723  	cmdline, err := ReadFile(cmdlineFile)
   724  	if err != nil {
   725  		t.Skipf("can't read /proc file: %v", err)
   726  	}
   727  	in, err := Open(cmdlineFile)
   728  	if err != nil {
   729  		t.Fatal(err)
   730  	}
   731  	defer in.Close()
   732  	outFile := filepath.Join(t.TempDir(), "cmdline")
   733  	out, err := Create(outFile)
   734  	if err != nil {
   735  		t.Fatal(err)
   736  	}
   737  	if _, err := io.Copy(out, in); err != nil {
   738  		t.Fatal(err)
   739  	}
   740  	if err := out.Close(); err != nil {
   741  		t.Fatal(err)
   742  	}
   743  	copy, err := ReadFile(outFile)
   744  	if err != nil {
   745  		t.Fatal(err)
   746  	}
   747  	if !bytes.Equal(cmdline, copy) {
   748  		t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline)
   749  	}
   750  }
   751  
   752  func TestGetPollFDAndNetwork(t *testing.T) {
   753  	t.Run("tcp4", func(t *testing.T) { testGetPollFDAndNetwork(t, "tcp4") })
   754  	t.Run("unix", func(t *testing.T) { testGetPollFDAndNetwork(t, "unix") })
   755  }
   756  
   757  func testGetPollFDAndNetwork(t *testing.T, proto string) {
   758  	_, server := createSocketPair(t, proto)
   759  	sc, ok := server.(syscall.Conn)
   760  	if !ok {
   761  		t.Fatalf("server Conn is not a syscall.Conn")
   762  	}
   763  	rc, err := sc.SyscallConn()
   764  	if err != nil {
   765  		t.Fatalf("server SyscallConn error: %v", err)
   766  	}
   767  	if err = rc.Control(func(fd uintptr) {
   768  		pfd, network := GetPollFDAndNetwork(server)
   769  		if pfd == nil {
   770  			t.Fatalf("GetPollFDAndNetwork didn't return poll.FD")
   771  		}
   772  		if string(network) != proto {
   773  			t.Fatalf("GetPollFDAndNetwork returned wrong network, got: %s, want: %s", network, proto)
   774  		}
   775  		if pfd.Sysfd != int(fd) {
   776  			t.Fatalf("GetPollFDAndNetwork returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
   777  		}
   778  		if !pfd.IsStream {
   779  			t.Fatalf("expected IsStream to be true")
   780  		}
   781  		if err = pfd.Init(proto, true); err == nil {
   782  			t.Fatalf("Init should have failed with the initialized poll.FD and return EEXIST error")
   783  		}
   784  	}); err != nil {
   785  		t.Fatalf("server Control error: %v", err)
   786  	}
   787  }
   788  
   789  func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
   790  	t.Helper()
   791  	if !nettest.TestableNetwork(proto) {
   792  		t.Skipf("%s does not support %q", runtime.GOOS, proto)
   793  	}
   794  
   795  	ln, err := nettest.NewLocalListener(proto)
   796  	if err != nil {
   797  		t.Fatalf("NewLocalListener error: %v", err)
   798  	}
   799  	t.Cleanup(func() {
   800  		if ln != nil {
   801  			ln.Close()
   802  		}
   803  		if client != nil {
   804  			client.Close()
   805  		}
   806  		if server != nil {
   807  			server.Close()
   808  		}
   809  	})
   810  	ch := make(chan struct{})
   811  	go func() {
   812  		var err error
   813  		server, err = ln.Accept()
   814  		if err != nil {
   815  			t.Errorf("Accept new connection error: %v", err)
   816  		}
   817  		ch <- struct{}{}
   818  	}()
   819  	client, err = net.Dial(proto, ln.Addr().String())
   820  	<-ch
   821  	if err != nil {
   822  		t.Fatalf("Dial new connection error: %v", err)
   823  	}
   824  	return client, server
   825  }
   826  

View as plain text