...

Source file src/github.com/gin-gonic/gin/recovery.go

Documentation: github.com/gin-gonic/gin

     1  // Copyright 2014 Manu Martinez-Almeida. All rights reserved.
     2  // Use of this source code is governed by a MIT style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gin
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"net"
    14  	"net/http"
    15  	"net/http/httputil"
    16  	"os"
    17  	"runtime"
    18  	"strings"
    19  	"time"
    20  )
    21  
    22  var (
    23  	dunno     = []byte("???")
    24  	centerDot = []byte("·")
    25  	dot       = []byte(".")
    26  	slash     = []byte("/")
    27  )
    28  
    29  // RecoveryFunc defines the function passable to CustomRecovery.
    30  type RecoveryFunc func(c *Context, err any)
    31  
    32  // Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
    33  func Recovery() HandlerFunc {
    34  	return RecoveryWithWriter(DefaultErrorWriter)
    35  }
    36  
    37  // CustomRecovery returns a middleware that recovers from any panics and calls the provided handle func to handle it.
    38  func CustomRecovery(handle RecoveryFunc) HandlerFunc {
    39  	return RecoveryWithWriter(DefaultErrorWriter, handle)
    40  }
    41  
    42  // RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one.
    43  func RecoveryWithWriter(out io.Writer, recovery ...RecoveryFunc) HandlerFunc {
    44  	if len(recovery) > 0 {
    45  		return CustomRecoveryWithWriter(out, recovery[0])
    46  	}
    47  	return CustomRecoveryWithWriter(out, defaultHandleRecovery)
    48  }
    49  
    50  // CustomRecoveryWithWriter returns a middleware for a given writer that recovers from any panics and calls the provided handle func to handle it.
    51  func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc {
    52  	var logger *log.Logger
    53  	if out != nil {
    54  		logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags)
    55  	}
    56  	return func(c *Context) {
    57  		defer func() {
    58  			if err := recover(); err != nil {
    59  				// Check for a broken connection, as it is not really a
    60  				// condition that warrants a panic stack trace.
    61  				var brokenPipe bool
    62  				if ne, ok := err.(*net.OpError); ok {
    63  					var se *os.SyscallError
    64  					if errors.As(ne, &se) {
    65  						seStr := strings.ToLower(se.Error())
    66  						if strings.Contains(seStr, "broken pipe") ||
    67  							strings.Contains(seStr, "connection reset by peer") {
    68  							brokenPipe = true
    69  						}
    70  					}
    71  				}
    72  				if logger != nil {
    73  					stack := stack(3)
    74  					httpRequest, _ := httputil.DumpRequest(c.Request, false)
    75  					headers := strings.Split(string(httpRequest), "\r\n")
    76  					for idx, header := range headers {
    77  						current := strings.Split(header, ":")
    78  						if current[0] == "Authorization" {
    79  							headers[idx] = current[0] + ": *"
    80  						}
    81  					}
    82  					headersToStr := strings.Join(headers, "\r\n")
    83  					if brokenPipe {
    84  						logger.Printf("%s\n%s%s", err, headersToStr, reset)
    85  					} else if IsDebugging() {
    86  						logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s",
    87  							timeFormat(time.Now()), headersToStr, err, stack, reset)
    88  					} else {
    89  						logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s",
    90  							timeFormat(time.Now()), err, stack, reset)
    91  					}
    92  				}
    93  				if brokenPipe {
    94  					// If the connection is dead, we can't write a status to it.
    95  					c.Error(err.(error)) //nolint: errcheck
    96  					c.Abort()
    97  				} else {
    98  					handle(c, err)
    99  				}
   100  			}
   101  		}()
   102  		c.Next()
   103  	}
   104  }
   105  
   106  func defaultHandleRecovery(c *Context, _ any) {
   107  	c.AbortWithStatus(http.StatusInternalServerError)
   108  }
   109  
   110  // stack returns a nicely formatted stack frame, skipping skip frames.
   111  func stack(skip int) []byte {
   112  	buf := new(bytes.Buffer) // the returned data
   113  	// As we loop, we open files and read them. These variables record the currently
   114  	// loaded file.
   115  	var lines [][]byte
   116  	var lastFile string
   117  	for i := skip; ; i++ { // Skip the expected number of frames
   118  		pc, file, line, ok := runtime.Caller(i)
   119  		if !ok {
   120  			break
   121  		}
   122  		// Print this much at least.  If we can't find the source, it won't show.
   123  		fmt.Fprintf(buf, "%s:%d (0x%x)\n", file, line, pc)
   124  		if file != lastFile {
   125  			data, err := os.ReadFile(file)
   126  			if err != nil {
   127  				continue
   128  			}
   129  			lines = bytes.Split(data, []byte{'\n'})
   130  			lastFile = file
   131  		}
   132  		fmt.Fprintf(buf, "\t%s: %s\n", function(pc), source(lines, line))
   133  	}
   134  	return buf.Bytes()
   135  }
   136  
   137  // source returns a space-trimmed slice of the n'th line.
   138  func source(lines [][]byte, n int) []byte {
   139  	n-- // in stack trace, lines are 1-indexed but our array is 0-indexed
   140  	if n < 0 || n >= len(lines) {
   141  		return dunno
   142  	}
   143  	return bytes.TrimSpace(lines[n])
   144  }
   145  
   146  // function returns, if possible, the name of the function containing the PC.
   147  func function(pc uintptr) []byte {
   148  	fn := runtime.FuncForPC(pc)
   149  	if fn == nil {
   150  		return dunno
   151  	}
   152  	name := []byte(fn.Name())
   153  	// The name includes the path name to the package, which is unnecessary
   154  	// since the file name is already included.  Plus, it has center dots.
   155  	// That is, we see
   156  	//	runtime/debug.*T·ptrmethod
   157  	// and want
   158  	//	*T.ptrmethod
   159  	// Also the package path might contain dot (e.g. code.google.com/...),
   160  	// so first eliminate the path prefix
   161  	if lastSlash := bytes.LastIndex(name, slash); lastSlash >= 0 {
   162  		name = name[lastSlash+1:]
   163  	}
   164  	if period := bytes.Index(name, dot); period >= 0 {
   165  		name = name[period+1:]
   166  	}
   167  	name = bytes.ReplaceAll(name, centerDot, dot)
   168  	return name
   169  }
   170  
   171  // timeFormat returns a customized time string for logger.
   172  func timeFormat(t time.Time) string {
   173  	return t.Format("2006/01/02 - 15:04:05")
   174  }
   175  

View as plain text