...

Source file src/golang.org/x/sys/windows/mkwinsyscall/mkwinsyscall.go

Documentation: golang.org/x/sys/windows/mkwinsyscall

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  /*
     6  mkwinsyscall generates windows system call bodies
     7  
     8  It parses all files specified on command line containing function
     9  prototypes (like syscall_windows.go) and prints system call bodies
    10  to standard output.
    11  
    12  The prototypes are marked by lines beginning with "//sys" and read
    13  like func declarations if //sys is replaced by func, but:
    14  
    15    - The parameter lists must give a name for each argument. This
    16      includes return parameters.
    17  
    18    - The parameter lists must give a type for each argument:
    19      the (x, y, z int) shorthand is not allowed.
    20  
    21    - If the return parameter is an error number, it must be named err.
    22  
    23    - If go func name needs to be different from its winapi dll name,
    24      the winapi name could be specified at the end, after "=" sign, like
    25      //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA
    26  
    27    - Each function that returns err needs to supply a condition, that
    28      return value of winapi will be tested against to detect failure.
    29      This would set err to windows "last-error", otherwise it will be nil.
    30      The value can be provided at end of //sys declaration, like
    31      //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA
    32      and is [failretval==0] by default.
    33  
    34    - If the function name ends in a "?", then the function not existing is non-
    35      fatal, and an error will be returned instead of panicking.
    36  
    37  Usage:
    38  
    39  	mkwinsyscall [flags] [path ...]
    40  
    41  The flags are:
    42  
    43  	-output
    44  		Specify output file name (outputs to console if blank).
    45  	-trace
    46  		Generate print statement after every syscall.
    47  */
    48  package main
    49  
    50  import (
    51  	"bufio"
    52  	"bytes"
    53  	"errors"
    54  	"flag"
    55  	"fmt"
    56  	"go/format"
    57  	"go/parser"
    58  	"go/token"
    59  	"io"
    60  	"log"
    61  	"os"
    62  	"path/filepath"
    63  	"runtime"
    64  	"sort"
    65  	"strconv"
    66  	"strings"
    67  	"text/template"
    68  )
    69  
    70  var (
    71  	filename       = flag.String("output", "", "output file name (standard output if omitted)")
    72  	printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall")
    73  	systemDLL      = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory")
    74  )
    75  
    76  func trim(s string) string {
    77  	return strings.Trim(s, " \t")
    78  }
    79  
    80  var packageName string
    81  
    82  func packagename() string {
    83  	return packageName
    84  }
    85  
    86  func windowsdot() string {
    87  	if packageName == "windows" {
    88  		return ""
    89  	}
    90  	return "windows."
    91  }
    92  
    93  func syscalldot() string {
    94  	if packageName == "syscall" {
    95  		return ""
    96  	}
    97  	return "syscall."
    98  }
    99  
   100  // Param is function parameter
   101  type Param struct {
   102  	Name      string
   103  	Type      string
   104  	fn        *Fn
   105  	tmpVarIdx int
   106  }
   107  
   108  // tmpVar returns temp variable name that will be used to represent p during syscall.
   109  func (p *Param) tmpVar() string {
   110  	if p.tmpVarIdx < 0 {
   111  		p.tmpVarIdx = p.fn.curTmpVarIdx
   112  		p.fn.curTmpVarIdx++
   113  	}
   114  	return fmt.Sprintf("_p%d", p.tmpVarIdx)
   115  }
   116  
   117  // BoolTmpVarCode returns source code for bool temp variable.
   118  func (p *Param) BoolTmpVarCode() string {
   119  	const code = `var %[1]s uint32
   120  	if %[2]s {
   121  		%[1]s = 1
   122  	}`
   123  	return fmt.Sprintf(code, p.tmpVar(), p.Name)
   124  }
   125  
   126  // BoolPointerTmpVarCode returns source code for bool temp variable.
   127  func (p *Param) BoolPointerTmpVarCode() string {
   128  	const code = `var %[1]s uint32
   129  	if *%[2]s {
   130  		%[1]s = 1
   131  	}`
   132  	return fmt.Sprintf(code, p.tmpVar(), p.Name)
   133  }
   134  
   135  // SliceTmpVarCode returns source code for slice temp variable.
   136  func (p *Param) SliceTmpVarCode() string {
   137  	const code = `var %s *%s
   138  	if len(%s) > 0 {
   139  		%s = &%s[0]
   140  	}`
   141  	tmp := p.tmpVar()
   142  	return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name)
   143  }
   144  
   145  // StringTmpVarCode returns source code for string temp variable.
   146  func (p *Param) StringTmpVarCode() string {
   147  	errvar := p.fn.Rets.ErrorVarName()
   148  	if errvar == "" {
   149  		errvar = "_"
   150  	}
   151  	tmp := p.tmpVar()
   152  	const code = `var %s %s
   153  	%s, %s = %s(%s)`
   154  	s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name)
   155  	if errvar == "-" {
   156  		return s
   157  	}
   158  	const morecode = `
   159  	if %s != nil {
   160  		return
   161  	}`
   162  	return s + fmt.Sprintf(morecode, errvar)
   163  }
   164  
   165  // TmpVarCode returns source code for temp variable.
   166  func (p *Param) TmpVarCode() string {
   167  	switch {
   168  	case p.Type == "bool":
   169  		return p.BoolTmpVarCode()
   170  	case p.Type == "*bool":
   171  		return p.BoolPointerTmpVarCode()
   172  	case strings.HasPrefix(p.Type, "[]"):
   173  		return p.SliceTmpVarCode()
   174  	default:
   175  		return ""
   176  	}
   177  }
   178  
   179  // TmpVarReadbackCode returns source code for reading back the temp variable into the original variable.
   180  func (p *Param) TmpVarReadbackCode() string {
   181  	switch {
   182  	case p.Type == "*bool":
   183  		return fmt.Sprintf("*%s = %s != 0", p.Name, p.tmpVar())
   184  	default:
   185  		return ""
   186  	}
   187  }
   188  
   189  // TmpVarHelperCode returns source code for helper's temp variable.
   190  func (p *Param) TmpVarHelperCode() string {
   191  	if p.Type != "string" {
   192  		return ""
   193  	}
   194  	return p.StringTmpVarCode()
   195  }
   196  
   197  // SyscallArgList returns source code fragments representing p parameter
   198  // in syscall. Slices are translated into 2 syscall parameters: pointer to
   199  // the first element and length.
   200  func (p *Param) SyscallArgList() []string {
   201  	t := p.HelperType()
   202  	var s string
   203  	switch {
   204  	case t == "*bool":
   205  		s = fmt.Sprintf("unsafe.Pointer(&%s)", p.tmpVar())
   206  	case t[0] == '*':
   207  		s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name)
   208  	case t == "bool":
   209  		s = p.tmpVar()
   210  	case strings.HasPrefix(t, "[]"):
   211  		return []string{
   212  			fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()),
   213  			fmt.Sprintf("uintptr(len(%s))", p.Name),
   214  		}
   215  	default:
   216  		s = p.Name
   217  	}
   218  	return []string{fmt.Sprintf("uintptr(%s)", s)}
   219  }
   220  
   221  // IsError determines if p parameter is used to return error.
   222  func (p *Param) IsError() bool {
   223  	return p.Name == "err" && p.Type == "error"
   224  }
   225  
   226  // HelperType returns type of parameter p used in helper function.
   227  func (p *Param) HelperType() string {
   228  	if p.Type == "string" {
   229  		return p.fn.StrconvType()
   230  	}
   231  	return p.Type
   232  }
   233  
   234  // join concatenates parameters ps into a string with sep separator.
   235  // Each parameter is converted into string by applying fn to it
   236  // before conversion.
   237  func join(ps []*Param, fn func(*Param) string, sep string) string {
   238  	if len(ps) == 0 {
   239  		return ""
   240  	}
   241  	a := make([]string, 0)
   242  	for _, p := range ps {
   243  		a = append(a, fn(p))
   244  	}
   245  	return strings.Join(a, sep)
   246  }
   247  
   248  // Rets describes function return parameters.
   249  type Rets struct {
   250  	Name          string
   251  	Type          string
   252  	ReturnsError  bool
   253  	FailCond      string
   254  	fnMaybeAbsent bool
   255  }
   256  
   257  // ErrorVarName returns error variable name for r.
   258  func (r *Rets) ErrorVarName() string {
   259  	if r.ReturnsError {
   260  		return "err"
   261  	}
   262  	if r.Type == "error" {
   263  		return r.Name
   264  	}
   265  	return ""
   266  }
   267  
   268  // ToParams converts r into slice of *Param.
   269  func (r *Rets) ToParams() []*Param {
   270  	ps := make([]*Param, 0)
   271  	if len(r.Name) > 0 {
   272  		ps = append(ps, &Param{Name: r.Name, Type: r.Type})
   273  	}
   274  	if r.ReturnsError {
   275  		ps = append(ps, &Param{Name: "err", Type: "error"})
   276  	}
   277  	return ps
   278  }
   279  
   280  // List returns source code of syscall return parameters.
   281  func (r *Rets) List() string {
   282  	s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ")
   283  	if len(s) > 0 {
   284  		s = "(" + s + ")"
   285  	} else if r.fnMaybeAbsent {
   286  		s = "(err error)"
   287  	}
   288  	return s
   289  }
   290  
   291  // PrintList returns source code of trace printing part correspondent
   292  // to syscall return values.
   293  func (r *Rets) PrintList() string {
   294  	return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
   295  }
   296  
   297  // SetReturnValuesCode returns source code that accepts syscall return values.
   298  func (r *Rets) SetReturnValuesCode() string {
   299  	if r.Name == "" && !r.ReturnsError {
   300  		return ""
   301  	}
   302  	retvar := "r0"
   303  	if r.Name == "" {
   304  		retvar = "r1"
   305  	}
   306  	errvar := "_"
   307  	if r.ReturnsError {
   308  		errvar = "e1"
   309  	}
   310  	return fmt.Sprintf("%s, _, %s := ", retvar, errvar)
   311  }
   312  
   313  func (r *Rets) useLongHandleErrorCode(retvar string) string {
   314  	const code = `if %s {
   315  		err = errnoErr(e1)
   316  	}`
   317  	cond := retvar + " == 0"
   318  	if r.FailCond != "" {
   319  		cond = strings.Replace(r.FailCond, "failretval", retvar, 1)
   320  	}
   321  	return fmt.Sprintf(code, cond)
   322  }
   323  
   324  // SetErrorCode returns source code that sets return parameters.
   325  func (r *Rets) SetErrorCode() string {
   326  	const code = `if r0 != 0 {
   327  		%s = %sErrno(r0)
   328  	}`
   329  	const ntstatus = `if r0 != 0 {
   330  		ntstatus = %sNTStatus(r0)
   331  	}`
   332  	if r.Name == "" && !r.ReturnsError {
   333  		return ""
   334  	}
   335  	if r.Name == "" {
   336  		return r.useLongHandleErrorCode("r1")
   337  	}
   338  	if r.Type == "error" && r.Name == "ntstatus" {
   339  		return fmt.Sprintf(ntstatus, windowsdot())
   340  	}
   341  	if r.Type == "error" {
   342  		return fmt.Sprintf(code, r.Name, syscalldot())
   343  	}
   344  	s := ""
   345  	switch {
   346  	case r.Type[0] == '*':
   347  		s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type)
   348  	case r.Type == "bool":
   349  		s = fmt.Sprintf("%s = r0 != 0", r.Name)
   350  	default:
   351  		s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type)
   352  	}
   353  	if !r.ReturnsError {
   354  		return s
   355  	}
   356  	return s + "\n\t" + r.useLongHandleErrorCode(r.Name)
   357  }
   358  
   359  // Fn describes syscall function.
   360  type Fn struct {
   361  	Name        string
   362  	Params      []*Param
   363  	Rets        *Rets
   364  	PrintTrace  bool
   365  	dllname     string
   366  	dllfuncname string
   367  	src         string
   368  	// TODO: get rid of this field and just use parameter index instead
   369  	curTmpVarIdx int // insure tmp variables have uniq names
   370  }
   371  
   372  // extractParams parses s to extract function parameters.
   373  func extractParams(s string, f *Fn) ([]*Param, error) {
   374  	s = trim(s)
   375  	if s == "" {
   376  		return nil, nil
   377  	}
   378  	a := strings.Split(s, ",")
   379  	ps := make([]*Param, len(a))
   380  	for i := range ps {
   381  		s2 := trim(a[i])
   382  		b := strings.Split(s2, " ")
   383  		if len(b) != 2 {
   384  			b = strings.Split(s2, "\t")
   385  			if len(b) != 2 {
   386  				return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"")
   387  			}
   388  		}
   389  		ps[i] = &Param{
   390  			Name:      trim(b[0]),
   391  			Type:      trim(b[1]),
   392  			fn:        f,
   393  			tmpVarIdx: -1,
   394  		}
   395  	}
   396  	return ps, nil
   397  }
   398  
   399  // extractSection extracts text out of string s starting after start
   400  // and ending just before end. found return value will indicate success,
   401  // and prefix, body and suffix will contain correspondent parts of string s.
   402  func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) {
   403  	s = trim(s)
   404  	if strings.HasPrefix(s, string(start)) {
   405  		// no prefix
   406  		body = s[1:]
   407  	} else {
   408  		a := strings.SplitN(s, string(start), 2)
   409  		if len(a) != 2 {
   410  			return "", "", s, false
   411  		}
   412  		prefix = a[0]
   413  		body = a[1]
   414  	}
   415  	a := strings.SplitN(body, string(end), 2)
   416  	if len(a) != 2 {
   417  		return "", "", "", false
   418  	}
   419  	return prefix, a[0], a[1], true
   420  }
   421  
   422  // newFn parses string s and return created function Fn.
   423  func newFn(s string) (*Fn, error) {
   424  	s = trim(s)
   425  	f := &Fn{
   426  		Rets:       &Rets{},
   427  		src:        s,
   428  		PrintTrace: *printTraceFlag,
   429  	}
   430  	// function name and args
   431  	prefix, body, s, found := extractSection(s, '(', ')')
   432  	if !found || prefix == "" {
   433  		return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"")
   434  	}
   435  	f.Name = prefix
   436  	var err error
   437  	f.Params, err = extractParams(body, f)
   438  	if err != nil {
   439  		return nil, err
   440  	}
   441  	// return values
   442  	_, body, s, found = extractSection(s, '(', ')')
   443  	if found {
   444  		r, err := extractParams(body, f)
   445  		if err != nil {
   446  			return nil, err
   447  		}
   448  		switch len(r) {
   449  		case 0:
   450  		case 1:
   451  			if r[0].IsError() {
   452  				f.Rets.ReturnsError = true
   453  			} else {
   454  				f.Rets.Name = r[0].Name
   455  				f.Rets.Type = r[0].Type
   456  			}
   457  		case 2:
   458  			if !r[1].IsError() {
   459  				return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"")
   460  			}
   461  			f.Rets.ReturnsError = true
   462  			f.Rets.Name = r[0].Name
   463  			f.Rets.Type = r[0].Type
   464  		default:
   465  			return nil, errors.New("Too many return values in \"" + f.src + "\"")
   466  		}
   467  	}
   468  	// fail condition
   469  	_, body, s, found = extractSection(s, '[', ']')
   470  	if found {
   471  		f.Rets.FailCond = body
   472  	}
   473  	// dll and dll function names
   474  	s = trim(s)
   475  	if s == "" {
   476  		return f, nil
   477  	}
   478  	if !strings.HasPrefix(s, "=") {
   479  		return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
   480  	}
   481  	s = trim(s[1:])
   482  	if i := strings.LastIndex(s, "."); i >= 0 {
   483  		f.dllname = s[:i]
   484  		f.dllfuncname = s[i+1:]
   485  	} else {
   486  		f.dllfuncname = s
   487  	}
   488  	if f.dllfuncname == "" {
   489  		return nil, fmt.Errorf("function name is not specified in %q", s)
   490  	}
   491  	if n := f.dllfuncname; strings.HasSuffix(n, "?") {
   492  		f.dllfuncname = n[:len(n)-1]
   493  		f.Rets.fnMaybeAbsent = true
   494  	}
   495  	return f, nil
   496  }
   497  
   498  // DLLName returns DLL name for function f.
   499  func (f *Fn) DLLName() string {
   500  	if f.dllname == "" {
   501  		return "kernel32"
   502  	}
   503  	return f.dllname
   504  }
   505  
   506  // DLLVar returns a valid Go identifier that represents DLLName.
   507  func (f *Fn) DLLVar() string {
   508  	id := strings.Map(func(r rune) rune {
   509  		switch r {
   510  		case '.', '-':
   511  			return '_'
   512  		default:
   513  			return r
   514  		}
   515  	}, f.DLLName())
   516  	if !token.IsIdentifier(id) {
   517  		panic(fmt.Errorf("could not create Go identifier for DLLName %q", f.DLLName()))
   518  	}
   519  	return id
   520  }
   521  
   522  // DLLFuncName returns DLL function name for function f.
   523  func (f *Fn) DLLFuncName() string {
   524  	if f.dllfuncname == "" {
   525  		return f.Name
   526  	}
   527  	return f.dllfuncname
   528  }
   529  
   530  // ParamList returns source code for function f parameters.
   531  func (f *Fn) ParamList() string {
   532  	return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ")
   533  }
   534  
   535  // HelperParamList returns source code for helper function f parameters.
   536  func (f *Fn) HelperParamList() string {
   537  	return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ")
   538  }
   539  
   540  // ParamPrintList returns source code of trace printing part correspondent
   541  // to syscall input parameters.
   542  func (f *Fn) ParamPrintList() string {
   543  	return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
   544  }
   545  
   546  // ParamCount return number of syscall parameters for function f.
   547  func (f *Fn) ParamCount() int {
   548  	n := 0
   549  	for _, p := range f.Params {
   550  		n += len(p.SyscallArgList())
   551  	}
   552  	return n
   553  }
   554  
   555  // SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/...
   556  // to use. It returns parameter count for correspondent SyscallX function.
   557  func (f *Fn) SyscallParamCount() int {
   558  	n := f.ParamCount()
   559  	switch {
   560  	case n <= 3:
   561  		return 3
   562  	case n <= 6:
   563  		return 6
   564  	case n <= 9:
   565  		return 9
   566  	case n <= 12:
   567  		return 12
   568  	case n <= 15:
   569  		return 15
   570  	case n <= 42: // current SyscallN limit
   571  		return n
   572  	default:
   573  		panic("too many arguments to system call")
   574  	}
   575  }
   576  
   577  // Syscall determines which SyscallX function to use for function f.
   578  func (f *Fn) Syscall() string {
   579  	c := f.SyscallParamCount()
   580  	if c == 3 {
   581  		return syscalldot() + "Syscall"
   582  	}
   583  	if c > 15 {
   584  		return syscalldot() + "SyscallN"
   585  	}
   586  	return syscalldot() + "Syscall" + strconv.Itoa(c)
   587  }
   588  
   589  // SyscallParamList returns source code for SyscallX parameters for function f.
   590  func (f *Fn) SyscallParamList() string {
   591  	a := make([]string, 0)
   592  	for _, p := range f.Params {
   593  		a = append(a, p.SyscallArgList()...)
   594  	}
   595  	for len(a) < f.SyscallParamCount() {
   596  		a = append(a, "0")
   597  	}
   598  	return strings.Join(a, ", ")
   599  }
   600  
   601  // HelperCallParamList returns source code of call into function f helper.
   602  func (f *Fn) HelperCallParamList() string {
   603  	a := make([]string, 0, len(f.Params))
   604  	for _, p := range f.Params {
   605  		s := p.Name
   606  		if p.Type == "string" {
   607  			s = p.tmpVar()
   608  		}
   609  		a = append(a, s)
   610  	}
   611  	return strings.Join(a, ", ")
   612  }
   613  
   614  // MaybeAbsent returns source code for handling functions that are possibly unavailable.
   615  func (p *Fn) MaybeAbsent() string {
   616  	if !p.Rets.fnMaybeAbsent {
   617  		return ""
   618  	}
   619  	const code = `%[1]s = proc%[2]s.Find()
   620  	if %[1]s != nil {
   621  		return
   622  	}`
   623  	errorVar := p.Rets.ErrorVarName()
   624  	if errorVar == "" {
   625  		errorVar = "err"
   626  	}
   627  	return fmt.Sprintf(code, errorVar, p.DLLFuncName())
   628  }
   629  
   630  // IsUTF16 is true, if f is W (utf16) function. It is false
   631  // for all A (ascii) functions.
   632  func (f *Fn) IsUTF16() bool {
   633  	s := f.DLLFuncName()
   634  	return s[len(s)-1] == 'W'
   635  }
   636  
   637  // StrconvFunc returns name of Go string to OS string function for f.
   638  func (f *Fn) StrconvFunc() string {
   639  	if f.IsUTF16() {
   640  		return syscalldot() + "UTF16PtrFromString"
   641  	}
   642  	return syscalldot() + "BytePtrFromString"
   643  }
   644  
   645  // StrconvType returns Go type name used for OS string for f.
   646  func (f *Fn) StrconvType() string {
   647  	if f.IsUTF16() {
   648  		return "*uint16"
   649  	}
   650  	return "*byte"
   651  }
   652  
   653  // HasStringParam is true, if f has at least one string parameter.
   654  // Otherwise it is false.
   655  func (f *Fn) HasStringParam() bool {
   656  	for _, p := range f.Params {
   657  		if p.Type == "string" {
   658  			return true
   659  		}
   660  	}
   661  	return false
   662  }
   663  
   664  // HelperName returns name of function f helper.
   665  func (f *Fn) HelperName() string {
   666  	if !f.HasStringParam() {
   667  		return f.Name
   668  	}
   669  	return "_" + f.Name
   670  }
   671  
   672  // DLL is a DLL's filename and a string that is valid in a Go identifier that should be used when
   673  // naming a variable that refers to the DLL.
   674  type DLL struct {
   675  	Name string
   676  	Var  string
   677  }
   678  
   679  // Source files and functions.
   680  type Source struct {
   681  	Funcs           []*Fn
   682  	DLLFuncNames    []*Fn
   683  	Files           []string
   684  	StdLibImports   []string
   685  	ExternalImports []string
   686  }
   687  
   688  func (src *Source) Import(pkg string) {
   689  	src.StdLibImports = append(src.StdLibImports, pkg)
   690  	sort.Strings(src.StdLibImports)
   691  }
   692  
   693  func (src *Source) ExternalImport(pkg string) {
   694  	src.ExternalImports = append(src.ExternalImports, pkg)
   695  	sort.Strings(src.ExternalImports)
   696  }
   697  
   698  // ParseFiles parses files listed in fs and extracts all syscall
   699  // functions listed in sys comments. It returns source files
   700  // and functions collection *Source if successful.
   701  func ParseFiles(fs []string) (*Source, error) {
   702  	src := &Source{
   703  		Funcs: make([]*Fn, 0),
   704  		Files: make([]string, 0),
   705  		StdLibImports: []string{
   706  			"unsafe",
   707  		},
   708  		ExternalImports: make([]string, 0),
   709  	}
   710  	for _, file := range fs {
   711  		if err := src.ParseFile(file); err != nil {
   712  			return nil, err
   713  		}
   714  	}
   715  	src.DLLFuncNames = make([]*Fn, 0, len(src.Funcs))
   716  	uniq := make(map[string]bool, len(src.Funcs))
   717  	for _, fn := range src.Funcs {
   718  		name := fn.DLLFuncName()
   719  		if !uniq[name] {
   720  			src.DLLFuncNames = append(src.DLLFuncNames, fn)
   721  			uniq[name] = true
   722  		}
   723  	}
   724  	return src, nil
   725  }
   726  
   727  // DLLs return dll names for a source set src.
   728  func (src *Source) DLLs() []DLL {
   729  	uniq := make(map[string]bool)
   730  	r := make([]DLL, 0)
   731  	for _, f := range src.Funcs {
   732  		id := f.DLLVar()
   733  		if _, found := uniq[id]; !found {
   734  			uniq[id] = true
   735  			r = append(r, DLL{f.DLLName(), id})
   736  		}
   737  	}
   738  	sort.Slice(r, func(i, j int) bool {
   739  		return r[i].Var < r[j].Var
   740  	})
   741  	return r
   742  }
   743  
   744  // ParseFile adds additional file path to a source set src.
   745  func (src *Source) ParseFile(path string) error {
   746  	file, err := os.Open(path)
   747  	if err != nil {
   748  		return err
   749  	}
   750  	defer file.Close()
   751  
   752  	s := bufio.NewScanner(file)
   753  	for s.Scan() {
   754  		t := trim(s.Text())
   755  		if len(t) < 7 {
   756  			continue
   757  		}
   758  		if !strings.HasPrefix(t, "//sys") {
   759  			continue
   760  		}
   761  		t = t[5:]
   762  		if !(t[0] == ' ' || t[0] == '\t') {
   763  			continue
   764  		}
   765  		f, err := newFn(t[1:])
   766  		if err != nil {
   767  			return err
   768  		}
   769  		src.Funcs = append(src.Funcs, f)
   770  	}
   771  	if err := s.Err(); err != nil {
   772  		return err
   773  	}
   774  	src.Files = append(src.Files, path)
   775  	sort.Slice(src.Funcs, func(i, j int) bool {
   776  		fi, fj := src.Funcs[i], src.Funcs[j]
   777  		if fi.DLLName() == fj.DLLName() {
   778  			return fi.DLLFuncName() < fj.DLLFuncName()
   779  		}
   780  		return fi.DLLName() < fj.DLLName()
   781  	})
   782  
   783  	// get package name
   784  	fset := token.NewFileSet()
   785  	_, err = file.Seek(0, 0)
   786  	if err != nil {
   787  		return err
   788  	}
   789  	pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly)
   790  	if err != nil {
   791  		return err
   792  	}
   793  	packageName = pkg.Name.Name
   794  
   795  	return nil
   796  }
   797  
   798  // IsStdRepo reports whether src is part of standard library.
   799  func (src *Source) IsStdRepo() (bool, error) {
   800  	if len(src.Files) == 0 {
   801  		return false, errors.New("no input files provided")
   802  	}
   803  	abspath, err := filepath.Abs(src.Files[0])
   804  	if err != nil {
   805  		return false, err
   806  	}
   807  	goroot := runtime.GOROOT()
   808  	if runtime.GOOS == "windows" {
   809  		abspath = strings.ToLower(abspath)
   810  		goroot = strings.ToLower(goroot)
   811  	}
   812  	sep := string(os.PathSeparator)
   813  	if !strings.HasSuffix(goroot, sep) {
   814  		goroot += sep
   815  	}
   816  	return strings.HasPrefix(abspath, goroot), nil
   817  }
   818  
   819  // Generate output source file from a source set src.
   820  func (src *Source) Generate(w io.Writer) error {
   821  	const (
   822  		pkgStd         = iota // any package in std library
   823  		pkgXSysWindows        // x/sys/windows package
   824  		pkgOther
   825  	)
   826  	isStdRepo, err := src.IsStdRepo()
   827  	if err != nil {
   828  		return err
   829  	}
   830  	var pkgtype int
   831  	switch {
   832  	case isStdRepo:
   833  		pkgtype = pkgStd
   834  	case packageName == "windows":
   835  		// TODO: this needs better logic than just using package name
   836  		pkgtype = pkgXSysWindows
   837  	default:
   838  		pkgtype = pkgOther
   839  	}
   840  	if *systemDLL {
   841  		switch pkgtype {
   842  		case pkgStd:
   843  			src.Import("internal/syscall/windows/sysdll")
   844  		case pkgXSysWindows:
   845  		default:
   846  			src.ExternalImport("golang.org/x/sys/windows")
   847  		}
   848  	}
   849  	if packageName != "syscall" {
   850  		src.Import("syscall")
   851  	}
   852  	funcMap := template.FuncMap{
   853  		"packagename": packagename,
   854  		"syscalldot":  syscalldot,
   855  		"newlazydll": func(dll string) string {
   856  			arg := "\"" + dll + ".dll\""
   857  			if !*systemDLL {
   858  				return syscalldot() + "NewLazyDLL(" + arg + ")"
   859  			}
   860  			switch pkgtype {
   861  			case pkgStd:
   862  				return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))"
   863  			case pkgXSysWindows:
   864  				return "NewLazySystemDLL(" + arg + ")"
   865  			default:
   866  				return "windows.NewLazySystemDLL(" + arg + ")"
   867  			}
   868  		},
   869  	}
   870  	t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate))
   871  	err = t.Execute(w, src)
   872  	if err != nil {
   873  		return errors.New("Failed to execute template: " + err.Error())
   874  	}
   875  	return nil
   876  }
   877  
   878  func writeTempSourceFile(data []byte) (string, error) {
   879  	f, err := os.CreateTemp("", "mkwinsyscall-generated-*.go")
   880  	if err != nil {
   881  		return "", err
   882  	}
   883  	_, err = f.Write(data)
   884  	if closeErr := f.Close(); err == nil {
   885  		err = closeErr
   886  	}
   887  	if err != nil {
   888  		os.Remove(f.Name()) // best effort
   889  		return "", err
   890  	}
   891  	return f.Name(), nil
   892  }
   893  
   894  func usage() {
   895  	fmt.Fprintf(os.Stderr, "usage: mkwinsyscall [flags] [path ...]\n")
   896  	flag.PrintDefaults()
   897  	os.Exit(1)
   898  }
   899  
   900  func main() {
   901  	flag.Usage = usage
   902  	flag.Parse()
   903  	if len(flag.Args()) <= 0 {
   904  		fmt.Fprintf(os.Stderr, "no files to parse provided\n")
   905  		usage()
   906  	}
   907  
   908  	src, err := ParseFiles(flag.Args())
   909  	if err != nil {
   910  		log.Fatal(err)
   911  	}
   912  
   913  	var buf bytes.Buffer
   914  	if err := src.Generate(&buf); err != nil {
   915  		log.Fatal(err)
   916  	}
   917  
   918  	data, err := format.Source(buf.Bytes())
   919  	if err != nil {
   920  		log.Printf("failed to format source: %v", err)
   921  		f, err := writeTempSourceFile(buf.Bytes())
   922  		if err != nil {
   923  			log.Fatalf("failed to write unformatted source to file: %v", err)
   924  		}
   925  		log.Fatalf("for diagnosis, wrote unformatted source to %v", f)
   926  	}
   927  	if *filename == "" {
   928  		_, err = os.Stdout.Write(data)
   929  	} else {
   930  		err = os.WriteFile(*filename, data, 0644)
   931  	}
   932  	if err != nil {
   933  		log.Fatal(err)
   934  	}
   935  }
   936  
   937  // TODO: use println instead to print in the following template
   938  const srcTemplate = `
   939  
   940  {{define "main"}}// Code generated by 'go generate'; DO NOT EDIT.
   941  
   942  package {{packagename}}
   943  
   944  import (
   945  {{range .StdLibImports}}"{{.}}"
   946  {{end}}
   947  
   948  {{range .ExternalImports}}"{{.}}"
   949  {{end}}
   950  )
   951  
   952  var _ unsafe.Pointer
   953  
   954  // Do the interface allocations only once for common
   955  // Errno values.
   956  const (
   957  	errnoERROR_IO_PENDING = 997
   958  )
   959  
   960  var (
   961  	errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING)
   962  	errERROR_EINVAL error     = {{syscalldot}}EINVAL
   963  )
   964  
   965  // errnoErr returns common boxed Errno values, to prevent
   966  // allocations at runtime.
   967  func errnoErr(e {{syscalldot}}Errno) error {
   968  	switch e {
   969  	case 0:
   970  		return errERROR_EINVAL
   971  	case errnoERROR_IO_PENDING:
   972  		return errERROR_IO_PENDING
   973  	}
   974  	// TODO: add more here, after collecting data on the common
   975  	// error values see on Windows. (perhaps when running
   976  	// all.bat?)
   977  	return e
   978  }
   979  
   980  var (
   981  {{template "dlls" .}}
   982  {{template "funcnames" .}})
   983  {{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}}
   984  {{end}}
   985  
   986  {{/* help functions */}}
   987  
   988  {{define "dlls"}}{{range .DLLs}}	mod{{.Var}} = {{newlazydll .Name}}
   989  {{end}}{{end}}
   990  
   991  {{define "funcnames"}}{{range .DLLFuncNames}}	proc{{.DLLFuncName}} = mod{{.DLLVar}}.NewProc("{{.DLLFuncName}}")
   992  {{end}}{{end}}
   993  
   994  {{define "helperbody"}}
   995  func {{.Name}}({{.ParamList}}) {{template "results" .}}{
   996  {{template "helpertmpvars" .}}	return {{.HelperName}}({{.HelperCallParamList}})
   997  }
   998  {{end}}
   999  
  1000  {{define "funcbody"}}
  1001  func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{
  1002  {{template "maybeabsent" .}}	{{template "tmpvars" .}}	{{template "syscall" .}}	{{template "tmpvarsreadback" .}}
  1003  {{template "seterror" .}}{{template "printtrace" .}}	return
  1004  }
  1005  {{end}}
  1006  
  1007  {{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}}	{{.TmpVarHelperCode}}
  1008  {{end}}{{end}}{{end}}
  1009  
  1010  {{define "maybeabsent"}}{{if .MaybeAbsent}}{{.MaybeAbsent}}
  1011  {{end}}{{end}}
  1012  
  1013  {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}}	{{.TmpVarCode}}
  1014  {{end}}{{end}}{{end}}
  1015  
  1016  {{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}}
  1017  
  1018  {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(),{{if le .ParamCount 15}} {{.ParamCount}},{{end}} {{.SyscallParamList}}){{end}}
  1019  
  1020  {{define "tmpvarsreadback"}}{{range .Params}}{{if .TmpVarReadbackCode}}
  1021  {{.TmpVarReadbackCode}}{{end}}{{end}}{{end}}
  1022  
  1023  {{define "seterror"}}{{if .Rets.SetErrorCode}}	{{.Rets.SetErrorCode}}
  1024  {{end}}{{end}}
  1025  
  1026  {{define "printtrace"}}{{if .PrintTrace}}	print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n")
  1027  {{end}}{{end}}
  1028  
  1029  `
  1030  

View as plain text