Skip to content

Commit 993c1a5

Browse files
Enable TorchScript export, loading and inference for causal lm models (#283)
* add original example * add jit model for generation * enable generation * add requirements and README * enable caching * move modeling to ipex subpackage * fix import * rename model class * Update code and move modeling.py to utils folder. (#290) * Update code and move modeling.py to utils folder. * move modeling.py to generation folder * Fixed typo * Update code for support instance of fx model to TracedModelForCausalLM * Update code * Fixed typo * Update readme * Add generation for bloom models (#295) * add bloom generation * Fix generation * refactorization * set back input ids to prepare_inputs * Add tests * fix dates * remove unecessary list comprehension * remove unecessary list comprehension * rename class * fix dependencies * fix dependency --------- Co-authored-by: Cheng, Penghui <penghui.cheng@intel.com>
1 parent 08f6330 commit 993c1a5

File tree

8 files changed

+899
-1
lines changed

8 files changed

+899
-1
lines changed

.github/workflows/test_generation.yml

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2+
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3+
name: Intel Generation Utils - Test
4+
5+
on:
6+
push:
7+
branches: [ main ]
8+
pull_request:
9+
branches: [ main ]
10+
11+
concurrency:
12+
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
13+
cancel-in-progress: true
14+
15+
jobs:
16+
build:
17+
strategy:
18+
fail-fast: false
19+
matrix:
20+
python-version: [3.8, 3.9]
21+
os: [ubuntu-latest]
22+
23+
runs-on: ${{ matrix.os }}
24+
steps:
25+
- uses: actions/checkout@v2
26+
- name: Setup Python ${{ matrix.python-version }}
27+
uses: actions/setup-python@v2
28+
with:
29+
python-version: ${{ matrix.python-version }}
30+
- name: Install dependencies
31+
run: |
32+
python -m pip install --upgrade pip
33+
pip install optimum[exporters]
34+
pip install .[tests]
35+
- name: Test with Pytest
36+
run: |
37+
pytest tests/generation/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!---
2+
Copyright 2023 The HuggingFace Team. All rights reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
-->
16+
17+
## Language generation
18+
19+
Based on the script [`run_generation.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-generation/run_generation.py).
20+
21+
The original generation task only supported the PyTorch eager model. By calling the `TorchScriptModelForCausalLM` class, we can now support a TorchScript model for generation tasks.
22+
23+
Example usage:
24+
25+
```bash
26+
python run_generation.py \
27+
--model_type=gpt2 \
28+
--model_name_or_path=gpt2 \
29+
--jit
30+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
sentencepiece != 0.1.92
2+
protobuf
3+
torch >= 2.0.0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
4+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
18+
"""
19+
20+
21+
import argparse
22+
import logging
23+
24+
import numpy as np
25+
import torch
26+
from transformers import (
27+
CTRLLMHeadModel,
28+
CTRLTokenizer,
29+
GPT2LMHeadModel,
30+
GPT2Tokenizer,
31+
OpenAIGPTLMHeadModel,
32+
OpenAIGPTTokenizer,
33+
TransfoXLLMHeadModel,
34+
TransfoXLTokenizer,
35+
XLMTokenizer,
36+
XLMWithLMHeadModel,
37+
XLNetLMHeadModel,
38+
XLNetTokenizer,
39+
)
40+
41+
from optimum.intel.generation.modeling import TorchScriptModelForCausalLM
42+
43+
44+
logging.basicConfig(
45+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
46+
datefmt="%m/%d/%Y %H:%M:%S",
47+
level=logging.INFO,
48+
)
49+
logger = logging.getLogger(__name__)
50+
51+
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
52+
53+
MODEL_CLASSES = {
54+
"gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
55+
"ctrl": (CTRLLMHeadModel, CTRLTokenizer),
56+
"openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
57+
"xlnet": (XLNetLMHeadModel, XLNetTokenizer),
58+
"transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
59+
"xlm": (XLMWithLMHeadModel, XLMTokenizer),
60+
}
61+
62+
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
63+
# in https://github.com/rusiaaman/XLNet-gen#methodology
64+
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
65+
PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
66+
(except for Alexei and Maria) are discovered.
67+
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
68+
remainder of the story. 1883 Western Siberia,
69+
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
70+
Rasputin has a vision and denounces one of the men as a horse thief. Although his
71+
father initially slaps him for making such an accusation, Rasputin watches as the
72+
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
73+
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
74+
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
75+
76+
77+
def set_seed(args):
78+
np.random.seed(args.seed)
79+
torch.manual_seed(args.seed)
80+
if args.n_gpu > 0:
81+
torch.cuda.manual_seed_all(args.seed)
82+
83+
84+
#
85+
# Functions to prepare models' input
86+
#
87+
88+
89+
def prepare_ctrl_input(args, _, tokenizer, prompt_text):
90+
if args.temperature > 0.7:
91+
logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
92+
93+
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
94+
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
95+
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
96+
return prompt_text
97+
98+
99+
def prepare_xlm_input(args, model, tokenizer, prompt_text):
100+
# kwargs = {"language": None, "mask_token_id": None}
101+
102+
# Set the language
103+
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
104+
if hasattr(model.config, "lang2id") and use_lang_emb:
105+
available_languages = model.config.lang2id.keys()
106+
if args.xlm_language in available_languages:
107+
language = args.xlm_language
108+
else:
109+
language = None
110+
while language not in available_languages:
111+
language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
112+
113+
model.config.lang_id = model.config.lang2id[language]
114+
# kwargs["language"] = tokenizer.lang2id[language]
115+
116+
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
117+
# XLM masked-language modeling (MLM) models need masked token
118+
# is_xlm_mlm = "mlm" in args.model_name_or_path
119+
# if is_xlm_mlm:
120+
# kwargs["mask_token_id"] = tokenizer.mask_token_id
121+
122+
return prompt_text
123+
124+
125+
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
126+
prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
127+
prompt_text = prefix + prompt_text
128+
return prompt_text
129+
130+
131+
def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
132+
prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
133+
prompt_text = prefix + prompt_text
134+
return prompt_text
135+
136+
137+
PREPROCESSING_FUNCTIONS = {
138+
"ctrl": prepare_ctrl_input,
139+
"xlm": prepare_xlm_input,
140+
"xlnet": prepare_xlnet_input,
141+
"transfo-xl": prepare_transfoxl_input,
142+
}
143+
144+
145+
def adjust_length_to_model(length, max_sequence_length):
146+
if length < 0 and max_sequence_length > 0:
147+
length = max_sequence_length
148+
elif 0 < max_sequence_length < length:
149+
length = max_sequence_length # No generation bigger than model size
150+
elif length < 0:
151+
length = MAX_LENGTH # avoid infinite loop
152+
return length
153+
154+
155+
def main():
156+
parser = argparse.ArgumentParser()
157+
parser.add_argument(
158+
"--model_type",
159+
default=None,
160+
type=str,
161+
required=True,
162+
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
163+
)
164+
parser.add_argument(
165+
"--model_name_or_path",
166+
default=None,
167+
type=str,
168+
required=True,
169+
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
170+
)
171+
172+
parser.add_argument("--prompt", type=str, default="")
173+
parser.add_argument("--length", type=int, default=20)
174+
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
175+
176+
parser.add_argument(
177+
"--temperature",
178+
type=float,
179+
default=1.0,
180+
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
181+
)
182+
parser.add_argument(
183+
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
184+
)
185+
parser.add_argument("--k", type=int, default=0)
186+
parser.add_argument("--p", type=float, default=0.9)
187+
188+
parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
189+
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
190+
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
191+
192+
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
193+
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
194+
parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
195+
parser.add_argument(
196+
"--fp16",
197+
action="store_true",
198+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
199+
)
200+
parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference")
201+
202+
parser.add_argument(
203+
"--output_dir",
204+
default=None,
205+
type=str,
206+
help="Output directory where to save the resulting model",
207+
)
208+
args = parser.parse_args()
209+
210+
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
211+
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
212+
213+
logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}")
214+
215+
set_seed(args)
216+
217+
# Initialize the model and tokenizer
218+
try:
219+
args.model_type = args.model_type.lower()
220+
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
221+
except KeyError:
222+
raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
223+
224+
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
225+
226+
if args.jit:
227+
model = TorchScriptModelForCausalLM.from_pretrained(args.model_name_or_path, export=True)
228+
else:
229+
model = model_class.from_pretrained(args.model_name_or_path)
230+
231+
if args.output_dir is not None and args.jit:
232+
model.save_pretrained(args.output_dir)
233+
tokenizer.save_pretrained(args.output_dir)
234+
235+
model.to(args.device)
236+
237+
args.length = adjust_length_to_model(
238+
args.length,
239+
max_sequence_length=model.config.max_position_embeddings
240+
if hasattr(model.config, "max_position_embeddings")
241+
else 0,
242+
)
243+
logger.info(args)
244+
245+
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
246+
247+
# Different models need different input formatting and/or extra arguments
248+
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
249+
if requires_preprocessing:
250+
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
251+
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
252+
253+
if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
254+
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
255+
else:
256+
tokenizer_kwargs = {}
257+
258+
encoded_prompt = tokenizer.encode(
259+
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
260+
)
261+
else:
262+
prefix = args.prefix if args.prefix else args.padding_text
263+
encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
264+
encoded_prompt = encoded_prompt.to(args.device)
265+
266+
if encoded_prompt.size()[-1] == 0:
267+
input_ids = None
268+
else:
269+
input_ids = encoded_prompt
270+
271+
output_sequences = model.generate(
272+
input_ids=input_ids,
273+
max_length=args.length + len(encoded_prompt[0]),
274+
temperature=args.temperature,
275+
top_k=args.k,
276+
top_p=args.p,
277+
repetition_penalty=args.repetition_penalty,
278+
do_sample=True,
279+
num_return_sequences=args.num_return_sequences,
280+
)
281+
282+
# Remove the batch dimension when returning multiple sequences
283+
if len(output_sequences.shape) > 2:
284+
output_sequences.squeeze_()
285+
286+
generated_sequences = []
287+
288+
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
289+
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
290+
generated_sequence = generated_sequence.tolist()
291+
292+
# Decode text
293+
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
294+
295+
# Remove all text after the stop token
296+
text = text[: text.find(args.stop_token) if args.stop_token else None]
297+
298+
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
299+
total_sequence = (
300+
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
301+
)
302+
303+
generated_sequences.append(total_sequence)
304+
print(total_sequence)
305+
306+
return generated_sequences
307+
308+
309+
if __name__ == "__main__":
310+
main()

0 commit comments

Comments
 (0)