...

Source file src/gitlab.hexacode.org/go-libs/chromem-go/persistence.go

Documentation: gitlab.hexacode.org/go-libs/chromem-go

     1  package chromem
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"crypto/aes"
     7  	"crypto/cipher"
     8  	"crypto/rand"
     9  	"crypto/sha256"
    10  	"encoding/gob"
    11  	"encoding/hex"
    12  	"errors"
    13  	"fmt"
    14  	"io"
    15  	"io/fs"
    16  	"os"
    17  	"path/filepath"
    18  )
    19  
    20  const metadataFileName = "00000000"
    21  
    22  func hash2hex(name string) string {
    23  	hash := sha256.Sum256([]byte(name))
    24  	// We encode 4 of the 32 bytes (32 out of 256 bits), so 8 hex characters.
    25  	// It's enough to avoid collisions in reasonable amounts of documents per collection
    26  	// and being shorter is better for file paths.
    27  	return hex.EncodeToString(hash[:4])
    28  }
    29  
    30  // persistToFile persists an object to a file at the given path. The object is serialized
    31  // as gob, optionally compressed with flate (as gzip) and optionally encrypted with
    32  // AES-GCM. The encryption key must be 32 bytes long. If the file exists, it's
    33  // overwritten, otherwise created.
    34  func persistToFile(filePath string, obj any, compress bool, encryptionKey string) error {
    35  	if filePath == "" {
    36  		return fmt.Errorf("file path is empty")
    37  	}
    38  	// AES 256 requires a 32 byte key
    39  	if encryptionKey != "" {
    40  		if len(encryptionKey) != 32 {
    41  			return errors.New("encryption key must be 32 bytes long")
    42  		}
    43  	}
    44  
    45  	// If path doesn't exist, create the parent path.
    46  	// If path exists, and it's a directory, return an error.
    47  	fi, err := os.Stat(filePath)
    48  	if err != nil {
    49  		if !errors.Is(err, fs.ErrNotExist) {
    50  			return fmt.Errorf("couldn't get info about the path: %w", err)
    51  		} else {
    52  			// If the file doesn't exist, create the parent path
    53  			err := os.MkdirAll(filepath.Dir(filePath), 0o700)
    54  			if err != nil {
    55  				return fmt.Errorf("couldn't create parent directories to path: %w", err)
    56  			}
    57  		}
    58  	} else if fi.IsDir() {
    59  		return fmt.Errorf("path is a directory: %s", filePath)
    60  	}
    61  
    62  	// Open file for writing
    63  	f, err := os.Create(filePath)
    64  	if err != nil {
    65  		return fmt.Errorf("couldn't create file: %w", err)
    66  	}
    67  	defer f.Close()
    68  
    69  	return persistToWriter(f, obj, compress, encryptionKey)
    70  }
    71  
    72  // persistToWriter persists an object to a writer. The object is serialized
    73  // as gob, optionally compressed with flate (as gzip) and optionally encrypted with
    74  // AES-GCM. The encryption key must be 32 bytes long.
    75  // If the writer has to be closed, it's the caller's responsibility.
    76  func persistToWriter(w io.Writer, obj any, compress bool, encryptionKey string) error {
    77  	// AES 256 requires a 32 byte key
    78  	if encryptionKey != "" {
    79  		if len(encryptionKey) != 32 {
    80  			return errors.New("encryption key must be 32 bytes long")
    81  		}
    82  	}
    83  
    84  	// We want to:
    85  	// Encode as gob -> compress with flate -> encrypt with AES-GCM -> write to
    86  	// passed writer.
    87  	// To reduce memory usage we chain the writers instead of buffering, so we start
    88  	// from the end. For AES GCM sealing the stdlib doesn't provide a writer though.
    89  
    90  	var chainedWriter io.Writer
    91  	if encryptionKey == "" {
    92  		chainedWriter = w
    93  	} else {
    94  		chainedWriter = &bytes.Buffer{}
    95  	}
    96  
    97  	var gzw *gzip.Writer
    98  	var enc *gob.Encoder
    99  	if compress {
   100  		gzw = gzip.NewWriter(chainedWriter)
   101  		enc = gob.NewEncoder(gzw)
   102  	} else {
   103  		enc = gob.NewEncoder(chainedWriter)
   104  	}
   105  
   106  	// Start encoding, it will write to the chain of writers.
   107  	if err := enc.Encode(obj); err != nil {
   108  		return fmt.Errorf("couldn't encode or write object: %w", err)
   109  	}
   110  
   111  	// If compressing, close the gzip writer. Otherwise, the gzip footer won't be
   112  	// written yet. When using encryption (and chainedWriter is a buffer) then
   113  	// we'll encrypt an incomplete stream. Without encryption when we return here and having
   114  	// a deferred Close(), there might be a silenced error.
   115  	if compress {
   116  		err := gzw.Close()
   117  		if err != nil {
   118  			return fmt.Errorf("couldn't close gzip writer: %w", err)
   119  		}
   120  	}
   121  
   122  	// Without encyrption, the chain is done and the writing is finished.
   123  	if encryptionKey == "" {
   124  		return nil
   125  	}
   126  
   127  	// Otherwise, encrypt and then write to the unchained target writer.
   128  	block, err := aes.NewCipher([]byte(encryptionKey))
   129  	if err != nil {
   130  		return fmt.Errorf("couldn't create new AES cipher: %w", err)
   131  	}
   132  	gcm, err := cipher.NewGCM(block)
   133  	if err != nil {
   134  		return fmt.Errorf("couldn't create GCM wrapper: %w", err)
   135  	}
   136  	nonce := make([]byte, gcm.NonceSize())
   137  	if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
   138  		return fmt.Errorf("couldn't read random bytes for nonce: %w", err)
   139  	}
   140  	// chainedWriter is a *bytes.Buffer
   141  	buf := chainedWriter.(*bytes.Buffer)
   142  	encrypted := gcm.Seal(nonce, nonce, buf.Bytes(), nil)
   143  	_, err = w.Write(encrypted)
   144  	if err != nil {
   145  		return fmt.Errorf("couldn't write encrypted data: %w", err)
   146  	}
   147  
   148  	return nil
   149  }
   150  
   151  // readFromFile reads an object from a file at the given path. The object is deserialized
   152  // from gob. `obj` must be a pointer to an instantiated object. The file may
   153  // optionally be compressed as gzip and/or encrypted with AES-GCM. The encryption
   154  // key must be 32 bytes long.
   155  func readFromFile(filePath string, obj any, encryptionKey string) error {
   156  	if filePath == "" {
   157  		return fmt.Errorf("file path is empty")
   158  	}
   159  	// AES 256 requires a 32 byte key
   160  	if encryptionKey != "" {
   161  		if len(encryptionKey) != 32 {
   162  			return errors.New("encryption key must be 32 bytes long")
   163  		}
   164  	}
   165  
   166  	r, err := os.Open(filePath)
   167  	if err != nil {
   168  		return fmt.Errorf("couldn't open file: %w", err)
   169  	}
   170  	defer r.Close()
   171  
   172  	return readFromReader(r, obj, encryptionKey)
   173  }
   174  
   175  // readFromReader reads an object from a Reader. The object is deserialized from gob.
   176  // `obj` must be a pointer to an instantiated object. The stream may optionally
   177  // be compressed as gzip and/or encrypted with AES-GCM. The encryption key must
   178  // be 32 bytes long.
   179  // If the reader has to be closed, it's the caller's responsibility.
   180  func readFromReader(r io.ReadSeeker, obj any, encryptionKey string) error {
   181  	// AES 256 requires a 32 byte key
   182  	if encryptionKey != "" {
   183  		if len(encryptionKey) != 32 {
   184  			return errors.New("encryption key must be 32 bytes long")
   185  		}
   186  	}
   187  
   188  	// We want to:
   189  	// Read from reader -> decrypt with AES-GCM -> decompress with flate -> decode
   190  	// as gob.
   191  	// To reduce memory usage we chain the readers instead of buffering, so we start
   192  	// from the end. For the decryption there's no reader though.
   193  
   194  	// For the chainedReader we don't declare it as ReadSeeker, so we can reassign
   195  	// the gzip reader to it.
   196  	var chainedReader io.Reader
   197  
   198  	// Decrypt if an encryption key is provided
   199  	if encryptionKey != "" {
   200  		encrypted, err := io.ReadAll(r)
   201  		if err != nil {
   202  			return fmt.Errorf("couldn't read from reader: %w", err)
   203  		}
   204  		block, err := aes.NewCipher([]byte(encryptionKey))
   205  		if err != nil {
   206  			return fmt.Errorf("couldn't create AES cipher: %w", err)
   207  		}
   208  		gcm, err := cipher.NewGCM(block)
   209  		if err != nil {
   210  			return fmt.Errorf("couldn't create GCM wrapper: %w", err)
   211  		}
   212  		nonceSize := gcm.NonceSize()
   213  		if len(encrypted) < nonceSize {
   214  			return fmt.Errorf("encrypted data too short")
   215  		}
   216  		nonce, ciphertext := encrypted[:nonceSize], encrypted[nonceSize:]
   217  		data, err := gcm.Open(nil, nonce, ciphertext, nil)
   218  		if err != nil {
   219  			return fmt.Errorf("couldn't decrypt data: %w", err)
   220  		}
   221  
   222  		chainedReader = bytes.NewReader(data)
   223  	} else {
   224  		chainedReader = r
   225  	}
   226  
   227  	// Determine if the stream is compressed
   228  	magicNumber := make([]byte, 2)
   229  	_, err := chainedReader.Read(magicNumber)
   230  	if err != nil {
   231  		return fmt.Errorf("couldn't read magic number to determine whether the stream is compressed: %w", err)
   232  	}
   233  	var compressed bool
   234  	if magicNumber[0] == 0x1f && magicNumber[1] == 0x8b {
   235  		compressed = true
   236  	}
   237  
   238  	// Reset reader. Both the reader from the param and bytes.Reader support seeking.
   239  	if s, ok := chainedReader.(io.Seeker); !ok {
   240  		return fmt.Errorf("reader doesn't support seeking")
   241  	} else {
   242  		_, err := s.Seek(0, 0)
   243  		if err != nil {
   244  			return fmt.Errorf("couldn't reset reader: %w", err)
   245  		}
   246  	}
   247  
   248  	if compressed {
   249  		gzr, err := gzip.NewReader(chainedReader)
   250  		if err != nil {
   251  			return fmt.Errorf("couldn't create gzip reader: %w", err)
   252  		}
   253  		defer gzr.Close()
   254  		chainedReader = gzr
   255  	}
   256  
   257  	dec := gob.NewDecoder(chainedReader)
   258  	err = dec.Decode(obj)
   259  	if err != nil {
   260  		return fmt.Errorf("couldn't decode object: %w", err)
   261  	}
   262  
   263  	return nil
   264  }
   265  
   266  // removeFile removes a file at the given path. If the file doesn't exist, it's a no-op.
   267  func removeFile(filePath string) error {
   268  	if filePath == "" {
   269  		return fmt.Errorf("file path is empty")
   270  	}
   271  
   272  	err := os.Remove(filePath)
   273  	if err != nil {
   274  		if !errors.Is(err, fs.ErrNotExist) {
   275  			return fmt.Errorf("couldn't remove file %q: %w", filePath, err)
   276  		}
   277  	}
   278  
   279  	return nil
   280  }
   281  

View as plain text