...
1 package chromem
2
3 import (
4 "cmp"
5 "container/heap"
6 "context"
7 "fmt"
8 "runtime"
9 "slices"
10 "strings"
11 "sync"
12 )
13
14 var supportedFilters = []string{"$contains", "$not_contains"}
15
16 type docSim struct {
17 docID string
18 similarity float32
19 }
20
21
22
23 type docMaxHeap []docSim
24
25 func (h docMaxHeap) Len() int { return len(h) }
26 func (h docMaxHeap) Less(i, j int) bool { return h[i].similarity < h[j].similarity }
27 func (h docMaxHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
28
29 func (h *docMaxHeap) Push(x any) {
30
31
32 *h = append(*h, x.(docSim))
33 }
34
35 func (h *docMaxHeap) Pop() any {
36 old := *h
37 n := len(old)
38 x := old[n-1]
39 *h = old[0 : n-1]
40 return x
41 }
42
43
44
45
46 type maxDocSims struct {
47 h docMaxHeap
48 lock sync.RWMutex
49 size int
50 }
51
52
53 func newMaxDocSims(size int) *maxDocSims {
54 return &maxDocSims{
55 h: make(docMaxHeap, 0, size),
56 size: size,
57 }
58 }
59
60
61 func (d *maxDocSims) add(doc docSim) {
62 d.lock.Lock()
63 defer d.lock.Unlock()
64 if d.h.Len() < d.size {
65 heap.Push(&d.h, doc)
66 } else if d.h.Len() > 0 && d.h[0].similarity < doc.similarity {
67
68 heap.Pop(&d.h)
69 heap.Push(&d.h, doc)
70 }
71 }
72
73
74
75
76 func (d *maxDocSims) values() []docSim {
77 d.lock.RLock()
78 defer d.lock.RUnlock()
79 slices.SortFunc(d.h, func(i, j docSim) int {
80 return cmp.Compare(j.similarity, i.similarity)
81 })
82 return d.h
83 }
84
85
86
87 func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document {
88 filteredDocs := make([]*Document, 0, len(docs))
89 filteredDocsLock := sync.Mutex{}
90
91
92 numCPUs := runtime.NumCPU()
93 numDocs := len(docs)
94 concurrency := numCPUs
95 if numDocs < numCPUs {
96 concurrency = numDocs
97 }
98
99 docChan := make(chan *Document, concurrency*2)
100
101 wg := sync.WaitGroup{}
102 for i := 0; i < concurrency; i++ {
103 wg.Add(1)
104 go func() {
105 defer wg.Done()
106 for doc := range docChan {
107 if documentMatchesFilters(doc, where, whereDocument) {
108 filteredDocsLock.Lock()
109 filteredDocs = append(filteredDocs, doc)
110 filteredDocsLock.Unlock()
111 }
112 }
113 }()
114 }
115
116 for _, doc := range docs {
117 docChan <- doc
118 }
119 close(docChan)
120
121 wg.Wait()
122
123
124
125 if len(filteredDocs) == 0 {
126 filteredDocs = nil
127 }
128 return filteredDocs
129 }
130
131
132
133 func documentMatchesFilters(document *Document, where, whereDocument map[string]string) bool {
134
135 for k, v := range where {
136
137
138
139 if document.Metadata[k] != v {
140 return false
141 }
142 }
143
144
145 for k, v := range whereDocument {
146 switch k {
147 case "$contains":
148 if !strings.Contains(document.Content, v) {
149 return false
150 }
151 case "$not_contains":
152 if strings.Contains(document.Content, v) {
153 return false
154 }
155 default:
156
157
158
159 }
160 }
161
162 return true
163 }
164
165 func getMostSimilarDocs(ctx context.Context, queryVectors []float32, docs []*Document, n int) ([]docSim, error) {
166 nMaxDocs := newMaxDocSims(n)
167
168
169 numCPUs := runtime.NumCPU()
170 numDocs := len(docs)
171 concurrency := numCPUs
172 if numDocs < numCPUs {
173 concurrency = numDocs
174 }
175
176 var sharedErr error
177 sharedErrLock := sync.Mutex{}
178 ctx, cancel := context.WithCancelCause(ctx)
179 defer cancel(nil)
180 setSharedErr := func(err error) {
181 sharedErrLock.Lock()
182 defer sharedErrLock.Unlock()
183
184 if sharedErr == nil {
185 sharedErr = err
186
187 cancel(sharedErr)
188 }
189 }
190
191 wg := sync.WaitGroup{}
192
193
194
195 subSliceSize := len(docs) / concurrency
196 rem := len(docs) % concurrency
197 for i := 0; i < concurrency; i++ {
198 start := i * subSliceSize
199 end := start + subSliceSize
200
201 if i == concurrency-1 {
202 end += rem
203 }
204
205 wg.Add(1)
206 go func(subSlice []*Document) {
207 defer wg.Done()
208 for _, doc := range subSlice {
209
210 if ctx.Err() != nil {
211 return
212 }
213
214
215 sim, err := dotProduct(queryVectors, doc.Embedding)
216 if err != nil {
217 setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err))
218 return
219 }
220
221 nMaxDocs.add(docSim{docID: doc.ID, similarity: sim})
222 }
223 }(docs[start:end])
224 }
225
226 wg.Wait()
227
228 if sharedErr != nil {
229 return nil, sharedErr
230 }
231
232 return nMaxDocs.values(), nil
233 }
234
View as plain text