1
1
import typer
2
2
from rich .console import Console
3
3
4
+ from toyllm .core import GenerationConfig
4
5
from toyllm .gpt2 import GPTModel , GPTModelSize , GPTTextGenerator , gpt2_tokenizer
5
6
from toyllm .sps import GPTSpsModel , SpsTextGenerator
6
7
from toyllm .util import Timer
7
8
8
9
9
10
def main (
10
11
prompt_text : str = "Alan Turing theorized that computers would one day become" ,
11
- generate_tokens : int = 256 ,
12
+ max_new_tokens : int = 256 ,
12
13
k : int = 4 , # K in sps paper
13
14
) -> None :
15
+ generate_config = GenerationConfig (max_new_tokens = max_new_tokens )
16
+
14
17
console = Console ()
15
18
console .print (f"Prompt: { prompt_text } " , style = "bold blue" )
16
19
# Test the speculative sampling
@@ -25,8 +28,7 @@ def main(
25
28
with Timer (name = "Speculative Sampling" ):
26
29
generate_text = sps_text_generator .generate (
27
30
prompt = prompt_text ,
28
- target_seq_len = generate_tokens ,
29
- temperature = 0 ,
31
+ config = generate_config ,
30
32
)
31
33
console .print (f"Generated: { generate_text [:200 ]} " , style = "bold green" )
32
34
console .print (f"{ '-' * 20 } Speculative Sampling { '-' * 20 } " , style = "bold blue" )
@@ -39,7 +41,7 @@ def main(
39
41
with Timer (name = "Naive GPT2 Auto-Regressive" ):
40
42
generate_text = gpt_text_generator .generate (
41
43
prompt = prompt_text ,
42
- max_gen_tokens = generate_tokens ,
44
+ config = generate_config ,
43
45
)
44
46
console .print (f"Generated: { generate_text [:200 ]} " , style = "bold green" )
45
47
console .print (f"{ '-' * 20 } Naive GPT2 Auto-Regressive { '-' * 20 } " , style = "bold blue" )
0 commit comments