1
2
3
4
5
6
7
8
9
10 package otr
11
12 import (
13 "bytes"
14 "crypto/sha256"
15 "errors"
16 "hash"
17 "math/big"
18 )
19
20 type smpFailure string
21
22 func (s smpFailure) Error() string {
23 return string(s)
24 }
25
26 var smpFailureError = smpFailure("otr: SMP protocol failed")
27 var smpSecretMissingError = smpFailure("otr: mutual secret needed")
28
29 const smpVersion = 1
30
31 const (
32 smpState1 = iota
33 smpState2
34 smpState3
35 smpState4
36 )
37
38 type smpState struct {
39 state int
40 a2, a3, b2, b3, pb, qb *big.Int
41 g2a, g3a *big.Int
42 g2, g3 *big.Int
43 g3b, papb, qaqb, ra *big.Int
44 saved *tlv
45 secret *big.Int
46 question string
47 }
48
49 func (c *Conversation) startSMP(question string) (tlvs []tlv) {
50 if c.smp.state != smpState1 {
51 tlvs = append(tlvs, c.generateSMPAbort())
52 }
53 tlvs = append(tlvs, c.generateSMP1(question))
54 c.smp.question = ""
55 c.smp.state = smpState2
56 return
57 }
58
59 func (c *Conversation) resetSMP() {
60 c.smp.state = smpState1
61 c.smp.secret = nil
62 c.smp.question = ""
63 }
64
65 func (c *Conversation) processSMP(in tlv) (out tlv, complete bool, err error) {
66 data := in.data
67
68 switch in.typ {
69 case tlvTypeSMPAbort:
70 if c.smp.state != smpState1 {
71 err = smpFailureError
72 }
73 c.resetSMP()
74 return
75 case tlvTypeSMP1WithQuestion:
76
77 nulPos := bytes.IndexByte(data, 0)
78 if nulPos == -1 {
79 err = errors.New("otr: SMP message with question didn't contain a NUL byte")
80 return
81 }
82 c.smp.question = string(data[:nulPos])
83 data = data[nulPos+1:]
84 }
85
86 numMPIs, data, ok := getU32(data)
87 if !ok || numMPIs > 20 {
88 err = errors.New("otr: corrupt SMP message")
89 return
90 }
91
92 mpis := make([]*big.Int, numMPIs)
93 for i := range mpis {
94 var ok bool
95 mpis[i], data, ok = getMPI(data)
96 if !ok {
97 err = errors.New("otr: corrupt SMP message")
98 return
99 }
100 }
101
102 switch in.typ {
103 case tlvTypeSMP1, tlvTypeSMP1WithQuestion:
104 if c.smp.state != smpState1 {
105 c.resetSMP()
106 out = c.generateSMPAbort()
107 return
108 }
109 if c.smp.secret == nil {
110 err = smpSecretMissingError
111 return
112 }
113 if err = c.processSMP1(mpis); err != nil {
114 return
115 }
116 c.smp.state = smpState3
117 out = c.generateSMP2()
118 case tlvTypeSMP2:
119 if c.smp.state != smpState2 {
120 c.resetSMP()
121 out = c.generateSMPAbort()
122 return
123 }
124 if out, err = c.processSMP2(mpis); err != nil {
125 out = c.generateSMPAbort()
126 return
127 }
128 c.smp.state = smpState4
129 case tlvTypeSMP3:
130 if c.smp.state != smpState3 {
131 c.resetSMP()
132 out = c.generateSMPAbort()
133 return
134 }
135 if out, err = c.processSMP3(mpis); err != nil {
136 return
137 }
138 c.smp.state = smpState1
139 c.smp.secret = nil
140 complete = true
141 case tlvTypeSMP4:
142 if c.smp.state != smpState4 {
143 c.resetSMP()
144 out = c.generateSMPAbort()
145 return
146 }
147 if err = c.processSMP4(mpis); err != nil {
148 out = c.generateSMPAbort()
149 return
150 }
151 c.smp.state = smpState1
152 c.smp.secret = nil
153 complete = true
154 default:
155 panic("unknown SMP message")
156 }
157
158 return
159 }
160
161 func (c *Conversation) calcSMPSecret(mutualSecret []byte, weStarted bool) {
162 h := sha256.New()
163 h.Write([]byte{smpVersion})
164 if weStarted {
165 h.Write(c.PrivateKey.PublicKey.Fingerprint())
166 h.Write(c.TheirPublicKey.Fingerprint())
167 } else {
168 h.Write(c.TheirPublicKey.Fingerprint())
169 h.Write(c.PrivateKey.PublicKey.Fingerprint())
170 }
171 h.Write(c.SSID[:])
172 h.Write(mutualSecret)
173 c.smp.secret = new(big.Int).SetBytes(h.Sum(nil))
174 }
175
176 func (c *Conversation) generateSMP1(question string) tlv {
177 var randBuf [16]byte
178 c.smp.a2 = c.randMPI(randBuf[:])
179 c.smp.a3 = c.randMPI(randBuf[:])
180 g2a := new(big.Int).Exp(g, c.smp.a2, p)
181 g3a := new(big.Int).Exp(g, c.smp.a3, p)
182 h := sha256.New()
183
184 r2 := c.randMPI(randBuf[:])
185 r := new(big.Int).Exp(g, r2, p)
186 c2 := new(big.Int).SetBytes(hashMPIs(h, 1, r))
187 d2 := new(big.Int).Mul(c.smp.a2, c2)
188 d2.Sub(r2, d2)
189 d2.Mod(d2, q)
190 if d2.Sign() < 0 {
191 d2.Add(d2, q)
192 }
193
194 r3 := c.randMPI(randBuf[:])
195 r.Exp(g, r3, p)
196 c3 := new(big.Int).SetBytes(hashMPIs(h, 2, r))
197 d3 := new(big.Int).Mul(c.smp.a3, c3)
198 d3.Sub(r3, d3)
199 d3.Mod(d3, q)
200 if d3.Sign() < 0 {
201 d3.Add(d3, q)
202 }
203
204 var ret tlv
205 if len(question) > 0 {
206 ret.typ = tlvTypeSMP1WithQuestion
207 ret.data = append(ret.data, question...)
208 ret.data = append(ret.data, 0)
209 } else {
210 ret.typ = tlvTypeSMP1
211 }
212 ret.data = appendU32(ret.data, 6)
213 ret.data = appendMPIs(ret.data, g2a, c2, d2, g3a, c3, d3)
214 return ret
215 }
216
217 func (c *Conversation) processSMP1(mpis []*big.Int) error {
218 if len(mpis) != 6 {
219 return errors.New("otr: incorrect number of arguments in SMP1 message")
220 }
221 g2a := mpis[0]
222 c2 := mpis[1]
223 d2 := mpis[2]
224 g3a := mpis[3]
225 c3 := mpis[4]
226 d3 := mpis[5]
227 h := sha256.New()
228
229 r := new(big.Int).Exp(g, d2, p)
230 s := new(big.Int).Exp(g2a, c2, p)
231 r.Mul(r, s)
232 r.Mod(r, p)
233 t := new(big.Int).SetBytes(hashMPIs(h, 1, r))
234 if c2.Cmp(t) != 0 {
235 return errors.New("otr: ZKP c2 incorrect in SMP1 message")
236 }
237 r.Exp(g, d3, p)
238 s.Exp(g3a, c3, p)
239 r.Mul(r, s)
240 r.Mod(r, p)
241 t.SetBytes(hashMPIs(h, 2, r))
242 if c3.Cmp(t) != 0 {
243 return errors.New("otr: ZKP c3 incorrect in SMP1 message")
244 }
245
246 c.smp.g2a = g2a
247 c.smp.g3a = g3a
248 return nil
249 }
250
251 func (c *Conversation) generateSMP2() tlv {
252 var randBuf [16]byte
253 b2 := c.randMPI(randBuf[:])
254 c.smp.b3 = c.randMPI(randBuf[:])
255 r2 := c.randMPI(randBuf[:])
256 r3 := c.randMPI(randBuf[:])
257 r4 := c.randMPI(randBuf[:])
258 r5 := c.randMPI(randBuf[:])
259 r6 := c.randMPI(randBuf[:])
260
261 g2b := new(big.Int).Exp(g, b2, p)
262 g3b := new(big.Int).Exp(g, c.smp.b3, p)
263
264 r := new(big.Int).Exp(g, r2, p)
265 h := sha256.New()
266 c2 := new(big.Int).SetBytes(hashMPIs(h, 3, r))
267 d2 := new(big.Int).Mul(b2, c2)
268 d2.Sub(r2, d2)
269 d2.Mod(d2, q)
270 if d2.Sign() < 0 {
271 d2.Add(d2, q)
272 }
273
274 r.Exp(g, r3, p)
275 c3 := new(big.Int).SetBytes(hashMPIs(h, 4, r))
276 d3 := new(big.Int).Mul(c.smp.b3, c3)
277 d3.Sub(r3, d3)
278 d3.Mod(d3, q)
279 if d3.Sign() < 0 {
280 d3.Add(d3, q)
281 }
282
283 c.smp.g2 = new(big.Int).Exp(c.smp.g2a, b2, p)
284 c.smp.g3 = new(big.Int).Exp(c.smp.g3a, c.smp.b3, p)
285 c.smp.pb = new(big.Int).Exp(c.smp.g3, r4, p)
286 c.smp.qb = new(big.Int).Exp(g, r4, p)
287 r.Exp(c.smp.g2, c.smp.secret, p)
288 c.smp.qb.Mul(c.smp.qb, r)
289 c.smp.qb.Mod(c.smp.qb, p)
290
291 s := new(big.Int)
292 s.Exp(c.smp.g2, r6, p)
293 r.Exp(g, r5, p)
294 s.Mul(r, s)
295 s.Mod(s, p)
296 r.Exp(c.smp.g3, r5, p)
297 cp := new(big.Int).SetBytes(hashMPIs(h, 5, r, s))
298
299
300
301 s.Mul(r4, cp)
302 r.Sub(r5, s)
303 d5 := new(big.Int).Mod(r, q)
304 if d5.Sign() < 0 {
305 d5.Add(d5, q)
306 }
307
308 s.Mul(c.smp.secret, cp)
309 r.Sub(r6, s)
310 d6 := new(big.Int).Mod(r, q)
311 if d6.Sign() < 0 {
312 d6.Add(d6, q)
313 }
314
315 var ret tlv
316 ret.typ = tlvTypeSMP2
317 ret.data = appendU32(ret.data, 11)
318 ret.data = appendMPIs(ret.data, g2b, c2, d2, g3b, c3, d3, c.smp.pb, c.smp.qb, cp, d5, d6)
319 return ret
320 }
321
322 func (c *Conversation) processSMP2(mpis []*big.Int) (out tlv, err error) {
323 if len(mpis) != 11 {
324 err = errors.New("otr: incorrect number of arguments in SMP2 message")
325 return
326 }
327 g2b := mpis[0]
328 c2 := mpis[1]
329 d2 := mpis[2]
330 g3b := mpis[3]
331 c3 := mpis[4]
332 d3 := mpis[5]
333 pb := mpis[6]
334 qb := mpis[7]
335 cp := mpis[8]
336 d5 := mpis[9]
337 d6 := mpis[10]
338 h := sha256.New()
339
340 r := new(big.Int).Exp(g, d2, p)
341 s := new(big.Int).Exp(g2b, c2, p)
342 r.Mul(r, s)
343 r.Mod(r, p)
344 s.SetBytes(hashMPIs(h, 3, r))
345 if c2.Cmp(s) != 0 {
346 err = errors.New("otr: ZKP c2 failed in SMP2 message")
347 return
348 }
349
350 r.Exp(g, d3, p)
351 s.Exp(g3b, c3, p)
352 r.Mul(r, s)
353 r.Mod(r, p)
354 s.SetBytes(hashMPIs(h, 4, r))
355 if c3.Cmp(s) != 0 {
356 err = errors.New("otr: ZKP c3 failed in SMP2 message")
357 return
358 }
359
360 c.smp.g2 = new(big.Int).Exp(g2b, c.smp.a2, p)
361 c.smp.g3 = new(big.Int).Exp(g3b, c.smp.a3, p)
362
363 r.Exp(g, d5, p)
364 s.Exp(c.smp.g2, d6, p)
365 r.Mul(r, s)
366 s.Exp(qb, cp, p)
367 r.Mul(r, s)
368 r.Mod(r, p)
369
370 s.Exp(c.smp.g3, d5, p)
371 t := new(big.Int).Exp(pb, cp, p)
372 s.Mul(s, t)
373 s.Mod(s, p)
374 t.SetBytes(hashMPIs(h, 5, s, r))
375 if cp.Cmp(t) != 0 {
376 err = errors.New("otr: ZKP cP failed in SMP2 message")
377 return
378 }
379
380 var randBuf [16]byte
381 r4 := c.randMPI(randBuf[:])
382 r5 := c.randMPI(randBuf[:])
383 r6 := c.randMPI(randBuf[:])
384 r7 := c.randMPI(randBuf[:])
385
386 pa := new(big.Int).Exp(c.smp.g3, r4, p)
387 r.Exp(c.smp.g2, c.smp.secret, p)
388 qa := new(big.Int).Exp(g, r4, p)
389 qa.Mul(qa, r)
390 qa.Mod(qa, p)
391
392 r.Exp(g, r5, p)
393 s.Exp(c.smp.g2, r6, p)
394 r.Mul(r, s)
395 r.Mod(r, p)
396
397 s.Exp(c.smp.g3, r5, p)
398 cp.SetBytes(hashMPIs(h, 6, s, r))
399
400 r.Mul(r4, cp)
401 d5 = new(big.Int).Sub(r5, r)
402 d5.Mod(d5, q)
403 if d5.Sign() < 0 {
404 d5.Add(d5, q)
405 }
406
407 r.Mul(c.smp.secret, cp)
408 d6 = new(big.Int).Sub(r6, r)
409 d6.Mod(d6, q)
410 if d6.Sign() < 0 {
411 d6.Add(d6, q)
412 }
413
414 r.ModInverse(qb, p)
415 qaqb := new(big.Int).Mul(qa, r)
416 qaqb.Mod(qaqb, p)
417
418 ra := new(big.Int).Exp(qaqb, c.smp.a3, p)
419 r.Exp(qaqb, r7, p)
420 s.Exp(g, r7, p)
421 cr := new(big.Int).SetBytes(hashMPIs(h, 7, s, r))
422
423 r.Mul(c.smp.a3, cr)
424 d7 := new(big.Int).Sub(r7, r)
425 d7.Mod(d7, q)
426 if d7.Sign() < 0 {
427 d7.Add(d7, q)
428 }
429
430 c.smp.g3b = g3b
431 c.smp.qaqb = qaqb
432
433 r.ModInverse(pb, p)
434 c.smp.papb = new(big.Int).Mul(pa, r)
435 c.smp.papb.Mod(c.smp.papb, p)
436 c.smp.ra = ra
437
438 out.typ = tlvTypeSMP3
439 out.data = appendU32(out.data, 8)
440 out.data = appendMPIs(out.data, pa, qa, cp, d5, d6, ra, cr, d7)
441 return
442 }
443
444 func (c *Conversation) processSMP3(mpis []*big.Int) (out tlv, err error) {
445 if len(mpis) != 8 {
446 err = errors.New("otr: incorrect number of arguments in SMP3 message")
447 return
448 }
449 pa := mpis[0]
450 qa := mpis[1]
451 cp := mpis[2]
452 d5 := mpis[3]
453 d6 := mpis[4]
454 ra := mpis[5]
455 cr := mpis[6]
456 d7 := mpis[7]
457 h := sha256.New()
458
459 r := new(big.Int).Exp(g, d5, p)
460 s := new(big.Int).Exp(c.smp.g2, d6, p)
461 r.Mul(r, s)
462 s.Exp(qa, cp, p)
463 r.Mul(r, s)
464 r.Mod(r, p)
465
466 s.Exp(c.smp.g3, d5, p)
467 t := new(big.Int).Exp(pa, cp, p)
468 s.Mul(s, t)
469 s.Mod(s, p)
470 t.SetBytes(hashMPIs(h, 6, s, r))
471 if t.Cmp(cp) != 0 {
472 err = errors.New("otr: ZKP cP failed in SMP3 message")
473 return
474 }
475
476 r.ModInverse(c.smp.qb, p)
477 qaqb := new(big.Int).Mul(qa, r)
478 qaqb.Mod(qaqb, p)
479
480 r.Exp(qaqb, d7, p)
481 s.Exp(ra, cr, p)
482 r.Mul(r, s)
483 r.Mod(r, p)
484
485 s.Exp(g, d7, p)
486 t.Exp(c.smp.g3a, cr, p)
487 s.Mul(s, t)
488 s.Mod(s, p)
489 t.SetBytes(hashMPIs(h, 7, s, r))
490 if t.Cmp(cr) != 0 {
491 err = errors.New("otr: ZKP cR failed in SMP3 message")
492 return
493 }
494
495 var randBuf [16]byte
496 r7 := c.randMPI(randBuf[:])
497 rb := new(big.Int).Exp(qaqb, c.smp.b3, p)
498
499 r.Exp(qaqb, r7, p)
500 s.Exp(g, r7, p)
501 cr = new(big.Int).SetBytes(hashMPIs(h, 8, s, r))
502
503 r.Mul(c.smp.b3, cr)
504 d7 = new(big.Int).Sub(r7, r)
505 d7.Mod(d7, q)
506 if d7.Sign() < 0 {
507 d7.Add(d7, q)
508 }
509
510 out.typ = tlvTypeSMP4
511 out.data = appendU32(out.data, 3)
512 out.data = appendMPIs(out.data, rb, cr, d7)
513
514 r.ModInverse(c.smp.pb, p)
515 r.Mul(pa, r)
516 r.Mod(r, p)
517 s.Exp(ra, c.smp.b3, p)
518 if r.Cmp(s) != 0 {
519 err = smpFailureError
520 }
521
522 return
523 }
524
525 func (c *Conversation) processSMP4(mpis []*big.Int) error {
526 if len(mpis) != 3 {
527 return errors.New("otr: incorrect number of arguments in SMP4 message")
528 }
529 rb := mpis[0]
530 cr := mpis[1]
531 d7 := mpis[2]
532 h := sha256.New()
533
534 r := new(big.Int).Exp(c.smp.qaqb, d7, p)
535 s := new(big.Int).Exp(rb, cr, p)
536 r.Mul(r, s)
537 r.Mod(r, p)
538
539 s.Exp(g, d7, p)
540 t := new(big.Int).Exp(c.smp.g3b, cr, p)
541 s.Mul(s, t)
542 s.Mod(s, p)
543 t.SetBytes(hashMPIs(h, 8, s, r))
544 if t.Cmp(cr) != 0 {
545 return errors.New("otr: ZKP cR failed in SMP4 message")
546 }
547
548 r.Exp(rb, c.smp.a3, p)
549 if r.Cmp(c.smp.papb) != 0 {
550 return smpFailureError
551 }
552
553 return nil
554 }
555
556 func (c *Conversation) generateSMPAbort() tlv {
557 return tlv{typ: tlvTypeSMPAbort}
558 }
559
560 func hashMPIs(h hash.Hash, magic byte, mpis ...*big.Int) []byte {
561 if h != nil {
562 h.Reset()
563 } else {
564 h = sha256.New()
565 }
566
567 h.Write([]byte{magic})
568 for _, mpi := range mpis {
569 h.Write(appendMPI(nil, mpi))
570 }
571 return h.Sum(nil)
572 }
573
View as plain text