...

Source file src/gitlab.hexacode.org/go-libs/chromem-go/collection_test.go

Documentation: gitlab.hexacode.org/go-libs/chromem-go

     1  package chromem
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"math/rand"
     7  	"os"
     8  	"slices"
     9  	"strconv"
    10  	"testing"
    11  )
    12  
    13  func TestCollection_Add(t *testing.T) {
    14  	ctx := context.Background()
    15  	name := "test"
    16  	metadata := map[string]string{"foo": "bar"}
    17  	vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
    18  	embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
    19  		return vectors, nil
    20  	}
    21  
    22  	// Create collection
    23  	db := NewDB()
    24  	c, err := db.CreateCollection(name, metadata, embeddingFunc)
    25  	if err != nil {
    26  		t.Fatal("expected no error, got", err)
    27  	}
    28  	if c == nil {
    29  		t.Fatal("expected collection, got nil")
    30  	}
    31  
    32  	// Add documents
    33  
    34  	ids := []string{"1", "2"}
    35  	embeddings := [][]float32{vectors, vectors}
    36  	metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
    37  	contents := []string{"hello world", "hallo welt"}
    38  
    39  	tt := []struct {
    40  		name       string
    41  		ids        []string
    42  		embeddings [][]float32
    43  		metadatas  []map[string]string
    44  		contents   []string
    45  	}{
    46  		{
    47  			name:       "No embeddings",
    48  			ids:        ids,
    49  			embeddings: nil,
    50  			metadatas:  metadatas,
    51  			contents:   contents,
    52  		},
    53  		{
    54  			name:       "With embeddings",
    55  			ids:        ids,
    56  			embeddings: embeddings,
    57  			metadatas:  metadatas,
    58  			contents:   contents,
    59  		},
    60  		{
    61  			name:       "With embeddings but no contents",
    62  			ids:        ids,
    63  			embeddings: embeddings,
    64  			metadatas:  metadatas,
    65  			contents:   nil,
    66  		},
    67  	}
    68  
    69  	for _, tc := range tt {
    70  		t.Run(tc.name, func(t *testing.T) {
    71  			err = c.Add(ctx, ids, nil, metadatas, contents)
    72  			if err != nil {
    73  				t.Fatal("expected nil, got", err)
    74  			}
    75  
    76  			// Check documents
    77  			if len(c.documents) != 2 {
    78  				t.Fatal("expected 2, got", len(c.documents))
    79  			}
    80  			for i, id := range ids {
    81  				doc, ok := c.documents[id]
    82  				if !ok {
    83  					t.Fatal("expected document, got nil")
    84  				}
    85  				if doc.ID != id {
    86  					t.Fatal("expected", id, "got", doc.ID)
    87  				}
    88  				if len(doc.Metadata) != 1 {
    89  					t.Fatal("expected 1, got", len(doc.Metadata))
    90  				}
    91  				if !slices.Equal(doc.Embedding, vectors) {
    92  					t.Fatal("expected", vectors, "got", doc.Embedding)
    93  				}
    94  				if doc.Content != contents[i] {
    95  					t.Fatal("expected", contents[i], "got", doc.Content)
    96  				}
    97  			}
    98  			// Metadata can't be accessed with the loop's i
    99  			if c.documents[ids[0]].Metadata["foo"] != "bar" {
   100  				t.Fatal("expected bar, got", c.documents[ids[0]].Metadata["foo"])
   101  			}
   102  			if c.documents[ids[1]].Metadata["a"] != "b" {
   103  				t.Fatal("expected b, got", c.documents[ids[1]].Metadata["a"])
   104  			}
   105  		})
   106  	}
   107  }
   108  
   109  func TestCollection_Add_Error(t *testing.T) {
   110  	ctx := context.Background()
   111  	name := "test"
   112  	metadata := map[string]string{"foo": "bar"}
   113  	vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
   114  	embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
   115  		return vectors, nil
   116  	}
   117  
   118  	// Create collection
   119  	db := NewDB()
   120  	c, err := db.CreateCollection(name, metadata, embeddingFunc)
   121  	if err != nil {
   122  		t.Fatal("expected no error, got", err)
   123  	}
   124  	if c == nil {
   125  		t.Fatal("expected collection, got nil")
   126  	}
   127  
   128  	// Add documents, provoking errors
   129  	ids := []string{"1", "2"}
   130  	embeddings := [][]float32{vectors, vectors}
   131  	metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
   132  	contents := []string{"hello world", "hallo welt"}
   133  
   134  	// Empty IDs
   135  	err = c.Add(ctx, []string{}, embeddings, metadatas, contents)
   136  	if err == nil {
   137  		t.Fatal("expected error, got nil")
   138  	}
   139  	// Empty embeddings and contents (both at the same time!)
   140  	err = c.Add(ctx, ids, [][]float32{}, metadatas, []string{})
   141  	if err == nil {
   142  		t.Fatal("expected error, got nil")
   143  	}
   144  	// Bad embeddings length
   145  	err = c.Add(ctx, ids, [][]float32{vectors}, metadatas, contents)
   146  	if err == nil {
   147  		t.Fatal("expected error, got nil")
   148  	}
   149  	// Bad metadatas length
   150  	err = c.Add(ctx, ids, embeddings, []map[string]string{{"foo": "bar"}}, contents)
   151  	if err == nil {
   152  		t.Fatal("expected error, got nil")
   153  	}
   154  	// Bad contents length
   155  	err = c.Add(ctx, ids, embeddings, metadatas, []string{"hello world"})
   156  	if err == nil {
   157  		t.Fatal("expected error, got nil")
   158  	}
   159  }
   160  
   161  func TestCollection_AddConcurrently(t *testing.T) {
   162  	ctx := context.Background()
   163  	name := "test"
   164  	metadata := map[string]string{"foo": "bar"}
   165  	vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
   166  	embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
   167  		return vectors, nil
   168  	}
   169  
   170  	// Create collection
   171  	db := NewDB()
   172  	c, err := db.CreateCollection(name, metadata, embeddingFunc)
   173  	if err != nil {
   174  		t.Fatal("expected no error, got", err)
   175  	}
   176  	if c == nil {
   177  		t.Fatal("expected collection, got nil")
   178  	}
   179  
   180  	// Add documents
   181  
   182  	ids := []string{"1", "2"}
   183  	embeddings := [][]float32{vectors, vectors}
   184  	metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
   185  	contents := []string{"hello world", "hallo welt"}
   186  
   187  	tt := []struct {
   188  		name       string
   189  		ids        []string
   190  		embeddings [][]float32
   191  		metadatas  []map[string]string
   192  		contents   []string
   193  	}{
   194  		{
   195  			name:       "No embeddings",
   196  			ids:        ids,
   197  			embeddings: nil,
   198  			metadatas:  metadatas,
   199  			contents:   contents,
   200  		},
   201  		{
   202  			name:       "With embeddings",
   203  			ids:        ids,
   204  			embeddings: embeddings,
   205  			metadatas:  metadatas,
   206  			contents:   contents,
   207  		},
   208  		{
   209  			name:       "With embeddings but no contents",
   210  			ids:        ids,
   211  			embeddings: embeddings,
   212  			metadatas:  metadatas,
   213  			contents:   nil,
   214  		},
   215  	}
   216  
   217  	for _, tc := range tt {
   218  		t.Run(tc.name, func(t *testing.T) {
   219  			err = c.AddConcurrently(ctx, ids, nil, metadatas, contents, 2)
   220  			if err != nil {
   221  				t.Fatal("expected nil, got", err)
   222  			}
   223  
   224  			// Check documents
   225  			if len(c.documents) != 2 {
   226  				t.Fatal("expected 2, got", len(c.documents))
   227  			}
   228  			for i, id := range ids {
   229  				doc, ok := c.documents[id]
   230  				if !ok {
   231  					t.Fatal("expected document, got nil")
   232  				}
   233  				if doc.ID != id {
   234  					t.Fatal("expected", id, "got", doc.ID)
   235  				}
   236  				if len(doc.Metadata) != 1 {
   237  					t.Fatal("expected 1, got", len(doc.Metadata))
   238  				}
   239  				if !slices.Equal(doc.Embedding, vectors) {
   240  					t.Fatal("expected", vectors, "got", doc.Embedding)
   241  				}
   242  				if doc.Content != contents[i] {
   243  					t.Fatal("expected", contents[i], "got", doc.Content)
   244  				}
   245  			}
   246  			// Metadata can't be accessed with the loop's i
   247  			if c.documents[ids[0]].Metadata["foo"] != "bar" {
   248  				t.Fatal("expected bar, got", c.documents[ids[0]].Metadata["foo"])
   249  			}
   250  			if c.documents[ids[1]].Metadata["a"] != "b" {
   251  				t.Fatal("expected b, got", c.documents[ids[1]].Metadata["a"])
   252  			}
   253  		})
   254  	}
   255  }
   256  
   257  func TestCollection_AddConcurrently_Error(t *testing.T) {
   258  	ctx := context.Background()
   259  	name := "test"
   260  	metadata := map[string]string{"foo": "bar"}
   261  	vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
   262  	embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
   263  		return vectors, nil
   264  	}
   265  
   266  	// Create collection
   267  	db := NewDB()
   268  	c, err := db.CreateCollection(name, metadata, embeddingFunc)
   269  	if err != nil {
   270  		t.Fatal("expected no error, got", err)
   271  	}
   272  	if c == nil {
   273  		t.Fatal("expected collection, got nil")
   274  	}
   275  
   276  	// Add documents, provoking errors
   277  	ids := []string{"1", "2"}
   278  	embeddings := [][]float32{vectors, vectors}
   279  	metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
   280  	contents := []string{"hello world", "hallo welt"}
   281  	// Empty IDs
   282  	err = c.AddConcurrently(ctx, []string{}, embeddings, metadatas, contents, 2)
   283  	if err == nil {
   284  		t.Fatal("expected error, got nil")
   285  	}
   286  	// Empty embeddings and contents (both at the same time!)
   287  	err = c.AddConcurrently(ctx, ids, [][]float32{}, metadatas, []string{}, 2)
   288  	if err == nil {
   289  		t.Fatal("expected error, got nil")
   290  	}
   291  	// Bad embeddings length
   292  	err = c.AddConcurrently(ctx, ids, [][]float32{vectors}, metadatas, contents, 2)
   293  	if err == nil {
   294  		t.Fatal("expected error, got nil")
   295  	}
   296  	// Bad metadatas length
   297  	err = c.AddConcurrently(ctx, ids, embeddings, []map[string]string{{"foo": "bar"}}, contents, 2)
   298  	if err == nil {
   299  		t.Fatal("expected error, got nil")
   300  	}
   301  	// Bad contents length
   302  	err = c.AddConcurrently(ctx, ids, embeddings, metadatas, []string{"hello world"}, 2)
   303  	if err == nil {
   304  		t.Fatal("expected error, got nil")
   305  	}
   306  	// Bad concurrency
   307  	err = c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 0)
   308  	if err == nil {
   309  		t.Fatal("expected error, got nil")
   310  	}
   311  }
   312  
   313  func TestCollection_QueryError(t *testing.T) {
   314  	// Create collection
   315  	db := NewDB()
   316  	name := "test"
   317  	metadata := map[string]string{"foo": "bar"}
   318  	vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
   319  	embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
   320  		return vectors, nil
   321  	}
   322  	c, err := db.CreateCollection(name, metadata, embeddingFunc)
   323  	if err != nil {
   324  		t.Fatal("expected no error, got", err)
   325  	}
   326  	if c == nil {
   327  		t.Fatal("expected collection, got nil")
   328  	}
   329  	// Add a document
   330  	err = c.AddDocument(context.Background(), Document{ID: "1", Content: "hello world"})
   331  	if err != nil {
   332  		t.Fatal("expected nil, got", err)
   333  	}
   334  
   335  	tt := []struct {
   336  		name   string
   337  		query  func() error
   338  		expErr string
   339  	}{
   340  		{
   341  			name: "Empty query",
   342  			query: func() error {
   343  				_, err := c.Query(context.Background(), "", 1, nil, nil)
   344  				return err
   345  			},
   346  			expErr: "queryText is empty",
   347  		},
   348  		{
   349  			name: "Negative limit",
   350  			query: func() error {
   351  				_, err := c.Query(context.Background(), "foo", -1, nil, nil)
   352  				return err
   353  			},
   354  			expErr: "nResults must be > 0",
   355  		},
   356  		{
   357  			name: "Zero limit",
   358  			query: func() error {
   359  				_, err := c.Query(context.Background(), "foo", 0, nil, nil)
   360  				return err
   361  			},
   362  			expErr: "nResults must be > 0",
   363  		},
   364  		{
   365  			name: "Limit greater than number of documents",
   366  			query: func() error {
   367  				_, err := c.Query(context.Background(), "foo", 2, nil, nil)
   368  				return err
   369  			},
   370  			expErr: "nResults must be <= the number of documents in the collection",
   371  		},
   372  		{
   373  			name: "Bad content filter",
   374  			query: func() error {
   375  				_, err := c.Query(context.Background(), "foo", 1, nil, map[string]string{"invalid": "foo"})
   376  				return err
   377  			},
   378  			expErr: "unsupported operator",
   379  		},
   380  	}
   381  
   382  	for _, tc := range tt {
   383  		t.Run(tc.name, func(t *testing.T) {
   384  			err := tc.query()
   385  			if err == nil {
   386  				t.Fatal("expected error, got nil")
   387  			} else if err.Error() != tc.expErr {
   388  				t.Fatal("expected", tc.expErr, "got", err)
   389  			}
   390  		})
   391  	}
   392  }
   393  
   394  func TestCollection_Count(t *testing.T) {
   395  	// Create collection
   396  	db := NewDB()
   397  	name := "test"
   398  	metadata := map[string]string{"foo": "bar"}
   399  	vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
   400  	embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
   401  		return vectors, nil
   402  	}
   403  	c, err := db.CreateCollection(name, metadata, embeddingFunc)
   404  	if err != nil {
   405  		t.Fatal("expected no error, got", err)
   406  	}
   407  	if c == nil {
   408  		t.Fatal("expected collection, got nil")
   409  	}
   410  
   411  	// Add documents
   412  	ids := []string{"1", "2"}
   413  	metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
   414  	contents := []string{"hello world", "hallo welt"}
   415  	err = c.Add(context.Background(), ids, nil, metadatas, contents)
   416  	if err != nil {
   417  		t.Fatal("expected nil, got", err)
   418  	}
   419  
   420  	// Check count
   421  	if c.Count() != 2 {
   422  		t.Fatal("expected 2, got", c.Count())
   423  	}
   424  }
   425  
   426  func TestCollection_Delete(t *testing.T) {
   427  	// Create persistent collection
   428  	tmpdir, err := os.MkdirTemp(os.TempDir(), "chromem-test-*")
   429  	if err != nil {
   430  		t.Fatal("expected no error, got", err)
   431  	}
   432  	db, err := NewPersistentDB(tmpdir, false)
   433  	if err != nil {
   434  		t.Fatal("expected no error, got", err)
   435  	}
   436  	name := "test"
   437  	metadata := map[string]string{"foo": "bar"}
   438  	vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
   439  	embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
   440  		return vectors, nil
   441  	}
   442  	c, err := db.CreateCollection(name, metadata, embeddingFunc)
   443  	if err != nil {
   444  		t.Fatal("expected no error, got", err)
   445  	}
   446  	if c == nil {
   447  		t.Fatal("expected collection, got nil")
   448  	}
   449  
   450  	// Add documents
   451  	ids := []string{"1", "2", "3", "4"}
   452  	metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}, {"foo": "bar"}, {"e": "f"}}
   453  	contents := []string{"hello world", "hallo welt", "bonjour le monde", "hola mundo"}
   454  	err = c.Add(context.Background(), ids, nil, metadatas, contents)
   455  	if err != nil {
   456  		t.Fatal("expected nil, got", err)
   457  	}
   458  
   459  	// Check count
   460  	if c.Count() != 4 {
   461  		t.Fatal("expected 4 documents, got", c.Count())
   462  	}
   463  
   464  	// Check number of files in the persist directory
   465  	d, err := os.ReadDir(c.persistDirectory)
   466  	if err != nil {
   467  		t.Fatal("expected nil, got", err)
   468  	}
   469  	if len(d) != 5 { // 4 documents + 1 metadata file
   470  		t.Fatal("expected 4 document files + 1 metadata file in persist_dir, got", len(d))
   471  	}
   472  
   473  	checkCount := func(expected int) {
   474  		// Check count
   475  		if c.Count() != expected {
   476  			t.Fatalf("expected %d documents, got %d", expected, c.Count())
   477  		}
   478  
   479  		// Check number of files in the persist directory
   480  		d, err = os.ReadDir(c.persistDirectory)
   481  		if err != nil {
   482  			t.Fatal("expected nil, got", err)
   483  		}
   484  		if len(d) != expected+1 { // 3 document + 1 metadata file
   485  			t.Fatalf("expected %d document files + 1 metadata file in persist_dir, got %d", expected, len(d))
   486  		}
   487  	}
   488  
   489  	// Test 1 - Remove document by ID: should delete one document
   490  	err = c.Delete(context.Background(), nil, nil, "4")
   491  	if err != nil {
   492  		t.Fatal("expected nil, got", err)
   493  	}
   494  	checkCount(3)
   495  
   496  	// Test 2 - Remove document by metadata
   497  	err = c.Delete(context.Background(), map[string]string{"foo": "bar"}, nil)
   498  	if err != nil {
   499  		t.Fatal("expected nil, got", err)
   500  	}
   501  
   502  	checkCount(1)
   503  
   504  	// Test 3 - Remove document by content
   505  	err = c.Delete(context.Background(), nil, map[string]string{"$contains": "hallo welt"})
   506  	if err != nil {
   507  		t.Fatal("expected nil, got", err)
   508  	}
   509  
   510  	checkCount(0)
   511  }
   512  
   513  // Global var for assignment in the benchmark to avoid compiler optimizations.
   514  var globalRes []Result
   515  
   516  func BenchmarkCollection_Query_NoContent_100(b *testing.B) {
   517  	benchmarkCollection_Query(b, 100, false)
   518  }
   519  
   520  func BenchmarkCollection_Query_NoContent_1000(b *testing.B) {
   521  	benchmarkCollection_Query(b, 1000, false)
   522  }
   523  
   524  func BenchmarkCollection_Query_NoContent_5000(b *testing.B) {
   525  	benchmarkCollection_Query(b, 5000, false)
   526  }
   527  
   528  func BenchmarkCollection_Query_NoContent_25000(b *testing.B) {
   529  	benchmarkCollection_Query(b, 25000, false)
   530  }
   531  
   532  func BenchmarkCollection_Query_NoContent_100000(b *testing.B) {
   533  	benchmarkCollection_Query(b, 100_000, false)
   534  }
   535  
   536  func BenchmarkCollection_Query_100(b *testing.B) {
   537  	benchmarkCollection_Query(b, 100, true)
   538  }
   539  
   540  func BenchmarkCollection_Query_1000(b *testing.B) {
   541  	benchmarkCollection_Query(b, 1000, true)
   542  }
   543  
   544  func BenchmarkCollection_Query_5000(b *testing.B) {
   545  	benchmarkCollection_Query(b, 5000, true)
   546  }
   547  
   548  func BenchmarkCollection_Query_25000(b *testing.B) {
   549  	benchmarkCollection_Query(b, 25000, true)
   550  }
   551  
   552  func BenchmarkCollection_Query_100000(b *testing.B) {
   553  	benchmarkCollection_Query(b, 100_000, true)
   554  }
   555  
   556  // n is number of documents in the collection
   557  func benchmarkCollection_Query(b *testing.B, n int, withContent bool) {
   558  	ctx := context.Background()
   559  
   560  	// Seed to make deterministic
   561  	r := rand.New(rand.NewSource(42))
   562  
   563  	d := 1536 // dimensions, same as text-embedding-3-small
   564  	// Random query vector
   565  	qv := make([]float32, d)
   566  	for j := 0; j < d; j++ {
   567  		qv[j] = r.Float32()
   568  	}
   569  	// The document embeddings are normalized, so the query must be normalized too.
   570  	qv = normalizeVector(qv)
   571  
   572  	// Create collection
   573  	db := NewDB()
   574  	name := "test"
   575  	embeddingFunc := func(_ context.Context, text string) ([]float32, error) {
   576  		return nil, errors.New("embedding func not expected to be called")
   577  	}
   578  	c, err := db.CreateCollection(name, nil, embeddingFunc)
   579  	if err != nil {
   580  		b.Fatal("expected no error, got", err)
   581  	}
   582  	if c == nil {
   583  		b.Fatal("expected collection, got nil")
   584  	}
   585  
   586  	// Add documents
   587  	for i := 0; i < n; i++ {
   588  		// Random embedding
   589  		v := make([]float32, d)
   590  		for j := 0; j < d; j++ {
   591  			v[j] = r.Float32()
   592  		}
   593  		v = normalizeVector(v)
   594  
   595  		// Add document with some metadata and content depending on parameter.
   596  		// When providing embeddings, the embedding func is not called.
   597  		is := strconv.Itoa(i)
   598  		doc := Document{
   599  			ID:        is,
   600  			Metadata:  map[string]string{"i": is, "foo": "bar" + is},
   601  			Embedding: v,
   602  		}
   603  		if withContent {
   604  			// Let's say we embed 500 tokens, that's ~375 words, ~1875 characters
   605  			doc.Content = randomString(r, 1875)
   606  		}
   607  
   608  		if err := c.AddDocument(ctx, doc); err != nil {
   609  			b.Fatal("expected nil, got", err)
   610  		}
   611  	}
   612  
   613  	b.ResetTimer()
   614  
   615  	// Query
   616  	var res []Result
   617  	for i := 0; i < b.N; i++ {
   618  		res, err = c.QueryEmbedding(ctx, qv, 10, nil, nil)
   619  	}
   620  	if err != nil {
   621  		b.Fatal("expected nil, got", err)
   622  	}
   623  	globalRes = res
   624  }
   625  
   626  // randomString returns a random string of length n using lowercase letters and space.
   627  func randomString(r *rand.Rand, n int) string {
   628  	// We add 5 spaces to get roughly one space every 5 characters
   629  	characters := []rune("abcdefghijklmnopqrstuvwxyz     ")
   630  
   631  	b := make([]rune, n)
   632  	for i := range b {
   633  		b[i] = characters[r.Intn(len(characters))]
   634  	}
   635  	return string(b)
   636  }
   637  

View as plain text