1  package chromem
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/fs"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  	"sync"
    13  )
    14  
    15  
    16  
    17  
    18  
    19  
    20  
    21  type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error)
    22  
    23  
    24  
    25  
    26  
    27  
    28  type DB struct {
    29  	collections     map[string]*Collection
    30  	collectionsLock sync.RWMutex
    31  
    32  	persistDirectory string
    33  	compress         bool
    34  
    35  	
    36  	
    37  }
    38  
    39  
    40  
    41  
    42  
    43  func NewDB() *DB {
    44  	return &DB{
    45  		collections: make(map[string]*Collection),
    46  	}
    47  }
    48  
    49  
    50  
    51  
    52  
    53  
    54  
    55  
    56  
    57  
    58  
    59  
    60  
    61  
    62  
    63  
    64  
    65  
    66  func NewPersistentDB(path string, compress bool) (*DB, error) {
    67  	if path == "" {
    68  		path = "./chromem-go"
    69  	} else {
    70  		
    71  		path = filepath.Clean(path)
    72  	}
    73  
    74  	
    75  	ext := ".gob"
    76  	if compress {
    77  		ext += ".gz"
    78  	}
    79  
    80  	db := &DB{
    81  		collections:      make(map[string]*Collection),
    82  		persistDirectory: path,
    83  		compress:         compress,
    84  	}
    85  
    86  	
    87  	fi, err := os.Stat(path)
    88  	if err != nil {
    89  		if errors.Is(err, fs.ErrNotExist) {
    90  			err := os.MkdirAll(path, 0o700)
    91  			if err != nil {
    92  				return nil, fmt.Errorf("couldn't create persistence directory: %w", err)
    93  			}
    94  
    95  			return db, nil
    96  		}
    97  		return nil, fmt.Errorf("couldn't get info about persistence directory: %w", err)
    98  	} else if !fi.IsDir() {
    99  		return nil, fmt.Errorf("path is not a directory: %s", path)
   100  	}
   101  
   102  	
   103  	dirEntries, err := os.ReadDir(path)
   104  	if err != nil {
   105  		return nil, fmt.Errorf("couldn't read persistence directory: %w", err)
   106  	}
   107  	for _, dirEntry := range dirEntries {
   108  		
   109  		
   110  		if !dirEntry.IsDir() {
   111  			continue
   112  		}
   113  		
   114  		
   115  		
   116  		
   117  		collectionPath := filepath.Join(path, dirEntry.Name())
   118  		collectionDirEntries, err := os.ReadDir(collectionPath)
   119  		if err != nil {
   120  			return nil, fmt.Errorf("couldn't read collection directory: %w", err)
   121  		}
   122  		c := &Collection{
   123  			documents:        make(map[string]*Document),
   124  			persistDirectory: collectionPath,
   125  			compress:         compress,
   126  			
   127  			
   128  			
   129  			
   130  		}
   131  		for _, collectionDirEntry := range collectionDirEntries {
   132  			
   133  			
   134  			if collectionDirEntry.IsDir() {
   135  				continue
   136  			}
   137  
   138  			fPath := filepath.Join(collectionPath, collectionDirEntry.Name())
   139  			
   140  			if collectionDirEntry.Name() == metadataFileName+ext {
   141  				
   142  				pc := struct {
   143  					Name     string
   144  					Metadata map[string]string
   145  				}{}
   146  				err := readFromFile(fPath, &pc, "")
   147  				if err != nil {
   148  					return nil, fmt.Errorf("couldn't read collection metadata: %w", err)
   149  				}
   150  				c.Name = pc.Name
   151  				c.metadata = pc.Metadata
   152  			} else if strings.HasSuffix(collectionDirEntry.Name(), ext) {
   153  				
   154  				d := &Document{}
   155  				err := readFromFile(fPath, d, "")
   156  				if err != nil {
   157  					return nil, fmt.Errorf("couldn't read document: %w", err)
   158  				}
   159  				c.documents[d.ID] = d
   160  			} else {
   161  				
   162  				continue
   163  			}
   164  		}
   165  		
   166  		
   167  		if c.Name == "" && len(c.documents) == 0 {
   168  			continue
   169  		}
   170  		
   171  		if c.Name == "" {
   172  			return nil, fmt.Errorf("collection metadata file not found: %s", collectionPath)
   173  		}
   174  
   175  		db.collections[c.Name] = c
   176  	}
   177  
   178  	return db, nil
   179  }
   180  
   181  
   182  
   183  
   184  
   185  
   186  
   187  
   188  
   189  
   190  
   191  func (db *DB) Import(filePath string, encryptionKey string) error {
   192  	return db.ImportFromFile(filePath, encryptionKey)
   193  }
   194  
   195  
   196  
   197  
   198  
   199  
   200  
   201  
   202  
   203  func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
   204  	if filePath == "" {
   205  		return fmt.Errorf("file path is empty")
   206  	}
   207  	if encryptionKey != "" {
   208  		
   209  		if len(encryptionKey) != 32 {
   210  			return errors.New("encryption key must be 32 bytes long")
   211  		}
   212  	}
   213  
   214  	
   215  	fi, err := os.Stat(filePath)
   216  	if err != nil {
   217  		if errors.Is(err, fs.ErrNotExist) {
   218  			return fmt.Errorf("file doesn't exist: %s", filePath)
   219  		}
   220  		return fmt.Errorf("couldn't get info about the file: %w", err)
   221  	} else if fi.IsDir() {
   222  		return fmt.Errorf("path is a directory: %s", filePath)
   223  	}
   224  
   225  	
   226  	
   227  	type persistenceCollection struct {
   228  		Name      string
   229  		Metadata  map[string]string
   230  		Documents map[string]*Document
   231  	}
   232  	persistenceDB := struct {
   233  		Collections map[string]*persistenceCollection
   234  	}{
   235  		Collections: make(map[string]*persistenceCollection, len(db.collections)),
   236  	}
   237  
   238  	db.collectionsLock.Lock()
   239  	defer db.collectionsLock.Unlock()
   240  
   241  	err = readFromFile(filePath, &persistenceDB, encryptionKey)
   242  	if err != nil {
   243  		return fmt.Errorf("couldn't read file: %w", err)
   244  	}
   245  
   246  	for _, pc := range persistenceDB.Collections {
   247  		c := &Collection{
   248  			Name: pc.Name,
   249  
   250  			metadata:  pc.Metadata,
   251  			documents: pc.Documents,
   252  		}
   253  		if db.persistDirectory != "" {
   254  			c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
   255  			c.compress = db.compress
   256  		}
   257  		db.collections[c.Name] = c
   258  	}
   259  
   260  	return nil
   261  }
   262  
   263  
   264  
   265  
   266  
   267  
   268  
   269  
   270  
   271  
   272  func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error {
   273  	if encryptionKey != "" {
   274  		
   275  		if len(encryptionKey) != 32 {
   276  			return errors.New("encryption key must be 32 bytes long")
   277  		}
   278  	}
   279  
   280  	
   281  	
   282  	type persistenceCollection struct {
   283  		Name      string
   284  		Metadata  map[string]string
   285  		Documents map[string]*Document
   286  	}
   287  	persistenceDB := struct {
   288  		Collections map[string]*persistenceCollection
   289  	}{
   290  		Collections: make(map[string]*persistenceCollection, len(db.collections)),
   291  	}
   292  
   293  	db.collectionsLock.Lock()
   294  	defer db.collectionsLock.Unlock()
   295  
   296  	err := readFromReader(reader, &persistenceDB, encryptionKey)
   297  	if err != nil {
   298  		return fmt.Errorf("couldn't read stream: %w", err)
   299  	}
   300  
   301  	for _, pc := range persistenceDB.Collections {
   302  		c := &Collection{
   303  			Name: pc.Name,
   304  
   305  			metadata:  pc.Metadata,
   306  			documents: pc.Documents,
   307  		}
   308  		if db.persistDirectory != "" {
   309  			c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
   310  			c.compress = db.compress
   311  		}
   312  		db.collections[c.Name] = c
   313  	}
   314  
   315  	return nil
   316  }
   317  
   318  
   319  
   320  
   321  
   322  
   323  
   324  
   325  
   326  
   327  
   328  
   329  func (db *DB) Export(filePath string, compress bool, encryptionKey string) error {
   330  	return db.ExportToFile(filePath, compress, encryptionKey)
   331  }
   332  
   333  
   334  
   335  
   336  
   337  
   338  
   339  
   340  
   341  
   342  func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string) error {
   343  	if filePath == "" {
   344  		filePath = "./chromem-go.gob"
   345  		if compress {
   346  			filePath += ".gz"
   347  		}
   348  		if encryptionKey != "" {
   349  			filePath += ".enc"
   350  		}
   351  	}
   352  	if encryptionKey != "" {
   353  		
   354  		if len(encryptionKey) != 32 {
   355  			return errors.New("encryption key must be 32 bytes long")
   356  		}
   357  	}
   358  
   359  	
   360  	
   361  	type persistenceCollection struct {
   362  		Name      string
   363  		Metadata  map[string]string
   364  		Documents map[string]*Document
   365  	}
   366  	persistenceDB := struct {
   367  		Collections map[string]*persistenceCollection
   368  	}{
   369  		Collections: make(map[string]*persistenceCollection, len(db.collections)),
   370  	}
   371  
   372  	db.collectionsLock.RLock()
   373  	defer db.collectionsLock.RUnlock()
   374  
   375  	for k, v := range db.collections {
   376  		persistenceDB.Collections[k] = &persistenceCollection{
   377  			Name:      v.Name,
   378  			Metadata:  v.metadata,
   379  			Documents: v.documents,
   380  		}
   381  	}
   382  
   383  	err := persistToFile(filePath, persistenceDB, compress, encryptionKey)
   384  	if err != nil {
   385  		return fmt.Errorf("couldn't export DB: %w", err)
   386  	}
   387  
   388  	return nil
   389  }
   390  
   391  
   392  
   393  
   394  
   395  
   396  
   397  
   398  
   399  
   400  func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string) error {
   401  	if encryptionKey != "" {
   402  		
   403  		if len(encryptionKey) != 32 {
   404  			return errors.New("encryption key must be 32 bytes long")
   405  		}
   406  	}
   407  
   408  	
   409  	
   410  	type persistenceCollection struct {
   411  		Name      string
   412  		Metadata  map[string]string
   413  		Documents map[string]*Document
   414  	}
   415  	persistenceDB := struct {
   416  		Collections map[string]*persistenceCollection
   417  	}{
   418  		Collections: make(map[string]*persistenceCollection, len(db.collections)),
   419  	}
   420  
   421  	db.collectionsLock.RLock()
   422  	defer db.collectionsLock.RUnlock()
   423  
   424  	for k, v := range db.collections {
   425  		persistenceDB.Collections[k] = &persistenceCollection{
   426  			Name:      v.Name,
   427  			Metadata:  v.metadata,
   428  			Documents: v.documents,
   429  		}
   430  	}
   431  
   432  	err := persistToWriter(writer, persistenceDB, compress, encryptionKey)
   433  	if err != nil {
   434  		return fmt.Errorf("couldn't export DB: %w", err)
   435  	}
   436  
   437  	return nil
   438  }
   439  
   440  
   441  
   442  
   443  
   444  
   445  
   446  func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
   447  	if name == "" {
   448  		return nil, errors.New("collection name is empty")
   449  	}
   450  	if embeddingFunc == nil {
   451  		embeddingFunc = NewEmbeddingFuncDefault()
   452  	}
   453  	collection, err := newCollection(name, metadata, embeddingFunc, db.persistDirectory, db.compress)
   454  	if err != nil {
   455  		return nil, fmt.Errorf("couldn't create collection: %w", err)
   456  	}
   457  
   458  	db.collectionsLock.Lock()
   459  	defer db.collectionsLock.Unlock()
   460  	db.collections[name] = collection
   461  	return collection, nil
   462  }
   463  
   464  
   465  
   466  
   467  
   468  
   469  
   470  
   471  func (db *DB) ListCollections() map[string]*Collection {
   472  	db.collectionsLock.RLock()
   473  	defer db.collectionsLock.RUnlock()
   474  
   475  	res := make(map[string]*Collection, len(db.collections))
   476  	for k, v := range db.collections {
   477  		res[k] = v
   478  	}
   479  
   480  	return res
   481  }
   482  
   483  
   484  
   485  
   486  
   487  
   488  
   489  
   490  
   491  func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collection {
   492  	db.collectionsLock.RLock()
   493  	defer db.collectionsLock.RUnlock()
   494  
   495  	c, ok := db.collections[name]
   496  	if !ok {
   497  		return nil
   498  	}
   499  
   500  	if c.embed == nil {
   501  		if embeddingFunc == nil {
   502  			c.embed = NewEmbeddingFuncDefault()
   503  		} else {
   504  			c.embed = embeddingFunc
   505  		}
   506  	}
   507  	return c
   508  }
   509  
   510  
   511  
   512  
   513  
   514  
   515  
   516  
   517  func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
   518  	
   519  	collection := db.GetCollection(name, embeddingFunc)
   520  	if collection == nil {
   521  		var err error
   522  		collection, err = db.CreateCollection(name, metadata, embeddingFunc)
   523  		if err != nil {
   524  			return nil, fmt.Errorf("couldn't create collection: %w", err)
   525  		}
   526  	}
   527  	return collection, nil
   528  }
   529  
   530  
   531  
   532  
   533  
   534  func (db *DB) DeleteCollection(name string) error {
   535  	db.collectionsLock.Lock()
   536  	defer db.collectionsLock.Unlock()
   537  
   538  	col, ok := db.collections[name]
   539  	if !ok {
   540  		return nil
   541  	}
   542  
   543  	if db.persistDirectory != "" {
   544  		collectionPath := col.persistDirectory
   545  		err := os.RemoveAll(collectionPath)
   546  		if err != nil {
   547  			return fmt.Errorf("couldn't delete collection directory: %w", err)
   548  		}
   549  	}
   550  
   551  	delete(db.collections, name)
   552  	return nil
   553  }
   554  
   555  
   556  
   557  
   558  func (db *DB) Reset() error {
   559  	db.collectionsLock.Lock()
   560  	defer db.collectionsLock.Unlock()
   561  
   562  	if db.persistDirectory != "" {
   563  		err := os.RemoveAll(db.persistDirectory)
   564  		if err != nil {
   565  			return fmt.Errorf("couldn't delete persistence directory: %w", err)
   566  		}
   567  		
   568  		err = os.MkdirAll(db.persistDirectory, 0o700)
   569  		if err != nil {
   570  			return fmt.Errorf("couldn't recreate persistence directory: %w", err)
   571  		}
   572  	}
   573  
   574  	
   575  	db.collections = make(map[string]*Collection)
   576  	return nil
   577  }
   578  
View as plain text