1 package chromem
2
3 import (
4 "context"
5 "errors"
6 "math/rand"
7 "os"
8 "slices"
9 "strconv"
10 "testing"
11 )
12
13 func TestCollection_Add(t *testing.T) {
14 ctx := context.Background()
15 name := "test"
16 metadata := map[string]string{"foo": "bar"}
17 vectors := []float32{-0.40824828, 0.40824828, 0.81649655}
18 embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
19 return vectors, nil
20 }
21
22
23 db := NewDB()
24 c, err := db.CreateCollection(name, metadata, embeddingFunc)
25 if err != nil {
26 t.Fatal("expected no error, got", err)
27 }
28 if c == nil {
29 t.Fatal("expected collection, got nil")
30 }
31
32
33
34 ids := []string{"1", "2"}
35 embeddings := [][]float32{vectors, vectors}
36 metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
37 contents := []string{"hello world", "hallo welt"}
38
39 tt := []struct {
40 name string
41 ids []string
42 embeddings [][]float32
43 metadatas []map[string]string
44 contents []string
45 }{
46 {
47 name: "No embeddings",
48 ids: ids,
49 embeddings: nil,
50 metadatas: metadatas,
51 contents: contents,
52 },
53 {
54 name: "With embeddings",
55 ids: ids,
56 embeddings: embeddings,
57 metadatas: metadatas,
58 contents: contents,
59 },
60 {
61 name: "With embeddings but no contents",
62 ids: ids,
63 embeddings: embeddings,
64 metadatas: metadatas,
65 contents: nil,
66 },
67 }
68
69 for _, tc := range tt {
70 t.Run(tc.name, func(t *testing.T) {
71 err = c.Add(ctx, ids, nil, metadatas, contents)
72 if err != nil {
73 t.Fatal("expected nil, got", err)
74 }
75
76
77 if len(c.documents) != 2 {
78 t.Fatal("expected 2, got", len(c.documents))
79 }
80 for i, id := range ids {
81 doc, ok := c.documents[id]
82 if !ok {
83 t.Fatal("expected document, got nil")
84 }
85 if doc.ID != id {
86 t.Fatal("expected", id, "got", doc.ID)
87 }
88 if len(doc.Metadata) != 1 {
89 t.Fatal("expected 1, got", len(doc.Metadata))
90 }
91 if !slices.Equal(doc.Embedding, vectors) {
92 t.Fatal("expected", vectors, "got", doc.Embedding)
93 }
94 if doc.Content != contents[i] {
95 t.Fatal("expected", contents[i], "got", doc.Content)
96 }
97 }
98
99 if c.documents[ids[0]].Metadata["foo"] != "bar" {
100 t.Fatal("expected bar, got", c.documents[ids[0]].Metadata["foo"])
101 }
102 if c.documents[ids[1]].Metadata["a"] != "b" {
103 t.Fatal("expected b, got", c.documents[ids[1]].Metadata["a"])
104 }
105 })
106 }
107 }
108
109 func TestCollection_Add_Error(t *testing.T) {
110 ctx := context.Background()
111 name := "test"
112 metadata := map[string]string{"foo": "bar"}
113 vectors := []float32{-0.40824828, 0.40824828, 0.81649655}
114 embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
115 return vectors, nil
116 }
117
118
119 db := NewDB()
120 c, err := db.CreateCollection(name, metadata, embeddingFunc)
121 if err != nil {
122 t.Fatal("expected no error, got", err)
123 }
124 if c == nil {
125 t.Fatal("expected collection, got nil")
126 }
127
128
129 ids := []string{"1", "2"}
130 embeddings := [][]float32{vectors, vectors}
131 metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
132 contents := []string{"hello world", "hallo welt"}
133
134
135 err = c.Add(ctx, []string{}, embeddings, metadatas, contents)
136 if err == nil {
137 t.Fatal("expected error, got nil")
138 }
139
140 err = c.Add(ctx, ids, [][]float32{}, metadatas, []string{})
141 if err == nil {
142 t.Fatal("expected error, got nil")
143 }
144
145 err = c.Add(ctx, ids, [][]float32{vectors}, metadatas, contents)
146 if err == nil {
147 t.Fatal("expected error, got nil")
148 }
149
150 err = c.Add(ctx, ids, embeddings, []map[string]string{{"foo": "bar"}}, contents)
151 if err == nil {
152 t.Fatal("expected error, got nil")
153 }
154
155 err = c.Add(ctx, ids, embeddings, metadatas, []string{"hello world"})
156 if err == nil {
157 t.Fatal("expected error, got nil")
158 }
159 }
160
161 func TestCollection_AddConcurrently(t *testing.T) {
162 ctx := context.Background()
163 name := "test"
164 metadata := map[string]string{"foo": "bar"}
165 vectors := []float32{-0.40824828, 0.40824828, 0.81649655}
166 embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
167 return vectors, nil
168 }
169
170
171 db := NewDB()
172 c, err := db.CreateCollection(name, metadata, embeddingFunc)
173 if err != nil {
174 t.Fatal("expected no error, got", err)
175 }
176 if c == nil {
177 t.Fatal("expected collection, got nil")
178 }
179
180
181
182 ids := []string{"1", "2"}
183 embeddings := [][]float32{vectors, vectors}
184 metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
185 contents := []string{"hello world", "hallo welt"}
186
187 tt := []struct {
188 name string
189 ids []string
190 embeddings [][]float32
191 metadatas []map[string]string
192 contents []string
193 }{
194 {
195 name: "No embeddings",
196 ids: ids,
197 embeddings: nil,
198 metadatas: metadatas,
199 contents: contents,
200 },
201 {
202 name: "With embeddings",
203 ids: ids,
204 embeddings: embeddings,
205 metadatas: metadatas,
206 contents: contents,
207 },
208 {
209 name: "With embeddings but no contents",
210 ids: ids,
211 embeddings: embeddings,
212 metadatas: metadatas,
213 contents: nil,
214 },
215 }
216
217 for _, tc := range tt {
218 t.Run(tc.name, func(t *testing.T) {
219 err = c.AddConcurrently(ctx, ids, nil, metadatas, contents, 2)
220 if err != nil {
221 t.Fatal("expected nil, got", err)
222 }
223
224
225 if len(c.documents) != 2 {
226 t.Fatal("expected 2, got", len(c.documents))
227 }
228 for i, id := range ids {
229 doc, ok := c.documents[id]
230 if !ok {
231 t.Fatal("expected document, got nil")
232 }
233 if doc.ID != id {
234 t.Fatal("expected", id, "got", doc.ID)
235 }
236 if len(doc.Metadata) != 1 {
237 t.Fatal("expected 1, got", len(doc.Metadata))
238 }
239 if !slices.Equal(doc.Embedding, vectors) {
240 t.Fatal("expected", vectors, "got", doc.Embedding)
241 }
242 if doc.Content != contents[i] {
243 t.Fatal("expected", contents[i], "got", doc.Content)
244 }
245 }
246
247 if c.documents[ids[0]].Metadata["foo"] != "bar" {
248 t.Fatal("expected bar, got", c.documents[ids[0]].Metadata["foo"])
249 }
250 if c.documents[ids[1]].Metadata["a"] != "b" {
251 t.Fatal("expected b, got", c.documents[ids[1]].Metadata["a"])
252 }
253 })
254 }
255 }
256
257 func TestCollection_AddConcurrently_Error(t *testing.T) {
258 ctx := context.Background()
259 name := "test"
260 metadata := map[string]string{"foo": "bar"}
261 vectors := []float32{-0.40824828, 0.40824828, 0.81649655}
262 embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
263 return vectors, nil
264 }
265
266
267 db := NewDB()
268 c, err := db.CreateCollection(name, metadata, embeddingFunc)
269 if err != nil {
270 t.Fatal("expected no error, got", err)
271 }
272 if c == nil {
273 t.Fatal("expected collection, got nil")
274 }
275
276
277 ids := []string{"1", "2"}
278 embeddings := [][]float32{vectors, vectors}
279 metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
280 contents := []string{"hello world", "hallo welt"}
281
282 err = c.AddConcurrently(ctx, []string{}, embeddings, metadatas, contents, 2)
283 if err == nil {
284 t.Fatal("expected error, got nil")
285 }
286
287 err = c.AddConcurrently(ctx, ids, [][]float32{}, metadatas, []string{}, 2)
288 if err == nil {
289 t.Fatal("expected error, got nil")
290 }
291
292 err = c.AddConcurrently(ctx, ids, [][]float32{vectors}, metadatas, contents, 2)
293 if err == nil {
294 t.Fatal("expected error, got nil")
295 }
296
297 err = c.AddConcurrently(ctx, ids, embeddings, []map[string]string{{"foo": "bar"}}, contents, 2)
298 if err == nil {
299 t.Fatal("expected error, got nil")
300 }
301
302 err = c.AddConcurrently(ctx, ids, embeddings, metadatas, []string{"hello world"}, 2)
303 if err == nil {
304 t.Fatal("expected error, got nil")
305 }
306
307 err = c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 0)
308 if err == nil {
309 t.Fatal("expected error, got nil")
310 }
311 }
312
313 func TestCollection_QueryError(t *testing.T) {
314
315 db := NewDB()
316 name := "test"
317 metadata := map[string]string{"foo": "bar"}
318 vectors := []float32{-0.40824828, 0.40824828, 0.81649655}
319 embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
320 return vectors, nil
321 }
322 c, err := db.CreateCollection(name, metadata, embeddingFunc)
323 if err != nil {
324 t.Fatal("expected no error, got", err)
325 }
326 if c == nil {
327 t.Fatal("expected collection, got nil")
328 }
329
330 err = c.AddDocument(context.Background(), Document{ID: "1", Content: "hello world"})
331 if err != nil {
332 t.Fatal("expected nil, got", err)
333 }
334
335 tt := []struct {
336 name string
337 query func() error
338 expErr string
339 }{
340 {
341 name: "Empty query",
342 query: func() error {
343 _, err := c.Query(context.Background(), "", 1, nil, nil)
344 return err
345 },
346 expErr: "queryText is empty",
347 },
348 {
349 name: "Negative limit",
350 query: func() error {
351 _, err := c.Query(context.Background(), "foo", -1, nil, nil)
352 return err
353 },
354 expErr: "nResults must be > 0",
355 },
356 {
357 name: "Zero limit",
358 query: func() error {
359 _, err := c.Query(context.Background(), "foo", 0, nil, nil)
360 return err
361 },
362 expErr: "nResults must be > 0",
363 },
364 {
365 name: "Limit greater than number of documents",
366 query: func() error {
367 _, err := c.Query(context.Background(), "foo", 2, nil, nil)
368 return err
369 },
370 expErr: "nResults must be <= the number of documents in the collection",
371 },
372 {
373 name: "Bad content filter",
374 query: func() error {
375 _, err := c.Query(context.Background(), "foo", 1, nil, map[string]string{"invalid": "foo"})
376 return err
377 },
378 expErr: "unsupported operator",
379 },
380 }
381
382 for _, tc := range tt {
383 t.Run(tc.name, func(t *testing.T) {
384 err := tc.query()
385 if err == nil {
386 t.Fatal("expected error, got nil")
387 } else if err.Error() != tc.expErr {
388 t.Fatal("expected", tc.expErr, "got", err)
389 }
390 })
391 }
392 }
393
394 func TestCollection_Count(t *testing.T) {
395
396 db := NewDB()
397 name := "test"
398 metadata := map[string]string{"foo": "bar"}
399 vectors := []float32{-0.40824828, 0.40824828, 0.81649655}
400 embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
401 return vectors, nil
402 }
403 c, err := db.CreateCollection(name, metadata, embeddingFunc)
404 if err != nil {
405 t.Fatal("expected no error, got", err)
406 }
407 if c == nil {
408 t.Fatal("expected collection, got nil")
409 }
410
411
412 ids := []string{"1", "2"}
413 metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
414 contents := []string{"hello world", "hallo welt"}
415 err = c.Add(context.Background(), ids, nil, metadatas, contents)
416 if err != nil {
417 t.Fatal("expected nil, got", err)
418 }
419
420
421 if c.Count() != 2 {
422 t.Fatal("expected 2, got", c.Count())
423 }
424 }
425
426 func TestCollection_Delete(t *testing.T) {
427
428 tmpdir, err := os.MkdirTemp(os.TempDir(), "chromem-test-*")
429 if err != nil {
430 t.Fatal("expected no error, got", err)
431 }
432 db, err := NewPersistentDB(tmpdir, false)
433 if err != nil {
434 t.Fatal("expected no error, got", err)
435 }
436 name := "test"
437 metadata := map[string]string{"foo": "bar"}
438 vectors := []float32{-0.40824828, 0.40824828, 0.81649655}
439 embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
440 return vectors, nil
441 }
442 c, err := db.CreateCollection(name, metadata, embeddingFunc)
443 if err != nil {
444 t.Fatal("expected no error, got", err)
445 }
446 if c == nil {
447 t.Fatal("expected collection, got nil")
448 }
449
450
451 ids := []string{"1", "2", "3", "4"}
452 metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}, {"foo": "bar"}, {"e": "f"}}
453 contents := []string{"hello world", "hallo welt", "bonjour le monde", "hola mundo"}
454 err = c.Add(context.Background(), ids, nil, metadatas, contents)
455 if err != nil {
456 t.Fatal("expected nil, got", err)
457 }
458
459
460 if c.Count() != 4 {
461 t.Fatal("expected 4 documents, got", c.Count())
462 }
463
464
465 d, err := os.ReadDir(c.persistDirectory)
466 if err != nil {
467 t.Fatal("expected nil, got", err)
468 }
469 if len(d) != 5 {
470 t.Fatal("expected 4 document files + 1 metadata file in persist_dir, got", len(d))
471 }
472
473 checkCount := func(expected int) {
474
475 if c.Count() != expected {
476 t.Fatalf("expected %d documents, got %d", expected, c.Count())
477 }
478
479
480 d, err = os.ReadDir(c.persistDirectory)
481 if err != nil {
482 t.Fatal("expected nil, got", err)
483 }
484 if len(d) != expected+1 {
485 t.Fatalf("expected %d document files + 1 metadata file in persist_dir, got %d", expected, len(d))
486 }
487 }
488
489
490 err = c.Delete(context.Background(), nil, nil, "4")
491 if err != nil {
492 t.Fatal("expected nil, got", err)
493 }
494 checkCount(3)
495
496
497 err = c.Delete(context.Background(), map[string]string{"foo": "bar"}, nil)
498 if err != nil {
499 t.Fatal("expected nil, got", err)
500 }
501
502 checkCount(1)
503
504
505 err = c.Delete(context.Background(), nil, map[string]string{"$contains": "hallo welt"})
506 if err != nil {
507 t.Fatal("expected nil, got", err)
508 }
509
510 checkCount(0)
511 }
512
513
514 var globalRes []Result
515
516 func BenchmarkCollection_Query_NoContent_100(b *testing.B) {
517 benchmarkCollection_Query(b, 100, false)
518 }
519
520 func BenchmarkCollection_Query_NoContent_1000(b *testing.B) {
521 benchmarkCollection_Query(b, 1000, false)
522 }
523
524 func BenchmarkCollection_Query_NoContent_5000(b *testing.B) {
525 benchmarkCollection_Query(b, 5000, false)
526 }
527
528 func BenchmarkCollection_Query_NoContent_25000(b *testing.B) {
529 benchmarkCollection_Query(b, 25000, false)
530 }
531
532 func BenchmarkCollection_Query_NoContent_100000(b *testing.B) {
533 benchmarkCollection_Query(b, 100_000, false)
534 }
535
536 func BenchmarkCollection_Query_100(b *testing.B) {
537 benchmarkCollection_Query(b, 100, true)
538 }
539
540 func BenchmarkCollection_Query_1000(b *testing.B) {
541 benchmarkCollection_Query(b, 1000, true)
542 }
543
544 func BenchmarkCollection_Query_5000(b *testing.B) {
545 benchmarkCollection_Query(b, 5000, true)
546 }
547
548 func BenchmarkCollection_Query_25000(b *testing.B) {
549 benchmarkCollection_Query(b, 25000, true)
550 }
551
552 func BenchmarkCollection_Query_100000(b *testing.B) {
553 benchmarkCollection_Query(b, 100_000, true)
554 }
555
556
557 func benchmarkCollection_Query(b *testing.B, n int, withContent bool) {
558 ctx := context.Background()
559
560
561 r := rand.New(rand.NewSource(42))
562
563 d := 1536
564
565 qv := make([]float32, d)
566 for j := 0; j < d; j++ {
567 qv[j] = r.Float32()
568 }
569
570 qv = normalizeVector(qv)
571
572
573 db := NewDB()
574 name := "test"
575 embeddingFunc := func(_ context.Context, text string) ([]float32, error) {
576 return nil, errors.New("embedding func not expected to be called")
577 }
578 c, err := db.CreateCollection(name, nil, embeddingFunc)
579 if err != nil {
580 b.Fatal("expected no error, got", err)
581 }
582 if c == nil {
583 b.Fatal("expected collection, got nil")
584 }
585
586
587 for i := 0; i < n; i++ {
588
589 v := make([]float32, d)
590 for j := 0; j < d; j++ {
591 v[j] = r.Float32()
592 }
593 v = normalizeVector(v)
594
595
596
597 is := strconv.Itoa(i)
598 doc := Document{
599 ID: is,
600 Metadata: map[string]string{"i": is, "foo": "bar" + is},
601 Embedding: v,
602 }
603 if withContent {
604
605 doc.Content = randomString(r, 1875)
606 }
607
608 if err := c.AddDocument(ctx, doc); err != nil {
609 b.Fatal("expected nil, got", err)
610 }
611 }
612
613 b.ResetTimer()
614
615
616 var res []Result
617 for i := 0; i < b.N; i++ {
618 res, err = c.QueryEmbedding(ctx, qv, 10, nil, nil)
619 }
620 if err != nil {
621 b.Fatal("expected nil, got", err)
622 }
623 globalRes = res
624 }
625
626
627 func randomString(r *rand.Rand, n int) string {
628
629 characters := []rune("abcdefghijklmnopqrstuvwxyz ")
630
631 b := make([]rune, n)
632 for i := range b {
633 b[i] = characters[r.Intn(len(characters))]
634 }
635 return string(b)
636 }
637
View as plain text