1  
     2  
     3  
     4  
     5  package packet
     6  
     7  import (
     8  	"crypto/cipher"
     9  	"crypto/sha1"
    10  	"crypto/subtle"
    11  	"golang.org/x/crypto/openpgp/errors"
    12  	"hash"
    13  	"io"
    14  	"strconv"
    15  )
    16  
    17  
    18  
    19  
    20  type SymmetricallyEncrypted struct {
    21  	MDC      bool 
    22  	contents io.Reader
    23  	prefix   []byte
    24  }
    25  
    26  const symmetricallyEncryptedVersion = 1
    27  
    28  func (se *SymmetricallyEncrypted) parse(r io.Reader) error {
    29  	if se.MDC {
    30  		
    31  		var buf [1]byte
    32  		_, err := readFull(r, buf[:])
    33  		if err != nil {
    34  			return err
    35  		}
    36  		if buf[0] != symmetricallyEncryptedVersion {
    37  			return errors.UnsupportedError("unknown SymmetricallyEncrypted version")
    38  		}
    39  	}
    40  	se.contents = r
    41  	return nil
    42  }
    43  
    44  
    45  
    46  
    47  func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.ReadCloser, error) {
    48  	keySize := c.KeySize()
    49  	if keySize == 0 {
    50  		return nil, errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(c)))
    51  	}
    52  	if len(key) != keySize {
    53  		return nil, errors.InvalidArgumentError("SymmetricallyEncrypted: incorrect key length")
    54  	}
    55  
    56  	if se.prefix == nil {
    57  		se.prefix = make([]byte, c.blockSize()+2)
    58  		_, err := readFull(se.contents, se.prefix)
    59  		if err != nil {
    60  			return nil, err
    61  		}
    62  	} else if len(se.prefix) != c.blockSize()+2 {
    63  		return nil, errors.InvalidArgumentError("can't try ciphers with different block lengths")
    64  	}
    65  
    66  	ocfbResync := OCFBResync
    67  	if se.MDC {
    68  		
    69  		ocfbResync = OCFBNoResync
    70  	}
    71  
    72  	s := NewOCFBDecrypter(c.new(key), se.prefix, ocfbResync)
    73  	if s == nil {
    74  		return nil, errors.ErrKeyIncorrect
    75  	}
    76  
    77  	plaintext := cipher.StreamReader{S: s, R: se.contents}
    78  
    79  	if se.MDC {
    80  		
    81  		h := sha1.New()
    82  		h.Write(se.prefix)
    83  		return &seMDCReader{in: plaintext, h: h}, nil
    84  	}
    85  
    86  	
    87  	return seReader{plaintext}, nil
    88  }
    89  
    90  
    91  type seReader struct {
    92  	in io.Reader
    93  }
    94  
    95  func (ser seReader) Read(buf []byte) (int, error) {
    96  	return ser.in.Read(buf)
    97  }
    98  
    99  func (ser seReader) Close() error {
   100  	return nil
   101  }
   102  
   103  const mdcTrailerSize = 1  + 1  + sha1.Size
   104  
   105  
   106  
   107  
   108  
   109  type seMDCReader struct {
   110  	in          io.Reader
   111  	h           hash.Hash
   112  	trailer     [mdcTrailerSize]byte
   113  	scratch     [mdcTrailerSize]byte
   114  	trailerUsed int
   115  	error       bool
   116  	eof         bool
   117  }
   118  
   119  func (ser *seMDCReader) Read(buf []byte) (n int, err error) {
   120  	if ser.error {
   121  		err = io.ErrUnexpectedEOF
   122  		return
   123  	}
   124  	if ser.eof {
   125  		err = io.EOF
   126  		return
   127  	}
   128  
   129  	
   130  	
   131  	for ser.trailerUsed < mdcTrailerSize {
   132  		n, err = ser.in.Read(ser.trailer[ser.trailerUsed:])
   133  		ser.trailerUsed += n
   134  		if err == io.EOF {
   135  			if ser.trailerUsed != mdcTrailerSize {
   136  				n = 0
   137  				err = io.ErrUnexpectedEOF
   138  				ser.error = true
   139  				return
   140  			}
   141  			ser.eof = true
   142  			n = 0
   143  			return
   144  		}
   145  
   146  		if err != nil {
   147  			n = 0
   148  			return
   149  		}
   150  	}
   151  
   152  	
   153  	
   154  	if len(buf) <= mdcTrailerSize {
   155  		n, err = readFull(ser.in, ser.scratch[:len(buf)])
   156  		copy(buf, ser.trailer[:n])
   157  		ser.h.Write(buf[:n])
   158  		copy(ser.trailer[:], ser.trailer[n:])
   159  		copy(ser.trailer[mdcTrailerSize-n:], ser.scratch[:])
   160  		if n < len(buf) {
   161  			ser.eof = true
   162  			err = io.EOF
   163  		}
   164  		return
   165  	}
   166  
   167  	n, err = ser.in.Read(buf[mdcTrailerSize:])
   168  	copy(buf, ser.trailer[:])
   169  	ser.h.Write(buf[:n])
   170  	copy(ser.trailer[:], buf[n:])
   171  
   172  	if err == io.EOF {
   173  		ser.eof = true
   174  	}
   175  	return
   176  }
   177  
   178  
   179  const mdcPacketTagByte = byte(0x80) | 0x40 | 19
   180  
   181  func (ser *seMDCReader) Close() error {
   182  	if ser.error {
   183  		return errors.SignatureError("error during reading")
   184  	}
   185  
   186  	for !ser.eof {
   187  		
   188  		var buf [1024]byte
   189  		_, err := ser.Read(buf[:])
   190  		if err == io.EOF {
   191  			break
   192  		}
   193  		if err != nil {
   194  			return errors.SignatureError("error during reading")
   195  		}
   196  	}
   197  
   198  	if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size {
   199  		return errors.SignatureError("MDC packet not found")
   200  	}
   201  	ser.h.Write(ser.trailer[:2])
   202  
   203  	final := ser.h.Sum(nil)
   204  	if subtle.ConstantTimeCompare(final, ser.trailer[2:]) != 1 {
   205  		return errors.SignatureError("hash mismatch")
   206  	}
   207  	return nil
   208  }
   209  
   210  
   211  
   212  
   213  type seMDCWriter struct {
   214  	w io.WriteCloser
   215  	h hash.Hash
   216  }
   217  
   218  func (w *seMDCWriter) Write(buf []byte) (n int, err error) {
   219  	w.h.Write(buf)
   220  	return w.w.Write(buf)
   221  }
   222  
   223  func (w *seMDCWriter) Close() (err error) {
   224  	var buf [mdcTrailerSize]byte
   225  
   226  	buf[0] = mdcPacketTagByte
   227  	buf[1] = sha1.Size
   228  	w.h.Write(buf[:2])
   229  	digest := w.h.Sum(nil)
   230  	copy(buf[2:], digest)
   231  
   232  	_, err = w.w.Write(buf[:])
   233  	if err != nil {
   234  		return
   235  	}
   236  	return w.w.Close()
   237  }
   238  
   239  
   240  type noOpCloser struct {
   241  	w io.Writer
   242  }
   243  
   244  func (c noOpCloser) Write(data []byte) (n int, err error) {
   245  	return c.w.Write(data)
   246  }
   247  
   248  func (c noOpCloser) Close() error {
   249  	return nil
   250  }
   251  
   252  
   253  
   254  
   255  
   256  func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, key []byte, config *Config) (contents io.WriteCloser, err error) {
   257  	if c.KeySize() != len(key) {
   258  		return nil, errors.InvalidArgumentError("SymmetricallyEncrypted.Serialize: bad key length")
   259  	}
   260  	writeCloser := noOpCloser{w}
   261  	ciphertext, err := serializeStreamHeader(writeCloser, packetTypeSymmetricallyEncryptedMDC)
   262  	if err != nil {
   263  		return
   264  	}
   265  
   266  	_, err = ciphertext.Write([]byte{symmetricallyEncryptedVersion})
   267  	if err != nil {
   268  		return
   269  	}
   270  
   271  	block := c.new(key)
   272  	blockSize := block.BlockSize()
   273  	iv := make([]byte, blockSize)
   274  	_, err = config.Random().Read(iv)
   275  	if err != nil {
   276  		return
   277  	}
   278  	s, prefix := NewOCFBEncrypter(block, iv, OCFBNoResync)
   279  	_, err = ciphertext.Write(prefix)
   280  	if err != nil {
   281  		return
   282  	}
   283  	plaintext := cipher.StreamWriter{S: s, W: ciphertext}
   284  
   285  	h := sha1.New()
   286  	h.Write(iv)
   287  	h.Write(iv[blockSize-2:])
   288  	contents = &seMDCWriter{w: plaintext, h: h}
   289  	return
   290  }
   291  
View as plain text