1
2
3
4
5 package netutil
6
7 import (
8 "context"
9 "errors"
10 "io"
11 "net"
12 "sync"
13 "sync/atomic"
14 "testing"
15 "time"
16 )
17
18 func TestLimitListenerOverload(t *testing.T) {
19 const (
20 max = 5
21 attempts = max * 2
22 msg = "bye\n"
23 )
24
25 l, err := net.Listen("tcp", "127.0.0.1:0")
26 if err != nil {
27 t.Fatal(err)
28 }
29 l = LimitListener(l, max)
30
31 var wg sync.WaitGroup
32 wg.Add(1)
33 saturated := make(chan struct{})
34 go func() {
35 defer wg.Done()
36
37 accepted := 0
38 for {
39 c, err := l.Accept()
40 if err != nil {
41 break
42 }
43 accepted++
44 if accepted == max {
45 close(saturated)
46 }
47 io.WriteString(c, msg)
48
49
50 defer c.Close()
51 }
52 t.Logf("with limit %d, accepted %d simultaneous connections", max, accepted)
53
54
55
56
57 if accepted != max {
58 t.Errorf("want exactly %d", max)
59 }
60 }()
61
62 dialCtx, cancelDial := context.WithCancel(context.Background())
63 defer cancelDial()
64 dialer := &net.Dialer{}
65
66 var dialed, served int32
67 var pendingDials sync.WaitGroup
68 for n := attempts; n > 0; n-- {
69 wg.Add(1)
70 pendingDials.Add(1)
71 go func() {
72 defer wg.Done()
73
74 c, err := dialer.DialContext(dialCtx, l.Addr().Network(), l.Addr().String())
75 pendingDials.Done()
76 if err != nil {
77 t.Log(err)
78 return
79 }
80 atomic.AddInt32(&dialed, 1)
81 defer c.Close()
82
83
84
85
86
87
88 if b, err := io.ReadAll(c); len(b) < len(msg) {
89 t.Log(err)
90 return
91 }
92 atomic.AddInt32(&served, 1)
93 }()
94 }
95
96
97
98
99 <-saturated
100 time.Sleep(10 * time.Millisecond)
101 cancelDial()
102
103
104 pendingDials.Wait()
105 l.Close()
106 wg.Wait()
107
108 t.Logf("served %d simultaneous connections (of %d dialed, %d attempted)", served, dialed, attempts)
109
110
111
112
113
114 if served > max {
115 t.Errorf("expected at most %d served", max)
116 }
117 }
118
119 func TestLimitListenerSaturation(t *testing.T) {
120 const (
121 max = 5
122 attemptsPerWave = max * 2
123 waves = 10
124 msg = "bye\n"
125 )
126
127 l, err := net.Listen("tcp", "127.0.0.1:0")
128 if err != nil {
129 t.Fatal(err)
130 }
131 l = LimitListener(l, max)
132
133 acceptDone := make(chan struct{})
134 defer func() {
135 l.Close()
136 <-acceptDone
137 }()
138 go func() {
139 defer close(acceptDone)
140
141 var open, peakOpen int32
142 var (
143 saturated = make(chan struct{})
144 saturatedOnce sync.Once
145 )
146 var wg sync.WaitGroup
147 for {
148 c, err := l.Accept()
149 if err != nil {
150 break
151 }
152 if n := atomic.AddInt32(&open, 1); n > peakOpen {
153 peakOpen = n
154 if n == max {
155 saturatedOnce.Do(func() {
156
157
158
159 time.AfterFunc(10*time.Millisecond, func() { close(saturated) })
160 })
161 }
162 }
163 wg.Add(1)
164 go func() {
165 <-saturated
166 io.WriteString(c, msg)
167 atomic.AddInt32(&open, -1)
168 c.Close()
169 wg.Done()
170 }()
171 }
172 wg.Wait()
173
174 t.Logf("with limit %d, accepted a peak of %d simultaneous connections", max, peakOpen)
175 if peakOpen > max {
176 t.Errorf("want at most %d", max)
177 }
178 }()
179
180 for wave := 0; wave < waves; wave++ {
181 var dialed, served int32
182 var wg sync.WaitGroup
183 for n := attemptsPerWave; n > 0; n-- {
184 wg.Add(1)
185 go func() {
186 defer wg.Done()
187
188 c, err := net.Dial(l.Addr().Network(), l.Addr().String())
189 if err != nil {
190 t.Log(err)
191 return
192 }
193 atomic.AddInt32(&dialed, 1)
194 defer c.Close()
195
196 if b, err := io.ReadAll(c); len(b) < len(msg) {
197 t.Log(err)
198 return
199 }
200 atomic.AddInt32(&served, 1)
201 }()
202 }
203 wg.Wait()
204
205 t.Logf("served %d connections (of %d dialed, %d attempted)", served, dialed, attemptsPerWave)
206
207
208
209
210
211
212 if dialed < max {
213 t.Errorf("expected at least %d dialed", max)
214 }
215 if served < dialed {
216 t.Errorf("expected all dialed connections to be served")
217 }
218 }
219 }
220
221 type errorListener struct {
222 net.Listener
223 }
224
225 func (errorListener) Accept() (net.Conn, error) {
226 return nil, errFake
227 }
228
229 var errFake = errors.New("fake error from errorListener")
230
231
232 func TestLimitListenerError(t *testing.T) {
233 const n = 2
234 ll := LimitListener(errorListener{}, n)
235 for i := 0; i < n+1; i++ {
236 _, err := ll.Accept()
237 if err != errFake {
238 t.Fatalf("Accept error = %v; want errFake", err)
239 }
240 }
241 }
242
243 func TestLimitListenerClose(t *testing.T) {
244 ln, err := net.Listen("tcp", "127.0.0.1:0")
245 if err != nil {
246 t.Fatal(err)
247 }
248 defer ln.Close()
249 ln = LimitListener(ln, 1)
250
251 errCh := make(chan error)
252 go func() {
253 defer close(errCh)
254 c, err := net.Dial(ln.Addr().Network(), ln.Addr().String())
255 if err != nil {
256 errCh <- err
257 return
258 }
259 c.Close()
260 }()
261
262 c, err := ln.Accept()
263 if err != nil {
264 t.Fatal(err)
265 }
266 defer c.Close()
267
268 err = <-errCh
269 if err != nil {
270 t.Fatalf("Dial: %v", err)
271 }
272
273
274
275 timer := time.AfterFunc(10*time.Millisecond, func() { ln.Close() })
276
277 c, err = ln.Accept()
278 if err == nil {
279 c.Close()
280 t.Errorf("Unexpected successful Accept()")
281 }
282 if timer.Stop() {
283 t.Errorf("Accept returned before listener closed: %v", err)
284 }
285 }
286
View as plain text