...

Source file src/golang.org/x/crypto/curve25519/curve25519_test.go

Documentation: golang.org/x/crypto/curve25519

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package curve25519_test
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"encoding/hex"
    11  	"testing"
    12  
    13  	"golang.org/x/crypto/curve25519"
    14  )
    15  
    16  const expectedHex = "89161fde887b2b53de549af483940106ecc114d6982daa98256de23bdf77661a"
    17  
    18  func TestX25519Basepoint(t *testing.T) {
    19  	x := make([]byte, 32)
    20  	x[0] = 1
    21  
    22  	for i := 0; i < 200; i++ {
    23  		var err error
    24  		x, err = curve25519.X25519(x, curve25519.Basepoint)
    25  		if err != nil {
    26  			t.Fatal(err)
    27  		}
    28  	}
    29  
    30  	result := hex.EncodeToString(x)
    31  	if result != expectedHex {
    32  		t.Errorf("incorrect result: got %s, want %s", result, expectedHex)
    33  	}
    34  }
    35  
    36  func TestLowOrderPoints(t *testing.T) {
    37  	scalar := make([]byte, curve25519.ScalarSize)
    38  	if _, err := rand.Read(scalar); err != nil {
    39  		t.Fatal(err)
    40  	}
    41  	for i, p := range lowOrderPoints {
    42  		out, err := curve25519.X25519(scalar, p)
    43  		if err == nil {
    44  			t.Errorf("%d: expected error, got nil", i)
    45  		}
    46  		if out != nil {
    47  			t.Errorf("%d: expected nil output, got %x", i, out)
    48  		}
    49  	}
    50  }
    51  
    52  func TestTestVectors(t *testing.T) {
    53  	t.Run("Legacy", func(t *testing.T) { testTestVectors(t, curve25519.ScalarMult) })
    54  	t.Run("X25519", func(t *testing.T) {
    55  		testTestVectors(t, func(dst, scalar, point *[32]byte) {
    56  			out, err := curve25519.X25519(scalar[:], point[:])
    57  			if err != nil {
    58  				t.Fatal(err)
    59  			}
    60  			copy(dst[:], out)
    61  		})
    62  	})
    63  }
    64  
    65  func testTestVectors(t *testing.T, scalarMult func(dst, scalar, point *[32]byte)) {
    66  	for _, tv := range testVectors {
    67  		var got [32]byte
    68  		scalarMult(&got, &tv.In, &tv.Base)
    69  		if !bytes.Equal(got[:], tv.Expect[:]) {
    70  			t.Logf("    in = %x", tv.In)
    71  			t.Logf("  base = %x", tv.Base)
    72  			t.Logf("   got = %x", got)
    73  			t.Logf("expect = %x", tv.Expect)
    74  			t.Fail()
    75  		}
    76  	}
    77  }
    78  
    79  // TestHighBitIgnored tests the following requirement in RFC 7748:
    80  //
    81  //	When receiving such an array, implementations of X25519 (but not X448) MUST
    82  //	mask the most significant bit in the final byte.
    83  //
    84  // Regression test for issue #30095.
    85  func TestHighBitIgnored(t *testing.T) {
    86  	var s, u [32]byte
    87  	rand.Read(s[:])
    88  	rand.Read(u[:])
    89  
    90  	var hi0, hi1 [32]byte
    91  
    92  	u[31] &= 0x7f
    93  	curve25519.ScalarMult(&hi0, &s, &u)
    94  
    95  	u[31] |= 0x80
    96  	curve25519.ScalarMult(&hi1, &s, &u)
    97  
    98  	if !bytes.Equal(hi0[:], hi1[:]) {
    99  		t.Errorf("high bit of group point should not affect result")
   100  	}
   101  }
   102  
   103  var benchmarkSink byte
   104  
   105  func BenchmarkX25519Basepoint(b *testing.B) {
   106  	scalar := make([]byte, curve25519.ScalarSize)
   107  	if _, err := rand.Read(scalar); err != nil {
   108  		b.Fatal(err)
   109  	}
   110  
   111  	b.ResetTimer()
   112  	for i := 0; i < b.N; i++ {
   113  		out, err := curve25519.X25519(scalar, curve25519.Basepoint)
   114  		if err != nil {
   115  			b.Fatal(err)
   116  		}
   117  		benchmarkSink ^= out[0]
   118  	}
   119  }
   120  
   121  func BenchmarkX25519(b *testing.B) {
   122  	scalar := make([]byte, curve25519.ScalarSize)
   123  	if _, err := rand.Read(scalar); err != nil {
   124  		b.Fatal(err)
   125  	}
   126  	point, err := curve25519.X25519(scalar, curve25519.Basepoint)
   127  	if err != nil {
   128  		b.Fatal(err)
   129  	}
   130  	if _, err := rand.Read(scalar); err != nil {
   131  		b.Fatal(err)
   132  	}
   133  
   134  	b.ResetTimer()
   135  	for i := 0; i < b.N; i++ {
   136  		out, err := curve25519.X25519(scalar, point)
   137  		if err != nil {
   138  			b.Fatal(err)
   139  		}
   140  		benchmarkSink ^= out[0]
   141  	}
   142  }
   143  

View as plain text