...

Source file src/golang.org/x/crypto/otr/smp.go

Documentation: golang.org/x/crypto/otr

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // This file implements the Socialist Millionaires Protocol as described in
     6  // http://www.cypherpunks.ca/otr/Protocol-v2-3.1.0.html. The protocol
     7  // specification is required in order to understand this code and, where
     8  // possible, the variable names in the code match up with the spec.
     9  
    10  package otr
    11  
    12  import (
    13  	"bytes"
    14  	"crypto/sha256"
    15  	"errors"
    16  	"hash"
    17  	"math/big"
    18  )
    19  
    20  type smpFailure string
    21  
    22  func (s smpFailure) Error() string {
    23  	return string(s)
    24  }
    25  
    26  var smpFailureError = smpFailure("otr: SMP protocol failed")
    27  var smpSecretMissingError = smpFailure("otr: mutual secret needed")
    28  
    29  const smpVersion = 1
    30  
    31  const (
    32  	smpState1 = iota
    33  	smpState2
    34  	smpState3
    35  	smpState4
    36  )
    37  
    38  type smpState struct {
    39  	state                  int
    40  	a2, a3, b2, b3, pb, qb *big.Int
    41  	g2a, g3a               *big.Int
    42  	g2, g3                 *big.Int
    43  	g3b, papb, qaqb, ra    *big.Int
    44  	saved                  *tlv
    45  	secret                 *big.Int
    46  	question               string
    47  }
    48  
    49  func (c *Conversation) startSMP(question string) (tlvs []tlv) {
    50  	if c.smp.state != smpState1 {
    51  		tlvs = append(tlvs, c.generateSMPAbort())
    52  	}
    53  	tlvs = append(tlvs, c.generateSMP1(question))
    54  	c.smp.question = ""
    55  	c.smp.state = smpState2
    56  	return
    57  }
    58  
    59  func (c *Conversation) resetSMP() {
    60  	c.smp.state = smpState1
    61  	c.smp.secret = nil
    62  	c.smp.question = ""
    63  }
    64  
    65  func (c *Conversation) processSMP(in tlv) (out tlv, complete bool, err error) {
    66  	data := in.data
    67  
    68  	switch in.typ {
    69  	case tlvTypeSMPAbort:
    70  		if c.smp.state != smpState1 {
    71  			err = smpFailureError
    72  		}
    73  		c.resetSMP()
    74  		return
    75  	case tlvTypeSMP1WithQuestion:
    76  		// We preprocess this into a SMP1 message.
    77  		nulPos := bytes.IndexByte(data, 0)
    78  		if nulPos == -1 {
    79  			err = errors.New("otr: SMP message with question didn't contain a NUL byte")
    80  			return
    81  		}
    82  		c.smp.question = string(data[:nulPos])
    83  		data = data[nulPos+1:]
    84  	}
    85  
    86  	numMPIs, data, ok := getU32(data)
    87  	if !ok || numMPIs > 20 {
    88  		err = errors.New("otr: corrupt SMP message")
    89  		return
    90  	}
    91  
    92  	mpis := make([]*big.Int, numMPIs)
    93  	for i := range mpis {
    94  		var ok bool
    95  		mpis[i], data, ok = getMPI(data)
    96  		if !ok {
    97  			err = errors.New("otr: corrupt SMP message")
    98  			return
    99  		}
   100  	}
   101  
   102  	switch in.typ {
   103  	case tlvTypeSMP1, tlvTypeSMP1WithQuestion:
   104  		if c.smp.state != smpState1 {
   105  			c.resetSMP()
   106  			out = c.generateSMPAbort()
   107  			return
   108  		}
   109  		if c.smp.secret == nil {
   110  			err = smpSecretMissingError
   111  			return
   112  		}
   113  		if err = c.processSMP1(mpis); err != nil {
   114  			return
   115  		}
   116  		c.smp.state = smpState3
   117  		out = c.generateSMP2()
   118  	case tlvTypeSMP2:
   119  		if c.smp.state != smpState2 {
   120  			c.resetSMP()
   121  			out = c.generateSMPAbort()
   122  			return
   123  		}
   124  		if out, err = c.processSMP2(mpis); err != nil {
   125  			out = c.generateSMPAbort()
   126  			return
   127  		}
   128  		c.smp.state = smpState4
   129  	case tlvTypeSMP3:
   130  		if c.smp.state != smpState3 {
   131  			c.resetSMP()
   132  			out = c.generateSMPAbort()
   133  			return
   134  		}
   135  		if out, err = c.processSMP3(mpis); err != nil {
   136  			return
   137  		}
   138  		c.smp.state = smpState1
   139  		c.smp.secret = nil
   140  		complete = true
   141  	case tlvTypeSMP4:
   142  		if c.smp.state != smpState4 {
   143  			c.resetSMP()
   144  			out = c.generateSMPAbort()
   145  			return
   146  		}
   147  		if err = c.processSMP4(mpis); err != nil {
   148  			out = c.generateSMPAbort()
   149  			return
   150  		}
   151  		c.smp.state = smpState1
   152  		c.smp.secret = nil
   153  		complete = true
   154  	default:
   155  		panic("unknown SMP message")
   156  	}
   157  
   158  	return
   159  }
   160  
   161  func (c *Conversation) calcSMPSecret(mutualSecret []byte, weStarted bool) {
   162  	h := sha256.New()
   163  	h.Write([]byte{smpVersion})
   164  	if weStarted {
   165  		h.Write(c.PrivateKey.PublicKey.Fingerprint())
   166  		h.Write(c.TheirPublicKey.Fingerprint())
   167  	} else {
   168  		h.Write(c.TheirPublicKey.Fingerprint())
   169  		h.Write(c.PrivateKey.PublicKey.Fingerprint())
   170  	}
   171  	h.Write(c.SSID[:])
   172  	h.Write(mutualSecret)
   173  	c.smp.secret = new(big.Int).SetBytes(h.Sum(nil))
   174  }
   175  
   176  func (c *Conversation) generateSMP1(question string) tlv {
   177  	var randBuf [16]byte
   178  	c.smp.a2 = c.randMPI(randBuf[:])
   179  	c.smp.a3 = c.randMPI(randBuf[:])
   180  	g2a := new(big.Int).Exp(g, c.smp.a2, p)
   181  	g3a := new(big.Int).Exp(g, c.smp.a3, p)
   182  	h := sha256.New()
   183  
   184  	r2 := c.randMPI(randBuf[:])
   185  	r := new(big.Int).Exp(g, r2, p)
   186  	c2 := new(big.Int).SetBytes(hashMPIs(h, 1, r))
   187  	d2 := new(big.Int).Mul(c.smp.a2, c2)
   188  	d2.Sub(r2, d2)
   189  	d2.Mod(d2, q)
   190  	if d2.Sign() < 0 {
   191  		d2.Add(d2, q)
   192  	}
   193  
   194  	r3 := c.randMPI(randBuf[:])
   195  	r.Exp(g, r3, p)
   196  	c3 := new(big.Int).SetBytes(hashMPIs(h, 2, r))
   197  	d3 := new(big.Int).Mul(c.smp.a3, c3)
   198  	d3.Sub(r3, d3)
   199  	d3.Mod(d3, q)
   200  	if d3.Sign() < 0 {
   201  		d3.Add(d3, q)
   202  	}
   203  
   204  	var ret tlv
   205  	if len(question) > 0 {
   206  		ret.typ = tlvTypeSMP1WithQuestion
   207  		ret.data = append(ret.data, question...)
   208  		ret.data = append(ret.data, 0)
   209  	} else {
   210  		ret.typ = tlvTypeSMP1
   211  	}
   212  	ret.data = appendU32(ret.data, 6)
   213  	ret.data = appendMPIs(ret.data, g2a, c2, d2, g3a, c3, d3)
   214  	return ret
   215  }
   216  
   217  func (c *Conversation) processSMP1(mpis []*big.Int) error {
   218  	if len(mpis) != 6 {
   219  		return errors.New("otr: incorrect number of arguments in SMP1 message")
   220  	}
   221  	g2a := mpis[0]
   222  	c2 := mpis[1]
   223  	d2 := mpis[2]
   224  	g3a := mpis[3]
   225  	c3 := mpis[4]
   226  	d3 := mpis[5]
   227  	h := sha256.New()
   228  
   229  	r := new(big.Int).Exp(g, d2, p)
   230  	s := new(big.Int).Exp(g2a, c2, p)
   231  	r.Mul(r, s)
   232  	r.Mod(r, p)
   233  	t := new(big.Int).SetBytes(hashMPIs(h, 1, r))
   234  	if c2.Cmp(t) != 0 {
   235  		return errors.New("otr: ZKP c2 incorrect in SMP1 message")
   236  	}
   237  	r.Exp(g, d3, p)
   238  	s.Exp(g3a, c3, p)
   239  	r.Mul(r, s)
   240  	r.Mod(r, p)
   241  	t.SetBytes(hashMPIs(h, 2, r))
   242  	if c3.Cmp(t) != 0 {
   243  		return errors.New("otr: ZKP c3 incorrect in SMP1 message")
   244  	}
   245  
   246  	c.smp.g2a = g2a
   247  	c.smp.g3a = g3a
   248  	return nil
   249  }
   250  
   251  func (c *Conversation) generateSMP2() tlv {
   252  	var randBuf [16]byte
   253  	b2 := c.randMPI(randBuf[:])
   254  	c.smp.b3 = c.randMPI(randBuf[:])
   255  	r2 := c.randMPI(randBuf[:])
   256  	r3 := c.randMPI(randBuf[:])
   257  	r4 := c.randMPI(randBuf[:])
   258  	r5 := c.randMPI(randBuf[:])
   259  	r6 := c.randMPI(randBuf[:])
   260  
   261  	g2b := new(big.Int).Exp(g, b2, p)
   262  	g3b := new(big.Int).Exp(g, c.smp.b3, p)
   263  
   264  	r := new(big.Int).Exp(g, r2, p)
   265  	h := sha256.New()
   266  	c2 := new(big.Int).SetBytes(hashMPIs(h, 3, r))
   267  	d2 := new(big.Int).Mul(b2, c2)
   268  	d2.Sub(r2, d2)
   269  	d2.Mod(d2, q)
   270  	if d2.Sign() < 0 {
   271  		d2.Add(d2, q)
   272  	}
   273  
   274  	r.Exp(g, r3, p)
   275  	c3 := new(big.Int).SetBytes(hashMPIs(h, 4, r))
   276  	d3 := new(big.Int).Mul(c.smp.b3, c3)
   277  	d3.Sub(r3, d3)
   278  	d3.Mod(d3, q)
   279  	if d3.Sign() < 0 {
   280  		d3.Add(d3, q)
   281  	}
   282  
   283  	c.smp.g2 = new(big.Int).Exp(c.smp.g2a, b2, p)
   284  	c.smp.g3 = new(big.Int).Exp(c.smp.g3a, c.smp.b3, p)
   285  	c.smp.pb = new(big.Int).Exp(c.smp.g3, r4, p)
   286  	c.smp.qb = new(big.Int).Exp(g, r4, p)
   287  	r.Exp(c.smp.g2, c.smp.secret, p)
   288  	c.smp.qb.Mul(c.smp.qb, r)
   289  	c.smp.qb.Mod(c.smp.qb, p)
   290  
   291  	s := new(big.Int)
   292  	s.Exp(c.smp.g2, r6, p)
   293  	r.Exp(g, r5, p)
   294  	s.Mul(r, s)
   295  	s.Mod(s, p)
   296  	r.Exp(c.smp.g3, r5, p)
   297  	cp := new(big.Int).SetBytes(hashMPIs(h, 5, r, s))
   298  
   299  	// D5 = r5 - r4 cP mod q and D6 = r6 - y cP mod q
   300  
   301  	s.Mul(r4, cp)
   302  	r.Sub(r5, s)
   303  	d5 := new(big.Int).Mod(r, q)
   304  	if d5.Sign() < 0 {
   305  		d5.Add(d5, q)
   306  	}
   307  
   308  	s.Mul(c.smp.secret, cp)
   309  	r.Sub(r6, s)
   310  	d6 := new(big.Int).Mod(r, q)
   311  	if d6.Sign() < 0 {
   312  		d6.Add(d6, q)
   313  	}
   314  
   315  	var ret tlv
   316  	ret.typ = tlvTypeSMP2
   317  	ret.data = appendU32(ret.data, 11)
   318  	ret.data = appendMPIs(ret.data, g2b, c2, d2, g3b, c3, d3, c.smp.pb, c.smp.qb, cp, d5, d6)
   319  	return ret
   320  }
   321  
   322  func (c *Conversation) processSMP2(mpis []*big.Int) (out tlv, err error) {
   323  	if len(mpis) != 11 {
   324  		err = errors.New("otr: incorrect number of arguments in SMP2 message")
   325  		return
   326  	}
   327  	g2b := mpis[0]
   328  	c2 := mpis[1]
   329  	d2 := mpis[2]
   330  	g3b := mpis[3]
   331  	c3 := mpis[4]
   332  	d3 := mpis[5]
   333  	pb := mpis[6]
   334  	qb := mpis[7]
   335  	cp := mpis[8]
   336  	d5 := mpis[9]
   337  	d6 := mpis[10]
   338  	h := sha256.New()
   339  
   340  	r := new(big.Int).Exp(g, d2, p)
   341  	s := new(big.Int).Exp(g2b, c2, p)
   342  	r.Mul(r, s)
   343  	r.Mod(r, p)
   344  	s.SetBytes(hashMPIs(h, 3, r))
   345  	if c2.Cmp(s) != 0 {
   346  		err = errors.New("otr: ZKP c2 failed in SMP2 message")
   347  		return
   348  	}
   349  
   350  	r.Exp(g, d3, p)
   351  	s.Exp(g3b, c3, p)
   352  	r.Mul(r, s)
   353  	r.Mod(r, p)
   354  	s.SetBytes(hashMPIs(h, 4, r))
   355  	if c3.Cmp(s) != 0 {
   356  		err = errors.New("otr: ZKP c3 failed in SMP2 message")
   357  		return
   358  	}
   359  
   360  	c.smp.g2 = new(big.Int).Exp(g2b, c.smp.a2, p)
   361  	c.smp.g3 = new(big.Int).Exp(g3b, c.smp.a3, p)
   362  
   363  	r.Exp(g, d5, p)
   364  	s.Exp(c.smp.g2, d6, p)
   365  	r.Mul(r, s)
   366  	s.Exp(qb, cp, p)
   367  	r.Mul(r, s)
   368  	r.Mod(r, p)
   369  
   370  	s.Exp(c.smp.g3, d5, p)
   371  	t := new(big.Int).Exp(pb, cp, p)
   372  	s.Mul(s, t)
   373  	s.Mod(s, p)
   374  	t.SetBytes(hashMPIs(h, 5, s, r))
   375  	if cp.Cmp(t) != 0 {
   376  		err = errors.New("otr: ZKP cP failed in SMP2 message")
   377  		return
   378  	}
   379  
   380  	var randBuf [16]byte
   381  	r4 := c.randMPI(randBuf[:])
   382  	r5 := c.randMPI(randBuf[:])
   383  	r6 := c.randMPI(randBuf[:])
   384  	r7 := c.randMPI(randBuf[:])
   385  
   386  	pa := new(big.Int).Exp(c.smp.g3, r4, p)
   387  	r.Exp(c.smp.g2, c.smp.secret, p)
   388  	qa := new(big.Int).Exp(g, r4, p)
   389  	qa.Mul(qa, r)
   390  	qa.Mod(qa, p)
   391  
   392  	r.Exp(g, r5, p)
   393  	s.Exp(c.smp.g2, r6, p)
   394  	r.Mul(r, s)
   395  	r.Mod(r, p)
   396  
   397  	s.Exp(c.smp.g3, r5, p)
   398  	cp.SetBytes(hashMPIs(h, 6, s, r))
   399  
   400  	r.Mul(r4, cp)
   401  	d5 = new(big.Int).Sub(r5, r)
   402  	d5.Mod(d5, q)
   403  	if d5.Sign() < 0 {
   404  		d5.Add(d5, q)
   405  	}
   406  
   407  	r.Mul(c.smp.secret, cp)
   408  	d6 = new(big.Int).Sub(r6, r)
   409  	d6.Mod(d6, q)
   410  	if d6.Sign() < 0 {
   411  		d6.Add(d6, q)
   412  	}
   413  
   414  	r.ModInverse(qb, p)
   415  	qaqb := new(big.Int).Mul(qa, r)
   416  	qaqb.Mod(qaqb, p)
   417  
   418  	ra := new(big.Int).Exp(qaqb, c.smp.a3, p)
   419  	r.Exp(qaqb, r7, p)
   420  	s.Exp(g, r7, p)
   421  	cr := new(big.Int).SetBytes(hashMPIs(h, 7, s, r))
   422  
   423  	r.Mul(c.smp.a3, cr)
   424  	d7 := new(big.Int).Sub(r7, r)
   425  	d7.Mod(d7, q)
   426  	if d7.Sign() < 0 {
   427  		d7.Add(d7, q)
   428  	}
   429  
   430  	c.smp.g3b = g3b
   431  	c.smp.qaqb = qaqb
   432  
   433  	r.ModInverse(pb, p)
   434  	c.smp.papb = new(big.Int).Mul(pa, r)
   435  	c.smp.papb.Mod(c.smp.papb, p)
   436  	c.smp.ra = ra
   437  
   438  	out.typ = tlvTypeSMP3
   439  	out.data = appendU32(out.data, 8)
   440  	out.data = appendMPIs(out.data, pa, qa, cp, d5, d6, ra, cr, d7)
   441  	return
   442  }
   443  
   444  func (c *Conversation) processSMP3(mpis []*big.Int) (out tlv, err error) {
   445  	if len(mpis) != 8 {
   446  		err = errors.New("otr: incorrect number of arguments in SMP3 message")
   447  		return
   448  	}
   449  	pa := mpis[0]
   450  	qa := mpis[1]
   451  	cp := mpis[2]
   452  	d5 := mpis[3]
   453  	d6 := mpis[4]
   454  	ra := mpis[5]
   455  	cr := mpis[6]
   456  	d7 := mpis[7]
   457  	h := sha256.New()
   458  
   459  	r := new(big.Int).Exp(g, d5, p)
   460  	s := new(big.Int).Exp(c.smp.g2, d6, p)
   461  	r.Mul(r, s)
   462  	s.Exp(qa, cp, p)
   463  	r.Mul(r, s)
   464  	r.Mod(r, p)
   465  
   466  	s.Exp(c.smp.g3, d5, p)
   467  	t := new(big.Int).Exp(pa, cp, p)
   468  	s.Mul(s, t)
   469  	s.Mod(s, p)
   470  	t.SetBytes(hashMPIs(h, 6, s, r))
   471  	if t.Cmp(cp) != 0 {
   472  		err = errors.New("otr: ZKP cP failed in SMP3 message")
   473  		return
   474  	}
   475  
   476  	r.ModInverse(c.smp.qb, p)
   477  	qaqb := new(big.Int).Mul(qa, r)
   478  	qaqb.Mod(qaqb, p)
   479  
   480  	r.Exp(qaqb, d7, p)
   481  	s.Exp(ra, cr, p)
   482  	r.Mul(r, s)
   483  	r.Mod(r, p)
   484  
   485  	s.Exp(g, d7, p)
   486  	t.Exp(c.smp.g3a, cr, p)
   487  	s.Mul(s, t)
   488  	s.Mod(s, p)
   489  	t.SetBytes(hashMPIs(h, 7, s, r))
   490  	if t.Cmp(cr) != 0 {
   491  		err = errors.New("otr: ZKP cR failed in SMP3 message")
   492  		return
   493  	}
   494  
   495  	var randBuf [16]byte
   496  	r7 := c.randMPI(randBuf[:])
   497  	rb := new(big.Int).Exp(qaqb, c.smp.b3, p)
   498  
   499  	r.Exp(qaqb, r7, p)
   500  	s.Exp(g, r7, p)
   501  	cr = new(big.Int).SetBytes(hashMPIs(h, 8, s, r))
   502  
   503  	r.Mul(c.smp.b3, cr)
   504  	d7 = new(big.Int).Sub(r7, r)
   505  	d7.Mod(d7, q)
   506  	if d7.Sign() < 0 {
   507  		d7.Add(d7, q)
   508  	}
   509  
   510  	out.typ = tlvTypeSMP4
   511  	out.data = appendU32(out.data, 3)
   512  	out.data = appendMPIs(out.data, rb, cr, d7)
   513  
   514  	r.ModInverse(c.smp.pb, p)
   515  	r.Mul(pa, r)
   516  	r.Mod(r, p)
   517  	s.Exp(ra, c.smp.b3, p)
   518  	if r.Cmp(s) != 0 {
   519  		err = smpFailureError
   520  	}
   521  
   522  	return
   523  }
   524  
   525  func (c *Conversation) processSMP4(mpis []*big.Int) error {
   526  	if len(mpis) != 3 {
   527  		return errors.New("otr: incorrect number of arguments in SMP4 message")
   528  	}
   529  	rb := mpis[0]
   530  	cr := mpis[1]
   531  	d7 := mpis[2]
   532  	h := sha256.New()
   533  
   534  	r := new(big.Int).Exp(c.smp.qaqb, d7, p)
   535  	s := new(big.Int).Exp(rb, cr, p)
   536  	r.Mul(r, s)
   537  	r.Mod(r, p)
   538  
   539  	s.Exp(g, d7, p)
   540  	t := new(big.Int).Exp(c.smp.g3b, cr, p)
   541  	s.Mul(s, t)
   542  	s.Mod(s, p)
   543  	t.SetBytes(hashMPIs(h, 8, s, r))
   544  	if t.Cmp(cr) != 0 {
   545  		return errors.New("otr: ZKP cR failed in SMP4 message")
   546  	}
   547  
   548  	r.Exp(rb, c.smp.a3, p)
   549  	if r.Cmp(c.smp.papb) != 0 {
   550  		return smpFailureError
   551  	}
   552  
   553  	return nil
   554  }
   555  
   556  func (c *Conversation) generateSMPAbort() tlv {
   557  	return tlv{typ: tlvTypeSMPAbort}
   558  }
   559  
   560  func hashMPIs(h hash.Hash, magic byte, mpis ...*big.Int) []byte {
   561  	if h != nil {
   562  		h.Reset()
   563  	} else {
   564  		h = sha256.New()
   565  	}
   566  
   567  	h.Write([]byte{magic})
   568  	for _, mpi := range mpis {
   569  		h.Write(appendMPI(nil, mpi))
   570  	}
   571  	return h.Sum(nil)
   572  }
   573  

View as plain text