1 package chromem
2
3 import (
4 "context"
5 "math/rand"
6 "os"
7 "path/filepath"
8 "reflect"
9 "slices"
10 "testing"
11 )
12
13 func TestNewPersistentDB(t *testing.T) {
14 t.Run("Create directory", func(t *testing.T) {
15 r := rand.New(rand.NewSource(rand.Int63()))
16 randString := randomString(r, 10)
17 path := filepath.Join(os.TempDir(), randString)
18 defer os.RemoveAll(path)
19
20
21 if _, err := os.Stat(path); !os.IsNotExist(err) {
22 t.Fatal("expected path to not exist, got", err)
23 }
24
25 db, err := NewPersistentDB(path, false)
26 if err != nil {
27 t.Fatal("expected no error, got", err)
28 }
29 if db == nil {
30 t.Fatal("expected DB, got nil")
31 }
32
33
34 if _, err := os.Stat(path); err != nil {
35 t.Fatal("expected path to exist, got", err)
36 }
37 })
38 t.Run("Existing directory", func(t *testing.T) {
39 path, err := os.MkdirTemp(os.TempDir(), "")
40 if err != nil {
41 t.Fatal("couldn't create temp dir:", err)
42 }
43 defer os.RemoveAll(path)
44
45 db, err := NewPersistentDB(path, false)
46 if err != nil {
47 t.Fatal("expected no error, got", err)
48 }
49 if db == nil {
50 t.Fatal("expected DB, got nil")
51 }
52 })
53 }
54
55 func TestNewPersistentDB_Errors(t *testing.T) {
56 t.Run("Path is an existing file", func(t *testing.T) {
57 f, err := os.CreateTemp(os.TempDir(), "")
58 if err != nil {
59 t.Fatal("couldn't create temp file:", err)
60 }
61 defer os.RemoveAll(f.Name())
62
63 _, err = NewPersistentDB(f.Name(), false)
64 if err == nil {
65 t.Fatal("expected error, got nil")
66 }
67 })
68 }
69
70 func TestDB_ImportExport(t *testing.T) {
71 r := rand.New(rand.NewSource(rand.Int63()))
72 randString := randomString(r, 10)
73 path := filepath.Join(os.TempDir(), randString)
74 defer os.RemoveAll(path)
75
76
77 name := "test"
78 metadata := map[string]string{"foo": "bar"}
79 vectors := []float32{-0.40824828, 0.40824828, 0.81649655}
80 embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
81 return vectors, nil
82 }
83
84 tt := []struct {
85 name string
86 filePath string
87 compress bool
88 encryptionKey string
89 }{
90 {
91 name: "gob",
92 filePath: path + ".gob",
93 compress: false,
94 encryptionKey: "",
95 },
96 {
97 name: "gob compressed",
98 filePath: path + ".gob.gz",
99 compress: true,
100 encryptionKey: "",
101 },
102 {
103 name: "gob compressed encrypted",
104 filePath: path + ".gob.gz.enc",
105 compress: true,
106 encryptionKey: randomString(r, 32),
107 },
108 {
109 name: "gob encrypted",
110 filePath: path + ".gob.enc",
111 compress: false,
112 encryptionKey: randomString(r, 32),
113 },
114 }
115
116 for _, tc := range tt {
117 t.Run(tc.name, func(t *testing.T) {
118
119 orig := NewDB()
120
121
122 c, err := orig.CreateCollection(name, metadata, embeddingFunc)
123 if err != nil {
124 t.Fatal("expected no error, got", err)
125 }
126 if c == nil {
127 t.Fatal("expected collection, got nil")
128 }
129
130 doc := Document{
131 ID: name,
132 Metadata: metadata,
133 Embedding: vectors,
134 Content: "test",
135 }
136 err = c.AddDocument(context.Background(), doc)
137 if err != nil {
138 t.Fatal("expected no error, got", err)
139 }
140
141
142 err = orig.ExportToFile(tc.filePath, tc.compress, tc.encryptionKey)
143 if err != nil {
144 t.Fatal("expected no error, got", err)
145 }
146
147 new := NewDB()
148
149
150 err = new.ImportFromFile(tc.filePath, tc.encryptionKey)
151 if err != nil {
152 t.Fatal("expected no error, got", err)
153 }
154
155
156
157
158 c.embed = nil
159 if !reflect.DeepEqual(orig, new) {
160 t.Fatalf("expected DB %+v, got %+v", orig, new)
161 }
162 })
163 }
164 }
165
166 func TestDB_CreateCollection(t *testing.T) {
167
168 name := "test"
169 metadata := map[string]string{"foo": "bar"}
170 vectors := []float32{-0.40824828, 0.40824828, 0.81649655}
171 embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
172 return vectors, nil
173 }
174
175 db := NewDB()
176
177 t.Run("OK", func(t *testing.T) {
178 c, err := db.CreateCollection(name, metadata, embeddingFunc)
179 if err != nil {
180 t.Fatal("expected no error, got", err)
181 }
182 if c == nil {
183 t.Fatal("expected collection, got nil")
184 }
185
186
187
188
189 if len(db.collections) != 1 {
190 t.Fatal("expected 1 collection, got", len(db.collections))
191 }
192
193 c2, ok := db.collections[name]
194 if !ok {
195 t.Fatal("expected collection", name, "not found")
196 }
197
198 gotVectors, err := c.embed(context.Background(), "test")
199 if err != nil {
200 t.Fatal("expected no error, got", err)
201 }
202 if !slices.Equal(gotVectors, vectors) {
203 t.Fatal("expected vectors", vectors, "got", gotVectors)
204 }
205 c.embed, c2.embed = nil, nil
206 if !reflect.DeepEqual(c, c2) {
207 t.Fatalf("expected collection %+v, got %+v", c, c2)
208 }
209 })
210
211 t.Run("NOK - Empty name", func(t *testing.T) {
212 _, err := db.CreateCollection("", metadata, embeddingFunc)
213 if err == nil {
214 t.Fatal("expected error, got nil")
215 }
216 })
217 }
218
219 func TestDB_ListCollections(t *testing.T) {
220
221 name := "test"
222 metadata := map[string]string{"foo": "bar"}
223 vectors := []float32{-0.40824828, 0.40824828, 0.81649655}
224 embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
225 return vectors, nil
226 }
227
228
229 db := NewDB()
230 orig, err := db.CreateCollection(name, metadata, embeddingFunc)
231 if err != nil {
232 t.Fatal("expected no error, got", err)
233 }
234
235
236 res := db.ListCollections()
237
238
239
240
241 if len(res) != 1 {
242 t.Fatal("expected 1 collection, got", len(res))
243 }
244
245 c, ok := res[name]
246 if !ok {
247 t.Fatal("expected collection", name, "not found")
248 }
249
250 gotVectors, err := c.embed(context.Background(), "test")
251 if err != nil {
252 t.Fatal("expected no error, got", err)
253 }
254 if !slices.Equal(gotVectors, vectors) {
255 t.Fatal("expected vectors", vectors, "got", gotVectors)
256 }
257 orig.embed, c.embed = nil, nil
258 if !reflect.DeepEqual(orig, c) {
259 t.Fatalf("expected collection %+v, got %+v", orig, c)
260 }
261
262
263
264 res["foo"] = &Collection{}
265 if len(db.ListCollections()) != 1 {
266 t.Fatal("expected 1 collection, got", len(db.ListCollections()))
267 }
268 }
269
270 func TestDB_GetCollection(t *testing.T) {
271
272 name := "test"
273 metadata := map[string]string{"foo": "bar"}
274 vectors := []float32{-0.40824828, 0.40824828, 0.81649655}
275 embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
276 return vectors, nil
277 }
278
279
280 db := NewDB()
281 orig, err := db.CreateCollection(name, metadata, embeddingFunc)
282 if err != nil {
283 t.Fatal("expected no error, got", err)
284 }
285
286
287 c := db.GetCollection(name, nil)
288
289
290 gotVectors, err := c.embed(context.Background(), "test")
291 if err != nil {
292 t.Fatal("expected no error, got", err)
293 }
294 if !slices.Equal(gotVectors, vectors) {
295 t.Fatal("expected vectors", vectors, "got", gotVectors)
296 }
297 orig.embed, c.embed = nil, nil
298 if !reflect.DeepEqual(orig, c) {
299 t.Fatalf("expected collection %+v, got %+v", orig, c)
300 }
301 }
302
303 func TestDB_GetOrCreateCollection(t *testing.T) {
304
305 name := "test"
306 metadata := map[string]string{"foo": "bar"}
307 vectors := []float32{-0.40824828, 0.40824828, 0.81649655}
308 embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
309 return vectors, nil
310 }
311
312 t.Run("Get", func(t *testing.T) {
313
314 db := NewDB()
315
316
317 orig, err := db.CreateCollection(name, metadata, embeddingFunc)
318 if err != nil {
319 t.Fatal("expected no error, got", err)
320 }
321
322
323
324
325 c, err := db.GetOrCreateCollection(name, nil, nil)
326 if err != nil {
327 t.Fatal("expected no error, got", err)
328 }
329 if c == nil {
330 t.Fatal("expected collection, got nil")
331 }
332
333
334 gotVectors, err := c.embed(context.Background(), "test")
335 if err != nil {
336 t.Fatal("expected no error, got", err)
337 }
338 if !slices.Equal(gotVectors, vectors) {
339 t.Fatal("expected vectors", vectors, "got", gotVectors)
340 }
341 orig.embed, c.embed = nil, nil
342 if !reflect.DeepEqual(orig, c) {
343 t.Fatalf("expected collection %+v, got %+v", orig, c)
344 }
345 })
346
347 t.Run("Create", func(t *testing.T) {
348
349 db := NewDB()
350
351
352 c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc)
353 if err != nil {
354 t.Fatal("expected no error, got", err)
355 }
356 if c == nil {
357 t.Fatal("expected collection, got nil")
358 }
359
360
361 c2, ok := db.collections[name]
362 if !ok {
363 t.Fatal("expected collection", name, "not found")
364 }
365 gotVectors, err := c.embed(context.Background(), "test")
366 if err != nil {
367 t.Fatal("expected no error, got", err)
368 }
369 if !slices.Equal(gotVectors, vectors) {
370 t.Fatal("expected vectors", vectors, "got", gotVectors)
371 }
372 c.embed, c2.embed = nil, nil
373 if !reflect.DeepEqual(c, c2) {
374 t.Fatalf("expected collection %+v, got %+v", c, c2)
375 }
376 })
377 }
378
379 func TestDB_DeleteCollection(t *testing.T) {
380
381 name := "test"
382 metadata := map[string]string{"foo": "bar"}
383 vectors := []float32{-0.40824828, 0.40824828, 0.81649655}
384 embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
385 return vectors, nil
386 }
387
388
389 db := NewDB()
390
391 _, err := db.CreateCollection(name, metadata, embeddingFunc)
392 if err != nil {
393 t.Fatal("expected no error, got", err)
394 }
395
396
397 if err := db.DeleteCollection(name); err != nil {
398 t.Fatal("expected no error, got", err)
399 }
400
401
402
403
404 if len(db.ListCollections()) != 0 {
405 t.Fatal("expected 0 collections, got", len(db.ListCollections()))
406 }
407
408 if len(db.collections) != 0 {
409 t.Fatal("expected 0 collections, got", len(db.collections))
410 }
411 }
412
413 func TestDB_Reset(t *testing.T) {
414
415 name := "test"
416 metadata := map[string]string{"foo": "bar"}
417 vectors := []float32{-0.40824828, 0.40824828, 0.81649655}
418 embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
419 return vectors, nil
420 }
421
422
423 db := NewDB()
424
425 _, err := db.CreateCollection(name, metadata, embeddingFunc)
426 if err != nil {
427 t.Fatal("expected no error, got", err)
428 }
429
430
431 if err := db.Reset(); err != nil {
432 t.Fatal("expected no error, got", err)
433 }
434
435
436
437
438 if len(db.ListCollections()) != 0 {
439 t.Fatal("expected 0 collections, got", len(db.ListCollections()))
440 }
441
442 if len(db.collections) != 0 {
443 t.Fatal("expected 0 collections, got", len(db.collections))
444 }
445 }
446
View as plain text