...

Source file src/gitlab.hexacode.org/go-libs/chromem-go/db.go

Documentation: gitlab.hexacode.org/go-libs/chromem-go

     1  package chromem
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/fs"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  	"sync"
    13  )
    14  
    15  // EmbeddingFunc is a function that creates embeddings for a given text.
    16  // chromem-go will use OpenAI`s "text-embedding-3-small" model by default,
    17  // but you can provide your own function, using any model you like.
    18  // The function must return a *normalized* vector, i.e. the length of the vector
    19  // must be 1. OpenAI's and Mistral's embedding models do this by default. Some
    20  // others like Nomic's "nomic-embed-text-v1.5" don't.
    21  type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error)
    22  
    23  // DB is the chromem-go database. It holds collections, which hold documents.
    24  //
    25  //	+----+    1-n    +------------+    n-n    +----------+
    26  //	| DB |-----------| Collection |-----------| Document |
    27  //	+----+           +------------+           +----------+
    28  type DB struct {
    29  	collections     map[string]*Collection
    30  	collectionsLock sync.RWMutex
    31  
    32  	persistDirectory string
    33  	compress         bool
    34  
    35  	// ⚠️ When adding fields here, consider adding them to the persistence struct
    36  	// versions in [DB.Export] and [DB.Import] as well!
    37  }
    38  
    39  // NewDB creates a new in-memory chromem-go DB.
    40  // While it doesn't write files when you add collections and documents, you can
    41  // still use [DB.Export] and [DB.Import] to export and import the entire DB
    42  // from a file.
    43  func NewDB() *DB {
    44  	return &DB{
    45  		collections: make(map[string]*Collection),
    46  	}
    47  }
    48  
    49  // NewPersistentDB creates a new persistent chromem-go DB.
    50  // If the path is empty, it defaults to "./chromem-go".
    51  // If compress is true, the files are compressed with gzip.
    52  //
    53  // The persistence covers the collections (including their documents) and the metadata.
    54  // However, it doesn't cover the EmbeddingFunc, as functions can't be serialized.
    55  // When some data is persisted, and you create a new persistent DB with the same
    56  // path, you'll have to provide the same EmbeddingFunc as before when getting an
    57  // existing collection and adding more documents to it.
    58  //
    59  // Currently, the persistence is done synchronously on each write operation, and
    60  // each document addition leads to a new file, encoded as gob. In the future we
    61  // will make this configurable (encoding, async writes, WAL-based writes, etc.).
    62  //
    63  // In addition to persistence for each added collection and document you can use
    64  // [DB.Export] and [DB.Import] to export and import the entire DB to/from a file,
    65  // which also works for the pure in-memory DB.
    66  func NewPersistentDB(path string, compress bool) (*DB, error) {
    67  	if path == "" {
    68  		path = "./chromem-go"
    69  	} else {
    70  		// Clean in case the user provides something like "./db/../db"
    71  		path = filepath.Clean(path)
    72  	}
    73  
    74  	// We check for this file extension and skip others
    75  	ext := ".gob"
    76  	if compress {
    77  		ext += ".gz"
    78  	}
    79  
    80  	db := &DB{
    81  		collections:      make(map[string]*Collection),
    82  		persistDirectory: path,
    83  		compress:         compress,
    84  	}
    85  
    86  	// If the directory doesn't exist, create it and return an empty DB.
    87  	fi, err := os.Stat(path)
    88  	if err != nil {
    89  		if errors.Is(err, fs.ErrNotExist) {
    90  			err := os.MkdirAll(path, 0o700)
    91  			if err != nil {
    92  				return nil, fmt.Errorf("couldn't create persistence directory: %w", err)
    93  			}
    94  
    95  			return db, nil
    96  		}
    97  		return nil, fmt.Errorf("couldn't get info about persistence directory: %w", err)
    98  	} else if !fi.IsDir() {
    99  		return nil, fmt.Errorf("path is not a directory: %s", path)
   100  	}
   101  
   102  	// Otherwise, read all collections and their documents from the directory.
   103  	dirEntries, err := os.ReadDir(path)
   104  	if err != nil {
   105  		return nil, fmt.Errorf("couldn't read persistence directory: %w", err)
   106  	}
   107  	for _, dirEntry := range dirEntries {
   108  		// Collections are subdirectories, so skip any files (which the user might
   109  		// have placed).
   110  		if !dirEntry.IsDir() {
   111  			continue
   112  		}
   113  		// For each subdirectory, create a collection and read its name, metadata
   114  		// and documents.
   115  		// TODO: Parallelize this (e.g. chan with $numCPU buffer and $numCPU goroutines
   116  		// reading from it).
   117  		collectionPath := filepath.Join(path, dirEntry.Name())
   118  		collectionDirEntries, err := os.ReadDir(collectionPath)
   119  		if err != nil {
   120  			return nil, fmt.Errorf("couldn't read collection directory: %w", err)
   121  		}
   122  		c := &Collection{
   123  			documents:        make(map[string]*Document),
   124  			persistDirectory: collectionPath,
   125  			compress:         compress,
   126  			// We can fill Name and metadata only after reading
   127  			// the metadata.
   128  			// We can fill embed only when the user calls DB.GetCollection() or
   129  			// DB.GetOrCreateCollection().
   130  		}
   131  		for _, collectionDirEntry := range collectionDirEntries {
   132  			// Files should be metadata and documents; skip subdirectories which
   133  			// the user might have placed.
   134  			if collectionDirEntry.IsDir() {
   135  				continue
   136  			}
   137  
   138  			fPath := filepath.Join(collectionPath, collectionDirEntry.Name())
   139  			// Differentiate between collection metadata, documents and other files.
   140  			if collectionDirEntry.Name() == metadataFileName+ext {
   141  				// Read name and metadata
   142  				pc := struct {
   143  					Name     string
   144  					Metadata map[string]string
   145  				}{}
   146  				err := readFromFile(fPath, &pc, "")
   147  				if err != nil {
   148  					return nil, fmt.Errorf("couldn't read collection metadata: %w", err)
   149  				}
   150  				c.Name = pc.Name
   151  				c.metadata = pc.Metadata
   152  			} else if strings.HasSuffix(collectionDirEntry.Name(), ext) {
   153  				// Read document
   154  				d := &Document{}
   155  				err := readFromFile(fPath, d, "")
   156  				if err != nil {
   157  					return nil, fmt.Errorf("couldn't read document: %w", err)
   158  				}
   159  				c.documents[d.ID] = d
   160  			} else {
   161  				// Might be a file that the user has placed
   162  				continue
   163  			}
   164  		}
   165  		// If we have neither name nor documents, it was likely a user-added
   166  		// directory, so skip it.
   167  		if c.Name == "" && len(c.documents) == 0 {
   168  			continue
   169  		}
   170  		// If we have no name, it means there was no metadata file
   171  		if c.Name == "" {
   172  			return nil, fmt.Errorf("collection metadata file not found: %s", collectionPath)
   173  		}
   174  
   175  		db.collections[c.Name] = c
   176  	}
   177  
   178  	return db, nil
   179  }
   180  
   181  // Import imports the DB from a file at the given path. The file must be encoded
   182  // as gob and can optionally be compressed with flate (as gzip) and encrypted
   183  // with AES-GCM.
   184  // This works for both the in-memory and persistent DBs.
   185  // Existing collections are overwritten.
   186  //
   187  // - filePath: Mandatory, must not be empty
   188  // - encryptionKey: Optional, must be 32 bytes long if provided
   189  //
   190  // Deprecated: Use [DB.ImportFromFile] instead.
   191  func (db *DB) Import(filePath string, encryptionKey string) error {
   192  	return db.ImportFromFile(filePath, encryptionKey)
   193  }
   194  
   195  // ImportFromFile imports the DB from a file at the given path. The file must be
   196  // encoded as gob and can optionally be compressed with flate (as gzip) and encrypted
   197  // with AES-GCM.
   198  // This works for both the in-memory and persistent DBs.
   199  // Existing collections are overwritten.
   200  //
   201  // - filePath: Mandatory, must not be empty
   202  // - encryptionKey: Optional, must be 32 bytes long if provided
   203  func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
   204  	if filePath == "" {
   205  		return fmt.Errorf("file path is empty")
   206  	}
   207  	if encryptionKey != "" {
   208  		// AES 256 requires a 32 byte key
   209  		if len(encryptionKey) != 32 {
   210  			return errors.New("encryption key must be 32 bytes long")
   211  		}
   212  	}
   213  
   214  	// If the file doesn't exist or is a directory, return an error.
   215  	fi, err := os.Stat(filePath)
   216  	if err != nil {
   217  		if errors.Is(err, fs.ErrNotExist) {
   218  			return fmt.Errorf("file doesn't exist: %s", filePath)
   219  		}
   220  		return fmt.Errorf("couldn't get info about the file: %w", err)
   221  	} else if fi.IsDir() {
   222  		return fmt.Errorf("path is a directory: %s", filePath)
   223  	}
   224  
   225  	// Create persistence structs with exported fields so that they can be decoded
   226  	// from gob.
   227  	type persistenceCollection struct {
   228  		Name      string
   229  		Metadata  map[string]string
   230  		Documents map[string]*Document
   231  	}
   232  	persistenceDB := struct {
   233  		Collections map[string]*persistenceCollection
   234  	}{
   235  		Collections: make(map[string]*persistenceCollection, len(db.collections)),
   236  	}
   237  
   238  	db.collectionsLock.Lock()
   239  	defer db.collectionsLock.Unlock()
   240  
   241  	err = readFromFile(filePath, &persistenceDB, encryptionKey)
   242  	if err != nil {
   243  		return fmt.Errorf("couldn't read file: %w", err)
   244  	}
   245  
   246  	for _, pc := range persistenceDB.Collections {
   247  		c := &Collection{
   248  			Name: pc.Name,
   249  
   250  			metadata:  pc.Metadata,
   251  			documents: pc.Documents,
   252  		}
   253  		if db.persistDirectory != "" {
   254  			c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
   255  			c.compress = db.compress
   256  		}
   257  		db.collections[c.Name] = c
   258  	}
   259  
   260  	return nil
   261  }
   262  
   263  // ImportFromReader imports the DB from a reader. The stream must be encoded as
   264  // gob and can optionally be compressed with flate (as gzip) and encrypted with
   265  // AES-GCM.
   266  // This works for both the in-memory and persistent DBs.
   267  // Existing collections are overwritten.
   268  // If the writer has to be closed, it's the caller's responsibility.
   269  //
   270  // - reader: An implementation of [io.ReadSeeker]
   271  // - encryptionKey: Optional, must be 32 bytes long if provided
   272  func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error {
   273  	if encryptionKey != "" {
   274  		// AES 256 requires a 32 byte key
   275  		if len(encryptionKey) != 32 {
   276  			return errors.New("encryption key must be 32 bytes long")
   277  		}
   278  	}
   279  
   280  	// Create persistence structs with exported fields so that they can be decoded
   281  	// from gob.
   282  	type persistenceCollection struct {
   283  		Name      string
   284  		Metadata  map[string]string
   285  		Documents map[string]*Document
   286  	}
   287  	persistenceDB := struct {
   288  		Collections map[string]*persistenceCollection
   289  	}{
   290  		Collections: make(map[string]*persistenceCollection, len(db.collections)),
   291  	}
   292  
   293  	db.collectionsLock.Lock()
   294  	defer db.collectionsLock.Unlock()
   295  
   296  	err := readFromReader(reader, &persistenceDB, encryptionKey)
   297  	if err != nil {
   298  		return fmt.Errorf("couldn't read stream: %w", err)
   299  	}
   300  
   301  	for _, pc := range persistenceDB.Collections {
   302  		c := &Collection{
   303  			Name: pc.Name,
   304  
   305  			metadata:  pc.Metadata,
   306  			documents: pc.Documents,
   307  		}
   308  		if db.persistDirectory != "" {
   309  			c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
   310  			c.compress = db.compress
   311  		}
   312  		db.collections[c.Name] = c
   313  	}
   314  
   315  	return nil
   316  }
   317  
   318  // Export exports the DB to a file at the given path. The file is encoded as gob,
   319  // optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM.
   320  // This works for both the in-memory and persistent DBs.
   321  // If the file exists, it's overwritten, otherwise created.
   322  //
   323  //   - filePath: If empty, it defaults to "./chromem-go.gob" (+ ".gz" + ".enc")
   324  //   - compress: Optional. Compresses as gzip if true.
   325  //   - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
   326  //     long if provided.
   327  //
   328  // Deprecated: Use [DB.ExportToFile] instead.
   329  func (db *DB) Export(filePath string, compress bool, encryptionKey string) error {
   330  	return db.ExportToFile(filePath, compress, encryptionKey)
   331  }
   332  
   333  // ExportToFile exports the DB to a file at the given path. The file is encoded as gob,
   334  // optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM.
   335  // This works for both the in-memory and persistent DBs.
   336  // If the file exists, it's overwritten, otherwise created.
   337  //
   338  //   - filePath: If empty, it defaults to "./chromem-go.gob" (+ ".gz" + ".enc")
   339  //   - compress: Optional. Compresses as gzip if true.
   340  //   - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
   341  //     long if provided.
   342  func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string) error {
   343  	if filePath == "" {
   344  		filePath = "./chromem-go.gob"
   345  		if compress {
   346  			filePath += ".gz"
   347  		}
   348  		if encryptionKey != "" {
   349  			filePath += ".enc"
   350  		}
   351  	}
   352  	if encryptionKey != "" {
   353  		// AES 256 requires a 32 byte key
   354  		if len(encryptionKey) != 32 {
   355  			return errors.New("encryption key must be 32 bytes long")
   356  		}
   357  	}
   358  
   359  	// Create persistence structs with exported fields so that they can be encoded
   360  	// as gob.
   361  	type persistenceCollection struct {
   362  		Name      string
   363  		Metadata  map[string]string
   364  		Documents map[string]*Document
   365  	}
   366  	persistenceDB := struct {
   367  		Collections map[string]*persistenceCollection
   368  	}{
   369  		Collections: make(map[string]*persistenceCollection, len(db.collections)),
   370  	}
   371  
   372  	db.collectionsLock.RLock()
   373  	defer db.collectionsLock.RUnlock()
   374  
   375  	for k, v := range db.collections {
   376  		persistenceDB.Collections[k] = &persistenceCollection{
   377  			Name:      v.Name,
   378  			Metadata:  v.metadata,
   379  			Documents: v.documents,
   380  		}
   381  	}
   382  
   383  	err := persistToFile(filePath, persistenceDB, compress, encryptionKey)
   384  	if err != nil {
   385  		return fmt.Errorf("couldn't export DB: %w", err)
   386  	}
   387  
   388  	return nil
   389  }
   390  
   391  // ExportToWriter exports the DB to a writer. The stream is encoded as gob,
   392  // optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM.
   393  // This works for both the in-memory and persistent DBs.
   394  // If the writer has to be closed, it's the caller's responsibility.
   395  //
   396  //   - writer: An implementation of [io.Writer]
   397  //   - compress: Optional. Compresses as gzip if true.
   398  //   - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
   399  //     long if provided.
   400  func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string) error {
   401  	if encryptionKey != "" {
   402  		// AES 256 requires a 32 byte key
   403  		if len(encryptionKey) != 32 {
   404  			return errors.New("encryption key must be 32 bytes long")
   405  		}
   406  	}
   407  
   408  	// Create persistence structs with exported fields so that they can be encoded
   409  	// as gob.
   410  	type persistenceCollection struct {
   411  		Name      string
   412  		Metadata  map[string]string
   413  		Documents map[string]*Document
   414  	}
   415  	persistenceDB := struct {
   416  		Collections map[string]*persistenceCollection
   417  	}{
   418  		Collections: make(map[string]*persistenceCollection, len(db.collections)),
   419  	}
   420  
   421  	db.collectionsLock.RLock()
   422  	defer db.collectionsLock.RUnlock()
   423  
   424  	for k, v := range db.collections {
   425  		persistenceDB.Collections[k] = &persistenceCollection{
   426  			Name:      v.Name,
   427  			Metadata:  v.metadata,
   428  			Documents: v.documents,
   429  		}
   430  	}
   431  
   432  	err := persistToWriter(writer, persistenceDB, compress, encryptionKey)
   433  	if err != nil {
   434  		return fmt.Errorf("couldn't export DB: %w", err)
   435  	}
   436  
   437  	return nil
   438  }
   439  
   440  // CreateCollection creates a new collection with the given name and metadata.
   441  //
   442  //   - name: The name of the collection to create.
   443  //   - metadata: Optional metadata to associate with the collection.
   444  //   - embeddingFunc: Optional function to use to embed documents.
   445  //     Uses the default embedding function if not provided.
   446  func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
   447  	if name == "" {
   448  		return nil, errors.New("collection name is empty")
   449  	}
   450  	if embeddingFunc == nil {
   451  		embeddingFunc = NewEmbeddingFuncDefault()
   452  	}
   453  	collection, err := newCollection(name, metadata, embeddingFunc, db.persistDirectory, db.compress)
   454  	if err != nil {
   455  		return nil, fmt.Errorf("couldn't create collection: %w", err)
   456  	}
   457  
   458  	db.collectionsLock.Lock()
   459  	defer db.collectionsLock.Unlock()
   460  	db.collections[name] = collection
   461  	return collection, nil
   462  }
   463  
   464  // ListCollections returns all collections in the DB, mapping name->Collection.
   465  // The returned map is a copy of the internal map, so it's safe to directly modify
   466  // the map itself. Direct modifications of the map won't reflect on the DB's map.
   467  // To do that use the DB's methods like CreateCollection() and DeleteCollection().
   468  // The map is not an entirely deep clone, so the collections themselves are still
   469  // the original ones. Any methods on the collections like Add() for adding documents
   470  // will be reflected on the DB's collections and are concurrency-safe.
   471  func (db *DB) ListCollections() map[string]*Collection {
   472  	db.collectionsLock.RLock()
   473  	defer db.collectionsLock.RUnlock()
   474  
   475  	res := make(map[string]*Collection, len(db.collections))
   476  	for k, v := range db.collections {
   477  		res[k] = v
   478  	}
   479  
   480  	return res
   481  }
   482  
   483  // GetCollection returns the collection with the given name.
   484  // The embeddingFunc param is only used if the DB is persistent and was just loaded
   485  // from storage, in which case no embedding func is set yet (funcs are not (de-)serializable).
   486  // It can be nil, in which case the default one will be used.
   487  // The returned collection is a reference to the original collection, so any methods
   488  // on the collection like Add() will be reflected on the DB's collection. Those
   489  // operations are concurrency-safe.
   490  // If the collection doesn't exist, this returns nil.
   491  func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collection {
   492  	db.collectionsLock.RLock()
   493  	defer db.collectionsLock.RUnlock()
   494  
   495  	c, ok := db.collections[name]
   496  	if !ok {
   497  		return nil
   498  	}
   499  
   500  	if c.embed == nil {
   501  		if embeddingFunc == nil {
   502  			c.embed = NewEmbeddingFuncDefault()
   503  		} else {
   504  			c.embed = embeddingFunc
   505  		}
   506  	}
   507  	return c
   508  }
   509  
   510  // GetOrCreateCollection returns the collection with the given name if it exists
   511  // in the DB, or otherwise creates it. When creating:
   512  //
   513  //   - name: The name of the collection to create.
   514  //   - metadata: Optional metadata to associate with the collection.
   515  //   - embeddingFunc: Optional function to use to embed documents.
   516  //     Uses the default embedding function if not provided.
   517  func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
   518  	// No need to lock here, because the methods we call do that.
   519  	collection := db.GetCollection(name, embeddingFunc)
   520  	if collection == nil {
   521  		var err error
   522  		collection, err = db.CreateCollection(name, metadata, embeddingFunc)
   523  		if err != nil {
   524  			return nil, fmt.Errorf("couldn't create collection: %w", err)
   525  		}
   526  	}
   527  	return collection, nil
   528  }
   529  
   530  // DeleteCollection deletes the collection with the given name.
   531  // If the collection doesn't exist, this is a no-op.
   532  // If the DB is persistent, it also removes the collection's directory.
   533  // You shouldn't hold any references to the collection after calling this method.
   534  func (db *DB) DeleteCollection(name string) error {
   535  	db.collectionsLock.Lock()
   536  	defer db.collectionsLock.Unlock()
   537  
   538  	col, ok := db.collections[name]
   539  	if !ok {
   540  		return nil
   541  	}
   542  
   543  	if db.persistDirectory != "" {
   544  		collectionPath := col.persistDirectory
   545  		err := os.RemoveAll(collectionPath)
   546  		if err != nil {
   547  			return fmt.Errorf("couldn't delete collection directory: %w", err)
   548  		}
   549  	}
   550  
   551  	delete(db.collections, name)
   552  	return nil
   553  }
   554  
   555  // Reset removes all collections from the DB.
   556  // If the DB is persistent, it also removes all contents of the DB directory.
   557  // You shouldn't hold any references to old collections after calling this method.
   558  func (db *DB) Reset() error {
   559  	db.collectionsLock.Lock()
   560  	defer db.collectionsLock.Unlock()
   561  
   562  	if db.persistDirectory != "" {
   563  		err := os.RemoveAll(db.persistDirectory)
   564  		if err != nil {
   565  			return fmt.Errorf("couldn't delete persistence directory: %w", err)
   566  		}
   567  		// Recreate empty root level directory
   568  		err = os.MkdirAll(db.persistDirectory, 0o700)
   569  		if err != nil {
   570  			return fmt.Errorf("couldn't recreate persistence directory: %w", err)
   571  		}
   572  	}
   573  
   574  	// Just assign a new map, the GC will take care of the rest.
   575  	db.collections = make(map[string]*Collection)
   576  	return nil
   577  }
   578  

View as plain text