...

Source file src/gitlab.hexacode.org/go-libs/chromem-go/query.go

Documentation: gitlab.hexacode.org/go-libs/chromem-go

     1  package chromem
     2  
     3  import (
     4  	"cmp"
     5  	"container/heap"
     6  	"context"
     7  	"fmt"
     8  	"runtime"
     9  	"slices"
    10  	"strings"
    11  	"sync"
    12  )
    13  
    14  var supportedFilters = []string{"$contains", "$not_contains"}
    15  
    16  type docSim struct {
    17  	docID      string
    18  	similarity float32
    19  }
    20  
    21  // docMaxHeap is a max-heap of docSims, based on similarity.
    22  // See https://pkg.go.dev/container/heap@go1.22#example-package-IntHeap
    23  type docMaxHeap []docSim
    24  
    25  func (h docMaxHeap) Len() int           { return len(h) }
    26  func (h docMaxHeap) Less(i, j int) bool { return h[i].similarity < h[j].similarity }
    27  func (h docMaxHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
    28  
    29  func (h *docMaxHeap) Push(x any) {
    30  	// Push and Pop use pointer receivers because they modify the slice's length,
    31  	// not just its contents.
    32  	*h = append(*h, x.(docSim))
    33  }
    34  
    35  func (h *docMaxHeap) Pop() any {
    36  	old := *h
    37  	n := len(old)
    38  	x := old[n-1]
    39  	*h = old[0 : n-1]
    40  	return x
    41  }
    42  
    43  // maxDocSims manages a max-heap of docSims with a fixed size, keeping the n highest
    44  // similarities. It's safe for concurrent use, but not the result of values().
    45  // In our benchmarks this was faster than sorting a slice of docSims at the end.
    46  type maxDocSims struct {
    47  	h    docMaxHeap
    48  	lock sync.RWMutex
    49  	size int
    50  }
    51  
    52  // newMaxDocSims creates a new nMaxDocs with a fixed size.
    53  func newMaxDocSims(size int) *maxDocSims {
    54  	return &maxDocSims{
    55  		h:    make(docMaxHeap, 0, size),
    56  		size: size,
    57  	}
    58  }
    59  
    60  // add inserts a new docSim into the heap, keeping only the top n similarities.
    61  func (d *maxDocSims) add(doc docSim) {
    62  	d.lock.Lock()
    63  	defer d.lock.Unlock()
    64  	if d.h.Len() < d.size {
    65  		heap.Push(&d.h, doc)
    66  	} else if d.h.Len() > 0 && d.h[0].similarity < doc.similarity {
    67  		// Replace the smallest similarity if the new doc's similarity is higher
    68  		heap.Pop(&d.h)
    69  		heap.Push(&d.h, doc)
    70  	}
    71  }
    72  
    73  // values returns the docSims in the heap, sorted by similarity (descending).
    74  // The call itself is safe for concurrent use with add(), but the result isn't.
    75  // Only work with the result after all calls to add() have finished.
    76  func (d *maxDocSims) values() []docSim {
    77  	d.lock.RLock()
    78  	defer d.lock.RUnlock()
    79  	slices.SortFunc(d.h, func(i, j docSim) int {
    80  		return cmp.Compare(j.similarity, i.similarity)
    81  	})
    82  	return d.h
    83  }
    84  
    85  // filterDocs filters a map of documents by metadata and content.
    86  // It does this concurrently.
    87  func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document {
    88  	filteredDocs := make([]*Document, 0, len(docs))
    89  	filteredDocsLock := sync.Mutex{}
    90  
    91  	// Determine concurrency. Use number of docs or CPUs, whichever is smaller.
    92  	numCPUs := runtime.NumCPU()
    93  	numDocs := len(docs)
    94  	concurrency := numCPUs
    95  	if numDocs < numCPUs {
    96  		concurrency = numDocs
    97  	}
    98  
    99  	docChan := make(chan *Document, concurrency*2)
   100  
   101  	wg := sync.WaitGroup{}
   102  	for i := 0; i < concurrency; i++ {
   103  		wg.Add(1)
   104  		go func() {
   105  			defer wg.Done()
   106  			for doc := range docChan {
   107  				if documentMatchesFilters(doc, where, whereDocument) {
   108  					filteredDocsLock.Lock()
   109  					filteredDocs = append(filteredDocs, doc)
   110  					filteredDocsLock.Unlock()
   111  				}
   112  			}
   113  		}()
   114  	}
   115  
   116  	for _, doc := range docs {
   117  		docChan <- doc
   118  	}
   119  	close(docChan)
   120  
   121  	wg.Wait()
   122  
   123  	// With filteredDocs being initialized as potentially large slice, let's return
   124  	// nil instead of the empty slice.
   125  	if len(filteredDocs) == 0 {
   126  		filteredDocs = nil
   127  	}
   128  	return filteredDocs
   129  }
   130  
   131  // documentMatchesFilters checks if a document matches the given filters.
   132  // When calling this function, the whereDocument keys must already be validated!
   133  func documentMatchesFilters(document *Document, where, whereDocument map[string]string) bool {
   134  	// A document's metadata must have *all* the fields in the where clause.
   135  	for k, v := range where {
   136  		// TODO: Do we want to check for existence of the key? I.e. should
   137  		// a where clause with empty string as value match a document's
   138  		// metadata that doesn't have the key at all?
   139  		if document.Metadata[k] != v {
   140  			return false
   141  		}
   142  	}
   143  
   144  	// A document must satisfy *all* filters, until we support the `$or` operator.
   145  	for k, v := range whereDocument {
   146  		switch k {
   147  		case "$contains":
   148  			if !strings.Contains(document.Content, v) {
   149  				return false
   150  			}
   151  		case "$not_contains":
   152  			if strings.Contains(document.Content, v) {
   153  				return false
   154  			}
   155  		default:
   156  			// No handling (error) required because we already validated the
   157  			// operators. This simplifies the concurrency logic (no err var
   158  			// and lock, no context to cancel).
   159  		}
   160  	}
   161  
   162  	return true
   163  }
   164  
   165  func getMostSimilarDocs(ctx context.Context, queryVectors []float32, docs []*Document, n int) ([]docSim, error) {
   166  	nMaxDocs := newMaxDocSims(n)
   167  
   168  	// Determine concurrency. Use number of docs or CPUs, whichever is smaller.
   169  	numCPUs := runtime.NumCPU()
   170  	numDocs := len(docs)
   171  	concurrency := numCPUs
   172  	if numDocs < numCPUs {
   173  		concurrency = numDocs
   174  	}
   175  
   176  	var sharedErr error
   177  	sharedErrLock := sync.Mutex{}
   178  	ctx, cancel := context.WithCancelCause(ctx)
   179  	defer cancel(nil)
   180  	setSharedErr := func(err error) {
   181  		sharedErrLock.Lock()
   182  		defer sharedErrLock.Unlock()
   183  		// Another goroutine might have already set the error.
   184  		if sharedErr == nil {
   185  			sharedErr = err
   186  			// Cancel the operation for all other goroutines.
   187  			cancel(sharedErr)
   188  		}
   189  	}
   190  
   191  	wg := sync.WaitGroup{}
   192  	// Instead of using a channel to pass documents into the goroutines, we just
   193  	// split the slice into sub-slices and pass those to the goroutines.
   194  	// This turned out to be faster in the query benchmarks.
   195  	subSliceSize := len(docs) / concurrency // Can leave remainder, e.g. 10/3 = 3; leaves 1
   196  	rem := len(docs) % concurrency
   197  	for i := 0; i < concurrency; i++ {
   198  		start := i * subSliceSize
   199  		end := start + subSliceSize
   200  		// Add remainder to last goroutine
   201  		if i == concurrency-1 {
   202  			end += rem
   203  		}
   204  
   205  		wg.Add(1)
   206  		go func(subSlice []*Document) {
   207  			defer wg.Done()
   208  			for _, doc := range subSlice {
   209  				// Stop work if another goroutine encountered an error.
   210  				if ctx.Err() != nil {
   211  					return
   212  				}
   213  
   214  				// As the vectors are normalized, the dot product is the cosine similarity.
   215  				sim, err := dotProduct(queryVectors, doc.Embedding)
   216  				if err != nil {
   217  					setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err))
   218  					return
   219  				}
   220  
   221  				nMaxDocs.add(docSim{docID: doc.ID, similarity: sim})
   222  			}
   223  		}(docs[start:end])
   224  	}
   225  
   226  	wg.Wait()
   227  
   228  	if sharedErr != nil {
   229  		return nil, sharedErr
   230  	}
   231  
   232  	return nMaxDocs.values(), nil
   233  }
   234  

View as plain text