Skip to content

Commit bec4bb1

Browse files
refactor: add core module
1 parent f84d845 commit bec4bb1

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

toyllm/cli/run_gpt2.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,22 @@
22

33
import typer
44

5+
from toyllm.core import GenerationConfig
56
from toyllm.gpt2 import GPTModel, GPTModelSize, GPTTextGenerator
67

78

89
def main(
910
prompt: str = "Alan Turing theorized that computers would one day become",
1011
model_size: GPTModelSize = GPTModelSize.SMALL,
11-
max_gen_tokens: int = 40,
12+
max_new_tokens: int = 40,
1213
) -> None:
1314
gpt_model = GPTModel(model_size).load()
1415
text_generator = GPTTextGenerator(gpt_model=gpt_model)
1516

1617
start_time = time.time()
1718
generate_text = text_generator.generate(
1819
prompt=prompt,
19-
max_gen_tokens=max_gen_tokens,
20+
config=GenerationConfig(max_new_tokens=max_new_tokens),
2021
)
2122
print(generate_text)
2223
end_time = time.time()

toyllm/cli/run_speculative_sampling.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import typer
22
from rich.console import Console
33

4+
from toyllm.core import GenerationConfig
45
from toyllm.gpt2 import GPTModel, GPTModelSize, GPTTextGenerator, gpt2_tokenizer
56
from toyllm.sps import GPTSpsModel, SpsTextGenerator
67
from toyllm.util import Timer
78

89

910
def main(
1011
prompt_text: str = "Alan Turing theorized that computers would one day become",
11-
generate_tokens: int = 256,
12+
max_new_tokens: int = 256,
1213
k: int = 4, # K in sps paper
1314
) -> None:
15+
generate_config = GenerationConfig(max_new_tokens=max_new_tokens)
16+
1417
console = Console()
1518
console.print(f"Prompt: {prompt_text}", style="bold blue")
1619
# Test the speculative sampling
@@ -25,8 +28,7 @@ def main(
2528
with Timer(name="Speculative Sampling"):
2629
generate_text = sps_text_generator.generate(
2730
prompt=prompt_text,
28-
target_seq_len=generate_tokens,
29-
temperature=0,
31+
config=generate_config,
3032
)
3133
console.print(f"Generated: {generate_text[:200]}", style="bold green")
3234
console.print(f"{'-' * 20} Speculative Sampling {'-' * 20}", style="bold blue")
@@ -39,7 +41,7 @@ def main(
3941
with Timer(name="Naive GPT2 Auto-Regressive"):
4042
generate_text = gpt_text_generator.generate(
4143
prompt=prompt_text,
42-
max_gen_tokens=generate_tokens,
44+
config=generate_config,
4345
)
4446
console.print(f"Generated: {generate_text[:200]}", style="bold green")
4547
console.print(f"{'-' * 20} Naive GPT2 Auto-Regressive {'-' * 20}", style="bold blue")

toyllm/gpt2/train.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from torch.utils.data import DataLoader
1212

13+
from toyllm.core import GenerationConfig
1314
from toyllm.device import current_device
1415
from toyllm.gpt2.config import GPTModelSize, GPTTrainingConfig
1516
from toyllm.gpt2.dataset import GPTDataloader
@@ -70,7 +71,14 @@ def evaluate_model(
7071
def generate_and_print_sample(model: GPTModel, tokenizer: tiktoken.Encoding, start_context: str) -> None:
7172
model.eval()
7273
text_generate = GPTTextGenerator(gpt_model=model, tokenizer=tokenizer)
73-
generate_text = text_generate.generate(prompt=start_context, max_gen_tokens=50, temperature=0.9, top_k=10)
74+
generate_text = text_generate.generate(
75+
prompt=start_context,
76+
config=GenerationConfig(
77+
max_new_tokens=50,
78+
temperature=0.9,
79+
top_k=10,
80+
),
81+
)
7482
print(generate_text.replace("\n", " ")) # Compact print format
7583
model.train()
7684

0 commit comments

Comments
 (0)