...

Source file src/crypto/internal/nistec/generate.go

Documentation: crypto/internal/nistec

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build ignore
     6  
     7  package main
     8  
     9  // Running this generator requires addchain v0.4.0, which can be installed with
    10  //
    11  //   go install github.com/mmcloughlin/addchain/cmd/addchain@v0.4.0
    12  //
    13  
    14  import (
    15  	"bytes"
    16  	"crypto/elliptic"
    17  	"fmt"
    18  	"go/format"
    19  	"io"
    20  	"log"
    21  	"math/big"
    22  	"os"
    23  	"os/exec"
    24  	"strings"
    25  	"text/template"
    26  )
    27  
    28  var curves = []struct {
    29  	P         string
    30  	Element   string
    31  	Params    *elliptic.CurveParams
    32  	BuildTags string
    33  }{
    34  	{
    35  		P:       "P224",
    36  		Element: "fiat.P224Element",
    37  		Params:  elliptic.P224().Params(),
    38  	},
    39  	{
    40  		P:         "P256",
    41  		Element:   "fiat.P256Element",
    42  		Params:    elliptic.P256().Params(),
    43  		BuildTags: "!amd64 && !arm64 && !ppc64le && !s390x",
    44  	},
    45  	{
    46  		P:       "P384",
    47  		Element: "fiat.P384Element",
    48  		Params:  elliptic.P384().Params(),
    49  	},
    50  	{
    51  		P:       "P521",
    52  		Element: "fiat.P521Element",
    53  		Params:  elliptic.P521().Params(),
    54  	},
    55  }
    56  
    57  func main() {
    58  	t := template.Must(template.New("tmplNISTEC").Parse(tmplNISTEC))
    59  
    60  	tmplAddchainFile, err := os.CreateTemp("", "addchain-template")
    61  	if err != nil {
    62  		log.Fatal(err)
    63  	}
    64  	defer os.Remove(tmplAddchainFile.Name())
    65  	if _, err := io.WriteString(tmplAddchainFile, tmplAddchain); err != nil {
    66  		log.Fatal(err)
    67  	}
    68  	if err := tmplAddchainFile.Close(); err != nil {
    69  		log.Fatal(err)
    70  	}
    71  
    72  	for _, c := range curves {
    73  		p := strings.ToLower(c.P)
    74  		elementLen := (c.Params.BitSize + 7) / 8
    75  		B := fmt.Sprintf("%#v", c.Params.B.FillBytes(make([]byte, elementLen)))
    76  		Gx := fmt.Sprintf("%#v", c.Params.Gx.FillBytes(make([]byte, elementLen)))
    77  		Gy := fmt.Sprintf("%#v", c.Params.Gy.FillBytes(make([]byte, elementLen)))
    78  
    79  		log.Printf("Generating %s.go...", p)
    80  		f, err := os.Create(p + ".go")
    81  		if err != nil {
    82  			log.Fatal(err)
    83  		}
    84  		defer f.Close()
    85  		buf := &bytes.Buffer{}
    86  		if err := t.Execute(buf, map[string]interface{}{
    87  			"P": c.P, "p": p, "B": B, "Gx": Gx, "Gy": Gy,
    88  			"Element": c.Element, "ElementLen": elementLen,
    89  			"BuildTags": c.BuildTags,
    90  		}); err != nil {
    91  			log.Fatal(err)
    92  		}
    93  		out, err := format.Source(buf.Bytes())
    94  		if err != nil {
    95  			log.Fatal(err)
    96  		}
    97  		if _, err := f.Write(out); err != nil {
    98  			log.Fatal(err)
    99  		}
   100  
   101  		// If p = 3 mod 4, implement modular square root by exponentiation.
   102  		mod4 := new(big.Int).Mod(c.Params.P, big.NewInt(4))
   103  		if mod4.Cmp(big.NewInt(3)) != 0 {
   104  			continue
   105  		}
   106  
   107  		exp := new(big.Int).Add(c.Params.P, big.NewInt(1))
   108  		exp.Div(exp, big.NewInt(4))
   109  
   110  		tmp, err := os.CreateTemp("", "addchain-"+p)
   111  		if err != nil {
   112  			log.Fatal(err)
   113  		}
   114  		defer os.Remove(tmp.Name())
   115  		cmd := exec.Command("addchain", "search", fmt.Sprintf("%d", exp))
   116  		cmd.Stderr = os.Stderr
   117  		cmd.Stdout = tmp
   118  		if err := cmd.Run(); err != nil {
   119  			log.Fatal(err)
   120  		}
   121  		if err := tmp.Close(); err != nil {
   122  			log.Fatal(err)
   123  		}
   124  		cmd = exec.Command("addchain", "gen", "-tmpl", tmplAddchainFile.Name(), tmp.Name())
   125  		cmd.Stderr = os.Stderr
   126  		out, err = cmd.Output()
   127  		if err != nil {
   128  			log.Fatal(err)
   129  		}
   130  		out = bytes.Replace(out, []byte("Element"), []byte(c.Element), -1)
   131  		out = bytes.Replace(out, []byte("sqrtCandidate"), []byte(p+"SqrtCandidate"), -1)
   132  		out, err = format.Source(out)
   133  		if err != nil {
   134  			log.Fatal(err)
   135  		}
   136  		if _, err := f.Write(out); err != nil {
   137  			log.Fatal(err)
   138  		}
   139  	}
   140  }
   141  
   142  const tmplNISTEC = `// Copyright 2022 The Go Authors. All rights reserved.
   143  // Use of this source code is governed by a BSD-style
   144  // license that can be found in the LICENSE file.
   145  
   146  // Code generated by generate.go. DO NOT EDIT.
   147  
   148  {{ if .BuildTags }}
   149  //go:build {{ .BuildTags }}
   150  {{ end }}
   151  
   152  package nistec
   153  
   154  import (
   155  	"crypto/internal/nistec/fiat"
   156  	"crypto/subtle"
   157  	"errors"
   158  	"sync"
   159  )
   160  
   161  // {{.p}}ElementLength is the length of an element of the base or scalar field,
   162  // which have the same bytes length for all NIST P curves.
   163  const {{.p}}ElementLength = {{ .ElementLen }}
   164  
   165  // {{.P}}Point is a {{.P}} point. The zero value is NOT valid.
   166  type {{.P}}Point struct {
   167  	// The point is represented in projective coordinates (X:Y:Z),
   168  	// where x = X/Z and y = Y/Z.
   169  	x, y, z *{{.Element}}
   170  }
   171  
   172  // New{{.P}}Point returns a new {{.P}}Point representing the point at infinity point.
   173  func New{{.P}}Point() *{{.P}}Point {
   174  	return &{{.P}}Point{
   175  		x: new({{.Element}}),
   176  		y: new({{.Element}}).One(),
   177  		z: new({{.Element}}),
   178  	}
   179  }
   180  
   181  // SetGenerator sets p to the canonical generator and returns p.
   182  func (p *{{.P}}Point) SetGenerator() *{{.P}}Point {
   183  	p.x.SetBytes({{.Gx}})
   184  	p.y.SetBytes({{.Gy}})
   185  	p.z.One()
   186  	return p
   187  }
   188  
   189  // Set sets p = q and returns p.
   190  func (p *{{.P}}Point) Set(q *{{.P}}Point) *{{.P}}Point {
   191  	p.x.Set(q.x)
   192  	p.y.Set(q.y)
   193  	p.z.Set(q.z)
   194  	return p
   195  }
   196  
   197  // SetBytes sets p to the compressed, uncompressed, or infinity value encoded in
   198  // b, as specified in SEC 1, Version 2.0, Section 2.3.4. If the point is not on
   199  // the curve, it returns nil and an error, and the receiver is unchanged.
   200  // Otherwise, it returns p.
   201  func (p *{{.P}}Point) SetBytes(b []byte) (*{{.P}}Point, error) {
   202  	switch {
   203  	// Point at infinity.
   204  	case len(b) == 1 && b[0] == 0:
   205  		return p.Set(New{{.P}}Point()), nil
   206  
   207  	// Uncompressed form.
   208  	case len(b) == 1+2*{{.p}}ElementLength && b[0] == 4:
   209  		x, err := new({{.Element}}).SetBytes(b[1 : 1+{{.p}}ElementLength])
   210  		if err != nil {
   211  			return nil, err
   212  		}
   213  		y, err := new({{.Element}}).SetBytes(b[1+{{.p}}ElementLength:])
   214  		if err != nil {
   215  			return nil, err
   216  		}
   217  		if err := {{.p}}CheckOnCurve(x, y); err != nil {
   218  			return nil, err
   219  		}
   220  		p.x.Set(x)
   221  		p.y.Set(y)
   222  		p.z.One()
   223  		return p, nil
   224  
   225  	// Compressed form.
   226  	case len(b) == 1+{{.p}}ElementLength && (b[0] == 2 || b[0] == 3):
   227  		x, err := new({{.Element}}).SetBytes(b[1:])
   228  		if err != nil {
   229  			return nil, err
   230  		}
   231  
   232  		// y² = x³ - 3x + b
   233  		y := {{.p}}Polynomial(new({{.Element}}), x)
   234  		if !{{.p}}Sqrt(y, y) {
   235  			return nil, errors.New("invalid {{.P}} compressed point encoding")
   236  		}
   237  
   238  		// Select the positive or negative root, as indicated by the least
   239  		// significant bit, based on the encoding type byte.
   240  		otherRoot := new({{.Element}})
   241  		otherRoot.Sub(otherRoot, y)
   242  		cond := y.Bytes()[{{.p}}ElementLength-1]&1 ^ b[0]&1
   243  		y.Select(otherRoot, y, int(cond))
   244  
   245  		p.x.Set(x)
   246  		p.y.Set(y)
   247  		p.z.One()
   248  		return p, nil
   249  
   250  	default:
   251  		return nil, errors.New("invalid {{.P}} point encoding")
   252  	}
   253  }
   254  
   255  
   256  var _{{.p}}B *{{.Element}}
   257  var _{{.p}}BOnce sync.Once
   258  
   259  func {{.p}}B() *{{.Element}} {
   260  	_{{.p}}BOnce.Do(func() {
   261  		_{{.p}}B, _ = new({{.Element}}).SetBytes({{.B}})
   262  	})
   263  	return _{{.p}}B
   264  }
   265  
   266  // {{.p}}Polynomial sets y2 to x³ - 3x + b, and returns y2.
   267  func {{.p}}Polynomial(y2, x *{{.Element}}) *{{.Element}} {
   268  	y2.Square(x)
   269  	y2.Mul(y2, x)
   270  
   271  	threeX := new({{.Element}}).Add(x, x)
   272  	threeX.Add(threeX, x)
   273  	y2.Sub(y2, threeX)
   274  
   275  	return y2.Add(y2, {{.p}}B())
   276  }
   277  
   278  func {{.p}}CheckOnCurve(x, y *{{.Element}}) error {
   279  	// y² = x³ - 3x + b
   280  	rhs := {{.p}}Polynomial(new({{.Element}}), x)
   281  	lhs := new({{.Element}}).Square(y)
   282  	if rhs.Equal(lhs) != 1 {
   283  		return errors.New("{{.P}} point not on curve")
   284  	}
   285  	return nil
   286  }
   287  
   288  // Bytes returns the uncompressed or infinity encoding of p, as specified in
   289  // SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the point at
   290  // infinity is shorter than all other encodings.
   291  func (p *{{.P}}Point) Bytes() []byte {
   292  	// This function is outlined to make the allocations inline in the caller
   293  	// rather than happen on the heap.
   294  	var out [1+2*{{.p}}ElementLength]byte
   295  	return p.bytes(&out)
   296  }
   297  
   298  func (p *{{.P}}Point) bytes(out *[1+2*{{.p}}ElementLength]byte) []byte {
   299  	if p.z.IsZero() == 1 {
   300  		return append(out[:0], 0)
   301  	}
   302  
   303  	zinv := new({{.Element}}).Invert(p.z)
   304  	x := new({{.Element}}).Mul(p.x, zinv)
   305  	y := new({{.Element}}).Mul(p.y, zinv)
   306  
   307  	buf := append(out[:0], 4)
   308  	buf = append(buf, x.Bytes()...)
   309  	buf = append(buf, y.Bytes()...)
   310  	return buf
   311  }
   312  
   313  // BytesX returns the encoding of the x-coordinate of p, as specified in SEC 1,
   314  // Version 2.0, Section 2.3.5, or an error if p is the point at infinity.
   315  func (p *{{.P}}Point) BytesX() ([]byte, error) {
   316  	// This function is outlined to make the allocations inline in the caller
   317  	// rather than happen on the heap.
   318  	var out [{{.p}}ElementLength]byte
   319  	return p.bytesX(&out)
   320  }
   321  
   322  func (p *{{.P}}Point) bytesX(out *[{{.p}}ElementLength]byte) ([]byte, error) {
   323  	if p.z.IsZero() == 1 {
   324  		return nil, errors.New("{{.P}} point is the point at infinity")
   325  	}
   326  
   327  	zinv := new({{.Element}}).Invert(p.z)
   328  	x := new({{.Element}}).Mul(p.x, zinv)
   329  
   330  	return append(out[:0], x.Bytes()...), nil
   331  }
   332  
   333  // BytesCompressed returns the compressed or infinity encoding of p, as
   334  // specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the
   335  // point at infinity is shorter than all other encodings.
   336  func (p *{{.P}}Point) BytesCompressed() []byte {
   337  	// This function is outlined to make the allocations inline in the caller
   338  	// rather than happen on the heap.
   339  	var out [1 + {{.p}}ElementLength]byte
   340  	return p.bytesCompressed(&out)
   341  }
   342  
   343  func (p *{{.P}}Point) bytesCompressed(out *[1 + {{.p}}ElementLength]byte) []byte {
   344  	if p.z.IsZero() == 1 {
   345  		return append(out[:0], 0)
   346  	}
   347  
   348  	zinv := new({{.Element}}).Invert(p.z)
   349  	x := new({{.Element}}).Mul(p.x, zinv)
   350  	y := new({{.Element}}).Mul(p.y, zinv)
   351  
   352  	// Encode the sign of the y coordinate (indicated by the least significant
   353  	// bit) as the encoding type (2 or 3).
   354  	buf := append(out[:0], 2)
   355  	buf[0] |= y.Bytes()[{{.p}}ElementLength-1] & 1
   356  	buf = append(buf, x.Bytes()...)
   357  	return buf
   358  }
   359  
   360  // Add sets q = p1 + p2, and returns q. The points may overlap.
   361  func (q *{{.P}}Point) Add(p1, p2 *{{.P}}Point) *{{.P}}Point {
   362  	// Complete addition formula for a = -3 from "Complete addition formulas for
   363  	// prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2.
   364  
   365  	t0 := new({{.Element}}).Mul(p1.x, p2.x)   // t0 := X1 * X2
   366  	t1 := new({{.Element}}).Mul(p1.y, p2.y)   // t1 := Y1 * Y2
   367  	t2 := new({{.Element}}).Mul(p1.z, p2.z)   // t2 := Z1 * Z2
   368  	t3 := new({{.Element}}).Add(p1.x, p1.y)   // t3 := X1 + Y1
   369  	t4 := new({{.Element}}).Add(p2.x, p2.y)   // t4 := X2 + Y2
   370  	t3.Mul(t3, t4)                            // t3 := t3 * t4
   371  	t4.Add(t0, t1)                            // t4 := t0 + t1
   372  	t3.Sub(t3, t4)                            // t3 := t3 - t4
   373  	t4.Add(p1.y, p1.z)                        // t4 := Y1 + Z1
   374  	x3 := new({{.Element}}).Add(p2.y, p2.z)   // X3 := Y2 + Z2
   375  	t4.Mul(t4, x3)                            // t4 := t4 * X3
   376  	x3.Add(t1, t2)                            // X3 := t1 + t2
   377  	t4.Sub(t4, x3)                            // t4 := t4 - X3
   378  	x3.Add(p1.x, p1.z)                        // X3 := X1 + Z1
   379  	y3 := new({{.Element}}).Add(p2.x, p2.z)   // Y3 := X2 + Z2
   380  	x3.Mul(x3, y3)                            // X3 := X3 * Y3
   381  	y3.Add(t0, t2)                            // Y3 := t0 + t2
   382  	y3.Sub(x3, y3)                            // Y3 := X3 - Y3
   383  	z3 := new({{.Element}}).Mul({{.p}}B(), t2)  // Z3 := b * t2
   384  	x3.Sub(y3, z3)                            // X3 := Y3 - Z3
   385  	z3.Add(x3, x3)                            // Z3 := X3 + X3
   386  	x3.Add(x3, z3)                            // X3 := X3 + Z3
   387  	z3.Sub(t1, x3)                            // Z3 := t1 - X3
   388  	x3.Add(t1, x3)                            // X3 := t1 + X3
   389  	y3.Mul({{.p}}B(), y3)                     // Y3 := b * Y3
   390  	t1.Add(t2, t2)                            // t1 := t2 + t2
   391  	t2.Add(t1, t2)                            // t2 := t1 + t2
   392  	y3.Sub(y3, t2)                            // Y3 := Y3 - t2
   393  	y3.Sub(y3, t0)                            // Y3 := Y3 - t0
   394  	t1.Add(y3, y3)                            // t1 := Y3 + Y3
   395  	y3.Add(t1, y3)                            // Y3 := t1 + Y3
   396  	t1.Add(t0, t0)                            // t1 := t0 + t0
   397  	t0.Add(t1, t0)                            // t0 := t1 + t0
   398  	t0.Sub(t0, t2)                            // t0 := t0 - t2
   399  	t1.Mul(t4, y3)                            // t1 := t4 * Y3
   400  	t2.Mul(t0, y3)                            // t2 := t0 * Y3
   401  	y3.Mul(x3, z3)                            // Y3 := X3 * Z3
   402  	y3.Add(y3, t2)                            // Y3 := Y3 + t2
   403  	x3.Mul(t3, x3)                            // X3 := t3 * X3
   404  	x3.Sub(x3, t1)                            // X3 := X3 - t1
   405  	z3.Mul(t4, z3)                            // Z3 := t4 * Z3
   406  	t1.Mul(t3, t0)                            // t1 := t3 * t0
   407  	z3.Add(z3, t1)                            // Z3 := Z3 + t1
   408  
   409  	q.x.Set(x3)
   410  	q.y.Set(y3)
   411  	q.z.Set(z3)
   412  	return q
   413  }
   414  
   415  // Double sets q = p + p, and returns q. The points may overlap.
   416  func (q *{{.P}}Point) Double(p *{{.P}}Point) *{{.P}}Point {
   417  	// Complete addition formula for a = -3 from "Complete addition formulas for
   418  	// prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2.
   419  
   420  	t0 := new({{.Element}}).Square(p.x)      // t0 := X ^ 2
   421  	t1 := new({{.Element}}).Square(p.y)      // t1 := Y ^ 2
   422  	t2 := new({{.Element}}).Square(p.z)      // t2 := Z ^ 2
   423  	t3 := new({{.Element}}).Mul(p.x, p.y)    // t3 := X * Y
   424  	t3.Add(t3, t3)                           // t3 := t3 + t3
   425  	z3 := new({{.Element}}).Mul(p.x, p.z)    // Z3 := X * Z
   426  	z3.Add(z3, z3)                           // Z3 := Z3 + Z3
   427  	y3 := new({{.Element}}).Mul({{.p}}B(), t2) // Y3 := b * t2
   428  	y3.Sub(y3, z3)                           // Y3 := Y3 - Z3
   429  	x3 := new({{.Element}}).Add(y3, y3)      // X3 := Y3 + Y3
   430  	y3.Add(x3, y3)                           // Y3 := X3 + Y3
   431  	x3.Sub(t1, y3)                           // X3 := t1 - Y3
   432  	y3.Add(t1, y3)                           // Y3 := t1 + Y3
   433  	y3.Mul(x3, y3)                           // Y3 := X3 * Y3
   434  	x3.Mul(x3, t3)                           // X3 := X3 * t3
   435  	t3.Add(t2, t2)                           // t3 := t2 + t2
   436  	t2.Add(t2, t3)                           // t2 := t2 + t3
   437  	z3.Mul({{.p}}B(), z3)                    // Z3 := b * Z3
   438  	z3.Sub(z3, t2)                           // Z3 := Z3 - t2
   439  	z3.Sub(z3, t0)                           // Z3 := Z3 - t0
   440  	t3.Add(z3, z3)                           // t3 := Z3 + Z3
   441  	z3.Add(z3, t3)                           // Z3 := Z3 + t3
   442  	t3.Add(t0, t0)                           // t3 := t0 + t0
   443  	t0.Add(t3, t0)                           // t0 := t3 + t0
   444  	t0.Sub(t0, t2)                           // t0 := t0 - t2
   445  	t0.Mul(t0, z3)                           // t0 := t0 * Z3
   446  	y3.Add(y3, t0)                           // Y3 := Y3 + t0
   447  	t0.Mul(p.y, p.z)                         // t0 := Y * Z
   448  	t0.Add(t0, t0)                           // t0 := t0 + t0
   449  	z3.Mul(t0, z3)                           // Z3 := t0 * Z3
   450  	x3.Sub(x3, z3)                           // X3 := X3 - Z3
   451  	z3.Mul(t0, t1)                           // Z3 := t0 * t1
   452  	z3.Add(z3, z3)                           // Z3 := Z3 + Z3
   453  	z3.Add(z3, z3)                           // Z3 := Z3 + Z3
   454  
   455  	q.x.Set(x3)
   456  	q.y.Set(y3)
   457  	q.z.Set(z3)
   458  	return q
   459  }
   460  
   461  // Select sets q to p1 if cond == 1, and to p2 if cond == 0.
   462  func (q *{{.P}}Point) Select(p1, p2 *{{.P}}Point, cond int) *{{.P}}Point {
   463  	q.x.Select(p1.x, p2.x, cond)
   464  	q.y.Select(p1.y, p2.y, cond)
   465  	q.z.Select(p1.z, p2.z, cond)
   466  	return q
   467  }
   468  
   469  // A {{.p}}Table holds the first 15 multiples of a point at offset -1, so [1]P
   470  // is at table[0], [15]P is at table[14], and [0]P is implicitly the identity
   471  // point.
   472  type {{.p}}Table [15]*{{.P}}Point
   473  
   474  // Select selects the n-th multiple of the table base point into p. It works in
   475  // constant time by iterating over every entry of the table. n must be in [0, 15].
   476  func (table *{{.p}}Table) Select(p *{{.P}}Point, n uint8) {
   477  	if n >= 16 {
   478  		panic("nistec: internal error: {{.p}}Table called with out-of-bounds value")
   479  	}
   480  	p.Set(New{{.P}}Point())
   481  	for i := uint8(1); i < 16; i++ {
   482  		cond := subtle.ConstantTimeByteEq(i, n)
   483  		p.Select(table[i-1], p, cond)
   484  	}
   485  }
   486  
   487  // ScalarMult sets p = scalar * q, and returns p.
   488  func (p *{{.P}}Point) ScalarMult(q *{{.P}}Point, scalar []byte) (*{{.P}}Point, error) {
   489  	// Compute a {{.p}}Table for the base point q. The explicit New{{.P}}Point
   490  	// calls get inlined, letting the allocations live on the stack.
   491  	var table = {{.p}}Table{New{{.P}}Point(), New{{.P}}Point(), New{{.P}}Point(),
   492  		New{{.P}}Point(), New{{.P}}Point(), New{{.P}}Point(), New{{.P}}Point(),
   493  		New{{.P}}Point(), New{{.P}}Point(), New{{.P}}Point(), New{{.P}}Point(),
   494  		New{{.P}}Point(), New{{.P}}Point(), New{{.P}}Point(), New{{.P}}Point()}
   495  	table[0].Set(q)
   496  	for i := 1; i < 15; i += 2 {
   497  		table[i].Double(table[i/2])
   498  		table[i+1].Add(table[i], q)
   499  	}
   500  
   501  	// Instead of doing the classic double-and-add chain, we do it with a
   502  	// four-bit window: we double four times, and then add [0-15]P.
   503  	t := New{{.P}}Point()
   504  	p.Set(New{{.P}}Point())
   505  	for i, byte := range scalar {
   506  		// No need to double on the first iteration, as p is the identity at
   507  		// this point, and [N]∞ = ∞.
   508  		if i != 0 {
   509  			p.Double(p)
   510  			p.Double(p)
   511  			p.Double(p)
   512  			p.Double(p)
   513  		}
   514  
   515  		windowValue := byte >> 4
   516  		table.Select(t, windowValue)
   517  		p.Add(p, t)
   518  
   519  		p.Double(p)
   520  		p.Double(p)
   521  		p.Double(p)
   522  		p.Double(p)
   523  
   524  		windowValue = byte & 0b1111
   525  		table.Select(t, windowValue)
   526  		p.Add(p, t)
   527  	}
   528  
   529  	return p, nil
   530  }
   531  
   532  var {{.p}}GeneratorTable *[{{.p}}ElementLength * 2]{{.p}}Table
   533  var {{.p}}GeneratorTableOnce sync.Once
   534  
   535  // generatorTable returns a sequence of {{.p}}Tables. The first table contains
   536  // multiples of G. Each successive table is the previous table doubled four
   537  // times.
   538  func (p *{{.P}}Point) generatorTable() *[{{.p}}ElementLength * 2]{{.p}}Table {
   539  	{{.p}}GeneratorTableOnce.Do(func() {
   540  		{{.p}}GeneratorTable = new([{{.p}}ElementLength * 2]{{.p}}Table)
   541  		base := New{{.P}}Point().SetGenerator()
   542  		for i := 0; i < {{.p}}ElementLength*2; i++ {
   543  			{{.p}}GeneratorTable[i][0] = New{{.P}}Point().Set(base)
   544  			for j := 1; j < 15; j++ {
   545  				{{.p}}GeneratorTable[i][j] = New{{.P}}Point().Add({{.p}}GeneratorTable[i][j-1], base)
   546  			}
   547  			base.Double(base)
   548  			base.Double(base)
   549  			base.Double(base)
   550  			base.Double(base)
   551  		}
   552  	})
   553  	return {{.p}}GeneratorTable
   554  }
   555  
   556  // ScalarBaseMult sets p = scalar * B, where B is the canonical generator, and
   557  // returns p.
   558  func (p *{{.P}}Point) ScalarBaseMult(scalar []byte) (*{{.P}}Point, error) {
   559  	if len(scalar) != {{.p}}ElementLength {
   560  		return nil, errors.New("invalid scalar length")
   561  	}
   562  	tables := p.generatorTable()
   563  
   564  	// This is also a scalar multiplication with a four-bit window like in
   565  	// ScalarMult, but in this case the doublings are precomputed. The value
   566  	// [windowValue]G added at iteration k would normally get doubled
   567  	// (totIterations-k)×4 times, but with a larger precomputation we can
   568  	// instead add [2^((totIterations-k)×4)][windowValue]G and avoid the
   569  	// doublings between iterations.
   570  	t := New{{.P}}Point()
   571  	p.Set(New{{.P}}Point())
   572  	tableIndex := len(tables) - 1
   573  	for _, byte := range scalar {
   574  		windowValue := byte >> 4
   575  		tables[tableIndex].Select(t, windowValue)
   576  		p.Add(p, t)
   577  		tableIndex--
   578  
   579  		windowValue = byte & 0b1111
   580  		tables[tableIndex].Select(t, windowValue)
   581  		p.Add(p, t)
   582  		tableIndex--
   583  	}
   584  
   585  	return p, nil
   586  }
   587  
   588  // {{.p}}Sqrt sets e to a square root of x. If x is not a square, {{.p}}Sqrt returns
   589  // false and e is unchanged. e and x can overlap.
   590  func {{.p}}Sqrt(e, x *{{ .Element }}) (isSquare bool) {
   591  	candidate := new({{ .Element }})
   592  	{{.p}}SqrtCandidate(candidate, x)
   593  	square := new({{ .Element }}).Square(candidate)
   594  	if square.Equal(x) != 1 {
   595  		return false
   596  	}
   597  	e.Set(candidate)
   598  	return true
   599  }
   600  `
   601  
   602  const tmplAddchain = `
   603  // sqrtCandidate sets z to a square root candidate for x. z and x must not overlap.
   604  func sqrtCandidate(z, x *Element) {
   605  	// Since p = 3 mod 4, exponentiation by (p + 1) / 4 yields a square root candidate.
   606  	//
   607  	// The sequence of {{ .Ops.Adds }} multiplications and {{ .Ops.Doubles }} squarings is derived from the
   608  	// following addition chain generated with {{ .Meta.Module }} {{ .Meta.ReleaseTag }}.
   609  	//
   610  	{{- range lines (format .Script) }}
   611  	//	{{ . }}
   612  	{{- end }}
   613  	//
   614  
   615  	{{- range .Program.Temporaries }}
   616  	var {{ . }} = new(Element)
   617  	{{- end }}
   618  	{{ range $i := .Program.Instructions -}}
   619  	{{- with add $i.Op }}
   620  	{{ $i.Output }}.Mul({{ .X }}, {{ .Y }})
   621  	{{- end -}}
   622  
   623  	{{- with double $i.Op }}
   624  	{{ $i.Output }}.Square({{ .X }})
   625  	{{- end -}}
   626  
   627  	{{- with shift $i.Op -}}
   628  	{{- $first := 0 -}}
   629  	{{- if ne $i.Output.Identifier .X.Identifier }}
   630  	{{ $i.Output }}.Square({{ .X }})
   631  	{{- $first = 1 -}}
   632  	{{- end }}
   633  	for s := {{ $first }}; s < {{ .S }}; s++ {
   634  		{{ $i.Output }}.Square({{ $i.Output }})
   635  	}
   636  	{{- end -}}
   637  	{{- end }}
   638  }
   639  `
   640  

View as plain text