Skip to content

Commit

Permalink
feat(sdk-js): choosing model in generate
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Jun 18, 2023
1 parent a717dbd commit 1bd5aa3
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 44 deletions.
1 change: 1 addition & 0 deletions sdk/embedbase-js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ type Chat = {
export interface GenerateOptions {
history: Chat[]
url?: string
model?: 'gpt-3.5-turbo-16k' | 'falcon'
}

export interface RangeOptions {
Expand Down
68 changes: 46 additions & 22 deletions sdk/embedbase-py/embedbase_client/async_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import asyncio
import itertools
Expand Down Expand Up @@ -579,12 +579,13 @@ def batch_chunks(l, n):

async def create_max_context(
self,
dataset: str,
dataset: Union[str, List[str]],
query: str,
max_tokens: int,
max_tokens: Union[int, List[int]],
) -> str:
"""
Create a context from a query by searching for similar documents and concatenating them up to the specified max tokens.
Create a context from a query by searching for similar documents and
concatenating them up to the specified max tokens.
Args:
dataset: The name of the dataset to search.
Expand All @@ -595,32 +596,55 @@ async def create_max_context(
A string containing the context.
Example usage:
context = create_max_context("Python is a programming language.", max_tokens=30)
context = await create_max_context("programming", "Python is a programming language.", 30)
print(context)
# Python is a programming language.
# Python is a high-level, general-purpose programming language.
# Python is interpreted, dynamically typed and garbage-collected.
# Python is designed to be highly extensible.
# Python is a multi-paradig
# or
context = await create_max_context(["programming", "science"], "Python lives planet earth.", [3, 30])
print(context)
# Pyt
# The earth orbits the sun.
# The earth is the third planet from the sun.
# The earth is the only planet known to support life.
# The earth formed approximately 4.5 billion years ago.
# The earth's gravity interacts with other objects in space, especially the sun and the moon.
"""

# try to build a context until it's big enough by incrementing top_k
top_k = 100
context = await self.create_context(dataset, query, top_k)
merged_context, size = merge_and_return_tokens(context, max_tokens)

tries = 0
max_tries = 3
while size < max_tokens and tries < max_tries:
top_k *= 3
context = await self.create_context(dataset, query, top_k)
async def create_context_for_dataset(d, max_tokens):
top_k = 100
context = await self.create_context(d, query, top_k)
merged_context, size = merge_and_return_tokens(context, max_tokens)
tries += 1

if size < max_tokens:
# warn the user that the context is smaller than the max tokens
print(
f"Warning: context is smaller than the max tokens ({size} < {max_tokens})"
)
tries = 0
max_tries = 3
while size < max_tokens and tries < max_tries:
top_k *= 3
context = await self.create_context(dataset, query, top_k)
merged_context, size = merge_and_return_tokens(context, max_tokens)
tries += 1

if size < max_tokens:
print(
f"Warning: context for dataset '{dataset}' is smaller than the max tokens ({size} < {max_tokens})"
)
return merged_context

if not isinstance(dataset, list):
dataset = [dataset]

if not isinstance(max_tokens, list):
max_tokens = [max_tokens for _ in range(len(dataset))]

if len(dataset) != len(max_tokens):
raise ValueError("The number of datasets and max_tokens should be equal.")

contexts = []
for ds, mt in zip(dataset, max_tokens):
context = await create_context_for_dataset(ds, mt)
contexts.append(context)

return merged_context
return "\n\n".join(contexts)
68 changes: 46 additions & 22 deletions sdk/embedbase-py/embedbase_client/sync_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Generator, List, Optional
from typing import Any, Dict, Generator, List, Optional, Union

import itertools
import json
Expand Down Expand Up @@ -600,12 +600,13 @@ def add_batch(batch):

def create_max_context(
self,
dataset: str,
dataset: Union[str, List[str]],
query: str,
max_tokens: int,
max_tokens: Union[int, List[int]],
) -> str:
"""
Create a context from a query by searching for similar documents and concatenating them up to the specified max tokens.
Create a context from a query by searching for similar documents and
concatenating them up to the specified max tokens.
Args:
dataset: The name of the dataset to search.
Expand All @@ -616,32 +617,55 @@ def create_max_context(
A string containing the context.
Example usage:
context = embedbase.create_max_context("my_dataset", "What is Python?", max_tokens=30)
context = create_max_context("programming", "Python is a programming language.", 30)
print(context)
# Python is a programming language.
# Python is a high-level, general-purpose programming language.
# Python is interpreted, dynamically typed and garbage-collected.
# Python is designed to be highly extensible.
# Python is a multi-paradig...
# Python is a multi-paradig
# or
context = create_max_context(["programming", "science"], "Python lives planet earth.", [3, 30])
print(context)
# Pyt
# The earth orbits the sun.
# The earth is the third planet from the sun.
# The earth is the only planet known to support life.
# The earth formed approximately 4.5 billion years ago.
# The earth's gravity interacts with other objects in space, especially the sun and the moon.
"""

# try to build a context until it's big enough by incrementing top_k
top_k = 100
context = self.create_context(dataset, query, top_k)
merged_context, size = merge_and_return_tokens(context, max_tokens)

tries = 0
max_tries = 3
while size < max_tokens and tries < max_tries:
top_k *= 3
def create_context_for_dataset(dataset, max_tokens):
top_k = 100
context = self.create_context(dataset, query, top_k)
merged_context, size = merge_and_return_tokens(context, max_tokens)
tries += 1

if size < max_tokens:
# warn the user that the context is smaller than the max tokens
print(
f"Warning: context is smaller than the max tokens ({size} < {max_tokens})"
)
tries = 0
max_tries = 3
while size < max_tokens and tries < max_tries:
top_k *= 3
context = self.create_context(dataset, query, top_k)
merged_context, size = merge_and_return_tokens(context, max_tokens)
tries += 1

if size < max_tokens:
print(
f"Warning: context for dataset '{dataset}' is smaller than the max tokens ({size} < {max_tokens})"
)
return merged_context

if not isinstance(dataset, list):
dataset = [dataset]

if not isinstance(max_tokens, list):
max_tokens = [max_tokens for _ in range(len(dataset))]

if len(dataset) != len(max_tokens):
raise ValueError("The number of datasets and max_tokens should be equal.")

contexts = []
for ds, mt in zip(dataset, max_tokens):
context = create_context_for_dataset(ds, mt)
contexts.append(context)

return merged_context
return "\n\n".join(contexts)
48 changes: 48 additions & 0 deletions sdk/embedbase-py/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,51 @@ async def test_create_max_context_async():

assert isinstance(context, str)
assert len(tokenizer.encode(context)) <= max_tokens


@pytest.mark.asyncio
async def test_create_max_context_multiple_datasets_async():
query = "What is Python?"
dataset1 = "programming"
dataset2 = "animals"
max_tokens1 = 20
max_tokens2 = 25
await client.dataset(dataset1).clear()
await client.dataset(dataset2).clear()
programming_documents = [
"Python is a programming language.",
"Java is another popular programming language.",
"JavaScript is widely used for web development.",
"C++ is commonly used for system programming.",
"Ruby is known for its simplicity and readability.",
"Go is a statically typed language developed by Google.",
"Rust is a systems programming language that focuses on safety and performance.",
"TypeScript is a superset of JavaScript that adds static typing.",
"PHP is a server-side scripting language used for web development.",
"Swift is a modern programming language developed by Apple for iOS app development.",
]
animal_documents = [
"Python is a type of snake.",
"Lions are known as the king of the jungle.",
"Elephants are the largest land animals.",
"Giraffes are known for their long necks.",
"Kangaroos are native to Australia.",
"Pandas are native to China and primarily eat bamboo.",
"Penguins live primarily in the Southern Hemisphere.",
"Tigers are carnivorous mammals found in Asia.",
"Whales are large marine mammals.",
"Zebras are part of the horse family and native to Africa.",
]

await client.dataset(dataset1).batch_add([{"data": d} for d in programming_documents])
await client.dataset(dataset2).batch_add([{"data": d} for d in animal_documents])
context = await client.create_max_context(
[dataset1, dataset2], query, [max_tokens1, max_tokens2]
)
tokenizer = get_encoding("cl100k_base")

assert isinstance(context, str)
context_parts = context.split("\n")
assert len(context_parts) == 2
assert len(tokenizer.encode(context_parts[0])) <= max_tokens1
assert len(tokenizer.encode(context_parts[1])) <= max_tokens2

2 comments on commit 1bd5aa3

@vercel
Copy link

@vercel vercel bot commented on 1bd5aa3 Jun 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

embedbase – ./dashboard

embedbase-prologe.vercel.app
embedbase-git-main-prologe.vercel.app
app.embedbase.xyz
embedbase.vercel.app

@vercel
Copy link

@vercel vercel bot commented on 1bd5aa3 Jun 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.