1
2
3
4
5 package websocket
6
7 import (
8 "bytes"
9 "crypto/rand"
10 "fmt"
11 "io"
12 "log"
13 "net"
14 "net/http"
15 "net/http/httptest"
16 "net/url"
17 "reflect"
18 "runtime"
19 "strings"
20 "sync"
21 "testing"
22 "time"
23 )
24
25 var serverAddr string
26 var once sync.Once
27
28 func echoServer(ws *Conn) {
29 defer ws.Close()
30 io.Copy(ws, ws)
31 }
32
33 type Count struct {
34 S string
35 N int
36 }
37
38 func countServer(ws *Conn) {
39 defer ws.Close()
40 for {
41 var count Count
42 err := JSON.Receive(ws, &count)
43 if err != nil {
44 return
45 }
46 count.N++
47 count.S = strings.Repeat(count.S, count.N)
48 err = JSON.Send(ws, count)
49 if err != nil {
50 return
51 }
52 }
53 }
54
55 type testCtrlAndDataHandler struct {
56 hybiFrameHandler
57 }
58
59 func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) {
60 h.hybiFrameHandler.conn.wio.Lock()
61 defer h.hybiFrameHandler.conn.wio.Unlock()
62 w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame)
63 if err != nil {
64 return 0, err
65 }
66 n, err := w.Write(b)
67 w.Close()
68 return n, err
69 }
70
71 func ctrlAndDataServer(ws *Conn) {
72 defer ws.Close()
73 h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
74 ws.frameHandler = h
75
76 go func() {
77 for i := 0; ; i++ {
78 var b []byte
79 if i%2 != 0 {
80 b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i))
81 }
82 if _, err := h.WritePing(b); err != nil {
83 break
84 }
85 if _, err := h.WritePong(b); err != nil {
86 break
87 }
88 time.Sleep(10 * time.Millisecond)
89 }
90 }()
91
92 b := make([]byte, 128)
93 for {
94 n, err := ws.Read(b)
95 if err != nil {
96 break
97 }
98 if _, err := ws.Write(b[:n]); err != nil {
99 break
100 }
101 }
102 }
103
104 func subProtocolHandshake(config *Config, req *http.Request) error {
105 for _, proto := range config.Protocol {
106 if proto == "chat" {
107 config.Protocol = []string{proto}
108 return nil
109 }
110 }
111 return ErrBadWebSocketProtocol
112 }
113
114 func subProtoServer(ws *Conn) {
115 for _, proto := range ws.Config().Protocol {
116 io.WriteString(ws, proto)
117 }
118 }
119
120 func startServer() {
121 http.Handle("/echo", Handler(echoServer))
122 http.Handle("/count", Handler(countServer))
123 http.Handle("/ctrldata", Handler(ctrlAndDataServer))
124 subproto := Server{
125 Handshake: subProtocolHandshake,
126 Handler: Handler(subProtoServer),
127 }
128 http.Handle("/subproto", subproto)
129 server := httptest.NewServer(nil)
130 serverAddr = server.Listener.Addr().String()
131 log.Print("Test WebSocket server listening on ", serverAddr)
132 }
133
134 func newConfig(t *testing.T, path string) *Config {
135 config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
136 return config
137 }
138
139 func TestEcho(t *testing.T) {
140 once.Do(startServer)
141
142
143 client, err := net.Dial("tcp", serverAddr)
144 if err != nil {
145 t.Fatal("dialing", err)
146 }
147 conn, err := NewClient(newConfig(t, "/echo"), client)
148 if err != nil {
149 t.Errorf("WebSocket handshake error: %v", err)
150 return
151 }
152
153 msg := []byte("hello, world\n")
154 if _, err := conn.Write(msg); err != nil {
155 t.Errorf("Write: %v", err)
156 }
157 var actual_msg = make([]byte, 512)
158 n, err := conn.Read(actual_msg)
159 if err != nil {
160 t.Errorf("Read: %v", err)
161 }
162 actual_msg = actual_msg[0:n]
163 if !bytes.Equal(msg, actual_msg) {
164 t.Errorf("Echo: expected %q got %q", msg, actual_msg)
165 }
166 conn.Close()
167 }
168
169 func TestAddr(t *testing.T) {
170 once.Do(startServer)
171
172
173 client, err := net.Dial("tcp", serverAddr)
174 if err != nil {
175 t.Fatal("dialing", err)
176 }
177 conn, err := NewClient(newConfig(t, "/echo"), client)
178 if err != nil {
179 t.Errorf("WebSocket handshake error: %v", err)
180 return
181 }
182
183 ra := conn.RemoteAddr().String()
184 if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
185 t.Errorf("Bad remote addr: %v", ra)
186 }
187 la := conn.LocalAddr().String()
188 if !strings.HasPrefix(la, "http://") {
189 t.Errorf("Bad local addr: %v", la)
190 }
191 conn.Close()
192 }
193
194 func TestCount(t *testing.T) {
195 once.Do(startServer)
196
197
198 client, err := net.Dial("tcp", serverAddr)
199 if err != nil {
200 t.Fatal("dialing", err)
201 }
202 conn, err := NewClient(newConfig(t, "/count"), client)
203 if err != nil {
204 t.Errorf("WebSocket handshake error: %v", err)
205 return
206 }
207
208 var count Count
209 count.S = "hello"
210 if err := JSON.Send(conn, count); err != nil {
211 t.Errorf("Write: %v", err)
212 }
213 if err := JSON.Receive(conn, &count); err != nil {
214 t.Errorf("Read: %v", err)
215 }
216 if count.N != 1 {
217 t.Errorf("count: expected %d got %d", 1, count.N)
218 }
219 if count.S != "hello" {
220 t.Errorf("count: expected %q got %q", "hello", count.S)
221 }
222 if err := JSON.Send(conn, count); err != nil {
223 t.Errorf("Write: %v", err)
224 }
225 if err := JSON.Receive(conn, &count); err != nil {
226 t.Errorf("Read: %v", err)
227 }
228 if count.N != 2 {
229 t.Errorf("count: expected %d got %d", 2, count.N)
230 }
231 if count.S != "hellohello" {
232 t.Errorf("count: expected %q got %q", "hellohello", count.S)
233 }
234 conn.Close()
235 }
236
237 func TestWithQuery(t *testing.T) {
238 once.Do(startServer)
239
240 client, err := net.Dial("tcp", serverAddr)
241 if err != nil {
242 t.Fatal("dialing", err)
243 }
244
245 config := newConfig(t, "/echo")
246 config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
247 if err != nil {
248 t.Fatal("location url", err)
249 }
250
251 ws, err := NewClient(config, client)
252 if err != nil {
253 t.Errorf("WebSocket handshake: %v", err)
254 return
255 }
256 ws.Close()
257 }
258
259 func testWithProtocol(t *testing.T, subproto []string) (string, error) {
260 once.Do(startServer)
261
262 client, err := net.Dial("tcp", serverAddr)
263 if err != nil {
264 t.Fatal("dialing", err)
265 }
266
267 config := newConfig(t, "/subproto")
268 config.Protocol = subproto
269
270 ws, err := NewClient(config, client)
271 if err != nil {
272 return "", err
273 }
274 msg := make([]byte, 16)
275 n, err := ws.Read(msg)
276 if err != nil {
277 return "", err
278 }
279 ws.Close()
280 return string(msg[:n]), nil
281 }
282
283 func TestWithProtocol(t *testing.T) {
284 proto, err := testWithProtocol(t, []string{"chat"})
285 if err != nil {
286 t.Errorf("SubProto: unexpected error: %v", err)
287 }
288 if proto != "chat" {
289 t.Errorf("SubProto: expected %q, got %q", "chat", proto)
290 }
291 }
292
293 func TestWithTwoProtocol(t *testing.T) {
294 proto, err := testWithProtocol(t, []string{"test", "chat"})
295 if err != nil {
296 t.Errorf("SubProto: unexpected error: %v", err)
297 }
298 if proto != "chat" {
299 t.Errorf("SubProto: expected %q, got %q", "chat", proto)
300 }
301 }
302
303 func TestWithBadProtocol(t *testing.T) {
304 _, err := testWithProtocol(t, []string{"test"})
305 if err != ErrBadStatus {
306 t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
307 }
308 }
309
310 func TestHTTP(t *testing.T) {
311 once.Do(startServer)
312
313
314
315
316 resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
317 if err != nil {
318 t.Errorf("Get: error %#v", err)
319 return
320 }
321 if resp == nil {
322 t.Error("Get: resp is null")
323 return
324 }
325 if resp.StatusCode != http.StatusBadRequest {
326 t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
327 }
328 }
329
330 func TestTrailingSpaces(t *testing.T) {
331
332
333
334 once.Do(startServer)
335 config := newConfig(t, "/echo")
336 for i := 0; i < 30; i++ {
337
338 ws, err := DialConfig(config)
339 if err != nil {
340 t.Errorf("Dial #%d failed: %v", i, err)
341 break
342 }
343 ws.Close()
344 }
345 }
346
347 func TestDialConfigBadVersion(t *testing.T) {
348 once.Do(startServer)
349 config := newConfig(t, "/echo")
350 config.Version = 1234
351
352 _, err := DialConfig(config)
353
354 if dialerr, ok := err.(*DialError); ok {
355 if dialerr.Err != ErrBadProtocolVersion {
356 t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
357 }
358 }
359 }
360
361 func TestDialConfigWithDialer(t *testing.T) {
362 once.Do(startServer)
363 config := newConfig(t, "/echo")
364 config.Dialer = &net.Dialer{
365 Deadline: time.Now().Add(-time.Minute),
366 }
367 _, err := DialConfig(config)
368 dialerr, ok := err.(*DialError)
369 if !ok {
370 t.Fatalf("DialError expected, got %#v", err)
371 }
372 neterr, ok := dialerr.Err.(*net.OpError)
373 if !ok {
374 t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
375 }
376 if !neterr.Timeout() {
377 t.Fatalf("expected timeout error, got %#v", neterr)
378 }
379 }
380
381 func TestSmallBuffer(t *testing.T) {
382
383
384 once.Do(startServer)
385
386
387 client, err := net.Dial("tcp", serverAddr)
388 if err != nil {
389 t.Fatal("dialing", err)
390 }
391 conn, err := NewClient(newConfig(t, "/echo"), client)
392 if err != nil {
393 t.Errorf("WebSocket handshake error: %v", err)
394 return
395 }
396
397 msg := []byte("hello, world\n")
398 if _, err := conn.Write(msg); err != nil {
399 t.Errorf("Write: %v", err)
400 }
401 var small_msg = make([]byte, 8)
402 n, err := conn.Read(small_msg)
403 if err != nil {
404 t.Errorf("Read: %v", err)
405 }
406 if !bytes.Equal(msg[:len(small_msg)], small_msg) {
407 t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
408 }
409 var second_msg = make([]byte, len(msg))
410 n, err = conn.Read(second_msg)
411 if err != nil {
412 t.Errorf("Read: %v", err)
413 }
414 second_msg = second_msg[0:n]
415 if !bytes.Equal(msg[len(small_msg):], second_msg) {
416 t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
417 }
418 conn.Close()
419 }
420
421 var parseAuthorityTests = []struct {
422 in *url.URL
423 out string
424 }{
425 {
426 &url.URL{
427 Scheme: "ws",
428 Host: "www.google.com",
429 },
430 "www.google.com:80",
431 },
432 {
433 &url.URL{
434 Scheme: "wss",
435 Host: "www.google.com",
436 },
437 "www.google.com:443",
438 },
439 {
440 &url.URL{
441 Scheme: "ws",
442 Host: "www.google.com:80",
443 },
444 "www.google.com:80",
445 },
446 {
447 &url.URL{
448 Scheme: "wss",
449 Host: "www.google.com:443",
450 },
451 "www.google.com:443",
452 },
453
454
455 {
456 &url.URL{
457 Scheme: "http",
458 Host: "www.google.com",
459 },
460 "www.google.com",
461 },
462 {
463 &url.URL{
464 Scheme: "http",
465 Host: "www.google.com:80",
466 },
467 "www.google.com:80",
468 },
469 {
470 &url.URL{
471 Scheme: "asdf",
472 Host: "127.0.0.1",
473 },
474 "127.0.0.1",
475 },
476 {
477 &url.URL{
478 Scheme: "asdf",
479 Host: "www.google.com",
480 },
481 "www.google.com",
482 },
483 }
484
485 func TestParseAuthority(t *testing.T) {
486 for _, tt := range parseAuthorityTests {
487 out := parseAuthority(tt.in)
488 if out != tt.out {
489 t.Errorf("got %v; want %v", out, tt.out)
490 }
491 }
492 }
493
494 type closerConn struct {
495 net.Conn
496 closed int
497 }
498
499 func (c *closerConn) Close() error {
500 c.closed++
501 return c.Conn.Close()
502 }
503
504 func TestClose(t *testing.T) {
505 if runtime.GOOS == "plan9" {
506 t.Skip("see golang.org/issue/11454")
507 }
508
509 once.Do(startServer)
510
511 conn, err := net.Dial("tcp", serverAddr)
512 if err != nil {
513 t.Fatal("dialing", err)
514 }
515
516 cc := closerConn{Conn: conn}
517
518 client, err := NewClient(newConfig(t, "/echo"), &cc)
519 if err != nil {
520 t.Fatalf("WebSocket handshake: %v", err)
521 }
522
523
524
525 conn.SetDeadline(time.Now().Add(-10 * time.Minute))
526
527 if err := client.Close(); err == nil {
528 t.Errorf("ws.Close(): expected error, got %v", err)
529 }
530 if cc.closed < 1 {
531 t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
532 }
533 }
534
535 var originTests = []struct {
536 req *http.Request
537 origin *url.URL
538 }{
539 {
540 req: &http.Request{
541 Header: http.Header{
542 "Origin": []string{"http://www.example.com"},
543 },
544 },
545 origin: &url.URL{
546 Scheme: "http",
547 Host: "www.example.com",
548 },
549 },
550 {
551 req: &http.Request{},
552 },
553 }
554
555 func TestOrigin(t *testing.T) {
556 conf := newConfig(t, "/echo")
557 conf.Version = ProtocolVersionHybi13
558 for i, tt := range originTests {
559 origin, err := Origin(conf, tt.req)
560 if err != nil {
561 t.Error(err)
562 continue
563 }
564 if !reflect.DeepEqual(origin, tt.origin) {
565 t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
566 continue
567 }
568 }
569 }
570
571 func TestCtrlAndData(t *testing.T) {
572 once.Do(startServer)
573
574 c, err := net.Dial("tcp", serverAddr)
575 if err != nil {
576 t.Fatal(err)
577 }
578 ws, err := NewClient(newConfig(t, "/ctrldata"), c)
579 if err != nil {
580 t.Fatal(err)
581 }
582 defer ws.Close()
583
584 h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
585 ws.frameHandler = h
586
587 b := make([]byte, 128)
588 for i := 0; i < 2; i++ {
589 data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i))
590 if _, err := ws.Write(data); err != nil {
591 t.Fatalf("#%d: %v", i, err)
592 }
593 var ctrl []byte
594 if i%2 != 0 {
595 ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i))
596 }
597 if _, err := h.WritePing(ctrl); err != nil {
598 t.Fatalf("#%d: %v", i, err)
599 }
600 n, err := ws.Read(b)
601 if err != nil {
602 t.Fatalf("#%d: %v", i, err)
603 }
604 if !bytes.Equal(b[:n], data) {
605 t.Fatalf("#%d: got %v; want %v", i, b[:n], data)
606 }
607 }
608 }
609
610 func TestCodec_ReceiveLimited(t *testing.T) {
611 const limit = 2048
612 var payloads [][]byte
613 for _, size := range []int{
614 1024,
615 2048,
616 4096,
617 2048,
618 } {
619 b := make([]byte, size)
620 rand.Read(b)
621 payloads = append(payloads, b)
622 }
623 handlerDone := make(chan struct{})
624 limitedHandler := func(ws *Conn) {
625 defer close(handlerDone)
626 ws.MaxPayloadBytes = limit
627 defer ws.Close()
628 for i, p := range payloads {
629 t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit)
630 var recv []byte
631 err := Message.Receive(ws, &recv)
632 switch err {
633 case nil:
634 case ErrFrameTooLarge:
635 if len(p) <= limit {
636 t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit)
637 }
638 continue
639 default:
640 t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err)
641 }
642 if len(recv) > limit {
643 t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit)
644 }
645 if !bytes.Equal(p, recv) {
646 t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p)
647 }
648 }
649 }
650 server := httptest.NewServer(Handler(limitedHandler))
651 defer server.CloseClientConnections()
652 defer server.Close()
653 addr := server.Listener.Addr().String()
654 ws, err := Dial("ws://"+addr+"/", "", "http://localhost/")
655 if err != nil {
656 t.Fatal(err)
657 }
658 defer ws.Close()
659 for i, p := range payloads {
660 if err := Message.Send(ws, p); err != nil {
661 t.Fatalf("payload #%d (size %d): %v", i, len(p), err)
662 }
663 }
664 <-handlerDone
665 }
666
View as plain text