...

Source file src/gitlab.hexacode.org/go-libs/chromem-go/collection.go

Documentation: gitlab.hexacode.org/go-libs/chromem-go

     1  package chromem
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"path/filepath"
     8  	"slices"
     9  	"sync"
    10  )
    11  
    12  // Collection represents a collection of documents.
    13  // It also has a configured embedding function, which is used when adding documents
    14  // that don't have embeddings yet.
    15  type Collection struct {
    16  	Name string
    17  
    18  	metadata      map[string]string
    19  	documents     map[string]*Document
    20  	documentsLock sync.RWMutex
    21  	embed         EmbeddingFunc
    22  
    23  	persistDirectory string
    24  	compress         bool
    25  
    26  	// ⚠️ When adding fields here, consider adding them to the persistence struct
    27  	// versions in [DB.Export] and [DB.Import] as well!
    28  }
    29  
    30  // We don't export this yet to keep the API surface to the bare minimum.
    31  // Users create collections via [Client.CreateCollection].
    32  func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string, compress bool) (*Collection, error) {
    33  	// We copy the metadata to avoid data races in case the caller modifies the
    34  	// map after creating the collection while we range over it.
    35  	m := make(map[string]string, len(metadata))
    36  	for k, v := range metadata {
    37  		m[k] = v
    38  	}
    39  
    40  	c := &Collection{
    41  		Name: name,
    42  
    43  		metadata:  m,
    44  		documents: make(map[string]*Document),
    45  		embed:     embed,
    46  	}
    47  
    48  	// Persistence
    49  	if dbDir != "" {
    50  		safeName := hash2hex(name)
    51  		c.persistDirectory = filepath.Join(dbDir, safeName)
    52  		c.compress = compress
    53  		// Persist name and metadata
    54  		metadataPath := filepath.Join(c.persistDirectory, metadataFileName)
    55  		metadataPath += ".gob"
    56  		if c.compress {
    57  			metadataPath += ".gz"
    58  		}
    59  		pc := struct {
    60  			Name     string
    61  			Metadata map[string]string
    62  		}{
    63  			Name:     name,
    64  			Metadata: m,
    65  		}
    66  		err := persistToFile(metadataPath, pc, compress, "")
    67  		if err != nil {
    68  			return nil, fmt.Errorf("couldn't persist collection metadata: %w", err)
    69  		}
    70  	}
    71  
    72  	return c, nil
    73  }
    74  
    75  // Add embeddings to the datastore.
    76  //
    77  //   - ids: The ids of the embeddings you wish to add
    78  //   - embeddings: The embeddings to add. If nil, embeddings will be computed based
    79  //     on the contents using the embeddingFunc set for the Collection. Optional.
    80  //   - metadatas: The metadata to associate with the embeddings. When querying,
    81  //     you can filter on this metadata. Optional.
    82  //   - contents: The contents to associate with the embeddings.
    83  //
    84  // This is a Chroma-like method. For a more Go-idiomatic one, see [AddDocuments].
    85  func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string) error {
    86  	return c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 1)
    87  }
    88  
    89  // AddConcurrently is like Add, but adds embeddings concurrently.
    90  // This is mostly useful when you don't pass any embeddings, so they have to be created.
    91  // Upon error, concurrently running operations are canceled and the error is returned.
    92  //
    93  // This is a Chroma-like method. For a more Go-idiomatic one, see [AddDocuments].
    94  func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string, concurrency int) error {
    95  	if len(ids) == 0 {
    96  		return errors.New("ids are empty")
    97  	}
    98  	if len(embeddings) == 0 && len(contents) == 0 {
    99  		return errors.New("either embeddings or contents must be filled")
   100  	}
   101  	if len(embeddings) != 0 {
   102  		if len(embeddings) != len(ids) {
   103  			return errors.New("ids and embeddings must have the same length")
   104  		}
   105  	} else {
   106  		// Assign empty slice, so we can simply access via index later
   107  		embeddings = make([][]float32, len(ids))
   108  	}
   109  	if len(metadatas) != 0 {
   110  		if len(ids) != len(metadatas) {
   111  			return errors.New("when metadatas is not empty it must have the same length as ids")
   112  		}
   113  	} else {
   114  		// Assign empty slice, so we can simply access via index later
   115  		metadatas = make([]map[string]string, len(ids))
   116  	}
   117  	if len(contents) != 0 {
   118  		if len(contents) != len(ids) {
   119  			return errors.New("ids and contents must have the same length")
   120  		}
   121  	} else {
   122  		// Assign empty slice, so we can simply access via index later
   123  		contents = make([]string, len(ids))
   124  	}
   125  	if concurrency < 1 {
   126  		return errors.New("concurrency must be at least 1")
   127  	}
   128  
   129  	// Convert Chroma-style parameters into a slice of documents.
   130  	docs := make([]Document, 0, len(ids))
   131  	for i, id := range ids {
   132  		docs = append(docs, Document{
   133  			ID:        id,
   134  			Metadata:  metadatas[i],
   135  			Embedding: embeddings[i],
   136  			Content:   contents[i],
   137  		})
   138  	}
   139  
   140  	return c.AddDocuments(ctx, docs, concurrency)
   141  }
   142  
   143  // AddDocuments adds documents to the collection with the specified concurrency.
   144  // If the documents don't have embeddings, they will be created using the collection's
   145  // embedding function.
   146  // Upon error, concurrently running operations are canceled and the error is returned.
   147  func (c *Collection) AddDocuments(ctx context.Context, documents []Document, concurrency int) error {
   148  	if len(documents) == 0 {
   149  		// TODO: Should this be a no-op instead?
   150  		return errors.New("documents slice is nil or empty")
   151  	}
   152  	if concurrency < 1 {
   153  		return errors.New("concurrency must be at least 1")
   154  	}
   155  	// For other validations we rely on AddDocument.
   156  
   157  	var sharedErr error
   158  	sharedErrLock := sync.Mutex{}
   159  	ctx, cancel := context.WithCancelCause(ctx)
   160  	defer cancel(nil)
   161  	setSharedErr := func(err error) {
   162  		sharedErrLock.Lock()
   163  		defer sharedErrLock.Unlock()
   164  		// Another goroutine might have already set the error.
   165  		if sharedErr == nil {
   166  			sharedErr = err
   167  			// Cancel the operation for all other goroutines.
   168  			cancel(sharedErr)
   169  		}
   170  	}
   171  
   172  	var wg sync.WaitGroup
   173  	semaphore := make(chan struct{}, concurrency)
   174  	for _, doc := range documents {
   175  		wg.Add(1)
   176  		go func(doc Document) {
   177  			defer wg.Done()
   178  
   179  			// Don't even start if another goroutine already failed.
   180  			if ctx.Err() != nil {
   181  				return
   182  			}
   183  
   184  			// Wait here while $concurrency other goroutines are creating documents.
   185  			semaphore <- struct{}{}
   186  			defer func() { <-semaphore }()
   187  
   188  			err := c.AddDocument(ctx, doc)
   189  			if err != nil {
   190  				setSharedErr(fmt.Errorf("couldn't add document '%s': %w", doc.ID, err))
   191  				return
   192  			}
   193  		}(doc)
   194  	}
   195  
   196  	wg.Wait()
   197  
   198  	return sharedErr
   199  }
   200  
   201  // AddDocument adds a document to the collection.
   202  // If the document doesn't have an embedding, it will be created using the collection's
   203  // embedding function.
   204  func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
   205  	if doc.ID == "" {
   206  		return errors.New("document ID is empty")
   207  	}
   208  	if len(doc.Embedding) == 0 && doc.Content == "" {
   209  		return errors.New("either document embedding or content must be filled")
   210  	}
   211  
   212  	// We copy the metadata to avoid data races in case the caller modifies the
   213  	// map after creating the document while we range over it.
   214  	m := make(map[string]string, len(doc.Metadata))
   215  	for k, v := range doc.Metadata {
   216  		m[k] = v
   217  	}
   218  
   219  	// Create embedding if they don't exist, otherwise normalize if necessary
   220  	if len(doc.Embedding) == 0 {
   221  		embedding, err := c.embed(ctx, doc.Content)
   222  		if err != nil {
   223  			return fmt.Errorf("couldn't create embedding of document: %w", err)
   224  		}
   225  		doc.Embedding = embedding
   226  	} else {
   227  		if !isNormalized(doc.Embedding) {
   228  			doc.Embedding = normalizeVector(doc.Embedding)
   229  		}
   230  	}
   231  
   232  	c.documentsLock.Lock()
   233  	// We don't defer the unlock because we want to do it earlier.
   234  	c.documents[doc.ID] = &doc
   235  	c.documentsLock.Unlock()
   236  
   237  	// Persist the document
   238  	if c.persistDirectory != "" {
   239  		docPath := c.getDocPath(doc.ID)
   240  		err := persistToFile(docPath, doc, c.compress, "")
   241  		if err != nil {
   242  			return fmt.Errorf("couldn't persist document to %q: %w", docPath, err)
   243  		}
   244  	}
   245  
   246  	return nil
   247  }
   248  
   249  // Delete removes document(s) from the collection.
   250  //
   251  //   - where: Conditional filtering on metadata. Optional.
   252  //   - whereDocument: Conditional filtering on documents. Optional.
   253  //   - ids: The ids of the documents to delete. If empty, all documents are deleted.
   254  func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error {
   255  	// must have at least one of where, whereDocument or ids
   256  	if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 {
   257  		return fmt.Errorf("must have at least one of where, whereDocument or ids")
   258  	}
   259  
   260  	if len(c.documents) == 0 {
   261  		return nil
   262  	}
   263  
   264  	for k := range whereDocument {
   265  		if !slices.Contains(supportedFilters, k) {
   266  			return errors.New("unsupported whereDocument operator")
   267  		}
   268  	}
   269  
   270  	var docIDs []string
   271  
   272  	c.documentsLock.Lock()
   273  	defer c.documentsLock.Unlock()
   274  
   275  	if where != nil || whereDocument != nil {
   276  		// metadata + content filters
   277  		filteredDocs := filterDocs(c.documents, where, whereDocument)
   278  		for _, doc := range filteredDocs {
   279  			docIDs = append(docIDs, doc.ID)
   280  		}
   281  	} else {
   282  		docIDs = ids
   283  	}
   284  
   285  	// No-op if no docs are left
   286  	if len(docIDs) == 0 {
   287  		return nil
   288  	}
   289  
   290  	for _, docID := range docIDs {
   291  		delete(c.documents, docID)
   292  
   293  		// Remove the document from disk
   294  		if c.persistDirectory != "" {
   295  			docPath := c.getDocPath(docID)
   296  			err := removeFile(docPath)
   297  			if err != nil {
   298  				return fmt.Errorf("couldn't remove document at %q: %w", docPath, err)
   299  			}
   300  		}
   301  	}
   302  
   303  	return nil
   304  }
   305  
   306  // Count returns the number of documents in the collection.
   307  func (c *Collection) Count() int {
   308  	c.documentsLock.RLock()
   309  	defer c.documentsLock.RUnlock()
   310  	return len(c.documents)
   311  }
   312  
   313  // Result represents a single result from a query.
   314  type Result struct {
   315  	ID        string
   316  	Metadata  map[string]string
   317  	Embedding []float32
   318  	Content   string
   319  
   320  	// The cosine similarity between the query and the document.
   321  	// The higher the value, the more similar the document is to the query.
   322  	// The value is in the range [-1, 1].
   323  	Similarity float32
   324  }
   325  
   326  // Query performs an exhaustive nearest neighbor search on the collection.
   327  //
   328  //   - queryText: The text to search for. Its embedding will be created using the
   329  //     collection's embedding function.
   330  //   - nResults: The maximum number of results to return. Must be > 0.
   331  //     There can be fewer results if a filter is applied.
   332  //   - where: Conditional filtering on metadata. Optional.
   333  //   - whereDocument: Conditional filtering on documents. Optional.
   334  func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
   335  	if queryText == "" {
   336  		return nil, errors.New("queryText is empty")
   337  	}
   338  
   339  	queryVectors, err := c.embed(ctx, queryText)
   340  	if err != nil {
   341  		return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
   342  	}
   343  
   344  	return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument)
   345  }
   346  
   347  // QueryEmbedding performs an exhaustive nearest neighbor search on the collection.
   348  //
   349  //   - queryEmbedding: The embedding of the query to search for. It must be created
   350  //     with the same embedding model as the document embeddings in the collection.
   351  //     The embedding will be normalized if it's not the case yet.
   352  //   - nResults: The maximum number of results to return. Must be > 0.
   353  //     There can be fewer results if a filter is applied.
   354  //   - where: Conditional filtering on metadata. Optional.
   355  //   - whereDocument: Conditional filtering on documents. Optional.
   356  func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
   357  	if len(queryEmbedding) == 0 {
   358  		return nil, errors.New("queryEmbedding is empty")
   359  	}
   360  	if nResults <= 0 {
   361  		return nil, errors.New("nResults must be > 0")
   362  	}
   363  	c.documentsLock.RLock()
   364  	defer c.documentsLock.RUnlock()
   365  	if nResults > len(c.documents) {
   366  		return nil, errors.New("nResults must be <= the number of documents in the collection")
   367  	}
   368  
   369  	if len(c.documents) == 0 {
   370  		return nil, nil
   371  	}
   372  
   373  	// Validate whereDocument operators
   374  	for k := range whereDocument {
   375  		if !slices.Contains(supportedFilters, k) {
   376  			return nil, errors.New("unsupported operator")
   377  		}
   378  	}
   379  
   380  	// Filter docs by metadata and content
   381  	filteredDocs := filterDocs(c.documents, where, whereDocument)
   382  
   383  	// No need to continue if the filters got rid of all documents
   384  	if len(filteredDocs) == 0 {
   385  		return nil, nil
   386  	}
   387  
   388  	// Normalize embedding if not the case yet. We only support cosine similarity
   389  	// for now and all documents were already normalized when added to the collection.
   390  	if !isNormalized(queryEmbedding) {
   391  		queryEmbedding = normalizeVector(queryEmbedding)
   392  	}
   393  
   394  	// If the filtering already reduced the number of documents to fewer than nResults,
   395  	// we only need to find the most similar docs among the filtered ones.
   396  	resLen := nResults
   397  	if len(filteredDocs) < nResults {
   398  		resLen = len(filteredDocs)
   399  	}
   400  
   401  	// For the remaining documents, get the most similar docs.
   402  	nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, resLen)
   403  	if err != nil {
   404  		return nil, fmt.Errorf("couldn't get most similar docs: %w", err)
   405  	}
   406  
   407  	// As long as we don't filter by threshold, resLen should match len(nMaxDocs).
   408  	if resLen != len(nMaxDocs) {
   409  		return nil, fmt.Errorf("internal error: expected %d results, got %d", resLen, len(nMaxDocs))
   410  	}
   411  
   412  	res := make([]Result, 0, resLen)
   413  	for i := 0; i < resLen; i++ {
   414  		res = append(res, Result{
   415  			ID:         nMaxDocs[i].docID,
   416  			Metadata:   c.documents[nMaxDocs[i].docID].Metadata,
   417  			Embedding:  c.documents[nMaxDocs[i].docID].Embedding,
   418  			Content:    c.documents[nMaxDocs[i].docID].Content,
   419  			Similarity: nMaxDocs[i].similarity,
   420  		})
   421  	}
   422  
   423  	return res, nil
   424  }
   425  
   426  // getDocPath generates the path to the document file.
   427  func (c *Collection) getDocPath(docID string) string {
   428  	safeID := hash2hex(docID)
   429  	docPath := filepath.Join(c.persistDirectory, safeID)
   430  	docPath += ".gob"
   431  	if c.compress {
   432  		docPath += ".gz"
   433  	}
   434  	return docPath
   435  }
   436  

View as plain text