1
2
3
4
5
6 package sockstest
7
8 import (
9 "errors"
10 "io"
11 "net"
12
13 "golang.org/x/net/internal/socks"
14 "golang.org/x/net/nettest"
15 )
16
17
18 type AuthRequest struct {
19 Version int
20 Methods []socks.AuthMethod
21 }
22
23
24 func ParseAuthRequest(b []byte) (*AuthRequest, error) {
25 if len(b) < 2 {
26 return nil, errors.New("short auth request")
27 }
28 if b[0] != socks.Version5 {
29 return nil, errors.New("unexpected protocol version")
30 }
31 if len(b)-2 < int(b[1]) {
32 return nil, errors.New("short auth request")
33 }
34 req := &AuthRequest{Version: int(b[0])}
35 if b[1] > 0 {
36 req.Methods = make([]socks.AuthMethod, b[1])
37 for i, m := range b[2 : 2+b[1]] {
38 req.Methods[i] = socks.AuthMethod(m)
39 }
40 }
41 return req, nil
42 }
43
44
45 func MarshalAuthReply(ver int, m socks.AuthMethod) ([]byte, error) {
46 return []byte{byte(ver), byte(m)}, nil
47 }
48
49
50 type CmdRequest struct {
51 Version int
52 Cmd socks.Command
53 Addr socks.Addr
54 }
55
56
57 func ParseCmdRequest(b []byte) (*CmdRequest, error) {
58 if len(b) < 7 {
59 return nil, errors.New("short cmd request")
60 }
61 if b[0] != socks.Version5 {
62 return nil, errors.New("unexpected protocol version")
63 }
64 if socks.Command(b[1]) != socks.CmdConnect {
65 return nil, errors.New("unexpected command")
66 }
67 if b[2] != 0 {
68 return nil, errors.New("non-zero reserved field")
69 }
70 req := &CmdRequest{Version: int(b[0]), Cmd: socks.Command(b[1])}
71 l := 2
72 off := 4
73 switch b[3] {
74 case socks.AddrTypeIPv4:
75 l += net.IPv4len
76 req.Addr.IP = make(net.IP, net.IPv4len)
77 case socks.AddrTypeIPv6:
78 l += net.IPv6len
79 req.Addr.IP = make(net.IP, net.IPv6len)
80 case socks.AddrTypeFQDN:
81 l += int(b[4])
82 off = 5
83 default:
84 return nil, errors.New("unknown address type")
85 }
86 if len(b[off:]) < l {
87 return nil, errors.New("short cmd request")
88 }
89 if req.Addr.IP != nil {
90 copy(req.Addr.IP, b[off:])
91 } else {
92 req.Addr.Name = string(b[off : off+l-2])
93 }
94 req.Addr.Port = int(b[off+l-2])<<8 | int(b[off+l-1])
95 return req, nil
96 }
97
98
99 func MarshalCmdReply(ver int, reply socks.Reply, a *socks.Addr) ([]byte, error) {
100 b := make([]byte, 4)
101 b[0] = byte(ver)
102 b[1] = byte(reply)
103 if a.Name != "" {
104 if len(a.Name) > 255 {
105 return nil, errors.New("fqdn too long")
106 }
107 b[3] = socks.AddrTypeFQDN
108 b = append(b, byte(len(a.Name)))
109 b = append(b, a.Name...)
110 } else if ip4 := a.IP.To4(); ip4 != nil {
111 b[3] = socks.AddrTypeIPv4
112 b = append(b, ip4...)
113 } else if ip6 := a.IP.To16(); ip6 != nil {
114 b[3] = socks.AddrTypeIPv6
115 b = append(b, ip6...)
116 } else {
117 return nil, errors.New("unknown address type")
118 }
119 b = append(b, byte(a.Port>>8), byte(a.Port))
120 return b, nil
121 }
122
123
124 type Server struct {
125 ln net.Listener
126 }
127
128
129 func (s *Server) Addr() net.Addr {
130 return s.ln.Addr()
131 }
132
133
134
135
136 func (s *Server) TargetAddr() net.Addr {
137 a := s.ln.Addr()
138 switch a := a.(type) {
139 case *net.TCPAddr:
140 if a.IP.To4() != nil {
141 return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 5963}
142 }
143 if a.IP.To16() != nil && a.IP.To4() == nil {
144 return &net.TCPAddr{IP: net.IPv6loopback, Port: 5963}
145 }
146 }
147 return nil
148 }
149
150
151 func (s *Server) Close() error {
152 return s.ln.Close()
153 }
154
155 func (s *Server) serve(authFunc, cmdFunc func(io.ReadWriter, []byte) error) {
156 c, err := s.ln.Accept()
157 if err != nil {
158 return
159 }
160 defer c.Close()
161 go s.serve(authFunc, cmdFunc)
162 b := make([]byte, 512)
163 n, err := c.Read(b)
164 if err != nil {
165 return
166 }
167 if err := authFunc(c, b[:n]); err != nil {
168 return
169 }
170 n, err = c.Read(b)
171 if err != nil {
172 return
173 }
174 if err := cmdFunc(c, b[:n]); err != nil {
175 return
176 }
177 }
178
179
180
181
182
183 func NewServer(authFunc, cmdFunc func(io.ReadWriter, []byte) error) (*Server, error) {
184 var err error
185 s := new(Server)
186 s.ln, err = nettest.NewLocalListener("tcp")
187 if err != nil {
188 return nil, err
189 }
190 go s.serve(authFunc, cmdFunc)
191 return s, nil
192 }
193
194
195 func NoAuthRequired(rw io.ReadWriter, b []byte) error {
196 req, err := ParseAuthRequest(b)
197 if err != nil {
198 return err
199 }
200 b, err = MarshalAuthReply(req.Version, socks.AuthMethodNotRequired)
201 if err != nil {
202 return err
203 }
204 n, err := rw.Write(b)
205 if err != nil {
206 return err
207 }
208 if n != len(b) {
209 return errors.New("short write")
210 }
211 return nil
212 }
213
214
215
216 func NoProxyRequired(rw io.ReadWriter, b []byte) error {
217 req, err := ParseCmdRequest(b)
218 if err != nil {
219 return err
220 }
221 req.Addr.Port += 1
222 if req.Addr.Name != "" {
223 req.Addr.Name = "boundaddr.doesnotexist"
224 } else if req.Addr.IP.To4() != nil {
225 req.Addr.IP = net.IPv4(127, 0, 0, 1)
226 } else {
227 req.Addr.IP = net.IPv6loopback
228 }
229 b, err = MarshalCmdReply(socks.Version5, socks.StatusSucceeded, &req.Addr)
230 if err != nil {
231 return err
232 }
233 n, err := rw.Write(b)
234 if err != nil {
235 return err
236 }
237 if n != len(b) {
238 return errors.New("short write")
239 }
240 return nil
241 }
242
View as plain text