...

Source file src/gitlab.hexacode.org/go-libs/chromem-go/embed_ollama.go

Documentation: gitlab.hexacode.org/go-libs/chromem-go

     1  package chromem
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"sync"
    12  )
    13  
    14  const defaultBaseURLOllama = "http://localhost:11434/api"
    15  
    16  type ollamaResponse struct {
    17  	Embedding []float32 `json:"embedding"`
    18  }
    19  
    20  // NewEmbeddingFuncOllama returns a function that creates embeddings for a text
    21  // using Ollama's embedding API. You can pass any model that Ollama supports and
    22  // that supports embeddings. A good one as of 2024-03-02 is "nomic-embed-text".
    23  // See https://ollama.com/library/nomic-embed-text
    24  // baseURLOllama is the base URL of the Ollama API. If it's empty,
    25  // "http://localhost:11434/api" is used.
    26  func NewEmbeddingFuncOllama(model string, baseURLOllama string) EmbeddingFunc {
    27  	if baseURLOllama == "" {
    28  		baseURLOllama = defaultBaseURLOllama
    29  	}
    30  
    31  	// We don't set a default timeout here, although it's usually a good idea.
    32  	// In our case though, the library user can set the timeout on the context,
    33  	// and it might have to be a long timeout, depending on the text length.
    34  	client := &http.Client{}
    35  
    36  	var checkedNormalized bool
    37  	checkNormalized := sync.Once{}
    38  
    39  	return func(ctx context.Context, text string) ([]float32, error) {
    40  		// Prepare the request body.
    41  		reqBody, err := json.Marshal(map[string]string{
    42  			"model":  model,
    43  			"prompt": text,
    44  		})
    45  		if err != nil {
    46  			return nil, fmt.Errorf("couldn't marshal request body: %w", err)
    47  		}
    48  
    49  		// Create the request. Creating it with context is important for a timeout
    50  		// to be possible, because the client is configured without a timeout.
    51  		req, err := http.NewRequestWithContext(ctx, "POST", baseURLOllama+"/embeddings", bytes.NewBuffer(reqBody))
    52  		if err != nil {
    53  			return nil, fmt.Errorf("couldn't create request: %w", err)
    54  		}
    55  		req.Header.Set("Content-Type", "application/json")
    56  
    57  		// Send the request.
    58  		resp, err := client.Do(req)
    59  		if err != nil {
    60  			return nil, fmt.Errorf("couldn't send request: %w", err)
    61  		}
    62  		defer resp.Body.Close()
    63  
    64  		// Check the response status.
    65  		if resp.StatusCode != http.StatusOK {
    66  			return nil, errors.New("error response from the embedding API: " + resp.Status)
    67  		}
    68  
    69  		// Read and decode the response body.
    70  		body, err := io.ReadAll(resp.Body)
    71  		if err != nil {
    72  			return nil, fmt.Errorf("couldn't read response body: %w", err)
    73  		}
    74  		var embeddingResponse ollamaResponse
    75  		err = json.Unmarshal(body, &embeddingResponse)
    76  		if err != nil {
    77  			return nil, fmt.Errorf("couldn't unmarshal response body: %w", err)
    78  		}
    79  
    80  		// Check if the response contains embeddings.
    81  		if len(embeddingResponse.Embedding) == 0 {
    82  			return nil, errors.New("no embeddings found in the response")
    83  		}
    84  
    85  		v := embeddingResponse.Embedding
    86  		checkNormalized.Do(func() {
    87  			if isNormalized(v) {
    88  				checkedNormalized = true
    89  			} else {
    90  				checkedNormalized = false
    91  			}
    92  		})
    93  		if !checkedNormalized {
    94  			v = normalizeVector(v)
    95  		}
    96  
    97  		return v, nil
    98  	}
    99  }
   100  

View as plain text