From 76f99ef71885712b2a4a0597cd16c9d256fe32ba Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 4 Mar 2025 15:52:58 -0800 Subject: [PATCH] Fixed retrievers data types --- go/ai/retriever.go | 10 +++++----- go/internal/doc-snippets/pinecone.go | 4 ++-- go/internal/doc-snippets/rag/main.go | 6 +++--- go/plugins/firebase/retriever.go | 4 ++-- go/plugins/firebase/retriever_test.go | 2 +- go/plugins/localvec/localvec.go | 2 +- go/plugins/localvec/localvec_test.go | 12 ++++++------ go/plugins/pinecone/genkit.go | 2 +- go/plugins/weaviate/weaviate.go | 2 +- go/samples/firebase-retrievers/main.go | 2 +- go/samples/pgvector/main.go | 2 +- 11 files changed, 24 insertions(+), 24 deletions(-) diff --git a/go/ai/retriever.go b/go/ai/retriever.go index 018911f6f..69c15e87d 100644 --- a/go/ai/retriever.go +++ b/go/ai/retriever.go @@ -39,15 +39,15 @@ type ( // IndexerRequest is the data we pass to add documents to the database. // The Options field is specific to the actual retriever implementation. type IndexerRequest struct { - Documents []*Document `json:"docs"` + Documents []*Document `json:"documents"` Options any `json:"options,omitempty"` } // RetrieverRequest is the data we pass to retrieve documents from the database. // The Options field is specific to the actual retriever implementation. type RetrieverRequest struct { - Document *Document `json:"content"` - Options any `json:"options,omitempty"` + Query *Document `json:"query"` + Options any `json:"options,omitempty"` } // RetrieverResponse is the response to a document lookup. @@ -115,7 +115,7 @@ type RetrieveOption func(req *RetrieverRequest) error // WithRetrieverText adds a simple text as document to RetrieveRequest. func WithRetrieverText(text string) RetrieveOption { return func(req *RetrieverRequest) error { - req.Document = DocumentFromText(text, nil) + req.Query = DocumentFromText(text, nil) return nil } } @@ -123,7 +123,7 @@ func WithRetrieverText(text string) RetrieveOption { // WithRetrieverDoc adds a document to RetrieveRequest. func WithRetrieverDoc(doc *Document) RetrieveOption { return func(req *RetrieverRequest) error { - req.Document = doc + req.Query = doc return nil } } diff --git a/go/internal/doc-snippets/pinecone.go b/go/internal/doc-snippets/pinecone.go index 306c10761..a73554193 100644 --- a/go/internal/doc-snippets/pinecone.go +++ b/go/internal/doc-snippets/pinecone.go @@ -67,8 +67,8 @@ func pineconeEx(ctx context.Context) error { // [START retrieve] resp, err := menuRetriever.Retrieve(ctx, &ai.RetrieverRequest{ - Document: ai.DocumentFromText(userInput, nil), - Options: nil, + Query: ai.DocumentFromText(userInput, nil), + Options: nil, }) if err != nil { return err diff --git a/go/internal/doc-snippets/rag/main.go b/go/internal/doc-snippets/rag/main.go index bd9934f80..37e209e55 100644 --- a/go/internal/doc-snippets/rag/main.go +++ b/go/internal/doc-snippets/rag/main.go @@ -159,7 +159,7 @@ func menuQA() { func(ctx context.Context, question string) (string, error) { // Retrieve text relevant to the user's question. docs, err := menuPdfRetriever.Retrieve(ctx, &ai.RetrieverRequest{ - Document: ai.DocumentFromText(question, nil), + Query: ai.DocumentFromText(question, nil), }) if err != nil { return "", err @@ -225,8 +225,8 @@ func customret() { // Call the retriever as in the simple case. response, err := menuPDFRetriever.Retrieve(ctx, &ai.RetrieverRequest{ - Document: req.Document, - Options: localvec.RetrieverOptions{K: opts.PreRerankK}, + Query: req.Query, + Options: localvec.RetrieverOptions{K: opts.PreRerankK}, }) if err != nil { return nil, err diff --git a/go/plugins/firebase/retriever.go b/go/plugins/firebase/retriever.go index 9741f3352..4e2ce4032 100644 --- a/go/plugins/firebase/retriever.go +++ b/go/plugins/firebase/retriever.go @@ -54,12 +54,12 @@ func DefineFirestoreRetriever(g *genkit.Genkit, cfg RetrieverOptions) (ai.Retrie } Retrieve := func(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) { - if req.Document == nil { + if req.Query == nil { return nil, fmt.Errorf("DefineFirestoreRetriever: Request document is nil") } // Generate query embedding using the Embedder - embedRequest := &ai.EmbedRequest{Documents: []*ai.Document{req.Document}} + embedRequest := &ai.EmbedRequest{Documents: []*ai.Document{req.Query}} embedResponse, err := cfg.Embedder.Embed(ctx, embedRequest) if err != nil { return nil, fmt.Errorf("DefineFirestoreRetriever: Embedding failed: %v", err) diff --git a/go/plugins/firebase/retriever_test.go b/go/plugins/firebase/retriever_test.go index 4f5e6d8d7..fdad44dd8 100644 --- a/go/plugins/firebase/retriever_test.go +++ b/go/plugins/firebase/retriever_test.go @@ -244,7 +244,7 @@ func TestFirestoreRetriever(t *testing.T) { inputDocument := ai.DocumentFromText(queryText, nil) req := &ai.RetrieverRequest{ - Document: inputDocument, + Query: inputDocument, } // Perform the retrieval diff --git a/go/plugins/localvec/localvec.go b/go/plugins/localvec/localvec.go index 14f947894..5ed379ea0 100644 --- a/go/plugins/localvec/localvec.go +++ b/go/plugins/localvec/localvec.go @@ -181,7 +181,7 @@ func (ds *docStore) retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai // Use the embedder to convert the document we want to // retrieve into a vector. ereq := &ai.EmbedRequest{ - Documents: []*ai.Document{req.Document}, + Documents: []*ai.Document{req.Query}, Options: ds.embedderOptions, } eres, err := ds.embedder.Embed(ctx, ereq) diff --git a/go/plugins/localvec/localvec_test.go b/go/plugins/localvec/localvec_test.go index afbdf5836..bf2054df0 100644 --- a/go/plugins/localvec/localvec_test.go +++ b/go/plugins/localvec/localvec_test.go @@ -64,8 +64,8 @@ func TestLocalVec(t *testing.T) { } retrieverReq := &ai.RetrieverRequest{ - Document: d1, - Options: retrieverOptions, + Query: d1, + Options: retrieverOptions, } retrieverResp, err := ds.retrieve(ctx, retrieverReq) if err != nil { @@ -132,8 +132,8 @@ func TestPersistentIndexing(t *testing.T) { } retrieverReq := &ai.RetrieverRequest{ - Document: d1, - Options: retrieverOptions, + Query: d1, + Options: retrieverOptions, } retrieverResp, err := ds.retrieve(ctx, retrieverReq) if err != nil { @@ -163,8 +163,8 @@ func TestPersistentIndexing(t *testing.T) { } retrieverReq = &ai.RetrieverRequest{ - Document: d1, - Options: retrieverOptions, + Query: d1, + Options: retrieverOptions, } retrieverResp, err = dsAnother.retrieve(ctx, retrieverReq) if err != nil { diff --git a/go/plugins/pinecone/genkit.go b/go/plugins/pinecone/genkit.go index a878131d5..456ab9679 100644 --- a/go/plugins/pinecone/genkit.go +++ b/go/plugins/pinecone/genkit.go @@ -282,7 +282,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai // Use the embedder to convert the document we want to // retrieve into a vector. ereq := &ai.EmbedRequest{ - Documents: []*ai.Document{req.Document}, + Documents: []*ai.Document{req.Query}, Options: ds.embedderOptions, } eres, err := ds.embedder.Embed(ctx, ereq) diff --git a/go/plugins/weaviate/weaviate.go b/go/plugins/weaviate/weaviate.go index 52cad76ac..2f24a5364 100644 --- a/go/plugins/weaviate/weaviate.go +++ b/go/plugins/weaviate/weaviate.go @@ -274,7 +274,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai // Use the embedder to convert the document to a vector. ereq := &ai.EmbedRequest{ - Documents: []*ai.Document{req.Document}, + Documents: []*ai.Document{req.Query}, Options: ds.embedderOptions, } eres, err := ds.embedder.Embed(ctx, ereq) diff --git a/go/samples/firebase-retrievers/main.go b/go/samples/firebase-retrievers/main.go index df046cfab..8f3462bc3 100644 --- a/go/samples/firebase-retrievers/main.go +++ b/go/samples/firebase-retrievers/main.go @@ -131,7 +131,7 @@ func main() { genkit.DefineFlow(g, "flow-retrieve-documents", func(ctx context.Context, query string) (string, error) { // Perform Firestore retrieval based on user input req := &ai.RetrieverRequest{ - Document: ai.DocumentFromText(query, nil), + Query: ai.DocumentFromText(query, nil), } log.Println("Starting retrieval with query:", query) resp, err := retriever.Retrieve(ctx, req) diff --git a/go/samples/pgvector/main.go b/go/samples/pgvector/main.go index d4f34d8d7..8ce43f6ae 100644 --- a/go/samples/pgvector/main.go +++ b/go/samples/pgvector/main.go @@ -110,7 +110,7 @@ const provider = "pgvector" // [START retr] func defineRetriever(g *genkit.Genkit, db *sql.DB, embedder ai.Embedder) ai.Retriever { f := func(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) { - eres, err := ai.Embed(ctx, embedder, ai.WithEmbedDocs(req.Document)) + eres, err := ai.Embed(ctx, embedder, ai.WithEmbedDocs(req.Query)) if err != nil { return nil, err }