Skip to content

Commit

Permalink
fix(go): Fixed retrievers data types (firebase#2242)
Browse files Browse the repository at this point in the history
  • Loading branch information
apascal07 authored and MarkTechson committed Mar 5, 2025
1 parent 7041edb commit 6d3be0d
Show file tree
Hide file tree
Showing 11 changed files with 24 additions and 24 deletions.
10 changes: 5 additions & 5 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ type (
// IndexerRequest is the data we pass to add documents to the database.
// The Options field is specific to the actual retriever implementation.
type IndexerRequest struct {
Documents []*Document `json:"docs"`
Documents []*Document `json:"documents"`
Options any `json:"options,omitempty"`
}

// RetrieverRequest is the data we pass to retrieve documents from the database.
// The Options field is specific to the actual retriever implementation.
type RetrieverRequest struct {
Document *Document `json:"content"`
Options any `json:"options,omitempty"`
Query *Document `json:"query"`
Options any `json:"options,omitempty"`
}

// RetrieverResponse is the response to a document lookup.
Expand Down Expand Up @@ -115,15 +115,15 @@ type RetrieveOption func(req *RetrieverRequest) error
// WithRetrieverText adds a simple text as document to RetrieveRequest.
func WithRetrieverText(text string) RetrieveOption {
return func(req *RetrieverRequest) error {
req.Document = DocumentFromText(text, nil)
req.Query = DocumentFromText(text, nil)
return nil
}
}

// WithRetrieverDoc adds a document to RetrieveRequest.
func WithRetrieverDoc(doc *Document) RetrieveOption {
return func(req *RetrieverRequest) error {
req.Document = doc
req.Query = doc
return nil
}
}
Expand Down
4 changes: 2 additions & 2 deletions go/internal/doc-snippets/pinecone.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ func pineconeEx(ctx context.Context) error {

// [START retrieve]
resp, err := menuRetriever.Retrieve(ctx, &ai.RetrieverRequest{
Document: ai.DocumentFromText(userInput, nil),
Options: nil,
Query: ai.DocumentFromText(userInput, nil),
Options: nil,
})
if err != nil {
return err
Expand Down
6 changes: 3 additions & 3 deletions go/internal/doc-snippets/rag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func menuQA() {
func(ctx context.Context, question string) (string, error) {
// Retrieve text relevant to the user's question.
docs, err := menuPdfRetriever.Retrieve(ctx, &ai.RetrieverRequest{
Document: ai.DocumentFromText(question, nil),
Query: ai.DocumentFromText(question, nil),
})
if err != nil {
return "", err
Expand Down Expand Up @@ -225,8 +225,8 @@ func customret() {

// Call the retriever as in the simple case.
response, err := menuPDFRetriever.Retrieve(ctx, &ai.RetrieverRequest{
Document: req.Document,
Options: localvec.RetrieverOptions{K: opts.PreRerankK},
Query: req.Query,
Options: localvec.RetrieverOptions{K: opts.PreRerankK},
})
if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions go/plugins/firebase/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ func DefineFirestoreRetriever(g *genkit.Genkit, cfg RetrieverOptions) (ai.Retrie
}

Retrieve := func(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
if req.Document == nil {
if req.Query == nil {
return nil, fmt.Errorf("DefineFirestoreRetriever: Request document is nil")
}

// Generate query embedding using the Embedder
embedRequest := &ai.EmbedRequest{Documents: []*ai.Document{req.Document}}
embedRequest := &ai.EmbedRequest{Documents: []*ai.Document{req.Query}}
embedResponse, err := cfg.Embedder.Embed(ctx, embedRequest)
if err != nil {
return nil, fmt.Errorf("DefineFirestoreRetriever: Embedding failed: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/firebase/retriever_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func TestFirestoreRetriever(t *testing.T) {
inputDocument := ai.DocumentFromText(queryText, nil)

req := &ai.RetrieverRequest{
Document: inputDocument,
Query: inputDocument,
}

// Perform the retrieval
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/localvec/localvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (ds *docStore) retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
// Use the embedder to convert the document we want to
// retrieve into a vector.
ereq := &ai.EmbedRequest{
Documents: []*ai.Document{req.Document},
Documents: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
Expand Down
12 changes: 6 additions & 6 deletions go/plugins/localvec/localvec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ func TestLocalVec(t *testing.T) {
}

retrieverReq := &ai.RetrieverRequest{
Document: d1,
Options: retrieverOptions,
Query: d1,
Options: retrieverOptions,
}
retrieverResp, err := ds.retrieve(ctx, retrieverReq)
if err != nil {
Expand Down Expand Up @@ -132,8 +132,8 @@ func TestPersistentIndexing(t *testing.T) {
}

retrieverReq := &ai.RetrieverRequest{
Document: d1,
Options: retrieverOptions,
Query: d1,
Options: retrieverOptions,
}
retrieverResp, err := ds.retrieve(ctx, retrieverReq)
if err != nil {
Expand Down Expand Up @@ -163,8 +163,8 @@ func TestPersistentIndexing(t *testing.T) {
}

retrieverReq = &ai.RetrieverRequest{
Document: d1,
Options: retrieverOptions,
Query: d1,
Options: retrieverOptions,
}
retrieverResp, err = dsAnother.retrieve(ctx, retrieverReq)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/pinecone/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
// Use the embedder to convert the document we want to
// retrieve into a vector.
ereq := &ai.EmbedRequest{
Documents: []*ai.Document{req.Document},
Documents: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/weaviate/weaviate.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai

// Use the embedder to convert the document to a vector.
ereq := &ai.EmbedRequest{
Documents: []*ai.Document{req.Document},
Documents: []*ai.Document{req.Query},
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
Expand Down
2 changes: 1 addition & 1 deletion go/samples/firebase-retrievers/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func main() {
genkit.DefineFlow(g, "flow-retrieve-documents", func(ctx context.Context, query string) (string, error) {
// Perform Firestore retrieval based on user input
req := &ai.RetrieverRequest{
Document: ai.DocumentFromText(query, nil),
Query: ai.DocumentFromText(query, nil),
}
log.Println("Starting retrieval with query:", query)
resp, err := retriever.Retrieve(ctx, req)
Expand Down
2 changes: 1 addition & 1 deletion go/samples/pgvector/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ const provider = "pgvector"
// [START retr]
func defineRetriever(g *genkit.Genkit, db *sql.DB, embedder ai.Embedder) ai.Retriever {
f := func(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
eres, err := ai.Embed(ctx, embedder, ai.WithEmbedDocs(req.Document))
eres, err := ai.Embed(ctx, embedder, ai.WithEmbedDocs(req.Query))
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 6d3be0d

Please sign in to comment.