...

Source file src/gitlab.hexacode.org/go-libs/chromem-go/db_test.go

Documentation: gitlab.hexacode.org/go-libs/chromem-go

     1  package chromem
     2  
     3  import (
     4  	"context"
     5  	"math/rand"
     6  	"os"
     7  	"path/filepath"
     8  	"reflect"
     9  	"slices"
    10  	"testing"
    11  )
    12  
    13  func TestNewPersistentDB(t *testing.T) {
    14  	t.Run("Create directory", func(t *testing.T) {
    15  		r := rand.New(rand.NewSource(rand.Int63()))
    16  		randString := randomString(r, 10)
    17  		path := filepath.Join(os.TempDir(), randString)
    18  		defer os.RemoveAll(path)
    19  
    20  		// Path shouldn't exist yet
    21  		if _, err := os.Stat(path); !os.IsNotExist(err) {
    22  			t.Fatal("expected path to not exist, got", err)
    23  		}
    24  
    25  		db, err := NewPersistentDB(path, false)
    26  		if err != nil {
    27  			t.Fatal("expected no error, got", err)
    28  		}
    29  		if db == nil {
    30  			t.Fatal("expected DB, got nil")
    31  		}
    32  
    33  		// Path should exist now
    34  		if _, err := os.Stat(path); err != nil {
    35  			t.Fatal("expected path to exist, got", err)
    36  		}
    37  	})
    38  	t.Run("Existing directory", func(t *testing.T) {
    39  		path, err := os.MkdirTemp(os.TempDir(), "")
    40  		if err != nil {
    41  			t.Fatal("couldn't create temp dir:", err)
    42  		}
    43  		defer os.RemoveAll(path)
    44  
    45  		db, err := NewPersistentDB(path, false)
    46  		if err != nil {
    47  			t.Fatal("expected no error, got", err)
    48  		}
    49  		if db == nil {
    50  			t.Fatal("expected DB, got nil")
    51  		}
    52  	})
    53  }
    54  
    55  func TestNewPersistentDB_Errors(t *testing.T) {
    56  	t.Run("Path is an existing file", func(t *testing.T) {
    57  		f, err := os.CreateTemp(os.TempDir(), "")
    58  		if err != nil {
    59  			t.Fatal("couldn't create temp file:", err)
    60  		}
    61  		defer os.RemoveAll(f.Name())
    62  
    63  		_, err = NewPersistentDB(f.Name(), false)
    64  		if err == nil {
    65  			t.Fatal("expected error, got nil")
    66  		}
    67  	})
    68  }
    69  
    70  func TestDB_ImportExport(t *testing.T) {
    71  	r := rand.New(rand.NewSource(rand.Int63()))
    72  	randString := randomString(r, 10)
    73  	path := filepath.Join(os.TempDir(), randString)
    74  	defer os.RemoveAll(path)
    75  
    76  	// Values in the collection
    77  	name := "test"
    78  	metadata := map[string]string{"foo": "bar"}
    79  	vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
    80  	embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
    81  		return vectors, nil
    82  	}
    83  
    84  	tt := []struct {
    85  		name          string
    86  		filePath      string
    87  		compress      bool
    88  		encryptionKey string
    89  	}{
    90  		{
    91  			name:          "gob",
    92  			filePath:      path + ".gob",
    93  			compress:      false,
    94  			encryptionKey: "",
    95  		},
    96  		{
    97  			name:          "gob compressed",
    98  			filePath:      path + ".gob.gz",
    99  			compress:      true,
   100  			encryptionKey: "",
   101  		},
   102  		{
   103  			name:          "gob compressed encrypted",
   104  			filePath:      path + ".gob.gz.enc",
   105  			compress:      true,
   106  			encryptionKey: randomString(r, 32),
   107  		},
   108  		{
   109  			name:          "gob encrypted",
   110  			filePath:      path + ".gob.enc",
   111  			compress:      false,
   112  			encryptionKey: randomString(r, 32),
   113  		},
   114  	}
   115  
   116  	for _, tc := range tt {
   117  		t.Run(tc.name, func(t *testing.T) {
   118  			// Create DB, can just be in-memory
   119  			orig := NewDB()
   120  
   121  			// Create collection
   122  			c, err := orig.CreateCollection(name, metadata, embeddingFunc)
   123  			if err != nil {
   124  				t.Fatal("expected no error, got", err)
   125  			}
   126  			if c == nil {
   127  				t.Fatal("expected collection, got nil")
   128  			}
   129  			// Add document
   130  			doc := Document{
   131  				ID:        name,
   132  				Metadata:  metadata,
   133  				Embedding: vectors,
   134  				Content:   "test",
   135  			}
   136  			err = c.AddDocument(context.Background(), doc)
   137  			if err != nil {
   138  				t.Fatal("expected no error, got", err)
   139  			}
   140  
   141  			// Export
   142  			err = orig.ExportToFile(tc.filePath, tc.compress, tc.encryptionKey)
   143  			if err != nil {
   144  				t.Fatal("expected no error, got", err)
   145  			}
   146  
   147  			new := NewDB()
   148  
   149  			// Import
   150  			err = new.ImportFromFile(tc.filePath, tc.encryptionKey)
   151  			if err != nil {
   152  				t.Fatal("expected no error, got", err)
   153  			}
   154  
   155  			// Check expectations
   156  			// We have to reset the embed function, but otherwise the DB objects
   157  			// should be deep equal.
   158  			c.embed = nil
   159  			if !reflect.DeepEqual(orig, new) {
   160  				t.Fatalf("expected DB %+v, got %+v", orig, new)
   161  			}
   162  		})
   163  	}
   164  }
   165  
   166  func TestDB_CreateCollection(t *testing.T) {
   167  	// Values in the collection
   168  	name := "test"
   169  	metadata := map[string]string{"foo": "bar"}
   170  	vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
   171  	embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
   172  		return vectors, nil
   173  	}
   174  
   175  	db := NewDB()
   176  
   177  	t.Run("OK", func(t *testing.T) {
   178  		c, err := db.CreateCollection(name, metadata, embeddingFunc)
   179  		if err != nil {
   180  			t.Fatal("expected no error, got", err)
   181  		}
   182  		if c == nil {
   183  			t.Fatal("expected collection, got nil")
   184  		}
   185  
   186  		// Check expectations
   187  
   188  		// DB should have one collection now
   189  		if len(db.collections) != 1 {
   190  			t.Fatal("expected 1 collection, got", len(db.collections))
   191  		}
   192  		// The collection should be the one we just created
   193  		c2, ok := db.collections[name]
   194  		if !ok {
   195  			t.Fatal("expected collection", name, "not found")
   196  		}
   197  		// Check the embedding function first, then the rest with DeepEqual
   198  		gotVectors, err := c.embed(context.Background(), "test")
   199  		if err != nil {
   200  			t.Fatal("expected no error, got", err)
   201  		}
   202  		if !slices.Equal(gotVectors, vectors) {
   203  			t.Fatal("expected vectors", vectors, "got", gotVectors)
   204  		}
   205  		c.embed, c2.embed = nil, nil
   206  		if !reflect.DeepEqual(c, c2) {
   207  			t.Fatalf("expected collection %+v, got %+v", c, c2)
   208  		}
   209  	})
   210  
   211  	t.Run("NOK - Empty name", func(t *testing.T) {
   212  		_, err := db.CreateCollection("", metadata, embeddingFunc)
   213  		if err == nil {
   214  			t.Fatal("expected error, got nil")
   215  		}
   216  	})
   217  }
   218  
   219  func TestDB_ListCollections(t *testing.T) {
   220  	// Values in the collection
   221  	name := "test"
   222  	metadata := map[string]string{"foo": "bar"}
   223  	vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
   224  	embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
   225  		return vectors, nil
   226  	}
   227  
   228  	// Create initial collection
   229  	db := NewDB()
   230  	orig, err := db.CreateCollection(name, metadata, embeddingFunc)
   231  	if err != nil {
   232  		t.Fatal("expected no error, got", err)
   233  	}
   234  
   235  	// List collections
   236  	res := db.ListCollections()
   237  
   238  	// Check expectations
   239  
   240  	// Should've returned a map with one collection
   241  	if len(res) != 1 {
   242  		t.Fatal("expected 1 collection, got", len(res))
   243  	}
   244  	// The collection should be the one we just created
   245  	c, ok := res[name]
   246  	if !ok {
   247  		t.Fatal("expected collection", name, "not found")
   248  	}
   249  	// Check the embedding function first, then the rest with DeepEqual
   250  	gotVectors, err := c.embed(context.Background(), "test")
   251  	if err != nil {
   252  		t.Fatal("expected no error, got", err)
   253  	}
   254  	if !slices.Equal(gotVectors, vectors) {
   255  		t.Fatal("expected vectors", vectors, "got", gotVectors)
   256  	}
   257  	orig.embed, c.embed = nil, nil
   258  	if !reflect.DeepEqual(orig, c) {
   259  		t.Fatalf("expected collection %+v, got %+v", orig, c)
   260  	}
   261  
   262  	// And it should be a copy. Adding a value here should not reflect on the DB's
   263  	// collection.
   264  	res["foo"] = &Collection{}
   265  	if len(db.ListCollections()) != 1 {
   266  		t.Fatal("expected 1 collection, got", len(db.ListCollections()))
   267  	}
   268  }
   269  
   270  func TestDB_GetCollection(t *testing.T) {
   271  	// Values in the collection
   272  	name := "test"
   273  	metadata := map[string]string{"foo": "bar"}
   274  	vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
   275  	embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
   276  		return vectors, nil
   277  	}
   278  
   279  	// Create initial collection
   280  	db := NewDB()
   281  	orig, err := db.CreateCollection(name, metadata, embeddingFunc)
   282  	if err != nil {
   283  		t.Fatal("expected no error, got", err)
   284  	}
   285  
   286  	// Get collection
   287  	c := db.GetCollection(name, nil)
   288  
   289  	// Check the embedding function first, then the rest with DeepEqual
   290  	gotVectors, err := c.embed(context.Background(), "test")
   291  	if err != nil {
   292  		t.Fatal("expected no error, got", err)
   293  	}
   294  	if !slices.Equal(gotVectors, vectors) {
   295  		t.Fatal("expected vectors", vectors, "got", gotVectors)
   296  	}
   297  	orig.embed, c.embed = nil, nil
   298  	if !reflect.DeepEqual(orig, c) {
   299  		t.Fatalf("expected collection %+v, got %+v", orig, c)
   300  	}
   301  }
   302  
   303  func TestDB_GetOrCreateCollection(t *testing.T) {
   304  	// Values in the collection
   305  	name := "test"
   306  	metadata := map[string]string{"foo": "bar"}
   307  	vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
   308  	embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
   309  		return vectors, nil
   310  	}
   311  
   312  	t.Run("Get", func(t *testing.T) {
   313  		// Create initial collection
   314  		db := NewDB()
   315  		// Create collection so that the GetOrCreateCollection() call below only
   316  		// gets it.
   317  		orig, err := db.CreateCollection(name, metadata, embeddingFunc)
   318  		if err != nil {
   319  			t.Fatal("expected no error, got", err)
   320  		}
   321  
   322  		// Call GetOrCreateCollection() with the same name to only get it. We pass
   323  		// nil for the metadata and embeddingFunc, so we can check that the returned
   324  		// collection is the original one, and not a new one.
   325  		c, err := db.GetOrCreateCollection(name, nil, nil)
   326  		if err != nil {
   327  			t.Fatal("expected no error, got", err)
   328  		}
   329  		if c == nil {
   330  			t.Fatal("expected collection, got nil")
   331  		}
   332  
   333  		// Check the embedding function first, then the rest with DeepEqual
   334  		gotVectors, err := c.embed(context.Background(), "test")
   335  		if err != nil {
   336  			t.Fatal("expected no error, got", err)
   337  		}
   338  		if !slices.Equal(gotVectors, vectors) {
   339  			t.Fatal("expected vectors", vectors, "got", gotVectors)
   340  		}
   341  		orig.embed, c.embed = nil, nil
   342  		if !reflect.DeepEqual(orig, c) {
   343  			t.Fatalf("expected collection %+v, got %+v", orig, c)
   344  		}
   345  	})
   346  
   347  	t.Run("Create", func(t *testing.T) {
   348  		// Create initial collection
   349  		db := NewDB()
   350  
   351  		// Call GetOrCreateCollection()
   352  		c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc)
   353  		if err != nil {
   354  			t.Fatal("expected no error, got", err)
   355  		}
   356  		if c == nil {
   357  			t.Fatal("expected collection, got nil")
   358  		}
   359  
   360  		// Check like we check CreateCollection()
   361  		c2, ok := db.collections[name]
   362  		if !ok {
   363  			t.Fatal("expected collection", name, "not found")
   364  		}
   365  		gotVectors, err := c.embed(context.Background(), "test")
   366  		if err != nil {
   367  			t.Fatal("expected no error, got", err)
   368  		}
   369  		if !slices.Equal(gotVectors, vectors) {
   370  			t.Fatal("expected vectors", vectors, "got", gotVectors)
   371  		}
   372  		c.embed, c2.embed = nil, nil
   373  		if !reflect.DeepEqual(c, c2) {
   374  			t.Fatalf("expected collection %+v, got %+v", c, c2)
   375  		}
   376  	})
   377  }
   378  
   379  func TestDB_DeleteCollection(t *testing.T) {
   380  	// Values in the collection
   381  	name := "test"
   382  	metadata := map[string]string{"foo": "bar"}
   383  	vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
   384  	embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
   385  		return vectors, nil
   386  	}
   387  
   388  	// Create initial collection
   389  	db := NewDB()
   390  	// We ignore the return value. CreateCollection is tested elsewhere.
   391  	_, err := db.CreateCollection(name, metadata, embeddingFunc)
   392  	if err != nil {
   393  		t.Fatal("expected no error, got", err)
   394  	}
   395  
   396  	// Delete collection
   397  	if err := db.DeleteCollection(name); err != nil {
   398  		t.Fatal("expected no error, got", err)
   399  	}
   400  
   401  	// Check expectations
   402  	// We don't have access to the documents field, but we can rely on DB.ListCollections()
   403  	// because it's tested elsewhere.
   404  	if len(db.ListCollections()) != 0 {
   405  		t.Fatal("expected 0 collections, got", len(db.ListCollections()))
   406  	}
   407  	// Also check internally
   408  	if len(db.collections) != 0 {
   409  		t.Fatal("expected 0 collections, got", len(db.collections))
   410  	}
   411  }
   412  
   413  func TestDB_Reset(t *testing.T) {
   414  	// Values in the collection
   415  	name := "test"
   416  	metadata := map[string]string{"foo": "bar"}
   417  	vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
   418  	embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
   419  		return vectors, nil
   420  	}
   421  
   422  	// Create initial collection
   423  	db := NewDB()
   424  	// We ignore the return value. CreateCollection is tested elsewhere.
   425  	_, err := db.CreateCollection(name, metadata, embeddingFunc)
   426  	if err != nil {
   427  		t.Fatal("expected no error, got", err)
   428  	}
   429  
   430  	// Reset DB
   431  	if err := db.Reset(); err != nil {
   432  		t.Fatal("expected no error, got", err)
   433  	}
   434  
   435  	// Check expectations
   436  	// We don't have access to the documents field, but we can rely on DB.ListCollections()
   437  	// because it's tested elsewhere.
   438  	if len(db.ListCollections()) != 0 {
   439  		t.Fatal("expected 0 collections, got", len(db.ListCollections()))
   440  	}
   441  	// Also check internally
   442  	if len(db.collections) != 0 {
   443  		t.Fatal("expected 0 collections, got", len(db.collections))
   444  	}
   445  }
   446  

View as plain text