Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(go): Fixed retrievers data types (Dev UI mismatch). #2242

Merged
merged 1 commit into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ type (
// IndexerRequest is the data we pass to add documents to the database.
// The Options field is specific to the actual retriever implementation.
type IndexerRequest struct {
Documents []*Document `json:"docs"`
Documents []*Document `json:"documents"`
Options any `json:"options,omitempty"`
}

// RetrieverRequest is the data we pass to retrieve documents from the database.
// The Options field is specific to the actual retriever implementation.
type RetrieverRequest struct {
Document *Document `json:"content"`
Options any `json:"options,omitempty"`
Query *Document `json:"query"`
Options any `json:"options,omitempty"`
}

// RetrieverResponse is the response to a document lookup.
Expand Down Expand Up @@ -115,15 +115,15 @@ type RetrieveOption func(req *RetrieverRequest) error
// WithRetrieverText adds a simple text as document to RetrieveRequest.
func WithRetrieverText(text string) RetrieveOption {
return func(req *RetrieverRequest) error {
req.Document = DocumentFromText(text, nil)
req.Query = DocumentFromText(text, nil)
return nil
}
}

// WithRetrieverDoc adds a document to RetrieveRequest.
func WithRetrieverDoc(doc *Document) RetrieveOption {
return func(req *RetrieverRequest) error {
req.Document = doc
req.Query = doc
return nil
}
}
Expand Down
4 changes: 2 additions & 2 deletions go/internal/doc-snippets/pinecone.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ func pineconeEx(ctx context.Context) error {

// [START retrieve]
resp, err := menuRetriever.Retrieve(ctx, &ai.RetrieverRequest{
Document: ai.DocumentFromText(userInput, nil),
Options: nil,
Query: ai.DocumentFromText(userInput, nil),
Options: nil,
})
if err != nil {
return err
Expand Down
6 changes: 3 additions & 3 deletions go/internal/doc-snippets/rag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func menuQA() {
func(ctx context.Context, question string) (string, error) {
// Retrieve text relevant to the user's question.
docs, err := menuPdfRetriever.Retrieve(ctx, &ai.RetrieverRequest{
Document: ai.DocumentFromText(question, nil),
Query: ai.DocumentFromText(question, nil),
})
if err != nil {
return "", err
Expand Down Expand Up @@ -225,8 +225,8 @@ func customret() {

// Call the retriever as in the simple case.
response, err := menuPDFRetriever.Retrieve(ctx, &ai.RetrieverRequest{
Document: req.Document,
Options: localvec.RetrieverOptions{K: opts.PreRerankK},
Query: req.Query,
Options: localvec.RetrieverOptions{K: opts.PreRerankK},
})
if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions go/plugins/firebase/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ func DefineFirestoreRetriever(g *genkit.Genkit, cfg RetrieverOptions) (ai.Retrie
}

Retrieve := func(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
if req.Document == nil {
if req.Query == nil {
return nil, fmt.Errorf("DefineFirestoreRetriever: Request document is nil")
}

// Generate query embedding using the Embedder
embedRequest := &ai.EmbedRequest{Documents: []*ai.Document{req.Document}}
embedRequest := &ai.EmbedRequest{Documents: []*ai.Document{req.Query}}
embedResponse, err := cfg.Embedder.Embed(ctx, embedRequest)
if err != nil {
return nil, fmt.Errorf("DefineFirestoreRetriever: Embedding failed: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/firebase/retriever_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func TestFirestoreRetriever(t *testing.T) {
inputDocument := ai.DocumentFromText(queryText, nil)

req := &ai.RetrieverRequest{
Document: inputDocument,
Query: inputDocument,
}

// Perform the retrieval
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/localvec/localvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (ds *docStore) retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
// Use the embedder to convert the document we want to
// retrieve into a vector.
ereq := &ai.EmbedRequest{
Documents: []*ai.Document{req.Document},
Documents: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
Expand Down
12 changes: 6 additions & 6 deletions go/plugins/localvec/localvec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ func TestLocalVec(t *testing.T) {
}

retrieverReq := &ai.RetrieverRequest{
Document: d1,
Options: retrieverOptions,
Query: d1,
Options: retrieverOptions,
}
retrieverResp, err := ds.retrieve(ctx, retrieverReq)
if err != nil {
Expand Down Expand Up @@ -132,8 +132,8 @@ func TestPersistentIndexing(t *testing.T) {
}

retrieverReq := &ai.RetrieverRequest{
Document: d1,
Options: retrieverOptions,
Query: d1,
Options: retrieverOptions,
}
retrieverResp, err := ds.retrieve(ctx, retrieverReq)
if err != nil {
Expand Down Expand Up @@ -163,8 +163,8 @@ func TestPersistentIndexing(t *testing.T) {
}

retrieverReq = &ai.RetrieverRequest{
Document: d1,
Options: retrieverOptions,
Query: d1,
Options: retrieverOptions,
}
retrieverResp, err = dsAnother.retrieve(ctx, retrieverReq)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/pinecone/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
// Use the embedder to convert the document we want to
// retrieve into a vector.
ereq := &ai.EmbedRequest{
Documents: []*ai.Document{req.Document},
Documents: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/weaviate/weaviate.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai

// Use the embedder to convert the document to a vector.
ereq := &ai.EmbedRequest{
Documents: []*ai.Document{req.Document},
Documents: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
Expand Down
2 changes: 1 addition & 1 deletion go/samples/firebase-retrievers/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func main() {
genkit.DefineFlow(g, "flow-retrieve-documents", func(ctx context.Context, query string) (string, error) {
// Perform Firestore retrieval based on user input
req := &ai.RetrieverRequest{
Document: ai.DocumentFromText(query, nil),
Query: ai.DocumentFromText(query, nil),
}
log.Println("Starting retrieval with query:", query)
resp, err := retriever.Retrieve(ctx, req)
Expand Down
2 changes: 1 addition & 1 deletion go/samples/pgvector/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ const provider = "pgvector"
// [START retr]
func defineRetriever(g *genkit.Genkit, db *sql.DB, embedder ai.Embedder) ai.Retriever {
f := func(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
eres, err := ai.Embed(ctx, embedder, ai.WithEmbedDocs(req.Document))
eres, err := ai.Embed(ctx, embedder, ai.WithEmbedDocs(req.Query))
if err != nil {
return nil, err
}
Expand Down
Loading