...
1 package chromem
2
3 import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "net/http"
11 "strings"
12 "sync"
13 )
14
15 type EmbeddingModelCohere string
16
17 const (
18 EmbeddingModelCohereMultilingualV2 EmbeddingModelCohere = "embed-multilingual-v2.0"
19 EmbeddingModelCohereEnglishLightV2 EmbeddingModelCohere = "embed-english-light-v2.0"
20 EmbeddingModelCohereEnglishV2 EmbeddingModelCohere = "embed-english-v2.0"
21 EmbeddingModelCohereMultilingualLightV3 EmbeddingModelCohere = "embed-multilingual-light-v3.0"
22 EmbeddingModelCohereEnglishLightV3 EmbeddingModelCohere = "embed-english-light-v3.0"
23 EmbeddingModelCohereMultilingualV3 EmbeddingModelCohere = "embed-multilingual-v3.0"
24 EmbeddingModelCohereEnglishV3 EmbeddingModelCohere = "embed-english-v3.0"
25 )
26
27
28 const (
29 InputTypeCohereSearchDocumentPrefix string = "search_document: "
30 InputTypeCohereSearchQueryPrefix string = "search_query: "
31 InputTypeCohereClassificationPrefix string = "classification: "
32 InputTypeCohereClusteringPrefix string = "clustering: "
33 )
34
35
36 const (
37 inputTypeCohereSearchDocument string = "search_document"
38 inputTypeCohereSearchQuery string = "search_query"
39 inputTypeCohereClassification string = "classification"
40 inputTypeCohereClustering string = "clustering"
41 )
42
43 const baseURLCohere = "https://api.cohere.ai/v1"
44
45 var validInputTypesCohere = map[string]string{
46 inputTypeCohereSearchDocument: InputTypeCohereSearchDocumentPrefix,
47 inputTypeCohereSearchQuery: InputTypeCohereSearchQueryPrefix,
48 inputTypeCohereClassification: InputTypeCohereClassificationPrefix,
49 inputTypeCohereClustering: InputTypeCohereClusteringPrefix,
50 }
51
52 type cohereResponse struct {
53 Embeddings [][]float32 `json:"embeddings"`
54 }
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83 func NewEmbeddingFuncCohere(apiKey string, model EmbeddingModelCohere) EmbeddingFunc {
84
85
86
87 client := &http.Client{}
88
89 var checkedNormalized bool
90 checkNormalized := sync.Once{}
91
92 return func(ctx context.Context, text string) ([]float32, error) {
93 var inputType string
94 for validInputType, validInputTypePrefix := range validInputTypesCohere {
95 if strings.HasPrefix(text, validInputTypePrefix) {
96 inputType = validInputType
97 text = strings.TrimPrefix(text, validInputTypePrefix)
98 break
99 }
100 }
101 if inputType == "" {
102 return nil, errors.New("text must start with a valid input type plus colon and space")
103 }
104
105
106 reqBody, err := json.Marshal(map[string]any{
107 "model": model,
108 "texts": []string{text},
109 "input_type": inputType,
110 })
111 if err != nil {
112 return nil, fmt.Errorf("couldn't marshal request body: %w", err)
113 }
114
115
116
117 req, err := http.NewRequestWithContext(ctx, "POST", baseURLCohere+"/embed", bytes.NewBuffer(reqBody))
118 if err != nil {
119 return nil, fmt.Errorf("couldn't create request: %w", err)
120 }
121 req.Header.Set("Accept", "application/json")
122 req.Header.Set("Content-Type", "application/json")
123 req.Header.Set("Authorization", "Bearer "+apiKey)
124
125
126 resp, err := client.Do(req)
127 if err != nil {
128 return nil, fmt.Errorf("couldn't send request: %w", err)
129 }
130 defer resp.Body.Close()
131
132
133 if resp.StatusCode != http.StatusOK {
134 return nil, errors.New("error response from the embedding API: " + resp.Status)
135 }
136
137
138 body, err := io.ReadAll(resp.Body)
139 if err != nil {
140 return nil, fmt.Errorf("couldn't read response body: %w", err)
141 }
142 var embeddingResponse cohereResponse
143 err = json.Unmarshal(body, &embeddingResponse)
144 if err != nil {
145 return nil, fmt.Errorf("couldn't unmarshal response body: %w", err)
146 }
147
148
149 if len(embeddingResponse.Embeddings) == 0 || len(embeddingResponse.Embeddings[0]) == 0 {
150 return nil, errors.New("no embeddings found in the response")
151 }
152
153 v := embeddingResponse.Embeddings[0]
154 checkNormalized.Do(func() {
155 if isNormalized(v) {
156 checkedNormalized = true
157 } else {
158 checkedNormalized = false
159 }
160 })
161 if !checkedNormalized {
162 v = normalizeVector(v)
163 }
164
165 return v, nil
166 }
167 }
168
View as plain text