...

Source file src/gitlab.hexacode.org/go-libs/chromem-go/embed_ollama_test.go

Documentation: gitlab.hexacode.org/go-libs/chromem-go

     1  package chromem
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"io"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"slices"
    12  	"strings"
    13  	"testing"
    14  )
    15  
    16  func TestNewEmbeddingFuncOllama(t *testing.T) {
    17  	model := "model-small"
    18  	baseURLSuffix := "/api"
    19  	prompt := "hello world"
    20  
    21  	wantBody, err := json.Marshal(map[string]string{
    22  		"model":  model,
    23  		"prompt": prompt,
    24  	})
    25  	if err != nil {
    26  		t.Fatal("unexpected error:", err)
    27  	}
    28  	wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
    29  
    30  	// Mock server
    31  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    32  		// Check URL
    33  		if !strings.HasSuffix(r.URL.Path, baseURLSuffix+"/embeddings") {
    34  			t.Fatal("expected URL", baseURLSuffix+"/embeddings", "got", r.URL.Path)
    35  		}
    36  		// Check method
    37  		if r.Method != "POST" {
    38  			t.Fatal("expected method POST, got", r.Method)
    39  		}
    40  		// Check headers
    41  		if r.Header.Get("Content-Type") != "application/json" {
    42  			t.Fatal("expected Content-Type header", "application/json", "got", r.Header.Get("Content-Type"))
    43  		}
    44  		// Check body
    45  		body, err := io.ReadAll(r.Body)
    46  		if err != nil {
    47  			t.Fatal("unexpected error:", err)
    48  		}
    49  		if !bytes.Equal(body, wantBody) {
    50  			t.Fatal("expected body", wantBody, "got", body)
    51  		}
    52  
    53  		// Write response
    54  		resp := ollamaResponse{
    55  			Embedding: wantRes,
    56  		}
    57  		w.WriteHeader(http.StatusOK)
    58  		_ = json.NewEncoder(w).Encode(resp)
    59  	}))
    60  	defer ts.Close()
    61  
    62  	// Get port from URL
    63  	u, err := url.Parse(ts.URL)
    64  	if err != nil {
    65  		t.Fatal("unexpected error:", err)
    66  	}
    67  
    68  	f := NewEmbeddingFuncOllama(model, strings.Replace(defaultBaseURLOllama, "11434", u.Port(), 1))
    69  	res, err := f(context.Background(), prompt)
    70  	if err != nil {
    71  		t.Fatal("expected nil, got", err)
    72  	}
    73  	if slices.Compare(wantRes, res) != 0 {
    74  		t.Fatal("expected res", wantRes, "got", res)
    75  	}
    76  }
    77  

View as plain text