Skip to content

Commit 4f02485

Browse files
committed
add ipex model example for text-generation
1 parent 72b0630 commit 4f02485

File tree

3 files changed

+245
-0
lines changed

3 files changed

+245
-0
lines changed
+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
<!---
2+
Copyright 2024 The HuggingFace Team. All rights reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
-->
16+
17+
## Language generation
18+
19+
Based on the script [`run_generation.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-generation/run_generation.py).
20+
21+
The original generation task only supported the PyTorch eager and graph model. By calling the `IPEXModelForCausalLM` class, we can now apply ipex optimizations to the eager and graph model for generation tasks.
22+
23+
24+
Example usage:
25+
### Use bf16 and JIT model
26+
```bash
27+
python run_generation.py \
28+
--model_name_or_path=gpt2 \
29+
--bf16 \
30+
--jit
31+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
sentencepiece
2+
protobuf
3+
torch >= 2.1.0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
4+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5+
# Copyright (c) 2024, INTEL CORPORATION. All rights reserved.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
19+
"""
20+
21+
22+
import argparse
23+
import logging
24+
25+
import torch
26+
from accelerate import PartialState
27+
from accelerate.utils import set_seed
28+
from transformers import AutoTokenizer
29+
30+
from optimum.intel.ipex import IPEXModelForCausalLM
31+
32+
33+
logging.basicConfig(
34+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
35+
datefmt="%m/%d/%Y %H:%M:%S",
36+
level=logging.INFO,
37+
)
38+
logger = logging.getLogger(__name__)
39+
40+
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
41+
42+
43+
def adjust_length_to_model(length, max_sequence_length):
44+
if length < 0 and max_sequence_length > 0:
45+
length = max_sequence_length
46+
elif 0 < max_sequence_length < length:
47+
length = max_sequence_length # No generation bigger than model size
48+
elif length < 0:
49+
length = MAX_LENGTH # avoid infinite loop
50+
return length
51+
52+
53+
def sparse_model_config(model_config):
54+
embedding_size = None
55+
if hasattr(model_config, "hidden_size"):
56+
embedding_size = model_config.hidden_size
57+
elif hasattr(model_config, "n_embed"):
58+
embedding_size = model_config.n_embed
59+
elif hasattr(model_config, "n_embd"):
60+
embedding_size = model_config.n_embd
61+
62+
num_head = None
63+
if hasattr(model_config, "num_attention_heads"):
64+
num_head = model_config.num_attention_heads
65+
elif hasattr(model_config, "n_head"):
66+
num_head = model_config.n_head
67+
68+
if embedding_size is None or num_head is None or num_head == 0:
69+
raise ValueError("Check the model config")
70+
71+
num_embedding_size_per_head = int(embedding_size / num_head)
72+
if hasattr(model_config, "n_layer"):
73+
num_layer = model_config.n_layer
74+
elif hasattr(model_config, "num_hidden_layers"):
75+
num_layer = model_config.num_hidden_layers
76+
else:
77+
raise ValueError("Number of hidden layers couldn't be determined from the model config")
78+
79+
return num_layer, num_head, num_embedding_size_per_head
80+
81+
82+
def main():
83+
parser = argparse.ArgumentParser()
84+
parser.add_argument(
85+
"--model_name_or_path",
86+
default=None,
87+
type=str,
88+
required=True,
89+
help="Path to pre-trained model or shortcut name",
90+
)
91+
92+
parser.add_argument("--prompt", type=str, default="")
93+
parser.add_argument("--length", type=int, default=20)
94+
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
95+
96+
parser.add_argument(
97+
"--temperature",
98+
type=float,
99+
default=1.0,
100+
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
101+
)
102+
parser.add_argument(
103+
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
104+
)
105+
parser.add_argument("--k", type=int, default=0)
106+
parser.add_argument("--p", type=float, default=0.9)
107+
108+
parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
109+
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
110+
111+
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
112+
parser.add_argument(
113+
"--use_cpu",
114+
action="store_true",
115+
help="Whether or not to use cpu. If set to False, " "we will use gpu/npu or mps device if available",
116+
)
117+
parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
118+
parser.add_argument(
119+
"--fp16",
120+
action="store_true",
121+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
122+
)
123+
parser.add_argument(
124+
"--bf16",
125+
action="store_true",
126+
help="Whether to use bfloat 16-bit precision (through INTEL AMX or AVX_512) instead of 32-bit",
127+
)
128+
parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference")
129+
args = parser.parse_args()
130+
131+
if args.fp16 and args.bf16:
132+
raise ValueError("You can only choose one of {fp16, bf16}")
133+
134+
torch_dtype = torch.float32
135+
if args.fp16:
136+
torch_dtype = torch.float16
137+
if args.bf16:
138+
torch_dtype = torch.bfloat16
139+
140+
# Initialize the distributed state.
141+
distributed_state = PartialState(cpu=args.use_cpu)
142+
143+
logger.warning(f"device: {distributed_state.device}, 16-bits inference: {args.fp16 or args.bf16}")
144+
145+
if args.seed is not None:
146+
set_seed(args.seed)
147+
148+
# Initialize the model and tokenizer
149+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
150+
if tokenizer.pad_token is None:
151+
tokenizer.pad_token = tokenizer.eos_token
152+
model = IPEXModelForCausalLM.from_pretrained(args.model_name_or_path, export=args.jit, torch_dtype=torch_dtype)
153+
154+
# Set the model to the right device
155+
model.to(distributed_state.device)
156+
157+
max_seq_length = getattr(model.config, "max_position_embeddings", 0)
158+
args.length = adjust_length_to_model(args.length, max_sequence_length=max_seq_length)
159+
logger.info(args)
160+
161+
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
162+
163+
prefix = args.prefix if args.prefix else args.padding_text
164+
encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
165+
encoded_prompt = encoded_prompt.to(distributed_state.device)
166+
167+
if encoded_prompt.size()[-1] == 0:
168+
input_ids = None
169+
else:
170+
input_ids = encoded_prompt
171+
172+
output_sequences = model.generate(
173+
input_ids=input_ids,
174+
max_length=args.length + len(encoded_prompt[0]),
175+
temperature=args.temperature,
176+
top_k=args.k,
177+
top_p=args.p,
178+
repetition_penalty=args.repetition_penalty,
179+
do_sample=True,
180+
num_return_sequences=args.num_return_sequences,
181+
)
182+
183+
# Remove the batch dimension when returning multiple sequences
184+
if len(output_sequences.shape) > 2:
185+
output_sequences.squeeze_()
186+
187+
generated_sequences = []
188+
189+
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
190+
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
191+
generated_sequence = generated_sequence.tolist()
192+
193+
# Decode text
194+
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
195+
196+
# Remove all text after the stop token
197+
text = text[: text.find(args.stop_token) if args.stop_token else None]
198+
199+
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
200+
total_sequence = (
201+
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
202+
)
203+
204+
generated_sequences.append(total_sequence)
205+
print(total_sequence)
206+
207+
return generated_sequences
208+
209+
210+
if __name__ == "__main__":
211+
main()

0 commit comments

Comments
 (0)