Skip to content

Commit

Permalink
feat(js/plugins/firebase): Expose vector distance and distance thresh…
Browse files Browse the repository at this point in the history
…old options in Firestore vector store retriever (#2246)
  • Loading branch information
ifielker authored Mar 5, 2025
1 parent a842e5e commit b3cd325
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 39 deletions.
2 changes: 1 addition & 1 deletion js/plugins/firebase/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"@genkit-ai/google-cloud": "workspace:^"
},
"peerDependencies": {
"@google-cloud/firestore": "^7.6.0",
"@google-cloud/firestore": "^7.11.0",
"firebase-admin": ">=12.2",
"genkit": "workspace:^"
},
Expand Down
40 changes: 37 additions & 3 deletions js/plugins/firebase/src/firestoreRetriever.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ export function defineFirestoreRetriever(
contentField: string | ((snap: QueryDocumentSnapshot) => Part[]);
/** The distance measure to use when comparing vectors. Defaults to 'COSINE'. */
distanceMeasure?: 'EUCLIDEAN' | 'COSINE' | 'DOT_PRODUCT';
/**
* Specifies a threshold for which no less similar documents will be returned. The behavior
* of the specified `distanceMeasure` will affect the meaning of the distance threshold.
*
* - For `distanceMeasure: "EUCLIDEAN"`, the meaning of `distanceThreshold` is:
* SELECT docs WHERE euclidean_distance <= distanceThreshold
* - For `distanceMeasure: "COSINE"`, the meaning of `distanceThreshold` is:
* SELECT docs WHERE cosine_distance <= distanceThreshold
* - For `distanceMeasure: "DOT_PRODUCT"`, the meaning of `distanceThreshold` is:
* SELECT docs WHERE dot_product_distance >= distanceThreshold
*/
distanceThreshold?: number;
/**
* Optionally specifies the name of a metadata field that will be set on each returned Document,
* which will contain the computed distance for the document.
*/
distanceResultField?: string;
/**
* A list of fields to include in the returned document metadata. If not supplied, all fields other
* than the vector are included. Alternatively, provide a transform function to extract the desired
Expand All @@ -108,6 +125,8 @@ export function defineFirestoreRetriever(
metadataFields,
contentField,
distanceMeasure,
distanceThreshold,
distanceResultField,
} = config;
return ai.defineRetriever(
{
Expand All @@ -118,6 +137,14 @@ export function defineFirestoreRetriever(
configSchema: z.object({
where: z.record(z.any()).optional(),
limit: z.number(),
/* Supply or override the distanceMeasure */
distanceMeasure: z
.enum(['COSINE', 'DOT_PRODUCT', 'EUCLIDEAN'])
.optional(),
/* Supply or override the distanceThreshold */
distanceThreshold: z.number().optional(),
/* Supply or override the metadata field where distances are stored. */
distanceResultField: z.string().optional(),
/* Supply or override the collection for retrieval. */
collection: z.string().optional(),
}),
Expand All @@ -135,11 +162,18 @@ export function defineFirestoreRetriever(
query = query.where(field, '==', options.where![field]);
}
// Single embedding for text input
const embedding = (await ai.embed({ embedder, content }))[0].embedding;
const queryVector = (await ai.embed({ embedder, content }))[0].embedding;

const result = await query
.findNearest(vectorField, embedding, {
.findNearest({
vectorField,
queryVector,
limit: options.limit || 10,
distanceMeasure: distanceMeasure || 'COSINE',
distanceMeasure:
options.distanceMeasure || distanceMeasure || 'COSINE',
distanceResultField:
options.distanceResultField || distanceResultField,
distanceThreshold: options.distanceThreshold || distanceThreshold,
})
.get();

Expand Down
42 changes: 14 additions & 28 deletions js/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

79 changes: 72 additions & 7 deletions js/testapps/rag/src/pdf-rag-firebase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,59 @@ export const pdfChatRetrieverFirebase = defineFirestoreRetriever(ai, {
contentField: 'facts',
vectorField: 'embedding',
embedder: textEmbedding004,
distanceMeasure: 'COSINE',
//distanceMeasure: 'COSINE', // optional
//distanceResultField: 'vector_distance', // optional
//distanceThreshold: 0.8, // optional
});

// Define a simple RAG flow, we will evaluate this flow
export const pdfQAFirebase = ai.defineFlow(
{
name: 'pdfQAFirebase',
inputSchema: z.string(),
outputSchema: z.string(),
inputSchema: z.object({
distanceMeasure: z
.string()
.describe("One of 'COSINE', 'DOT_PRODUCT', 'EUCLIDEAN'")
.default('COSINE')
.optional(),
distanceThreshold: z
.number()
.describe(
'The numeric distance threshold. The significance depends on distanceMeasure'
)
.default(0.56)
.optional(),
distanceResultField: z
.string()
.describe('The name of the metadata field that stores distance results')
.default('vector_distance')
.optional(),
query: z
.string()
.describe('Ask questions about the pdf')
.default('Summarize the pdf'),
}),
outputSchema: z.object({
documentCount: z.string(),
distances: z.string(),
response: z.string(),
}),
},
async (query) => {
async ({
distanceMeasure,
distanceThreshold,
distanceResultField,
query,
}) => {
const docs = await ai.retrieve({
retriever: pdfChatRetrieverFirebase,
query,
options: { limit: 3 },
options: {
limit: 10,
distanceMeasure,
distanceThreshold,
distanceResultField,
},
});
console.log(docs);

Expand All @@ -90,7 +128,31 @@ export const pdfQAFirebase = ai.defineFlow(
model: gemini15Flash,
prompt: augmentedPrompt,
});
return llmResponse.text;

let distances: Array<number> = [];
let maxDistance = NaN;
let minDistance = NaN;
if (distanceResultField) {
// Note: if you change the default distanceResultField by setting it in
// defineFirestoreRetriever, then you need to change this code to look
// for that field as well i.e. distanceResultField || <default you set>
distances = docs
.map((d) => {
if (d.metadata && d.metadata[distanceResultField]) {
return d.metadata[distanceResultField];
}
return undefined;
})
.filter((n) => n !== undefined);
maxDistance = Math.max(...distances);
minDistance = Math.min(...distances);
}

return {
documentCount: `${docs.length} of 10`,
distances: `min: ${minDistance}, max: ${maxDistance}`,
response: llmResponse.text,
};
}
);

Expand All @@ -115,7 +177,10 @@ const chunkingConfig = {
export const indexPdfFirebase = ai.defineFlow(
{
name: 'indexPdfFirestore',
inputSchema: z.string().describe('PDF file path'),
inputSchema: z
.string()
.describe('PDF file path')
.default('./docs/flume-java.pdf'),
},
async (filePath) => {
filePath = path.resolve(filePath);
Expand Down

0 comments on commit b3cd325

Please sign in to comment.