1
2
3
4
5 package tls
6
7 import (
8 "crypto/aes"
9 "crypto/cipher"
10 "crypto/hmac"
11 "crypto/sha256"
12 "crypto/subtle"
13 "crypto/x509"
14 "errors"
15 "io"
16
17 "golang.org/x/crypto/cryptobyte"
18 )
19
20
21 type SessionState struct {
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73 Extra [][]byte
74
75
76
77
78 EarlyData bool
79
80 version uint16
81 isClient bool
82 cipherSuite uint16
83
84
85
86 createdAt uint64
87 secret []byte
88 extMasterSecret bool
89 peerCertificates []*x509.Certificate
90 activeCertHandles []*activeCert
91 ocspResponse []byte
92 scts [][]byte
93 verifiedChains [][]*x509.Certificate
94 alpnProtocol string
95
96
97 useBy uint64
98 ageAdd uint32
99 }
100
101
102
103
104
105
106
107 func (s *SessionState) Bytes() ([]byte, error) {
108 var b cryptobyte.Builder
109 b.AddUint16(s.version)
110 if s.isClient {
111 b.AddUint8(2)
112 } else {
113 b.AddUint8(1)
114 }
115 b.AddUint16(s.cipherSuite)
116 addUint64(&b, s.createdAt)
117 b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
118 b.AddBytes(s.secret)
119 })
120 b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
121 for _, extra := range s.Extra {
122 b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
123 b.AddBytes(extra)
124 })
125 }
126 })
127 if s.extMasterSecret {
128 b.AddUint8(1)
129 } else {
130 b.AddUint8(0)
131 }
132 if s.EarlyData {
133 b.AddUint8(1)
134 } else {
135 b.AddUint8(0)
136 }
137 marshalCertificate(&b, Certificate{
138 Certificate: certificatesToBytesSlice(s.peerCertificates),
139 OCSPStaple: s.ocspResponse,
140 SignedCertificateTimestamps: s.scts,
141 })
142 b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
143 for _, chain := range s.verifiedChains {
144 b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
145
146 if len(chain) == 0 {
147 b.SetError(errors.New("tls: internal error: empty verified chain"))
148 return
149 }
150 for _, cert := range chain[1:] {
151 b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
152 b.AddBytes(cert.Raw)
153 })
154 }
155 })
156 }
157 })
158 if s.EarlyData {
159 b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
160 b.AddBytes([]byte(s.alpnProtocol))
161 })
162 }
163 if s.isClient {
164 if s.version >= VersionTLS13 {
165 addUint64(&b, s.useBy)
166 b.AddUint32(s.ageAdd)
167 }
168 }
169 return b.Bytes()
170 }
171
172 func certificatesToBytesSlice(certs []*x509.Certificate) [][]byte {
173 s := make([][]byte, 0, len(certs))
174 for _, c := range certs {
175 s = append(s, c.Raw)
176 }
177 return s
178 }
179
180
181 func ParseSessionState(data []byte) (*SessionState, error) {
182 ss := &SessionState{}
183 s := cryptobyte.String(data)
184 var typ, extMasterSecret, earlyData uint8
185 var cert Certificate
186 var extra cryptobyte.String
187 if !s.ReadUint16(&ss.version) ||
188 !s.ReadUint8(&typ) ||
189 (typ != 1 && typ != 2) ||
190 !s.ReadUint16(&ss.cipherSuite) ||
191 !readUint64(&s, &ss.createdAt) ||
192 !readUint8LengthPrefixed(&s, &ss.secret) ||
193 !s.ReadUint24LengthPrefixed(&extra) ||
194 !s.ReadUint8(&extMasterSecret) ||
195 !s.ReadUint8(&earlyData) ||
196 len(ss.secret) == 0 ||
197 !unmarshalCertificate(&s, &cert) {
198 return nil, errors.New("tls: invalid session encoding")
199 }
200 for !extra.Empty() {
201 var e []byte
202 if !readUint24LengthPrefixed(&extra, &e) {
203 return nil, errors.New("tls: invalid session encoding")
204 }
205 ss.Extra = append(ss.Extra, e)
206 }
207 switch extMasterSecret {
208 case 0:
209 ss.extMasterSecret = false
210 case 1:
211 ss.extMasterSecret = true
212 default:
213 return nil, errors.New("tls: invalid session encoding")
214 }
215 switch earlyData {
216 case 0:
217 ss.EarlyData = false
218 case 1:
219 ss.EarlyData = true
220 default:
221 return nil, errors.New("tls: invalid session encoding")
222 }
223 for _, cert := range cert.Certificate {
224 c, err := globalCertCache.newCert(cert)
225 if err != nil {
226 return nil, err
227 }
228 ss.activeCertHandles = append(ss.activeCertHandles, c)
229 ss.peerCertificates = append(ss.peerCertificates, c.cert)
230 }
231 ss.ocspResponse = cert.OCSPStaple
232 ss.scts = cert.SignedCertificateTimestamps
233 var chainList cryptobyte.String
234 if !s.ReadUint24LengthPrefixed(&chainList) {
235 return nil, errors.New("tls: invalid session encoding")
236 }
237 for !chainList.Empty() {
238 var certList cryptobyte.String
239 if !chainList.ReadUint24LengthPrefixed(&certList) {
240 return nil, errors.New("tls: invalid session encoding")
241 }
242 var chain []*x509.Certificate
243 if len(ss.peerCertificates) == 0 {
244 return nil, errors.New("tls: invalid session encoding")
245 }
246 chain = append(chain, ss.peerCertificates[0])
247 for !certList.Empty() {
248 var cert []byte
249 if !readUint24LengthPrefixed(&certList, &cert) {
250 return nil, errors.New("tls: invalid session encoding")
251 }
252 c, err := globalCertCache.newCert(cert)
253 if err != nil {
254 return nil, err
255 }
256 ss.activeCertHandles = append(ss.activeCertHandles, c)
257 chain = append(chain, c.cert)
258 }
259 ss.verifiedChains = append(ss.verifiedChains, chain)
260 }
261 if ss.EarlyData {
262 var alpn []byte
263 if !readUint8LengthPrefixed(&s, &alpn) {
264 return nil, errors.New("tls: invalid session encoding")
265 }
266 ss.alpnProtocol = string(alpn)
267 }
268 if isClient := typ == 2; !isClient {
269 if !s.Empty() {
270 return nil, errors.New("tls: invalid session encoding")
271 }
272 return ss, nil
273 }
274 ss.isClient = true
275 if len(ss.peerCertificates) == 0 {
276 return nil, errors.New("tls: no server certificates in client session")
277 }
278 if ss.version < VersionTLS13 {
279 if !s.Empty() {
280 return nil, errors.New("tls: invalid session encoding")
281 }
282 return ss, nil
283 }
284 if !s.ReadUint64(&ss.useBy) || !s.ReadUint32(&ss.ageAdd) || !s.Empty() {
285 return nil, errors.New("tls: invalid session encoding")
286 }
287 return ss, nil
288 }
289
290
291
292 func (c *Conn) sessionState() (*SessionState, error) {
293 return &SessionState{
294 version: c.vers,
295 cipherSuite: c.cipherSuite,
296 createdAt: uint64(c.config.time().Unix()),
297 alpnProtocol: c.clientProtocol,
298 peerCertificates: c.peerCertificates,
299 activeCertHandles: c.activeCertHandles,
300 ocspResponse: c.ocspResponse,
301 scts: c.scts,
302 isClient: c.isClient,
303 extMasterSecret: c.extMasterSecret,
304 verifiedChains: c.verifiedChains,
305 }, nil
306 }
307
308
309
310 func (c *Config) EncryptTicket(cs ConnectionState, ss *SessionState) ([]byte, error) {
311 ticketKeys := c.ticketKeys(nil)
312 stateBytes, err := ss.Bytes()
313 if err != nil {
314 return nil, err
315 }
316 return c.encryptTicket(stateBytes, ticketKeys)
317 }
318
319 func (c *Config) encryptTicket(state []byte, ticketKeys []ticketKey) ([]byte, error) {
320 if len(ticketKeys) == 0 {
321 return nil, errors.New("tls: internal error: session ticket keys unavailable")
322 }
323
324 encrypted := make([]byte, aes.BlockSize+len(state)+sha256.Size)
325 iv := encrypted[:aes.BlockSize]
326 ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
327 authenticated := encrypted[:len(encrypted)-sha256.Size]
328 macBytes := encrypted[len(encrypted)-sha256.Size:]
329
330 if _, err := io.ReadFull(c.rand(), iv); err != nil {
331 return nil, err
332 }
333 key := ticketKeys[0]
334 block, err := aes.NewCipher(key.aesKey[:])
335 if err != nil {
336 return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
337 }
338 cipher.NewCTR(block, iv).XORKeyStream(ciphertext, state)
339
340 mac := hmac.New(sha256.New, key.hmacKey[:])
341 mac.Write(authenticated)
342 mac.Sum(macBytes[:0])
343
344 return encrypted, nil
345 }
346
347
348
349
350
351 func (c *Config) DecryptTicket(identity []byte, cs ConnectionState) (*SessionState, error) {
352 ticketKeys := c.ticketKeys(nil)
353 stateBytes := c.decryptTicket(identity, ticketKeys)
354 if stateBytes == nil {
355 return nil, nil
356 }
357 s, err := ParseSessionState(stateBytes)
358 if err != nil {
359 return nil, nil
360 }
361 return s, nil
362 }
363
364 func (c *Config) decryptTicket(encrypted []byte, ticketKeys []ticketKey) []byte {
365 if len(encrypted) < aes.BlockSize+sha256.Size {
366 return nil
367 }
368
369 iv := encrypted[:aes.BlockSize]
370 ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
371 authenticated := encrypted[:len(encrypted)-sha256.Size]
372 macBytes := encrypted[len(encrypted)-sha256.Size:]
373
374 for _, key := range ticketKeys {
375 mac := hmac.New(sha256.New, key.hmacKey[:])
376 mac.Write(authenticated)
377 expected := mac.Sum(nil)
378
379 if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
380 continue
381 }
382
383 block, err := aes.NewCipher(key.aesKey[:])
384 if err != nil {
385 return nil
386 }
387 plaintext := make([]byte, len(ciphertext))
388 cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
389
390 return plaintext
391 }
392
393 return nil
394 }
395
396
397
398 type ClientSessionState struct {
399 ticket []byte
400 session *SessionState
401 }
402
403
404
405
406
407
408 func (cs *ClientSessionState) ResumptionState() (ticket []byte, state *SessionState, err error) {
409 return cs.ticket, cs.session, nil
410 }
411
412
413
414
415
416
417 func NewResumptionState(ticket []byte, state *SessionState) (*ClientSessionState, error) {
418 return &ClientSessionState{
419 ticket: ticket, session: state,
420 }, nil
421 }
422
View as plain text