...

Source file src/gitlab.hexacode.org/go-libs/chromem-go/embed_cohere.go

Documentation: gitlab.hexacode.org/go-libs/chromem-go

     1  package chromem
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"strings"
    12  	"sync"
    13  )
    14  
    15  type EmbeddingModelCohere string
    16  
    17  const (
    18  	EmbeddingModelCohereMultilingualV2      EmbeddingModelCohere = "embed-multilingual-v2.0"
    19  	EmbeddingModelCohereEnglishLightV2      EmbeddingModelCohere = "embed-english-light-v2.0"
    20  	EmbeddingModelCohereEnglishV2           EmbeddingModelCohere = "embed-english-v2.0"
    21  	EmbeddingModelCohereMultilingualLightV3 EmbeddingModelCohere = "embed-multilingual-light-v3.0"
    22  	EmbeddingModelCohereEnglishLightV3      EmbeddingModelCohere = "embed-english-light-v3.0"
    23  	EmbeddingModelCohereMultilingualV3      EmbeddingModelCohere = "embed-multilingual-v3.0"
    24  	EmbeddingModelCohereEnglishV3           EmbeddingModelCohere = "embed-english-v3.0"
    25  )
    26  
    27  // Prefixes for external use.
    28  const (
    29  	InputTypeCohereSearchDocumentPrefix string = "search_document: "
    30  	InputTypeCohereSearchQueryPrefix    string = "search_query: "
    31  	InputTypeCohereClassificationPrefix string = "classification: "
    32  	InputTypeCohereClusteringPrefix     string = "clustering: "
    33  )
    34  
    35  // Input types for internal use.
    36  const (
    37  	inputTypeCohereSearchDocument string = "search_document"
    38  	inputTypeCohereSearchQuery    string = "search_query"
    39  	inputTypeCohereClassification string = "classification"
    40  	inputTypeCohereClustering     string = "clustering"
    41  )
    42  
    43  const baseURLCohere = "https://api.cohere.ai/v1"
    44  
    45  var validInputTypesCohere = map[string]string{
    46  	inputTypeCohereSearchDocument: InputTypeCohereSearchDocumentPrefix,
    47  	inputTypeCohereSearchQuery:    InputTypeCohereSearchQueryPrefix,
    48  	inputTypeCohereClassification: InputTypeCohereClassificationPrefix,
    49  	inputTypeCohereClustering:     InputTypeCohereClusteringPrefix,
    50  }
    51  
    52  type cohereResponse struct {
    53  	Embeddings [][]float32 `json:"embeddings"`
    54  }
    55  
    56  // NewEmbeddingFuncCohere returns a function that creates embeddings for a text
    57  // using Cohere's API. One important difference to OpenAI's and other's APIs is
    58  // that Cohere differentiates between document embeddings and search/query embeddings.
    59  // In order for this embedding func to do the differentiation, you have to prepend
    60  // the text with either "search_document" or "search_query". We'll cut off that
    61  // prefix before sending the document/query body to the API, we'll just use the
    62  // prefix to choose the right "input type" as they call it.
    63  //
    64  // When you set up a chromem-go collection with this embedding function, you might
    65  // want to create the document separately with [NewDocument] and then cut off the
    66  // prefix before adding the document to the collection. Otherwise, when you query
    67  // the collection, the returned documents will still have the prefix in their content.
    68  //
    69  //	cohereFunc := chromem.NewEmbeddingFuncCohere(cohereApiKey, chromem.EmbeddingModelCohereEnglishV3)
    70  //	content := "The sky is blue because of Rayleigh scattering."
    71  //	// Create the document with the prefix.
    72  //	contentWithPrefix := chromem.InputTypeCohereSearchDocumentPrefix + content
    73  //	doc, _ := NewDocument(ctx, id, metadata, nil, contentWithPrefix, cohereFunc)
    74  //	// Remove the prefix so that later query results don't have it.
    75  //	doc.Content = content
    76  //	_ = collection.AddDocument(ctx, doc)
    77  //
    78  // This is not necessary if you don't keep the content in the documents, as chromem-go
    79  // also works when documents only have embeddings.
    80  // You can also keep the prefix in the document, and only remove it after querying.
    81  //
    82  // We plan to improve this in the future.
    83  func NewEmbeddingFuncCohere(apiKey string, model EmbeddingModelCohere) EmbeddingFunc {
    84  	// We don't set a default timeout here, although it's usually a good idea.
    85  	// In our case though, the library user can set the timeout on the context,
    86  	// and it might have to be a long timeout, depending on the text length.
    87  	client := &http.Client{}
    88  
    89  	var checkedNormalized bool
    90  	checkNormalized := sync.Once{}
    91  
    92  	return func(ctx context.Context, text string) ([]float32, error) {
    93  		var inputType string
    94  		for validInputType, validInputTypePrefix := range validInputTypesCohere {
    95  			if strings.HasPrefix(text, validInputTypePrefix) {
    96  				inputType = validInputType
    97  				text = strings.TrimPrefix(text, validInputTypePrefix)
    98  				break
    99  			}
   100  		}
   101  		if inputType == "" {
   102  			return nil, errors.New("text must start with a valid input type plus colon and space")
   103  		}
   104  
   105  		// Prepare the request body.
   106  		reqBody, err := json.Marshal(map[string]any{
   107  			"model":      model,
   108  			"texts":      []string{text},
   109  			"input_type": inputType,
   110  		})
   111  		if err != nil {
   112  			return nil, fmt.Errorf("couldn't marshal request body: %w", err)
   113  		}
   114  
   115  		// Create the request. Creating it with context is important for a timeout
   116  		// to be possible, because the client is configured without a timeout.
   117  		req, err := http.NewRequestWithContext(ctx, "POST", baseURLCohere+"/embed", bytes.NewBuffer(reqBody))
   118  		if err != nil {
   119  			return nil, fmt.Errorf("couldn't create request: %w", err)
   120  		}
   121  		req.Header.Set("Accept", "application/json")
   122  		req.Header.Set("Content-Type", "application/json")
   123  		req.Header.Set("Authorization", "Bearer "+apiKey)
   124  
   125  		// Send the request.
   126  		resp, err := client.Do(req)
   127  		if err != nil {
   128  			return nil, fmt.Errorf("couldn't send request: %w", err)
   129  		}
   130  		defer resp.Body.Close()
   131  
   132  		// Check the response status.
   133  		if resp.StatusCode != http.StatusOK {
   134  			return nil, errors.New("error response from the embedding API: " + resp.Status)
   135  		}
   136  
   137  		// Read and decode the response body.
   138  		body, err := io.ReadAll(resp.Body)
   139  		if err != nil {
   140  			return nil, fmt.Errorf("couldn't read response body: %w", err)
   141  		}
   142  		var embeddingResponse cohereResponse
   143  		err = json.Unmarshal(body, &embeddingResponse)
   144  		if err != nil {
   145  			return nil, fmt.Errorf("couldn't unmarshal response body: %w", err)
   146  		}
   147  
   148  		// Check if the response contains embeddings.
   149  		if len(embeddingResponse.Embeddings) == 0 || len(embeddingResponse.Embeddings[0]) == 0 {
   150  			return nil, errors.New("no embeddings found in the response")
   151  		}
   152  
   153  		v := embeddingResponse.Embeddings[0]
   154  		checkNormalized.Do(func() {
   155  			if isNormalized(v) {
   156  				checkedNormalized = true
   157  			} else {
   158  				checkedNormalized = false
   159  			}
   160  		})
   161  		if !checkedNormalized {
   162  			v = normalizeVector(v)
   163  		}
   164  
   165  		return v, nil
   166  	}
   167  }
   168  

View as plain text