1 package chromem
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "path/filepath"
8 "slices"
9 "sync"
10 )
11
12
13
14
15 type Collection struct {
16 Name string
17
18 metadata map[string]string
19 documents map[string]*Document
20 documentsLock sync.RWMutex
21 embed EmbeddingFunc
22
23 persistDirectory string
24 compress bool
25
26
27
28 }
29
30
31
32 func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string, compress bool) (*Collection, error) {
33
34
35 m := make(map[string]string, len(metadata))
36 for k, v := range metadata {
37 m[k] = v
38 }
39
40 c := &Collection{
41 Name: name,
42
43 metadata: m,
44 documents: make(map[string]*Document),
45 embed: embed,
46 }
47
48
49 if dbDir != "" {
50 safeName := hash2hex(name)
51 c.persistDirectory = filepath.Join(dbDir, safeName)
52 c.compress = compress
53
54 metadataPath := filepath.Join(c.persistDirectory, metadataFileName)
55 metadataPath += ".gob"
56 if c.compress {
57 metadataPath += ".gz"
58 }
59 pc := struct {
60 Name string
61 Metadata map[string]string
62 }{
63 Name: name,
64 Metadata: m,
65 }
66 err := persistToFile(metadataPath, pc, compress, "")
67 if err != nil {
68 return nil, fmt.Errorf("couldn't persist collection metadata: %w", err)
69 }
70 }
71
72 return c, nil
73 }
74
75
76
77
78
79
80
81
82
83
84
85 func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string) error {
86 return c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 1)
87 }
88
89
90
91
92
93
94 func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string, concurrency int) error {
95 if len(ids) == 0 {
96 return errors.New("ids are empty")
97 }
98 if len(embeddings) == 0 && len(contents) == 0 {
99 return errors.New("either embeddings or contents must be filled")
100 }
101 if len(embeddings) != 0 {
102 if len(embeddings) != len(ids) {
103 return errors.New("ids and embeddings must have the same length")
104 }
105 } else {
106
107 embeddings = make([][]float32, len(ids))
108 }
109 if len(metadatas) != 0 {
110 if len(ids) != len(metadatas) {
111 return errors.New("when metadatas is not empty it must have the same length as ids")
112 }
113 } else {
114
115 metadatas = make([]map[string]string, len(ids))
116 }
117 if len(contents) != 0 {
118 if len(contents) != len(ids) {
119 return errors.New("ids and contents must have the same length")
120 }
121 } else {
122
123 contents = make([]string, len(ids))
124 }
125 if concurrency < 1 {
126 return errors.New("concurrency must be at least 1")
127 }
128
129
130 docs := make([]Document, 0, len(ids))
131 for i, id := range ids {
132 docs = append(docs, Document{
133 ID: id,
134 Metadata: metadatas[i],
135 Embedding: embeddings[i],
136 Content: contents[i],
137 })
138 }
139
140 return c.AddDocuments(ctx, docs, concurrency)
141 }
142
143
144
145
146
147 func (c *Collection) AddDocuments(ctx context.Context, documents []Document, concurrency int) error {
148 if len(documents) == 0 {
149
150 return errors.New("documents slice is nil or empty")
151 }
152 if concurrency < 1 {
153 return errors.New("concurrency must be at least 1")
154 }
155
156
157 var sharedErr error
158 sharedErrLock := sync.Mutex{}
159 ctx, cancel := context.WithCancelCause(ctx)
160 defer cancel(nil)
161 setSharedErr := func(err error) {
162 sharedErrLock.Lock()
163 defer sharedErrLock.Unlock()
164
165 if sharedErr == nil {
166 sharedErr = err
167
168 cancel(sharedErr)
169 }
170 }
171
172 var wg sync.WaitGroup
173 semaphore := make(chan struct{}, concurrency)
174 for _, doc := range documents {
175 wg.Add(1)
176 go func(doc Document) {
177 defer wg.Done()
178
179
180 if ctx.Err() != nil {
181 return
182 }
183
184
185 semaphore <- struct{}{}
186 defer func() { <-semaphore }()
187
188 err := c.AddDocument(ctx, doc)
189 if err != nil {
190 setSharedErr(fmt.Errorf("couldn't add document '%s': %w", doc.ID, err))
191 return
192 }
193 }(doc)
194 }
195
196 wg.Wait()
197
198 return sharedErr
199 }
200
201
202
203
204 func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
205 if doc.ID == "" {
206 return errors.New("document ID is empty")
207 }
208 if len(doc.Embedding) == 0 && doc.Content == "" {
209 return errors.New("either document embedding or content must be filled")
210 }
211
212
213
214 m := make(map[string]string, len(doc.Metadata))
215 for k, v := range doc.Metadata {
216 m[k] = v
217 }
218
219
220 if len(doc.Embedding) == 0 {
221 embedding, err := c.embed(ctx, doc.Content)
222 if err != nil {
223 return fmt.Errorf("couldn't create embedding of document: %w", err)
224 }
225 doc.Embedding = embedding
226 } else {
227 if !isNormalized(doc.Embedding) {
228 doc.Embedding = normalizeVector(doc.Embedding)
229 }
230 }
231
232 c.documentsLock.Lock()
233
234 c.documents[doc.ID] = &doc
235 c.documentsLock.Unlock()
236
237
238 if c.persistDirectory != "" {
239 docPath := c.getDocPath(doc.ID)
240 err := persistToFile(docPath, doc, c.compress, "")
241 if err != nil {
242 return fmt.Errorf("couldn't persist document to %q: %w", docPath, err)
243 }
244 }
245
246 return nil
247 }
248
249
250
251
252
253
254 func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error {
255
256 if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 {
257 return fmt.Errorf("must have at least one of where, whereDocument or ids")
258 }
259
260 if len(c.documents) == 0 {
261 return nil
262 }
263
264 for k := range whereDocument {
265 if !slices.Contains(supportedFilters, k) {
266 return errors.New("unsupported whereDocument operator")
267 }
268 }
269
270 var docIDs []string
271
272 c.documentsLock.Lock()
273 defer c.documentsLock.Unlock()
274
275 if where != nil || whereDocument != nil {
276
277 filteredDocs := filterDocs(c.documents, where, whereDocument)
278 for _, doc := range filteredDocs {
279 docIDs = append(docIDs, doc.ID)
280 }
281 } else {
282 docIDs = ids
283 }
284
285
286 if len(docIDs) == 0 {
287 return nil
288 }
289
290 for _, docID := range docIDs {
291 delete(c.documents, docID)
292
293
294 if c.persistDirectory != "" {
295 docPath := c.getDocPath(docID)
296 err := removeFile(docPath)
297 if err != nil {
298 return fmt.Errorf("couldn't remove document at %q: %w", docPath, err)
299 }
300 }
301 }
302
303 return nil
304 }
305
306
307 func (c *Collection) Count() int {
308 c.documentsLock.RLock()
309 defer c.documentsLock.RUnlock()
310 return len(c.documents)
311 }
312
313
314 type Result struct {
315 ID string
316 Metadata map[string]string
317 Embedding []float32
318 Content string
319
320
321
322
323 Similarity float32
324 }
325
326
327
328
329
330
331
332
333
334 func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
335 if queryText == "" {
336 return nil, errors.New("queryText is empty")
337 }
338
339 queryVectors, err := c.embed(ctx, queryText)
340 if err != nil {
341 return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
342 }
343
344 return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument)
345 }
346
347
348
349
350
351
352
353
354
355
356 func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
357 if len(queryEmbedding) == 0 {
358 return nil, errors.New("queryEmbedding is empty")
359 }
360 if nResults <= 0 {
361 return nil, errors.New("nResults must be > 0")
362 }
363 c.documentsLock.RLock()
364 defer c.documentsLock.RUnlock()
365 if nResults > len(c.documents) {
366 return nil, errors.New("nResults must be <= the number of documents in the collection")
367 }
368
369 if len(c.documents) == 0 {
370 return nil, nil
371 }
372
373
374 for k := range whereDocument {
375 if !slices.Contains(supportedFilters, k) {
376 return nil, errors.New("unsupported operator")
377 }
378 }
379
380
381 filteredDocs := filterDocs(c.documents, where, whereDocument)
382
383
384 if len(filteredDocs) == 0 {
385 return nil, nil
386 }
387
388
389
390 if !isNormalized(queryEmbedding) {
391 queryEmbedding = normalizeVector(queryEmbedding)
392 }
393
394
395
396 resLen := nResults
397 if len(filteredDocs) < nResults {
398 resLen = len(filteredDocs)
399 }
400
401
402 nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, resLen)
403 if err != nil {
404 return nil, fmt.Errorf("couldn't get most similar docs: %w", err)
405 }
406
407
408 if resLen != len(nMaxDocs) {
409 return nil, fmt.Errorf("internal error: expected %d results, got %d", resLen, len(nMaxDocs))
410 }
411
412 res := make([]Result, 0, resLen)
413 for i := 0; i < resLen; i++ {
414 res = append(res, Result{
415 ID: nMaxDocs[i].docID,
416 Metadata: c.documents[nMaxDocs[i].docID].Metadata,
417 Embedding: c.documents[nMaxDocs[i].docID].Embedding,
418 Content: c.documents[nMaxDocs[i].docID].Content,
419 Similarity: nMaxDocs[i].similarity,
420 })
421 }
422
423 return res, nil
424 }
425
426
427 func (c *Collection) getDocPath(docID string) string {
428 safeID := hash2hex(docID)
429 docPath := filepath.Join(c.persistDirectory, safeID)
430 docPath += ".gob"
431 if c.compress {
432 docPath += ".gz"
433 }
434 return docPath
435 }
436
View as plain text