// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows || zos

package socket_test

import (
	"bytes"
	"fmt"
	"io/ioutil"
	"net"
	"os"
	"os/exec"
	"path/filepath"
	"runtime"
	"strings"
	"syscall"
	"testing"

	"golang.org/x/net/internal/socket"
	"golang.org/x/net/nettest"
)

func TestSocket(t *testing.T) {
	t.Run("Option", func(t *testing.T) {
		testSocketOption(t, &socket.Option{Level: syscall.SOL_SOCKET, Name: syscall.SO_RCVBUF, Len: 4})
	})
}

func testSocketOption(t *testing.T, so *socket.Option) {
	c, err := nettest.NewLocalPacketListener("udp")
	if err != nil {
		t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
	}
	defer c.Close()
	cc, err := socket.NewConn(c.(net.Conn))
	if err != nil {
		t.Fatal(err)
	}
	const N = 2048
	if err := so.SetInt(cc, N); err != nil {
		t.Fatal(err)
	}
	n, err := so.GetInt(cc)
	if err != nil {
		t.Fatal(err)
	}
	if n < N {
		t.Fatalf("got %d; want greater than or equal to %d", n, N)
	}
}

type mockControl struct {
	Level int
	Type  int
	Data  []byte
}

func TestControlMessage(t *testing.T) {
	switch runtime.GOOS {
	case "windows":
		t.Skipf("not supported on %s", runtime.GOOS)
	}

	for _, tt := range []struct {
		cs []mockControl
	}{
		{
			[]mockControl{
				{Level: 1, Type: 1},
			},
		},
		{
			[]mockControl{
				{Level: 2, Type: 2, Data: []byte{0xfe}},
			},
		},
		{
			[]mockControl{
				{Level: 3, Type: 3, Data: []byte{0xfe, 0xff, 0xff, 0xfe}},
			},
		},
		{
			[]mockControl{
				{Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
			},
		},
		{
			[]mockControl{
				{Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
				{Level: 2, Type: 2, Data: []byte{0xfe}},
			},
		},
	} {
		var w []byte
		var tailPadLen int
		mm := socket.NewControlMessage([]int{0})
		for i, c := range tt.cs {
			m := socket.NewControlMessage([]int{len(c.Data)})
			l := len(m) - len(mm)
			if i == len(tt.cs)-1 && l > len(c.Data) {
				tailPadLen = l - len(c.Data)
			}
			w = append(w, m...)
		}

		var err error
		ww := make([]byte, len(w))
		copy(ww, w)
		m := socket.ControlMessage(ww)
		for _, c := range tt.cs {
			if err = m.MarshalHeader(c.Level, c.Type, len(c.Data)); err != nil {
				t.Fatalf("(%v).MarshalHeader() = %v", tt.cs, err)
			}
			copy(m.Data(len(c.Data)), c.Data)
			m = m.Next(len(c.Data))
		}
		m = socket.ControlMessage(w)
		for _, c := range tt.cs {
			m, err = m.Marshal(c.Level, c.Type, c.Data)
			if err != nil {
				t.Fatalf("(%v).Marshal() = %v", tt.cs, err)
			}
		}
		if !bytes.Equal(ww, w) {
			t.Fatalf("got %#v; want %#v", ww, w)
		}

		ws := [][]byte{w}
		if tailPadLen > 0 {
			// Test a message with no tail padding.
			nopad := w[:len(w)-tailPadLen]
			ws = append(ws, [][]byte{nopad}...)
		}
		for _, w := range ws {
			ms, err := socket.ControlMessage(w).Parse()
			if err != nil {
				t.Fatalf("(%v).Parse() = %v", tt.cs, err)
			}
			for i, m := range ms {
				lvl, typ, dataLen, err := m.ParseHeader()
				if err != nil {
					t.Fatalf("(%v).ParseHeader() = %v", tt.cs, err)
				}
				if lvl != tt.cs[i].Level || typ != tt.cs[i].Type || dataLen != len(tt.cs[i].Data) {
					t.Fatalf("%v: got %d, %d, %d; want %d, %d, %d", tt.cs[i], lvl, typ, dataLen, tt.cs[i].Level, tt.cs[i].Type, len(tt.cs[i].Data))
				}
			}
		}
	}
}

func TestUDP(t *testing.T) {
	switch runtime.GOOS {
	case "windows":
		t.Skipf("not supported on %s", runtime.GOOS)
	}

	c, err := nettest.NewLocalPacketListener("udp")
	if err != nil {
		t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
	}
	defer c.Close()
	// test that wrapped connections work with NewConn too
	type wrappedConn struct{ *net.UDPConn }
	cc, err := socket.NewConn(&wrappedConn{c.(*net.UDPConn)})
	if err != nil {
		t.Fatal(err)
	}

	// create a dialed connection talking (only) to c/cc
	cDialed, err := net.Dial("udp", c.LocalAddr().String())
	if err != nil {
		t.Fatal(err)
	}
	ccDialed, err := socket.NewConn(cDialed)
	if err != nil {
		t.Fatal(err)
	}

	const data = "HELLO-R-U-THERE"
	messageTests := []struct {
		name string
		conn *socket.Conn
		dest net.Addr
	}{
		{
			name: "Message",
			conn: cc,
			dest: c.LocalAddr(),
		},
		{
			name: "Message-dialed",
			conn: ccDialed,
			dest: nil,
		},
	}
	for _, tt := range messageTests {
		t.Run(tt.name, func(t *testing.T) {
			wm := socket.Message{
				Buffers: bytes.SplitAfter([]byte(data), []byte("-")),
				Addr:    tt.dest,
			}
			if err := tt.conn.SendMsg(&wm, 0); err != nil {
				t.Fatal(err)
			}
			b := make([]byte, 32)
			rm := socket.Message{
				Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]},
			}
			if err := cc.RecvMsg(&rm, 0); err != nil {
				t.Fatal(err)
			}
			received := string(b[:rm.N])
			if received != data {
				t.Fatalf("Roundtrip SendMsg/RecvMsg got %q; want %q", received, data)
			}
		})
	}

	switch runtime.GOOS {
	case "android", "linux":
		messagesTests := []struct {
			name string
			conn *socket.Conn
			dest net.Addr
		}{
			{
				name: "Messages",
				conn: cc,
				dest: c.LocalAddr(),
			},
			{
				name: "Messages-dialed",
				conn: ccDialed,
				dest: nil,
			},
		}
		for _, tt := range messagesTests {
			t.Run(tt.name, func(t *testing.T) {
				wmbs := bytes.SplitAfter([]byte(data), []byte("-"))
				wms := []socket.Message{
					{Buffers: wmbs[:1], Addr: tt.dest},
					{Buffers: wmbs[1:], Addr: tt.dest},
				}
				n, err := tt.conn.SendMsgs(wms, 0)
				if err != nil {
					t.Fatal(err)
				}
				if n != len(wms) {
					t.Fatalf("SendMsgs(%#v) != %d; want %d", wms, n, len(wms))
				}
				rmbs := [][]byte{make([]byte, 32), make([]byte, 32)}
				rms := []socket.Message{
					{Buffers: [][]byte{rmbs[0]}},
					{Buffers: [][]byte{rmbs[1][:1], rmbs[1][1:3], rmbs[1][3:7], rmbs[1][7:11], rmbs[1][11:]}},
				}
				nrecv := 0
				for nrecv < len(rms) {
					n, err := cc.RecvMsgs(rms[nrecv:], 0)
					if err != nil {
						t.Fatal(err)
					}
					nrecv += n
				}
				received0, received1 := string(rmbs[0][:rms[0].N]), string(rmbs[1][:rms[1].N])
				assembled := received0 + received1
				assembledReordered := received1 + received0
				if assembled != data && assembledReordered != data {
					t.Fatalf("Roundtrip SendMsgs/RecvMsgs got %q / %q; want %q", assembled, assembledReordered, data)
				}
			})
		}
		t.Run("Messages-undialed-no-dst", func(t *testing.T) {
			// sending without destination address should fail.
			// This checks that the internally recycled buffers are reset correctly.
			data := []byte("HELLO-R-U-THERE")
			wmbs := bytes.SplitAfter(data, []byte("-"))
			wms := []socket.Message{
				{Buffers: wmbs[:1], Addr: nil},
				{Buffers: wmbs[1:], Addr: nil},
			}
			n, err := cc.SendMsgs(wms, 0)
			if n != 0 && err == nil {
				t.Fatal("expected error, destination address required")
			}
		})
	}

	// The behavior of transmission for zero byte paylaod depends
	// on each platform implementation. Some may transmit only
	// protocol header and options, other may transmit nothing.
	// We test only that SendMsg and SendMsgs will not crash with
	// empty buffers.
	wm := socket.Message{
		Buffers: [][]byte{{}},
		Addr:    c.LocalAddr(),
	}
	cc.SendMsg(&wm, 0)
	wms := []socket.Message{
		{Buffers: [][]byte{{}}, Addr: c.LocalAddr()},
	}
	cc.SendMsgs(wms, 0)
}

func BenchmarkUDP(b *testing.B) {
	c, err := nettest.NewLocalPacketListener("udp")
	if err != nil {
		b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
	}
	defer c.Close()
	cc, err := socket.NewConn(c.(net.Conn))
	if err != nil {
		b.Fatal(err)
	}
	data := []byte("HELLO-R-U-THERE")
	wm := socket.Message{
		Buffers: [][]byte{data},
		Addr:    c.LocalAddr(),
	}
	rm := socket.Message{
		Buffers: [][]byte{make([]byte, 128)},
		OOB:     make([]byte, 128),
	}

	for M := 1; M <= 1<<9; M = M << 1 {
		b.Run(fmt.Sprintf("Iter-%d", M), func(b *testing.B) {
			for i := 0; i < b.N; i++ {
				for j := 0; j < M; j++ {
					if err := cc.SendMsg(&wm, 0); err != nil {
						b.Fatal(err)
					}
					if err := cc.RecvMsg(&rm, 0); err != nil {
						b.Fatal(err)
					}
				}
			}
		})
		switch runtime.GOOS {
		case "android", "linux":
			wms := make([]socket.Message, M)
			for i := range wms {
				wms[i].Buffers = [][]byte{data}
				wms[i].Addr = c.LocalAddr()
			}
			rms := make([]socket.Message, M)
			for i := range rms {
				rms[i].Buffers = [][]byte{make([]byte, 128)}
				rms[i].OOB = make([]byte, 128)
			}
			b.Run(fmt.Sprintf("Batch-%d", M), func(b *testing.B) {
				for i := 0; i < b.N; i++ {
					if _, err := cc.SendMsgs(wms, 0); err != nil {
						b.Fatal(err)
					}
					if _, err := cc.RecvMsgs(rms, 0); err != nil {
						b.Fatal(err)
					}
				}
			})
		}
	}
}

func TestRace(t *testing.T) {
	tests := []string{
		`
package main
import (
	"log"
	"net"

	"golang.org/x/net/ipv4"
)

var g byte

func main() {
	c, err := net.ListenPacket("udp", "127.0.0.1:0")
	if err != nil {
		log.Fatalf("ListenPacket: %v", err)
	}
	cc := ipv4.NewPacketConn(c)
	sync := make(chan bool)
	src := make([]byte, 100)
	dst := make([]byte, 100)
	go func() {
		if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil {
			log.Fatalf("WriteTo: %v", err)
		}
	}()
	go func() {
		if _, _, _, err := cc.ReadFrom(dst); err != nil {
			log.Fatalf("ReadFrom: %v", err)
		}
		sync <- true
	}()
	g = dst[0]
	<-sync
}
`,
		`
package main
import (
	"log"
	"net"

	"golang.org/x/net/ipv4"
)

func main() {
	c, err := net.ListenPacket("udp", "127.0.0.1:0")
	if err != nil {
		log.Fatalf("ListenPacket: %v", err)
	}
	cc := ipv4.NewPacketConn(c)
	sync := make(chan bool)
	src := make([]byte, 100)
	dst := make([]byte, 100)
	go func() {
		if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil {
			log.Fatalf("WriteTo: %v", err)
		}
		sync <- true
	}()
	src[0] = 0
	go func() {
		if _, _, _, err := cc.ReadFrom(dst); err != nil {
			log.Fatalf("ReadFrom: %v", err)
		}
	}()
	<-sync
}
`,
	}
	platforms := map[string]bool{
		"linux/amd64":   true,
		"linux/ppc64le": true,
		"linux/arm64":   true,
	}
	if !platforms[runtime.GOOS+"/"+runtime.GOARCH] {
		t.Skip("skipping test on non-race-enabled host.")
	}
	if runtime.Compiler == "gccgo" {
		t.Skip("skipping race test when built with gccgo")
	}
	dir, err := ioutil.TempDir("", "testrace")
	if err != nil {
		t.Fatalf("failed to create temp directory: %v", err)
	}
	defer os.RemoveAll(dir)
	goBinary := filepath.Join(runtime.GOROOT(), "bin", "go")
	t.Logf("%s version", goBinary)
	got, err := exec.Command(goBinary, "version").CombinedOutput()
	if len(got) > 0 {
		t.Logf("%s", got)
	}
	if err != nil {
		t.Fatalf("go version failed: %v", err)
	}
	for i, test := range tests {
		t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
			src := filepath.Join(dir, fmt.Sprintf("test%d.go", i))
			if err := ioutil.WriteFile(src, []byte(test), 0644); err != nil {
				t.Fatalf("failed to write file: %v", err)
			}
			t.Logf("%s run -race %s", goBinary, src)
			got, err := exec.Command(goBinary, "run", "-race", src).CombinedOutput()
			if len(got) > 0 {
				t.Logf("%s", got)
			}
			if strings.Contains(string(got), "-race requires cgo") {
				t.Log("CGO is not enabled so can't use -race")
			} else if !strings.Contains(string(got), "WARNING: DATA RACE") {
				t.Errorf("race not detected for test %d: err:%v", i, err)
			}
		})
	}
}