1
2
3
4
5 package gin
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io"
12 "log"
13 "net"
14 "net/http"
15 "net/http/httputil"
16 "os"
17 "runtime"
18 "strings"
19 "time"
20 )
21
22 var (
23 dunno = []byte("???")
24 centerDot = []byte("·")
25 dot = []byte(".")
26 slash = []byte("/")
27 )
28
29
30 type RecoveryFunc func(c *Context, err any)
31
32
33 func Recovery() HandlerFunc {
34 return RecoveryWithWriter(DefaultErrorWriter)
35 }
36
37
38 func CustomRecovery(handle RecoveryFunc) HandlerFunc {
39 return RecoveryWithWriter(DefaultErrorWriter, handle)
40 }
41
42
43 func RecoveryWithWriter(out io.Writer, recovery ...RecoveryFunc) HandlerFunc {
44 if len(recovery) > 0 {
45 return CustomRecoveryWithWriter(out, recovery[0])
46 }
47 return CustomRecoveryWithWriter(out, defaultHandleRecovery)
48 }
49
50
51 func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc {
52 var logger *log.Logger
53 if out != nil {
54 logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags)
55 }
56 return func(c *Context) {
57 defer func() {
58 if err := recover(); err != nil {
59
60
61 var brokenPipe bool
62 if ne, ok := err.(*net.OpError); ok {
63 var se *os.SyscallError
64 if errors.As(ne, &se) {
65 seStr := strings.ToLower(se.Error())
66 if strings.Contains(seStr, "broken pipe") ||
67 strings.Contains(seStr, "connection reset by peer") {
68 brokenPipe = true
69 }
70 }
71 }
72 if logger != nil {
73 stack := stack(3)
74 httpRequest, _ := httputil.DumpRequest(c.Request, false)
75 headers := strings.Split(string(httpRequest), "\r\n")
76 for idx, header := range headers {
77 current := strings.Split(header, ":")
78 if current[0] == "Authorization" {
79 headers[idx] = current[0] + ": *"
80 }
81 }
82 headersToStr := strings.Join(headers, "\r\n")
83 if brokenPipe {
84 logger.Printf("%s\n%s%s", err, headersToStr, reset)
85 } else if IsDebugging() {
86 logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s",
87 timeFormat(time.Now()), headersToStr, err, stack, reset)
88 } else {
89 logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s",
90 timeFormat(time.Now()), err, stack, reset)
91 }
92 }
93 if brokenPipe {
94
95 c.Error(err.(error))
96 c.Abort()
97 } else {
98 handle(c, err)
99 }
100 }
101 }()
102 c.Next()
103 }
104 }
105
106 func defaultHandleRecovery(c *Context, _ any) {
107 c.AbortWithStatus(http.StatusInternalServerError)
108 }
109
110
111 func stack(skip int) []byte {
112 buf := new(bytes.Buffer)
113
114
115 var lines [][]byte
116 var lastFile string
117 for i := skip; ; i++ {
118 pc, file, line, ok := runtime.Caller(i)
119 if !ok {
120 break
121 }
122
123 fmt.Fprintf(buf, "%s:%d (0x%x)\n", file, line, pc)
124 if file != lastFile {
125 data, err := os.ReadFile(file)
126 if err != nil {
127 continue
128 }
129 lines = bytes.Split(data, []byte{'\n'})
130 lastFile = file
131 }
132 fmt.Fprintf(buf, "\t%s: %s\n", function(pc), source(lines, line))
133 }
134 return buf.Bytes()
135 }
136
137
138 func source(lines [][]byte, n int) []byte {
139 n--
140 if n < 0 || n >= len(lines) {
141 return dunno
142 }
143 return bytes.TrimSpace(lines[n])
144 }
145
146
147 func function(pc uintptr) []byte {
148 fn := runtime.FuncForPC(pc)
149 if fn == nil {
150 return dunno
151 }
152 name := []byte(fn.Name())
153
154
155
156
157
158
159
160
161 if lastSlash := bytes.LastIndex(name, slash); lastSlash >= 0 {
162 name = name[lastSlash+1:]
163 }
164 if period := bytes.Index(name, dot); period >= 0 {
165 name = name[period+1:]
166 }
167 name = bytes.ReplaceAll(name, centerDot, dot)
168 return name
169 }
170
171
172 func timeFormat(t time.Time) string {
173 return t.Format("2006/01/02 - 15:04:05")
174 }
175
View as plain text