1
2
3
4
5
6
7
8
9
10 package main
11
12 import (
13 "bytes"
14 "context"
15 "crypto/tls"
16 "errors"
17 "flag"
18 "fmt"
19 "io"
20 "log"
21 "log/slog"
22 "net"
23 "net/url"
24 "os"
25 "path/filepath"
26 "sync"
27
28 "golang.org/x/net/internal/quic"
29 "golang.org/x/net/internal/quic/qlog"
30 )
31
32 var (
33 listen = flag.String("listen", "", "listen address")
34 cert = flag.String("cert", "", "certificate")
35 pkey = flag.String("key", "", "private key")
36 root = flag.String("root", "", "serve files from this root")
37 output = flag.String("output", "", "directory to write files to")
38 qlogdir = flag.String("qlog", "", "directory to write qlog output to")
39 )
40
41 func main() {
42 ctx := context.Background()
43 flag.Parse()
44 urls := flag.Args()
45
46 config := &quic.Config{
47 TLSConfig: &tls.Config{
48 InsecureSkipVerify: true,
49 MinVersion: tls.VersionTLS13,
50 NextProtos: []string{"hq-interop"},
51 },
52 MaxBidiRemoteStreams: -1,
53 MaxUniRemoteStreams: -1,
54 QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
55 Level: quic.QLogLevelFrame,
56 Dir: *qlogdir,
57 })),
58 }
59 if *cert != "" {
60 c, err := tls.LoadX509KeyPair(*cert, *pkey)
61 if err != nil {
62 log.Fatal(err)
63 }
64 config.TLSConfig.Certificates = []tls.Certificate{c}
65 }
66 if *root != "" {
67 config.MaxBidiRemoteStreams = 100
68 }
69 if keylog := os.Getenv("SSLKEYLOGFILE"); keylog != "" {
70 f, err := os.Create(keylog)
71 if err != nil {
72 log.Fatal(err)
73 }
74 defer f.Close()
75 config.TLSConfig.KeyLogWriter = f
76 }
77
78 testcase := os.Getenv("TESTCASE")
79 switch testcase {
80 case "handshake", "keyupdate":
81 basicTest(ctx, config, urls)
82 return
83 case "chacha20":
84
85
86
87
88 case "transfer":
89
90
91
92
93 config.MaxStreamReadBufferSize = 64 << 10
94 config.MaxConnReadBufferSize = 64 << 10
95 basicTest(ctx, config, urls)
96 return
97 case "http3":
98
99 case "multiconnect":
100
101 case "resumption":
102
103 case "retry":
104
105 case "versionnegotiation":
106
107
108
109
110
111 if *listen != "" && len(urls) == 0 {
112 basicTest(ctx, config, urls)
113 return
114 }
115 case "v2":
116
117 case "zerortt":
118
119 }
120 fmt.Printf("unsupported test case %q\n", testcase)
121 os.Exit(127)
122 }
123
124
125
126
127
128
129 func basicTest(ctx context.Context, config *quic.Config, urls []string) {
130 l, err := quic.Listen("udp", *listen, config)
131 if err != nil {
132 log.Fatal(err)
133 }
134 log.Printf("listening on %v", l.LocalAddr())
135
136 byAuthority := map[string][]*url.URL{}
137 for _, s := range urls {
138 u, addr, err := parseURL(s)
139 if err != nil {
140 log.Fatal(err)
141 }
142 byAuthority[addr] = append(byAuthority[addr], u)
143 }
144 var g sync.WaitGroup
145 defer g.Wait()
146 for addr, u := range byAuthority {
147 addr, u := addr, u
148 g.Add(1)
149 go func() {
150 defer g.Done()
151 fetchFrom(ctx, l, addr, u)
152 }()
153 }
154
155 if config.MaxBidiRemoteStreams >= 0 {
156 serve(ctx, l)
157 }
158 }
159
160 func serve(ctx context.Context, l *quic.Endpoint) error {
161 for {
162 c, err := l.Accept(ctx)
163 if err != nil {
164 return err
165 }
166 go serveConn(ctx, c)
167 }
168 }
169
170 func serveConn(ctx context.Context, c *quic.Conn) {
171 for {
172 s, err := c.AcceptStream(ctx)
173 if err != nil {
174 return
175 }
176 go func() {
177 if err := serveReq(ctx, s); err != nil {
178 log.Print("serveReq:", err)
179 }
180 }()
181 }
182 }
183
184 func serveReq(ctx context.Context, s *quic.Stream) error {
185 defer s.Close()
186 req, err := io.ReadAll(s)
187 if err != nil {
188 return err
189 }
190 if !bytes.HasSuffix(req, []byte("\r\n")) {
191 return errors.New("invalid request")
192 }
193 req = bytes.TrimSuffix(req, []byte("\r\n"))
194 if !bytes.HasPrefix(req, []byte("GET /")) {
195 return errors.New("invalid request")
196 }
197 req = bytes.TrimPrefix(req, []byte("GET /"))
198 if !filepath.IsLocal(string(req)) {
199 return errors.New("invalid request")
200 }
201 f, err := os.Open(filepath.Join(*root, string(req)))
202 if err != nil {
203 return err
204 }
205 defer f.Close()
206 _, err = io.Copy(s, f)
207 return err
208 }
209
210 func parseURL(s string) (u *url.URL, authority string, err error) {
211 u, err = url.Parse(s)
212 if err != nil {
213 return nil, "", err
214 }
215 host := u.Hostname()
216 port := u.Port()
217 if port == "" {
218 port = "443"
219 }
220 authority = net.JoinHostPort(host, port)
221 return u, authority, nil
222 }
223
224 func fetchFrom(ctx context.Context, l *quic.Endpoint, addr string, urls []*url.URL) {
225 conn, err := l.Dial(ctx, "udp", addr)
226 if err != nil {
227 log.Printf("%v: %v", addr, err)
228 return
229 }
230 log.Printf("connected to %v", addr)
231 defer conn.Close()
232 var g sync.WaitGroup
233 for _, u := range urls {
234 u := u
235 g.Add(1)
236 go func() {
237 defer g.Done()
238 if err := fetchOne(ctx, conn, u); err != nil {
239 log.Printf("fetch %v: %v", u, err)
240 } else {
241 log.Printf("fetched %v", u)
242 }
243 }()
244 }
245 g.Wait()
246 }
247
248 func fetchOne(ctx context.Context, conn *quic.Conn, u *url.URL) error {
249 if len(u.Path) == 0 || u.Path[0] != '/' || !filepath.IsLocal(u.Path[1:]) {
250 return errors.New("invalid path")
251 }
252 file, err := os.Create(filepath.Join(*output, u.Path[1:]))
253 if err != nil {
254 return err
255 }
256 s, err := conn.NewStream(ctx)
257 if err != nil {
258 return err
259 }
260 defer s.Close()
261 if _, err := s.Write([]byte("GET " + u.Path + "\r\n")); err != nil {
262 return err
263 }
264 s.CloseWrite()
265 if _, err := io.Copy(file, s); err != nil {
266 return err
267 }
268 return nil
269 }
270
View as plain text