1 package chromem
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "io/fs"
9 "os"
10 "path/filepath"
11 "strings"
12 "sync"
13 )
14
15
16
17
18
19
20
21 type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error)
22
23
24
25
26
27
28 type DB struct {
29 collections map[string]*Collection
30 collectionsLock sync.RWMutex
31
32 persistDirectory string
33 compress bool
34
35
36
37 }
38
39
40
41
42
43 func NewDB() *DB {
44 return &DB{
45 collections: make(map[string]*Collection),
46 }
47 }
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66 func NewPersistentDB(path string, compress bool) (*DB, error) {
67 if path == "" {
68 path = "./chromem-go"
69 } else {
70
71 path = filepath.Clean(path)
72 }
73
74
75 ext := ".gob"
76 if compress {
77 ext += ".gz"
78 }
79
80 db := &DB{
81 collections: make(map[string]*Collection),
82 persistDirectory: path,
83 compress: compress,
84 }
85
86
87 fi, err := os.Stat(path)
88 if err != nil {
89 if errors.Is(err, fs.ErrNotExist) {
90 err := os.MkdirAll(path, 0o700)
91 if err != nil {
92 return nil, fmt.Errorf("couldn't create persistence directory: %w", err)
93 }
94
95 return db, nil
96 }
97 return nil, fmt.Errorf("couldn't get info about persistence directory: %w", err)
98 } else if !fi.IsDir() {
99 return nil, fmt.Errorf("path is not a directory: %s", path)
100 }
101
102
103 dirEntries, err := os.ReadDir(path)
104 if err != nil {
105 return nil, fmt.Errorf("couldn't read persistence directory: %w", err)
106 }
107 for _, dirEntry := range dirEntries {
108
109
110 if !dirEntry.IsDir() {
111 continue
112 }
113
114
115
116
117 collectionPath := filepath.Join(path, dirEntry.Name())
118 collectionDirEntries, err := os.ReadDir(collectionPath)
119 if err != nil {
120 return nil, fmt.Errorf("couldn't read collection directory: %w", err)
121 }
122 c := &Collection{
123 documents: make(map[string]*Document),
124 persistDirectory: collectionPath,
125 compress: compress,
126
127
128
129
130 }
131 for _, collectionDirEntry := range collectionDirEntries {
132
133
134 if collectionDirEntry.IsDir() {
135 continue
136 }
137
138 fPath := filepath.Join(collectionPath, collectionDirEntry.Name())
139
140 if collectionDirEntry.Name() == metadataFileName+ext {
141
142 pc := struct {
143 Name string
144 Metadata map[string]string
145 }{}
146 err := readFromFile(fPath, &pc, "")
147 if err != nil {
148 return nil, fmt.Errorf("couldn't read collection metadata: %w", err)
149 }
150 c.Name = pc.Name
151 c.metadata = pc.Metadata
152 } else if strings.HasSuffix(collectionDirEntry.Name(), ext) {
153
154 d := &Document{}
155 err := readFromFile(fPath, d, "")
156 if err != nil {
157 return nil, fmt.Errorf("couldn't read document: %w", err)
158 }
159 c.documents[d.ID] = d
160 } else {
161
162 continue
163 }
164 }
165
166
167 if c.Name == "" && len(c.documents) == 0 {
168 continue
169 }
170
171 if c.Name == "" {
172 return nil, fmt.Errorf("collection metadata file not found: %s", collectionPath)
173 }
174
175 db.collections[c.Name] = c
176 }
177
178 return db, nil
179 }
180
181
182
183
184
185
186
187
188
189
190
191 func (db *DB) Import(filePath string, encryptionKey string) error {
192 return db.ImportFromFile(filePath, encryptionKey)
193 }
194
195
196
197
198
199
200
201
202
203 func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
204 if filePath == "" {
205 return fmt.Errorf("file path is empty")
206 }
207 if encryptionKey != "" {
208
209 if len(encryptionKey) != 32 {
210 return errors.New("encryption key must be 32 bytes long")
211 }
212 }
213
214
215 fi, err := os.Stat(filePath)
216 if err != nil {
217 if errors.Is(err, fs.ErrNotExist) {
218 return fmt.Errorf("file doesn't exist: %s", filePath)
219 }
220 return fmt.Errorf("couldn't get info about the file: %w", err)
221 } else if fi.IsDir() {
222 return fmt.Errorf("path is a directory: %s", filePath)
223 }
224
225
226
227 type persistenceCollection struct {
228 Name string
229 Metadata map[string]string
230 Documents map[string]*Document
231 }
232 persistenceDB := struct {
233 Collections map[string]*persistenceCollection
234 }{
235 Collections: make(map[string]*persistenceCollection, len(db.collections)),
236 }
237
238 db.collectionsLock.Lock()
239 defer db.collectionsLock.Unlock()
240
241 err = readFromFile(filePath, &persistenceDB, encryptionKey)
242 if err != nil {
243 return fmt.Errorf("couldn't read file: %w", err)
244 }
245
246 for _, pc := range persistenceDB.Collections {
247 c := &Collection{
248 Name: pc.Name,
249
250 metadata: pc.Metadata,
251 documents: pc.Documents,
252 }
253 if db.persistDirectory != "" {
254 c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
255 c.compress = db.compress
256 }
257 db.collections[c.Name] = c
258 }
259
260 return nil
261 }
262
263
264
265
266
267
268
269
270
271
272 func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error {
273 if encryptionKey != "" {
274
275 if len(encryptionKey) != 32 {
276 return errors.New("encryption key must be 32 bytes long")
277 }
278 }
279
280
281
282 type persistenceCollection struct {
283 Name string
284 Metadata map[string]string
285 Documents map[string]*Document
286 }
287 persistenceDB := struct {
288 Collections map[string]*persistenceCollection
289 }{
290 Collections: make(map[string]*persistenceCollection, len(db.collections)),
291 }
292
293 db.collectionsLock.Lock()
294 defer db.collectionsLock.Unlock()
295
296 err := readFromReader(reader, &persistenceDB, encryptionKey)
297 if err != nil {
298 return fmt.Errorf("couldn't read stream: %w", err)
299 }
300
301 for _, pc := range persistenceDB.Collections {
302 c := &Collection{
303 Name: pc.Name,
304
305 metadata: pc.Metadata,
306 documents: pc.Documents,
307 }
308 if db.persistDirectory != "" {
309 c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
310 c.compress = db.compress
311 }
312 db.collections[c.Name] = c
313 }
314
315 return nil
316 }
317
318
319
320
321
322
323
324
325
326
327
328
329 func (db *DB) Export(filePath string, compress bool, encryptionKey string) error {
330 return db.ExportToFile(filePath, compress, encryptionKey)
331 }
332
333
334
335
336
337
338
339
340
341
342 func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string) error {
343 if filePath == "" {
344 filePath = "./chromem-go.gob"
345 if compress {
346 filePath += ".gz"
347 }
348 if encryptionKey != "" {
349 filePath += ".enc"
350 }
351 }
352 if encryptionKey != "" {
353
354 if len(encryptionKey) != 32 {
355 return errors.New("encryption key must be 32 bytes long")
356 }
357 }
358
359
360
361 type persistenceCollection struct {
362 Name string
363 Metadata map[string]string
364 Documents map[string]*Document
365 }
366 persistenceDB := struct {
367 Collections map[string]*persistenceCollection
368 }{
369 Collections: make(map[string]*persistenceCollection, len(db.collections)),
370 }
371
372 db.collectionsLock.RLock()
373 defer db.collectionsLock.RUnlock()
374
375 for k, v := range db.collections {
376 persistenceDB.Collections[k] = &persistenceCollection{
377 Name: v.Name,
378 Metadata: v.metadata,
379 Documents: v.documents,
380 }
381 }
382
383 err := persistToFile(filePath, persistenceDB, compress, encryptionKey)
384 if err != nil {
385 return fmt.Errorf("couldn't export DB: %w", err)
386 }
387
388 return nil
389 }
390
391
392
393
394
395
396
397
398
399
400 func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string) error {
401 if encryptionKey != "" {
402
403 if len(encryptionKey) != 32 {
404 return errors.New("encryption key must be 32 bytes long")
405 }
406 }
407
408
409
410 type persistenceCollection struct {
411 Name string
412 Metadata map[string]string
413 Documents map[string]*Document
414 }
415 persistenceDB := struct {
416 Collections map[string]*persistenceCollection
417 }{
418 Collections: make(map[string]*persistenceCollection, len(db.collections)),
419 }
420
421 db.collectionsLock.RLock()
422 defer db.collectionsLock.RUnlock()
423
424 for k, v := range db.collections {
425 persistenceDB.Collections[k] = &persistenceCollection{
426 Name: v.Name,
427 Metadata: v.metadata,
428 Documents: v.documents,
429 }
430 }
431
432 err := persistToWriter(writer, persistenceDB, compress, encryptionKey)
433 if err != nil {
434 return fmt.Errorf("couldn't export DB: %w", err)
435 }
436
437 return nil
438 }
439
440
441
442
443
444
445
446 func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
447 if name == "" {
448 return nil, errors.New("collection name is empty")
449 }
450 if embeddingFunc == nil {
451 embeddingFunc = NewEmbeddingFuncDefault()
452 }
453 collection, err := newCollection(name, metadata, embeddingFunc, db.persistDirectory, db.compress)
454 if err != nil {
455 return nil, fmt.Errorf("couldn't create collection: %w", err)
456 }
457
458 db.collectionsLock.Lock()
459 defer db.collectionsLock.Unlock()
460 db.collections[name] = collection
461 return collection, nil
462 }
463
464
465
466
467
468
469
470
471 func (db *DB) ListCollections() map[string]*Collection {
472 db.collectionsLock.RLock()
473 defer db.collectionsLock.RUnlock()
474
475 res := make(map[string]*Collection, len(db.collections))
476 for k, v := range db.collections {
477 res[k] = v
478 }
479
480 return res
481 }
482
483
484
485
486
487
488
489
490
491 func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collection {
492 db.collectionsLock.RLock()
493 defer db.collectionsLock.RUnlock()
494
495 c, ok := db.collections[name]
496 if !ok {
497 return nil
498 }
499
500 if c.embed == nil {
501 if embeddingFunc == nil {
502 c.embed = NewEmbeddingFuncDefault()
503 } else {
504 c.embed = embeddingFunc
505 }
506 }
507 return c
508 }
509
510
511
512
513
514
515
516
517 func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
518
519 collection := db.GetCollection(name, embeddingFunc)
520 if collection == nil {
521 var err error
522 collection, err = db.CreateCollection(name, metadata, embeddingFunc)
523 if err != nil {
524 return nil, fmt.Errorf("couldn't create collection: %w", err)
525 }
526 }
527 return collection, nil
528 }
529
530
531
532
533
534 func (db *DB) DeleteCollection(name string) error {
535 db.collectionsLock.Lock()
536 defer db.collectionsLock.Unlock()
537
538 col, ok := db.collections[name]
539 if !ok {
540 return nil
541 }
542
543 if db.persistDirectory != "" {
544 collectionPath := col.persistDirectory
545 err := os.RemoveAll(collectionPath)
546 if err != nil {
547 return fmt.Errorf("couldn't delete collection directory: %w", err)
548 }
549 }
550
551 delete(db.collections, name)
552 return nil
553 }
554
555
556
557
558 func (db *DB) Reset() error {
559 db.collectionsLock.Lock()
560 defer db.collectionsLock.Unlock()
561
562 if db.persistDirectory != "" {
563 err := os.RemoveAll(db.persistDirectory)
564 if err != nil {
565 return fmt.Errorf("couldn't delete persistence directory: %w", err)
566 }
567
568 err = os.MkdirAll(db.persistDirectory, 0o700)
569 if err != nil {
570 return fmt.Errorf("couldn't recreate persistence directory: %w", err)
571 }
572 }
573
574
575 db.collections = make(map[string]*Collection)
576 return nil
577 }
578
View as plain text