...
1 package chromem
2
3 import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "io"
8 "net/http"
9 "net/http/httptest"
10 "net/url"
11 "slices"
12 "strings"
13 "testing"
14 )
15
16 func TestNewEmbeddingFuncOllama(t *testing.T) {
17 model := "model-small"
18 baseURLSuffix := "/api"
19 prompt := "hello world"
20
21 wantBody, err := json.Marshal(map[string]string{
22 "model": model,
23 "prompt": prompt,
24 })
25 if err != nil {
26 t.Fatal("unexpected error:", err)
27 }
28 wantRes := []float32{-0.40824828, 0.40824828, 0.81649655}
29
30
31 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
32
33 if !strings.HasSuffix(r.URL.Path, baseURLSuffix+"/embeddings") {
34 t.Fatal("expected URL", baseURLSuffix+"/embeddings", "got", r.URL.Path)
35 }
36
37 if r.Method != "POST" {
38 t.Fatal("expected method POST, got", r.Method)
39 }
40
41 if r.Header.Get("Content-Type") != "application/json" {
42 t.Fatal("expected Content-Type header", "application/json", "got", r.Header.Get("Content-Type"))
43 }
44
45 body, err := io.ReadAll(r.Body)
46 if err != nil {
47 t.Fatal("unexpected error:", err)
48 }
49 if !bytes.Equal(body, wantBody) {
50 t.Fatal("expected body", wantBody, "got", body)
51 }
52
53
54 resp := ollamaResponse{
55 Embedding: wantRes,
56 }
57 w.WriteHeader(http.StatusOK)
58 _ = json.NewEncoder(w).Encode(resp)
59 }))
60 defer ts.Close()
61
62
63 u, err := url.Parse(ts.URL)
64 if err != nil {
65 t.Fatal("unexpected error:", err)
66 }
67
68 f := NewEmbeddingFuncOllama(model, strings.Replace(defaultBaseURLOllama, "11434", u.Port(), 1))
69 res, err := f(context.Background(), prompt)
70 if err != nil {
71 t.Fatal("expected nil, got", err)
72 }
73 if slices.Compare(wantRes, res) != 0 {
74 t.Fatal("expected res", wantRes, "got", res)
75 }
76 }
77
View as plain text