...

Source file src/golang.org/x/crypto/acme/autocert/autocert_test.go

Documentation: golang.org/x/crypto/acme/autocert

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package autocert
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"crypto"
    11  	"crypto/ecdsa"
    12  	"crypto/elliptic"
    13  	"crypto/rand"
    14  	"crypto/rsa"
    15  	"crypto/tls"
    16  	"crypto/x509"
    17  	"crypto/x509/pkix"
    18  	"encoding/asn1"
    19  	"fmt"
    20  	"io"
    21  	"math/big"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"reflect"
    25  	"strings"
    26  	"sync"
    27  	"testing"
    28  	"time"
    29  
    30  	"golang.org/x/crypto/acme"
    31  	"golang.org/x/crypto/acme/autocert/internal/acmetest"
    32  )
    33  
    34  var (
    35  	exampleDomain     = "example.org"
    36  	exampleCertKey    = certKey{domain: exampleDomain}
    37  	exampleCertKeyRSA = certKey{domain: exampleDomain, isRSA: true}
    38  )
    39  
    40  type memCache struct {
    41  	t       *testing.T
    42  	mu      sync.Mutex
    43  	keyData map[string][]byte
    44  }
    45  
    46  func (m *memCache) Get(ctx context.Context, key string) ([]byte, error) {
    47  	m.mu.Lock()
    48  	defer m.mu.Unlock()
    49  
    50  	v, ok := m.keyData[key]
    51  	if !ok {
    52  		return nil, ErrCacheMiss
    53  	}
    54  	return v, nil
    55  }
    56  
    57  // filenameSafe returns whether all characters in s are printable ASCII
    58  // and safe to use in a filename on most filesystems.
    59  func filenameSafe(s string) bool {
    60  	for _, c := range s {
    61  		if c < 0x20 || c > 0x7E {
    62  			return false
    63  		}
    64  		switch c {
    65  		case '\\', '/', ':', '*', '?', '"', '<', '>', '|':
    66  			return false
    67  		}
    68  	}
    69  	return true
    70  }
    71  
    72  func (m *memCache) Put(ctx context.Context, key string, data []byte) error {
    73  	if !filenameSafe(key) {
    74  		m.t.Errorf("invalid characters in cache key %q", key)
    75  	}
    76  
    77  	m.mu.Lock()
    78  	defer m.mu.Unlock()
    79  
    80  	m.keyData[key] = data
    81  	return nil
    82  }
    83  
    84  func (m *memCache) Delete(ctx context.Context, key string) error {
    85  	m.mu.Lock()
    86  	defer m.mu.Unlock()
    87  
    88  	delete(m.keyData, key)
    89  	return nil
    90  }
    91  
    92  func newMemCache(t *testing.T) *memCache {
    93  	return &memCache{
    94  		t:       t,
    95  		keyData: make(map[string][]byte),
    96  	}
    97  }
    98  
    99  func (m *memCache) numCerts() int {
   100  	m.mu.Lock()
   101  	defer m.mu.Unlock()
   102  
   103  	res := 0
   104  	for key := range m.keyData {
   105  		if strings.HasSuffix(key, "+token") ||
   106  			strings.HasSuffix(key, "+key") ||
   107  			strings.HasSuffix(key, "+http-01") {
   108  			continue
   109  		}
   110  		res++
   111  	}
   112  	return res
   113  }
   114  
   115  func dummyCert(pub interface{}, san ...string) ([]byte, error) {
   116  	return dateDummyCert(pub, time.Now(), time.Now().Add(90*24*time.Hour), san...)
   117  }
   118  
   119  func dateDummyCert(pub interface{}, start, end time.Time, san ...string) ([]byte, error) {
   120  	// use EC key to run faster on 386
   121  	key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  	t := &x509.Certificate{
   126  		SerialNumber:          randomSerial(),
   127  		NotBefore:             start,
   128  		NotAfter:              end,
   129  		BasicConstraintsValid: true,
   130  		KeyUsage:              x509.KeyUsageKeyEncipherment,
   131  		DNSNames:              san,
   132  	}
   133  	if pub == nil {
   134  		pub = &key.PublicKey
   135  	}
   136  	return x509.CreateCertificate(rand.Reader, t, t, pub, key)
   137  }
   138  
   139  func randomSerial() *big.Int {
   140  	serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 32))
   141  	if err != nil {
   142  		panic(err)
   143  	}
   144  	return serial
   145  }
   146  
   147  type algorithmSupport int
   148  
   149  const (
   150  	algRSA algorithmSupport = iota
   151  	algECDSA
   152  )
   153  
   154  func clientHelloInfo(sni string, alg algorithmSupport) *tls.ClientHelloInfo {
   155  	hello := &tls.ClientHelloInfo{
   156  		ServerName:   sni,
   157  		CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305},
   158  	}
   159  	if alg == algECDSA {
   160  		hello.CipherSuites = append(hello.CipherSuites, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305)
   161  	}
   162  	return hello
   163  }
   164  
   165  func testManager(t *testing.T) *Manager {
   166  	man := &Manager{
   167  		Prompt: AcceptTOS,
   168  		Cache:  newMemCache(t),
   169  	}
   170  	t.Cleanup(man.stopRenew)
   171  	return man
   172  }
   173  
   174  func TestGetCertificate(t *testing.T) {
   175  	tests := []struct {
   176  		name        string
   177  		hello       *tls.ClientHelloInfo
   178  		domain      string
   179  		expectError string
   180  		prepare     func(t *testing.T, man *Manager, s *acmetest.CAServer)
   181  		verify      func(t *testing.T, man *Manager, leaf *x509.Certificate)
   182  		disableALPN bool
   183  		disableHTTP bool
   184  	}{
   185  		{
   186  			name:        "ALPN",
   187  			hello:       clientHelloInfo("example.org", algECDSA),
   188  			domain:      "example.org",
   189  			disableHTTP: true,
   190  		},
   191  		{
   192  			name:        "HTTP",
   193  			hello:       clientHelloInfo("example.org", algECDSA),
   194  			domain:      "example.org",
   195  			disableALPN: true,
   196  		},
   197  		{
   198  			name:   "nilPrompt",
   199  			hello:  clientHelloInfo("example.org", algECDSA),
   200  			domain: "example.org",
   201  			prepare: func(t *testing.T, man *Manager, s *acmetest.CAServer) {
   202  				man.Prompt = nil
   203  			},
   204  			expectError: "Manager.Prompt not set",
   205  		},
   206  		{
   207  			name:   "trailingDot",
   208  			hello:  clientHelloInfo("example.org.", algECDSA),
   209  			domain: "example.org",
   210  		},
   211  		{
   212  			name:   "unicodeIDN",
   213  			hello:  clientHelloInfo("éé.com", algECDSA),
   214  			domain: "xn--9caa.com",
   215  		},
   216  		{
   217  			name:   "unicodeIDN/mixedCase",
   218  			hello:  clientHelloInfo("éÉ.com", algECDSA),
   219  			domain: "xn--9caa.com",
   220  		},
   221  		{
   222  			name:   "upperCase",
   223  			hello:  clientHelloInfo("EXAMPLE.ORG", algECDSA),
   224  			domain: "example.org",
   225  		},
   226  		{
   227  			name:   "goodCache",
   228  			hello:  clientHelloInfo("example.org", algECDSA),
   229  			domain: "example.org",
   230  			prepare: func(t *testing.T, man *Manager, s *acmetest.CAServer) {
   231  				// Make a valid cert and cache it.
   232  				c := s.Start().LeafCert(exampleDomain, "ECDSA",
   233  					// Use a time before the Let's Encrypt revocation cutoff to also test
   234  					// that non-Let's Encrypt certificates are not renewed.
   235  					time.Date(2022, time.January, 1, 0, 0, 0, 0, time.UTC),
   236  					time.Date(2122, time.January, 1, 0, 0, 0, 0, time.UTC),
   237  				)
   238  				if err := man.cachePut(context.Background(), exampleCertKey, c); err != nil {
   239  					t.Fatalf("man.cachePut: %v", err)
   240  				}
   241  			},
   242  			// Break the server to check that the cache is used.
   243  			disableALPN: true, disableHTTP: true,
   244  		},
   245  		{
   246  			name:   "expiredCache",
   247  			hello:  clientHelloInfo("example.org", algECDSA),
   248  			domain: "example.org",
   249  			prepare: func(t *testing.T, man *Manager, s *acmetest.CAServer) {
   250  				// Make an expired cert and cache it.
   251  				c := s.Start().LeafCert(exampleDomain, "ECDSA", time.Now().Add(-10*time.Minute), time.Now().Add(-5*time.Minute))
   252  				if err := man.cachePut(context.Background(), exampleCertKey, c); err != nil {
   253  					t.Fatalf("man.cachePut: %v", err)
   254  				}
   255  			},
   256  		},
   257  		{
   258  			name:   "forceRSA",
   259  			hello:  clientHelloInfo("example.org", algECDSA),
   260  			domain: "example.org",
   261  			prepare: func(t *testing.T, man *Manager, s *acmetest.CAServer) {
   262  				man.ForceRSA = true
   263  			},
   264  			verify: func(t *testing.T, man *Manager, leaf *x509.Certificate) {
   265  				if _, ok := leaf.PublicKey.(*ecdsa.PublicKey); !ok {
   266  					t.Errorf("leaf.PublicKey is %T; want *ecdsa.PublicKey", leaf.PublicKey)
   267  				}
   268  			},
   269  		},
   270  		{
   271  			name:   "goodLetsEncrypt",
   272  			hello:  clientHelloInfo("example.org", algECDSA),
   273  			domain: "example.org",
   274  			prepare: func(t *testing.T, man *Manager, s *acmetest.CAServer) {
   275  				// Make a valid certificate issued after the TLS-ALPN-01
   276  				// revocation window and cache it.
   277  				s.IssuerName(pkix.Name{Country: []string{"US"},
   278  					Organization: []string{"Let's Encrypt"}, CommonName: "R3"})
   279  				c := s.Start().LeafCert(exampleDomain, "ECDSA",
   280  					time.Date(2022, time.January, 26, 12, 0, 0, 0, time.UTC),
   281  					time.Date(2122, time.January, 1, 0, 0, 0, 0, time.UTC),
   282  				)
   283  				if err := man.cachePut(context.Background(), exampleCertKey, c); err != nil {
   284  					t.Fatalf("man.cachePut: %v", err)
   285  				}
   286  			},
   287  			// Break the server to check that the cache is used.
   288  			disableALPN: true, disableHTTP: true,
   289  		},
   290  		{
   291  			name:   "revokedLetsEncrypt",
   292  			hello:  clientHelloInfo("example.org", algECDSA),
   293  			domain: "example.org",
   294  			prepare: func(t *testing.T, man *Manager, s *acmetest.CAServer) {
   295  				// Make a certificate issued during the TLS-ALPN-01
   296  				// revocation window and cache it.
   297  				s.IssuerName(pkix.Name{Country: []string{"US"},
   298  					Organization: []string{"Let's Encrypt"}, CommonName: "R3"})
   299  				c := s.Start().LeafCert(exampleDomain, "ECDSA",
   300  					time.Date(2022, time.January, 1, 0, 0, 0, 0, time.UTC),
   301  					time.Date(2122, time.January, 1, 0, 0, 0, 0, time.UTC),
   302  				)
   303  				if err := man.cachePut(context.Background(), exampleCertKey, c); err != nil {
   304  					t.Fatalf("man.cachePut: %v", err)
   305  				}
   306  			},
   307  			verify: func(t *testing.T, man *Manager, leaf *x509.Certificate) {
   308  				if leaf.NotBefore.Before(time.Now().Add(-10 * time.Minute)) {
   309  					t.Error("certificate was not reissued")
   310  				}
   311  			},
   312  		},
   313  		{
   314  			// TestGetCertificate/tokenCache tests the fallback of token
   315  			// certificate fetches to cache when Manager.certTokens misses.
   316  			name:   "tokenCacheALPN",
   317  			hello:  clientHelloInfo("example.org", algECDSA),
   318  			domain: "example.org",
   319  			prepare: func(t *testing.T, man *Manager, s *acmetest.CAServer) {
   320  				// Make a separate manager with a shared cache, simulating
   321  				// separate nodes that serve requests for the same domain.
   322  				man2 := testManager(t)
   323  				man2.Cache = man.Cache
   324  				// Redirect the verification request to man2, although the
   325  				// client request will hit man, testing that they can complete a
   326  				// verification by communicating through the cache.
   327  				s.ResolveGetCertificate("example.org", man2.GetCertificate)
   328  			},
   329  			// Drop the default verification paths.
   330  			disableALPN: true,
   331  		},
   332  		{
   333  			name:   "tokenCacheHTTP",
   334  			hello:  clientHelloInfo("example.org", algECDSA),
   335  			domain: "example.org",
   336  			prepare: func(t *testing.T, man *Manager, s *acmetest.CAServer) {
   337  				man2 := testManager(t)
   338  				man2.Cache = man.Cache
   339  				s.ResolveHandler("example.org", man2.HTTPHandler(nil))
   340  			},
   341  			disableHTTP: true,
   342  		},
   343  		{
   344  			name:   "ecdsa",
   345  			hello:  clientHelloInfo("example.org", algECDSA),
   346  			domain: "example.org",
   347  			verify: func(t *testing.T, man *Manager, leaf *x509.Certificate) {
   348  				if _, ok := leaf.PublicKey.(*ecdsa.PublicKey); !ok {
   349  					t.Error("an ECDSA client was served a non-ECDSA certificate")
   350  				}
   351  			},
   352  		},
   353  		{
   354  			name:   "rsa",
   355  			hello:  clientHelloInfo("example.org", algRSA),
   356  			domain: "example.org",
   357  			verify: func(t *testing.T, man *Manager, leaf *x509.Certificate) {
   358  				if _, ok := leaf.PublicKey.(*rsa.PublicKey); !ok {
   359  					t.Error("an RSA client was served a non-RSA certificate")
   360  				}
   361  			},
   362  		},
   363  		{
   364  			name:   "wrongCacheKeyType",
   365  			hello:  clientHelloInfo("example.org", algECDSA),
   366  			domain: "example.org",
   367  			prepare: func(t *testing.T, man *Manager, s *acmetest.CAServer) {
   368  				// Make an RSA cert and cache it without suffix.
   369  				c := s.Start().LeafCert(exampleDomain, "RSA", time.Now(), time.Now().Add(90*24*time.Hour))
   370  				if err := man.cachePut(context.Background(), exampleCertKey, c); err != nil {
   371  					t.Fatalf("man.cachePut: %v", err)
   372  				}
   373  			},
   374  			verify: func(t *testing.T, man *Manager, leaf *x509.Certificate) {
   375  				// The RSA cached cert should be silently ignored and replaced.
   376  				if _, ok := leaf.PublicKey.(*ecdsa.PublicKey); !ok {
   377  					t.Error("an ECDSA client was served a non-ECDSA certificate")
   378  				}
   379  				if numCerts := man.Cache.(*memCache).numCerts(); numCerts != 1 {
   380  					t.Errorf("found %d certificates in cache; want %d", numCerts, 1)
   381  				}
   382  			},
   383  		},
   384  		{
   385  			name:   "almostExpiredCache",
   386  			hello:  clientHelloInfo("example.org", algECDSA),
   387  			domain: "example.org",
   388  			prepare: func(t *testing.T, man *Manager, s *acmetest.CAServer) {
   389  				man.RenewBefore = 24 * time.Hour
   390  				// Cache an almost expired cert.
   391  				c := s.Start().LeafCert(exampleDomain, "ECDSA", time.Now(), time.Now().Add(10*time.Minute))
   392  				if err := man.cachePut(context.Background(), exampleCertKey, c); err != nil {
   393  					t.Fatalf("man.cachePut: %v", err)
   394  				}
   395  			},
   396  		},
   397  		{
   398  			name:   "provideExternalAuth",
   399  			hello:  clientHelloInfo("example.org", algECDSA),
   400  			domain: "example.org",
   401  			prepare: func(t *testing.T, man *Manager, s *acmetest.CAServer) {
   402  				s.ExternalAccountRequired()
   403  
   404  				man.ExternalAccountBinding = &acme.ExternalAccountBinding{
   405  					KID: "test-key",
   406  					Key: make([]byte, 32),
   407  				}
   408  			},
   409  		},
   410  	}
   411  	for _, tt := range tests {
   412  		t.Run(tt.name, func(t *testing.T) {
   413  			man := testManager(t)
   414  			s := acmetest.NewCAServer(t)
   415  			if !tt.disableALPN {
   416  				s.ResolveGetCertificate(tt.domain, man.GetCertificate)
   417  			}
   418  			if !tt.disableHTTP {
   419  				s.ResolveHandler(tt.domain, man.HTTPHandler(nil))
   420  			}
   421  
   422  			if tt.prepare != nil {
   423  				tt.prepare(t, man, s)
   424  			}
   425  
   426  			s.Start()
   427  
   428  			man.Client = &acme.Client{DirectoryURL: s.URL()}
   429  
   430  			tlscert, err := man.GetCertificate(tt.hello)
   431  			if tt.expectError != "" {
   432  				if err == nil {
   433  					t.Fatal("expected error, got certificate")
   434  				}
   435  				if !strings.Contains(err.Error(), tt.expectError) {
   436  					t.Errorf("got %q, expected %q", err, tt.expectError)
   437  				}
   438  				return
   439  			}
   440  			if err != nil {
   441  				t.Fatalf("man.GetCertificate: %v", err)
   442  			}
   443  
   444  			leaf, err := x509.ParseCertificate(tlscert.Certificate[0])
   445  			if err != nil {
   446  				t.Fatal(err)
   447  			}
   448  			opts := x509.VerifyOptions{
   449  				DNSName:       tt.domain,
   450  				Intermediates: x509.NewCertPool(),
   451  				Roots:         s.Roots(),
   452  			}
   453  			for _, cert := range tlscert.Certificate[1:] {
   454  				c, err := x509.ParseCertificate(cert)
   455  				if err != nil {
   456  					t.Fatal(err)
   457  				}
   458  				opts.Intermediates.AddCert(c)
   459  			}
   460  			if _, err := leaf.Verify(opts); err != nil {
   461  				t.Error(err)
   462  			}
   463  
   464  			if san := leaf.DNSNames[0]; san != tt.domain {
   465  				t.Errorf("got SAN %q, expected %q", san, tt.domain)
   466  			}
   467  
   468  			if tt.verify != nil {
   469  				tt.verify(t, man, leaf)
   470  			}
   471  		})
   472  	}
   473  }
   474  
   475  func TestGetCertificate_failedAttempt(t *testing.T) {
   476  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   477  		w.WriteHeader(http.StatusBadRequest)
   478  	}))
   479  	defer ts.Close()
   480  
   481  	d := createCertRetryAfter
   482  	f := testDidRemoveState
   483  	defer func() {
   484  		createCertRetryAfter = d
   485  		testDidRemoveState = f
   486  	}()
   487  	createCertRetryAfter = 0
   488  	done := make(chan struct{})
   489  	testDidRemoveState = func(ck certKey) {
   490  		if ck != exampleCertKey {
   491  			t.Errorf("testDidRemoveState: domain = %v; want %v", ck, exampleCertKey)
   492  		}
   493  		close(done)
   494  	}
   495  
   496  	man := &Manager{
   497  		Prompt: AcceptTOS,
   498  		Client: &acme.Client{
   499  			DirectoryURL: ts.URL,
   500  		},
   501  	}
   502  	defer man.stopRenew()
   503  	hello := clientHelloInfo(exampleDomain, algECDSA)
   504  	if _, err := man.GetCertificate(hello); err == nil {
   505  		t.Error("GetCertificate: err is nil")
   506  	}
   507  
   508  	<-done
   509  	man.stateMu.Lock()
   510  	defer man.stateMu.Unlock()
   511  	if v, exist := man.state[exampleCertKey]; exist {
   512  		t.Errorf("state exists for %v: %+v", exampleCertKey, v)
   513  	}
   514  }
   515  
   516  func TestRevokeFailedAuthz(t *testing.T) {
   517  	ca := acmetest.NewCAServer(t)
   518  	// Make the authz unfulfillable on the client side, so it will be left
   519  	// pending at the end of the verification attempt.
   520  	ca.ChallengeTypes("fake-01", "fake-02")
   521  	ca.Start()
   522  
   523  	m := testManager(t)
   524  	m.Client = &acme.Client{DirectoryURL: ca.URL()}
   525  
   526  	_, err := m.GetCertificate(clientHelloInfo("example.org", algECDSA))
   527  	if err == nil {
   528  		t.Fatal("expected GetCertificate to fail")
   529  	}
   530  
   531  	logTicker := time.NewTicker(3 * time.Second)
   532  	defer logTicker.Stop()
   533  	for {
   534  		authz, err := m.Client.GetAuthorization(context.Background(), ca.URL()+"/authz/0")
   535  		if err != nil {
   536  			t.Fatal(err)
   537  		}
   538  		if authz.Status == acme.StatusDeactivated {
   539  			return
   540  		}
   541  
   542  		select {
   543  		case <-logTicker.C:
   544  			t.Logf("still waiting on revocations")
   545  		default:
   546  		}
   547  		time.Sleep(50 * time.Millisecond)
   548  	}
   549  }
   550  
   551  func TestHTTPHandlerDefaultFallback(t *testing.T) {
   552  	tt := []struct {
   553  		method, url  string
   554  		wantCode     int
   555  		wantLocation string
   556  	}{
   557  		{"GET", "http://example.org", 302, "https://example.org/"},
   558  		{"GET", "http://example.org/foo", 302, "https://example.org/foo"},
   559  		{"GET", "http://example.org/foo/bar/", 302, "https://example.org/foo/bar/"},
   560  		{"GET", "http://example.org/?a=b", 302, "https://example.org/?a=b"},
   561  		{"GET", "http://example.org/foo?a=b", 302, "https://example.org/foo?a=b"},
   562  		{"GET", "http://example.org:80/foo?a=b", 302, "https://example.org:443/foo?a=b"},
   563  		{"GET", "http://example.org:80/foo%20bar", 302, "https://example.org:443/foo%20bar"},
   564  		{"GET", "http://[2602:d1:xxxx::c60a]:1234", 302, "https://[2602:d1:xxxx::c60a]:443/"},
   565  		{"GET", "http://[2602:d1:xxxx::c60a]", 302, "https://[2602:d1:xxxx::c60a]/"},
   566  		{"GET", "http://[2602:d1:xxxx::c60a]/foo?a=b", 302, "https://[2602:d1:xxxx::c60a]/foo?a=b"},
   567  		{"HEAD", "http://example.org", 302, "https://example.org/"},
   568  		{"HEAD", "http://example.org/foo", 302, "https://example.org/foo"},
   569  		{"HEAD", "http://example.org/foo/bar/", 302, "https://example.org/foo/bar/"},
   570  		{"HEAD", "http://example.org/?a=b", 302, "https://example.org/?a=b"},
   571  		{"HEAD", "http://example.org/foo?a=b", 302, "https://example.org/foo?a=b"},
   572  		{"POST", "http://example.org", 400, ""},
   573  		{"PUT", "http://example.org", 400, ""},
   574  		{"GET", "http://example.org/.well-known/acme-challenge/x", 404, ""},
   575  	}
   576  	var m Manager
   577  	h := m.HTTPHandler(nil)
   578  	for i, test := range tt {
   579  		r := httptest.NewRequest(test.method, test.url, nil)
   580  		w := httptest.NewRecorder()
   581  		h.ServeHTTP(w, r)
   582  		if w.Code != test.wantCode {
   583  			t.Errorf("%d: w.Code = %d; want %d", i, w.Code, test.wantCode)
   584  			t.Errorf("%d: body: %s", i, w.Body.Bytes())
   585  		}
   586  		if v := w.Header().Get("Location"); v != test.wantLocation {
   587  			t.Errorf("%d: Location = %q; want %q", i, v, test.wantLocation)
   588  		}
   589  	}
   590  }
   591  
   592  func TestAccountKeyCache(t *testing.T) {
   593  	m := Manager{Cache: newMemCache(t)}
   594  	ctx := context.Background()
   595  	k1, err := m.accountKey(ctx)
   596  	if err != nil {
   597  		t.Fatal(err)
   598  	}
   599  	k2, err := m.accountKey(ctx)
   600  	if err != nil {
   601  		t.Fatal(err)
   602  	}
   603  	if !reflect.DeepEqual(k1, k2) {
   604  		t.Errorf("account keys don't match: k1 = %#v; k2 = %#v", k1, k2)
   605  	}
   606  }
   607  
   608  func TestCache(t *testing.T) {
   609  	ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   610  	if err != nil {
   611  		t.Fatal(err)
   612  	}
   613  	cert, err := dummyCert(ecdsaKey.Public(), exampleDomain)
   614  	if err != nil {
   615  		t.Fatal(err)
   616  	}
   617  	ecdsaCert := &tls.Certificate{
   618  		Certificate: [][]byte{cert},
   619  		PrivateKey:  ecdsaKey,
   620  	}
   621  
   622  	rsaKey, err := rsa.GenerateKey(rand.Reader, 512)
   623  	if err != nil {
   624  		t.Fatal(err)
   625  	}
   626  	cert, err = dummyCert(rsaKey.Public(), exampleDomain)
   627  	if err != nil {
   628  		t.Fatal(err)
   629  	}
   630  	rsaCert := &tls.Certificate{
   631  		Certificate: [][]byte{cert},
   632  		PrivateKey:  rsaKey,
   633  	}
   634  
   635  	man := &Manager{Cache: newMemCache(t)}
   636  	defer man.stopRenew()
   637  	ctx := context.Background()
   638  
   639  	if err := man.cachePut(ctx, exampleCertKey, ecdsaCert); err != nil {
   640  		t.Fatalf("man.cachePut: %v", err)
   641  	}
   642  	if err := man.cachePut(ctx, exampleCertKeyRSA, rsaCert); err != nil {
   643  		t.Fatalf("man.cachePut: %v", err)
   644  	}
   645  
   646  	res, err := man.cacheGet(ctx, exampleCertKey)
   647  	if err != nil {
   648  		t.Fatalf("man.cacheGet: %v", err)
   649  	}
   650  	if res == nil || !bytes.Equal(res.Certificate[0], ecdsaCert.Certificate[0]) {
   651  		t.Errorf("man.cacheGet = %+v; want %+v", res, ecdsaCert)
   652  	}
   653  
   654  	res, err = man.cacheGet(ctx, exampleCertKeyRSA)
   655  	if err != nil {
   656  		t.Fatalf("man.cacheGet: %v", err)
   657  	}
   658  	if res == nil || !bytes.Equal(res.Certificate[0], rsaCert.Certificate[0]) {
   659  		t.Errorf("man.cacheGet = %+v; want %+v", res, rsaCert)
   660  	}
   661  }
   662  
   663  func TestHostWhitelist(t *testing.T) {
   664  	policy := HostWhitelist("example.com", "EXAMPLE.ORG", "*.example.net", "éÉ.com")
   665  	tt := []struct {
   666  		host  string
   667  		allow bool
   668  	}{
   669  		{"example.com", true},
   670  		{"example.org", true},
   671  		{"xn--9caa.com", true}, // éé.com
   672  		{"one.example.com", false},
   673  		{"two.example.org", false},
   674  		{"three.example.net", false},
   675  		{"dummy", false},
   676  	}
   677  	for i, test := range tt {
   678  		err := policy(nil, test.host)
   679  		if err != nil && test.allow {
   680  			t.Errorf("%d: policy(%q): %v; want nil", i, test.host, err)
   681  		}
   682  		if err == nil && !test.allow {
   683  			t.Errorf("%d: policy(%q): nil; want an error", i, test.host)
   684  		}
   685  	}
   686  }
   687  
   688  func TestValidCert(t *testing.T) {
   689  	key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   690  	if err != nil {
   691  		t.Fatal(err)
   692  	}
   693  	key2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   694  	if err != nil {
   695  		t.Fatal(err)
   696  	}
   697  	key3, err := rsa.GenerateKey(rand.Reader, 512)
   698  	if err != nil {
   699  		t.Fatal(err)
   700  	}
   701  	cert1, err := dummyCert(key1.Public(), "example.org")
   702  	if err != nil {
   703  		t.Fatal(err)
   704  	}
   705  	cert2, err := dummyCert(key2.Public(), "example.org")
   706  	if err != nil {
   707  		t.Fatal(err)
   708  	}
   709  	cert3, err := dummyCert(key3.Public(), "example.org")
   710  	if err != nil {
   711  		t.Fatal(err)
   712  	}
   713  	now := time.Now()
   714  	early, err := dateDummyCert(key1.Public(), now.Add(time.Hour), now.Add(2*time.Hour), "example.org")
   715  	if err != nil {
   716  		t.Fatal(err)
   717  	}
   718  	expired, err := dateDummyCert(key1.Public(), now.Add(-2*time.Hour), now.Add(-time.Hour), "example.org")
   719  	if err != nil {
   720  		t.Fatal(err)
   721  	}
   722  
   723  	tt := []struct {
   724  		ck   certKey
   725  		key  crypto.Signer
   726  		cert [][]byte
   727  		ok   bool
   728  	}{
   729  		{certKey{domain: "example.org"}, key1, [][]byte{cert1}, true},
   730  		{certKey{domain: "example.org", isRSA: true}, key3, [][]byte{cert3}, true},
   731  		{certKey{domain: "example.org"}, key1, [][]byte{cert1, cert2, cert3}, true},
   732  		{certKey{domain: "example.org"}, key1, [][]byte{cert1, {1}}, false},
   733  		{certKey{domain: "example.org"}, key1, [][]byte{{1}}, false},
   734  		{certKey{domain: "example.org"}, key1, [][]byte{cert2}, false},
   735  		{certKey{domain: "example.org"}, key2, [][]byte{cert1}, false},
   736  		{certKey{domain: "example.org"}, key1, [][]byte{cert3}, false},
   737  		{certKey{domain: "example.org"}, key3, [][]byte{cert1}, false},
   738  		{certKey{domain: "example.net"}, key1, [][]byte{cert1}, false},
   739  		{certKey{domain: "example.org"}, key1, [][]byte{early}, false},
   740  		{certKey{domain: "example.org"}, key1, [][]byte{expired}, false},
   741  		{certKey{domain: "example.org", isRSA: true}, key1, [][]byte{cert1}, false},
   742  		{certKey{domain: "example.org"}, key3, [][]byte{cert3}, false},
   743  	}
   744  	for i, test := range tt {
   745  		leaf, err := validCert(test.ck, test.cert, test.key, now)
   746  		if err != nil && test.ok {
   747  			t.Errorf("%d: err = %v", i, err)
   748  		}
   749  		if err == nil && !test.ok {
   750  			t.Errorf("%d: err is nil", i)
   751  		}
   752  		if err == nil && test.ok && leaf == nil {
   753  			t.Errorf("%d: leaf is nil", i)
   754  		}
   755  	}
   756  }
   757  
   758  type cacheGetFunc func(ctx context.Context, key string) ([]byte, error)
   759  
   760  func (f cacheGetFunc) Get(ctx context.Context, key string) ([]byte, error) {
   761  	return f(ctx, key)
   762  }
   763  
   764  func (f cacheGetFunc) Put(ctx context.Context, key string, data []byte) error {
   765  	return fmt.Errorf("unsupported Put of %q = %q", key, data)
   766  }
   767  
   768  func (f cacheGetFunc) Delete(ctx context.Context, key string) error {
   769  	return fmt.Errorf("unsupported Delete of %q", key)
   770  }
   771  
   772  func TestManagerGetCertificateBogusSNI(t *testing.T) {
   773  	m := Manager{
   774  		Prompt: AcceptTOS,
   775  		Cache: cacheGetFunc(func(ctx context.Context, key string) ([]byte, error) {
   776  			return nil, fmt.Errorf("cache.Get of %s", key)
   777  		}),
   778  	}
   779  	tests := []struct {
   780  		name    string
   781  		wantErr string
   782  	}{
   783  		{"foo.com", "cache.Get of foo.com"},
   784  		{"foo.com.", "cache.Get of foo.com"},
   785  		{`a\b.com`, "acme/autocert: server name contains invalid character"},
   786  		{`a/b.com`, "acme/autocert: server name contains invalid character"},
   787  		{"", "acme/autocert: missing server name"},
   788  		{"foo", "acme/autocert: server name component count invalid"},
   789  		{".foo", "acme/autocert: server name component count invalid"},
   790  		{"foo.", "acme/autocert: server name component count invalid"},
   791  		{"fo.o", "cache.Get of fo.o"},
   792  	}
   793  	for _, tt := range tests {
   794  		_, err := m.GetCertificate(clientHelloInfo(tt.name, algECDSA))
   795  		got := fmt.Sprint(err)
   796  		if got != tt.wantErr {
   797  			t.Errorf("GetCertificate(SNI = %q) = %q; want %q", tt.name, got, tt.wantErr)
   798  		}
   799  	}
   800  }
   801  
   802  func TestCertRequest(t *testing.T) {
   803  	key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   804  	if err != nil {
   805  		t.Fatal(err)
   806  	}
   807  	// An extension from RFC7633. Any will do.
   808  	ext := pkix.Extension{
   809  		Id:    asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1},
   810  		Value: []byte("dummy"),
   811  	}
   812  	b, err := certRequest(key, "example.org", []pkix.Extension{ext})
   813  	if err != nil {
   814  		t.Fatalf("certRequest: %v", err)
   815  	}
   816  	r, err := x509.ParseCertificateRequest(b)
   817  	if err != nil {
   818  		t.Fatalf("ParseCertificateRequest: %v", err)
   819  	}
   820  	var found bool
   821  	for _, v := range r.Extensions {
   822  		if v.Id.Equal(ext.Id) {
   823  			found = true
   824  			break
   825  		}
   826  	}
   827  	if !found {
   828  		t.Errorf("want %v in Extensions: %v", ext, r.Extensions)
   829  	}
   830  }
   831  
   832  func TestSupportsECDSA(t *testing.T) {
   833  	tests := []struct {
   834  		CipherSuites     []uint16
   835  		SignatureSchemes []tls.SignatureScheme
   836  		SupportedCurves  []tls.CurveID
   837  		ecdsaOk          bool
   838  	}{
   839  		{[]uint16{
   840  			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
   841  		}, nil, nil, false},
   842  		{[]uint16{
   843  			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   844  		}, nil, nil, true},
   845  
   846  		// SignatureSchemes limits, not extends, CipherSuites
   847  		{[]uint16{
   848  			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
   849  		}, []tls.SignatureScheme{
   850  			tls.PKCS1WithSHA256, tls.ECDSAWithP256AndSHA256,
   851  		}, nil, false},
   852  		{[]uint16{
   853  			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   854  		}, []tls.SignatureScheme{
   855  			tls.PKCS1WithSHA256,
   856  		}, nil, false},
   857  		{[]uint16{
   858  			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   859  		}, []tls.SignatureScheme{
   860  			tls.PKCS1WithSHA256, tls.ECDSAWithP256AndSHA256,
   861  		}, nil, true},
   862  
   863  		{[]uint16{
   864  			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   865  		}, []tls.SignatureScheme{
   866  			tls.PKCS1WithSHA256, tls.ECDSAWithP256AndSHA256,
   867  		}, []tls.CurveID{
   868  			tls.CurveP521,
   869  		}, false},
   870  		{[]uint16{
   871  			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   872  		}, []tls.SignatureScheme{
   873  			tls.PKCS1WithSHA256, tls.ECDSAWithP256AndSHA256,
   874  		}, []tls.CurveID{
   875  			tls.CurveP256,
   876  			tls.CurveP521,
   877  		}, true},
   878  	}
   879  	for i, tt := range tests {
   880  		result := supportsECDSA(&tls.ClientHelloInfo{
   881  			CipherSuites:     tt.CipherSuites,
   882  			SignatureSchemes: tt.SignatureSchemes,
   883  			SupportedCurves:  tt.SupportedCurves,
   884  		})
   885  		if result != tt.ecdsaOk {
   886  			t.Errorf("%d: supportsECDSA = %v; want %v", i, result, tt.ecdsaOk)
   887  		}
   888  	}
   889  }
   890  
   891  func TestEndToEndALPN(t *testing.T) {
   892  	const domain = "example.org"
   893  
   894  	// ACME CA server
   895  	ca := acmetest.NewCAServer(t).Start()
   896  
   897  	// User HTTPS server.
   898  	m := &Manager{
   899  		Prompt: AcceptTOS,
   900  		Client: &acme.Client{DirectoryURL: ca.URL()},
   901  	}
   902  	us := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   903  		w.Write([]byte("OK"))
   904  	}))
   905  	us.TLS = &tls.Config{
   906  		NextProtos: []string{"http/1.1", acme.ALPNProto},
   907  		GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   908  			cert, err := m.GetCertificate(hello)
   909  			if err != nil {
   910  				t.Errorf("m.GetCertificate: %v", err)
   911  			}
   912  			return cert, err
   913  		},
   914  	}
   915  	us.StartTLS()
   916  	defer us.Close()
   917  	// In TLS-ALPN challenge verification, CA connects to the domain:443 in question.
   918  	// Because the domain won't resolve in tests, we need to tell the CA
   919  	// where to dial to instead.
   920  	ca.Resolve(domain, strings.TrimPrefix(us.URL, "https://"))
   921  
   922  	// A client visiting user's HTTPS server.
   923  	tr := &http.Transport{
   924  		TLSClientConfig: &tls.Config{
   925  			RootCAs:    ca.Roots(),
   926  			ServerName: domain,
   927  		},
   928  	}
   929  	client := &http.Client{Transport: tr}
   930  	res, err := client.Get(us.URL)
   931  	if err != nil {
   932  		t.Fatal(err)
   933  	}
   934  	defer res.Body.Close()
   935  	b, err := io.ReadAll(res.Body)
   936  	if err != nil {
   937  		t.Fatal(err)
   938  	}
   939  	if v := string(b); v != "OK" {
   940  		t.Errorf("user server response: %q; want 'OK'", v)
   941  	}
   942  }
   943  
   944  func TestEndToEndHTTP(t *testing.T) {
   945  	const domain = "example.org"
   946  
   947  	// ACME CA server.
   948  	ca := acmetest.NewCAServer(t).ChallengeTypes("http-01").Start()
   949  
   950  	// User HTTP server for the ACME challenge.
   951  	m := testManager(t)
   952  	m.Client = &acme.Client{DirectoryURL: ca.URL()}
   953  	s := httptest.NewServer(m.HTTPHandler(nil))
   954  	defer s.Close()
   955  
   956  	// User HTTPS server.
   957  	ss := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   958  		w.Write([]byte("OK"))
   959  	}))
   960  	ss.TLS = &tls.Config{
   961  		NextProtos: []string{"http/1.1", acme.ALPNProto},
   962  		GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   963  			cert, err := m.GetCertificate(hello)
   964  			if err != nil {
   965  				t.Errorf("m.GetCertificate: %v", err)
   966  			}
   967  			return cert, err
   968  		},
   969  	}
   970  	ss.StartTLS()
   971  	defer ss.Close()
   972  
   973  	// Redirect the CA requests to the HTTP server.
   974  	ca.Resolve(domain, strings.TrimPrefix(s.URL, "http://"))
   975  
   976  	// A client visiting user's HTTPS server.
   977  	tr := &http.Transport{
   978  		TLSClientConfig: &tls.Config{
   979  			RootCAs:    ca.Roots(),
   980  			ServerName: domain,
   981  		},
   982  	}
   983  	client := &http.Client{Transport: tr}
   984  	res, err := client.Get(ss.URL)
   985  	if err != nil {
   986  		t.Fatal(err)
   987  	}
   988  	defer res.Body.Close()
   989  	b, err := io.ReadAll(res.Body)
   990  	if err != nil {
   991  		t.Fatal(err)
   992  	}
   993  	if v := string(b); v != "OK" {
   994  		t.Errorf("user server response: %q; want 'OK'", v)
   995  	}
   996  }
   997  

View as plain text