Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 11, 2024
1 parent b4b8df7 commit 795badb
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 35 deletions.
4 changes: 2 additions & 2 deletions examples/cpu_ipex_bert.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ launcher:
backend:
device: cpu
export: true
no_weights: false # because on multi-node machines, intializing weights could harm performance
torch_dtype: float32 # but use bfloat16 on compatible Intel CPUs
no_weights: false # on multi-node machines, intializing weights in the benchmark could harm performance
torch_dtype: float32 # use bfloat16 on compatible Intel CPUs
model: google-bert/bert-base-uncased

scenario:
Expand Down
4 changes: 2 additions & 2 deletions examples/cpu_ipex_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ launcher:
backend:
device: cpu
export: true
no_weights: false # because on multi-node machines, intializing weights could harm performance
torch_dtype: float32 # but use bfloat16 on compatible Intel CPUs
no_weights: false # on multi-node machines, intializing weights in the benchmark could harm performance
torch_dtype: float32 # use bfloat16 on compatible Intel CPUs
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0

scenario:
Expand Down
20 changes: 0 additions & 20 deletions examples/cpu_onnxruntime_timm.yaml

This file was deleted.

5 changes: 4 additions & 1 deletion examples/cpu_openvino_8bit_bert.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ backend:
device: cpu
reshape: true
no_weights: true
load_in_8bit: false # enable 8bit on compatible Intel CPU machines
load_in_8bit: true
model: google-bert/bert-base-uncased
reshape_kwargs:
batch_size: 1
sequence_length: 128

scenario:
memory: true
Expand Down
1 change: 1 addition & 0 deletions examples/cuda_tgi_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ backend:
device: cuda
device_ids: 0
cuda_graphs: 0 # remove for better perf but bigger memory footprint
no_weights: true
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0

scenario:
Expand Down
1 change: 1 addition & 0 deletions examples/cuda_trt_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ launcher:
backend:
device: cuda
device_ids: 0
no_weights: true
max_batch_size: 4
max_new_tokens: 32
max_prompt_length: 64
Expand Down
3 changes: 2 additions & 1 deletion examples/cuda_vllm_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ launcher:
backend:
device: cuda
device_ids: 0
serving_mode: online # server-like
no_weights: true
serving_mode: online
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
engine_args:
enforce_eager: true # remove for better perf but bigger memory footprint
Expand Down
12 changes: 6 additions & 6 deletions examples/mps_pytorch_bert.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ defaults:

name: mps_pytorch_bert

backend:
device: mps
no_weights: true
model: bert-base-uncased

scenario:
latency: true
memory: true
latency: true
input_shapes:
batch_size: 1
sequence_length: 128

backend:
device: mps
no_weights: true
model: bert-base-uncased
4 changes: 3 additions & 1 deletion optimum_benchmark/backends/onnxruntime/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def quantize_onnx_files(self) -> None:

if self.is_calibrated:
self.logger.info("\t+ Generating calibration dataset")
dataset_shapes = {"dataset_size": 1, "sequence_length": 1, **self.model_shapes}
dataset_shapes = {"dataset_size": 2, "sequence_length": 2, "num_choices": 2}
calibration_dataset = DatasetGenerator(
task=self.config.task, dataset_shapes=dataset_shapes, model_shapes=self.model_shapes
)()
Expand Down Expand Up @@ -275,8 +275,10 @@ def quantize_onnx_files(self) -> None:
preprocessor=None,
file_suffix="",
)

if self.pretrained_processor is not None:
self.pretrained_processor.save_pretrained(self.quantized_model)

if self.pretrained_config is not None:
self.pretrained_config.save_pretrained(self.quantized_model)

Expand Down
30 changes: 28 additions & 2 deletions optimum_benchmark/backends/tensorrt_llm/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,36 @@ def trtllm_kwargs(self):

def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.generate(
inputs=inputs.get("input_ids"), attention_mask=inputs.get("attention_mask"), **kwargs
input_ids=inputs.get("input_ids"),
attention_mask=inputs.get("attention_mask"),
min_length=kwargs.get("min_new_tokens", None),
max_new_tokens=kwargs.get("max_new_tokens", None),
repetition_penalty=kwargs.get("repetition_penalty", None),
length_penalty=kwargs.get("length_penalty", None),
pad_token_id=kwargs.get("pad_token_id", None),
bos_token_id=kwargs.get("bos_token_id", None),
eos_token_id=kwargs.get("eos_token_id", None),
temperature=kwargs.get("temperature", None),
num_beams=kwargs.get("num_beams", None),
top_p=kwargs.get("top_p", None),
top_k=kwargs.get("top_k", None),
seed=kwargs.get("seed", None),
)

def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.generate(
inputs=inputs.get("input_ids"), attention_mask=inputs.get("attention_mask"), **kwargs
input_ids=inputs.get("input_ids"),
attention_mask=inputs.get("attention_mask"),
min_length=kwargs.get("min_new_tokens", None),
max_new_tokens=kwargs.get("max_new_tokens", None),
repetition_penalty=kwargs.get("repetition_penalty", None),
length_penalty=kwargs.get("length_penalty", None),
pad_token_id=kwargs.get("pad_token_id", None),
bos_token_id=kwargs.get("bos_token_id", None),
eos_token_id=kwargs.get("eos_token_id", None),
temperature=kwargs.get("temperature", None),
num_beams=kwargs.get("num_beams", None),
top_p=kwargs.get("top_p", None),
top_k=kwargs.get("top_k", None),
seed=kwargs.get("seed", None),
)

0 comments on commit 795badb

Please sign in to comment.