...
1 package chromem
2
3 import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "net/http"
11 "sync"
12 )
13
14 const defaultBaseURLOllama = "http://localhost:11434/api"
15
16 type ollamaResponse struct {
17 Embedding []float32 `json:"embedding"`
18 }
19
20
21
22
23
24
25
26 func NewEmbeddingFuncOllama(model string, baseURLOllama string) EmbeddingFunc {
27 if baseURLOllama == "" {
28 baseURLOllama = defaultBaseURLOllama
29 }
30
31
32
33
34 client := &http.Client{}
35
36 var checkedNormalized bool
37 checkNormalized := sync.Once{}
38
39 return func(ctx context.Context, text string) ([]float32, error) {
40
41 reqBody, err := json.Marshal(map[string]string{
42 "model": model,
43 "prompt": text,
44 })
45 if err != nil {
46 return nil, fmt.Errorf("couldn't marshal request body: %w", err)
47 }
48
49
50
51 req, err := http.NewRequestWithContext(ctx, "POST", baseURLOllama+"/embeddings", bytes.NewBuffer(reqBody))
52 if err != nil {
53 return nil, fmt.Errorf("couldn't create request: %w", err)
54 }
55 req.Header.Set("Content-Type", "application/json")
56
57
58 resp, err := client.Do(req)
59 if err != nil {
60 return nil, fmt.Errorf("couldn't send request: %w", err)
61 }
62 defer resp.Body.Close()
63
64
65 if resp.StatusCode != http.StatusOK {
66 return nil, errors.New("error response from the embedding API: " + resp.Status)
67 }
68
69
70 body, err := io.ReadAll(resp.Body)
71 if err != nil {
72 return nil, fmt.Errorf("couldn't read response body: %w", err)
73 }
74 var embeddingResponse ollamaResponse
75 err = json.Unmarshal(body, &embeddingResponse)
76 if err != nil {
77 return nil, fmt.Errorf("couldn't unmarshal response body: %w", err)
78 }
79
80
81 if len(embeddingResponse.Embedding) == 0 {
82 return nil, errors.New("no embeddings found in the response")
83 }
84
85 v := embeddingResponse.Embedding
86 checkNormalized.Do(func() {
87 if isNormalized(v) {
88 checkedNormalized = true
89 } else {
90 checkedNormalized = false
91 }
92 })
93 if !checkedNormalized {
94 v = normalizeVector(v)
95 }
96
97 return v, nil
98 }
99 }
100
View as plain text