1
2
3
4
5
6
7
8
9
10
11
12 package websocket
13
14 import (
15 "bufio"
16 "crypto/tls"
17 "encoding/json"
18 "errors"
19 "io"
20 "io/ioutil"
21 "net"
22 "net/http"
23 "net/url"
24 "sync"
25 "time"
26 )
27
28 const (
29 ProtocolVersionHybi13 = 13
30 ProtocolVersionHybi = ProtocolVersionHybi13
31 SupportedProtocolVersion = "13"
32
33 ContinuationFrame = 0
34 TextFrame = 1
35 BinaryFrame = 2
36 CloseFrame = 8
37 PingFrame = 9
38 PongFrame = 10
39 UnknownFrame = 255
40
41 DefaultMaxPayloadBytes = 32 << 20
42 )
43
44
45 type ProtocolError struct {
46 ErrorString string
47 }
48
49 func (err *ProtocolError) Error() string { return err.ErrorString }
50
51 var (
52 ErrBadProtocolVersion = &ProtocolError{"bad protocol version"}
53 ErrBadScheme = &ProtocolError{"bad scheme"}
54 ErrBadStatus = &ProtocolError{"bad status"}
55 ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"}
56 ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"}
57 ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"}
58 ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"}
59 ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"}
60 ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"}
61 ErrBadFrame = &ProtocolError{"bad frame"}
62 ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"}
63 ErrNotWebSocket = &ProtocolError{"not websocket protocol"}
64 ErrBadRequestMethod = &ProtocolError{"bad method"}
65 ErrNotSupported = &ProtocolError{"not supported"}
66 )
67
68
69
70 var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit")
71
72
73 type Addr struct {
74 *url.URL
75 }
76
77
78 func (addr *Addr) Network() string { return "websocket" }
79
80
81 type Config struct {
82
83 Location *url.URL
84
85
86 Origin *url.URL
87
88
89 Protocol []string
90
91
92 Version int
93
94
95 TlsConfig *tls.Config
96
97
98 Header http.Header
99
100
101 Dialer *net.Dialer
102
103 handshakeData map[string]string
104 }
105
106
107 type serverHandshaker interface {
108
109
110 ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error)
111
112
113
114 AcceptHandshake(buf *bufio.Writer) (err error)
115
116
117 NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn)
118 }
119
120
121 type frameReader interface {
122
123 io.Reader
124
125
126 PayloadType() byte
127
128
129 HeaderReader() io.Reader
130
131
132
133 TrailerReader() io.Reader
134
135
136 Len() int
137 }
138
139
140 type frameReaderFactory interface {
141 NewFrameReader() (r frameReader, err error)
142 }
143
144
145 type frameWriter interface {
146
147 io.WriteCloser
148 }
149
150
151 type frameWriterFactory interface {
152 NewFrameWriter(payloadType byte) (w frameWriter, err error)
153 }
154
155 type frameHandler interface {
156 HandleFrame(frame frameReader) (r frameReader, err error)
157 WriteClose(status int) (err error)
158 }
159
160
161
162
163 type Conn struct {
164 config *Config
165 request *http.Request
166
167 buf *bufio.ReadWriter
168 rwc io.ReadWriteCloser
169
170 rio sync.Mutex
171 frameReaderFactory
172 frameReader
173
174 wio sync.Mutex
175 frameWriterFactory
176
177 frameHandler
178 PayloadType byte
179 defaultCloseStatus int
180
181
182
183 MaxPayloadBytes int
184 }
185
186
187
188
189
190
191 func (ws *Conn) Read(msg []byte) (n int, err error) {
192 ws.rio.Lock()
193 defer ws.rio.Unlock()
194 again:
195 if ws.frameReader == nil {
196 frame, err := ws.frameReaderFactory.NewFrameReader()
197 if err != nil {
198 return 0, err
199 }
200 ws.frameReader, err = ws.frameHandler.HandleFrame(frame)
201 if err != nil {
202 return 0, err
203 }
204 if ws.frameReader == nil {
205 goto again
206 }
207 }
208 n, err = ws.frameReader.Read(msg)
209 if err == io.EOF {
210 if trailer := ws.frameReader.TrailerReader(); trailer != nil {
211 io.Copy(ioutil.Discard, trailer)
212 }
213 ws.frameReader = nil
214 goto again
215 }
216 return n, err
217 }
218
219
220
221 func (ws *Conn) Write(msg []byte) (n int, err error) {
222 ws.wio.Lock()
223 defer ws.wio.Unlock()
224 w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType)
225 if err != nil {
226 return 0, err
227 }
228 n, err = w.Write(msg)
229 w.Close()
230 return n, err
231 }
232
233
234 func (ws *Conn) Close() error {
235 err := ws.frameHandler.WriteClose(ws.defaultCloseStatus)
236 err1 := ws.rwc.Close()
237 if err != nil {
238 return err
239 }
240 return err1
241 }
242
243
244 func (ws *Conn) IsClientConn() bool { return ws.request == nil }
245
246
247 func (ws *Conn) IsServerConn() bool { return ws.request != nil }
248
249
250
251 func (ws *Conn) LocalAddr() net.Addr {
252 if ws.IsClientConn() {
253 return &Addr{ws.config.Origin}
254 }
255 return &Addr{ws.config.Location}
256 }
257
258
259
260 func (ws *Conn) RemoteAddr() net.Addr {
261 if ws.IsClientConn() {
262 return &Addr{ws.config.Location}
263 }
264 return &Addr{ws.config.Origin}
265 }
266
267 var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn")
268
269
270 func (ws *Conn) SetDeadline(t time.Time) error {
271 if conn, ok := ws.rwc.(net.Conn); ok {
272 return conn.SetDeadline(t)
273 }
274 return errSetDeadline
275 }
276
277
278 func (ws *Conn) SetReadDeadline(t time.Time) error {
279 if conn, ok := ws.rwc.(net.Conn); ok {
280 return conn.SetReadDeadline(t)
281 }
282 return errSetDeadline
283 }
284
285
286 func (ws *Conn) SetWriteDeadline(t time.Time) error {
287 if conn, ok := ws.rwc.(net.Conn); ok {
288 return conn.SetWriteDeadline(t)
289 }
290 return errSetDeadline
291 }
292
293
294 func (ws *Conn) Config() *Config { return ws.config }
295
296
297
298 func (ws *Conn) Request() *http.Request { return ws.request }
299
300
301 type Codec struct {
302 Marshal func(v interface{}) (data []byte, payloadType byte, err error)
303 Unmarshal func(data []byte, payloadType byte, v interface{}) (err error)
304 }
305
306
307 func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
308 data, payloadType, err := cd.Marshal(v)
309 if err != nil {
310 return err
311 }
312 ws.wio.Lock()
313 defer ws.wio.Unlock()
314 w, err := ws.frameWriterFactory.NewFrameWriter(payloadType)
315 if err != nil {
316 return err
317 }
318 _, err = w.Write(data)
319 w.Close()
320 return err
321 }
322
323
324
325
326
327
328
329 func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
330 ws.rio.Lock()
331 defer ws.rio.Unlock()
332 if ws.frameReader != nil {
333 _, err = io.Copy(ioutil.Discard, ws.frameReader)
334 if err != nil {
335 return err
336 }
337 ws.frameReader = nil
338 }
339 again:
340 frame, err := ws.frameReaderFactory.NewFrameReader()
341 if err != nil {
342 return err
343 }
344 frame, err = ws.frameHandler.HandleFrame(frame)
345 if err != nil {
346 return err
347 }
348 if frame == nil {
349 goto again
350 }
351 maxPayloadBytes := ws.MaxPayloadBytes
352 if maxPayloadBytes == 0 {
353 maxPayloadBytes = DefaultMaxPayloadBytes
354 }
355 if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) {
356
357
358
359
360
361 ws.frameReader = frame
362 return ErrFrameTooLarge
363 }
364 payloadType := frame.PayloadType()
365 data, err := ioutil.ReadAll(frame)
366 if err != nil {
367 return err
368 }
369 return cd.Unmarshal(data, payloadType, v)
370 }
371
372 func marshal(v interface{}) (msg []byte, payloadType byte, err error) {
373 switch data := v.(type) {
374 case string:
375 return []byte(data), TextFrame, nil
376 case []byte:
377 return data, BinaryFrame, nil
378 }
379 return nil, UnknownFrame, ErrNotSupported
380 }
381
382 func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
383 switch data := v.(type) {
384 case *string:
385 *data = string(msg)
386 return nil
387 case *[]byte:
388 *data = msg
389 return nil
390 }
391 return ErrNotSupported
392 }
393
394
419 var Message = Codec{marshal, unmarshal}
420
421 func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) {
422 msg, err = json.Marshal(v)
423 return msg, TextFrame, err
424 }
425
426 func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
427 return json.Unmarshal(msg, v)
428 }
429
430
449 var JSON = Codec{jsonMarshal, jsonUnmarshal}
450
View as plain text