Skip to content

Commit

Permalink
sngle prepare inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 11, 2024
1 parent 6738b85 commit b4b8df7
Show file tree
Hide file tree
Showing 13 changed files with 36 additions and 62 deletions.
12 changes: 1 addition & 11 deletions optimum_benchmark/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,23 +106,13 @@ def create_no_weights_model(self) -> None:
self.logger.info("\t+ Saving no weights model's config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
This method is used to prepare and register the inputs before passing them to the model.
It can be used to move the inputs to the correct device, or rename their keys.
"""
return inputs

def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
This method is used to prepare and register the inputs before passing them to the model.
It can be used to move the inputs to the correct device, or rename their keys.
"""
return inputs

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return self.prepare_inputs_after_load(self.prepare_inputs_before_load(inputs))

def load(self) -> None:
raise NotImplementedError("Backend must implement load method")

Expand Down
13 changes: 6 additions & 7 deletions optimum_benchmark/backends/ipex/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,25 @@ def load(self) -> None:
self.logger.info("\t+ Creating no weights IPEXModel")
self.create_no_weights_model()
self.logger.info("\t+ Loading no weights IPEXModel")
self._load_ipexmodel_with_no_weights()
self.load_ipexmodel_with_no_weights()
else:
self.logger.info("\t+ Loading pretrained IPEXModel")
self._load_ipexmodel_from_pretrained()
self.load_ipexmodel_from_pretrained()

self.tmpdir.cleanup()

def _load_ipexmodel_from_pretrained(self) -> None:
def load_ipexmodel_from_pretrained(self) -> None:
self.pretrained_model = self.ipexmodel_class.from_pretrained(
self.config.model,
**self.config.model_kwargs,
**self.ipexmodel_kwargs,
)

def _load_ipexmodel_with_no_weights(self) -> None:
def load_ipexmodel_with_no_weights(self) -> None:
with fast_weights_init():
self.logger.info("\t+ Loading no weights IPEXModel")
original_model, self.config.model = self.config.model, self.no_weights_model
original_export, self.config.export = self.config.export, True
self._load_ipexmodel_from_pretrained()
self.load_ipexmodel_from_pretrained()
self.config.export = original_export
self.config.model = original_model

Expand All @@ -77,7 +76,7 @@ def ipexmodel_kwargs(self) -> Dict[str, Any]:
def split_between_processes(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs
Expand Down
8 changes: 5 additions & 3 deletions optimum_benchmark/backends/llama_cpp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def llama_cpp_kwargs(self) -> Dict[str, Any]:
"echo": False,
}

def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.config.task == "text-generation":
if inputs["input_ids"].shape[0] != 1:
raise ValueError("Batch size must be 1 for Text Generation with llama-cpp-python")
Expand All @@ -55,9 +55,11 @@ def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any:
self.pretrained_model.embed(**inputs)

def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> list[int]:
next(self.pretrained_model.generate(**inputs))
generator = self.pretrained_model.generate(**inputs, reset=True)
for _ in range(kwargs["max_new_tokens"]):
next(generator)

def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> list[int]:
generator = self.pretrained_model.generate(**inputs)
generator = self.pretrained_model.generate(**inputs, reset=True)
for _ in range(kwargs["max_new_tokens"]):
next(generator)
5 changes: 1 addition & 4 deletions optimum_benchmark/backends/onnxruntime/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def quantize_onnx_files(self) -> None:
def split_between_processes(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs
Expand All @@ -293,9 +293,6 @@ def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(value, torch.Tensor):
inputs[key] = value.to(self.config.device)

return inputs

def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
for key in list(inputs.keys()):
if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names:
inputs.pop(key)
Expand Down
28 changes: 6 additions & 22 deletions optimum_benchmark/backends/openvino/backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
from collections import OrderedDict
from tempfile import TemporaryDirectory
from typing import Any, Dict
Expand Down Expand Up @@ -48,14 +47,8 @@ def load(self) -> None:
self.load_ovmodel_from_pretrained()

if self.config.reshape:
static_shapes = {
key: value
for key, value in self.model_shapes.items()
if key in inspect.getfullargspec(self.pretrained_model.reshape).args
}

self.logger.info(f"\t+ Reshaping model with static shapes: {static_shapes}")
self.pretrained_model.reshape(**static_shapes)
self.logger.info("\t+ Reshaping model with static shapes")
self.pretrained_model.reshape(**self.config.reshape_kwargs)

if self.config.half:
self.logger.info("\t+ Converting model to half precision")
Expand All @@ -78,7 +71,6 @@ def load_ovmodel_with_no_weights(self) -> None:
with fast_weights_init():
original_model, self.config.model = self.config.model, self.no_weights_model
original_export, self.config.export = self.config.export, True
self.logger.info("\t+ Loading no weights OVModel")
self.load_ovmodel_from_pretrained()
self.config.export = original_export
self.config.model = original_model
Expand All @@ -102,28 +94,20 @@ def ovmodel_kwargs(self) -> Dict[str, Any]:
if self.config.load_in_4bit is not None:
kwargs["load_in_4bit"] = self.config.load_in_4bit

if self.config.ov_config:
kwargs["ov_config"] = self.config.ov_config

return kwargs

@property
def split_between_processes(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs

if "input_ids" in inputs:
self.model_shapes.update(dict(zip(["batch_size", "sequence_length"], inputs["input_ids"].shape)))

if "pixel_values" in inputs:
self.model_shapes.update(
dict(zip(["batch_size", "num_channels", "height", "width"], inputs["pixel_values"].shape))
)

return inputs

def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
for key in list(inputs.keys()):
if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names:
inputs.pop(key)
Expand Down
5 changes: 2 additions & 3 deletions optimum_benchmark/backends/openvino/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ class OVConfig(BackendConfig):
use_merged: Optional[bool] = None
load_in_8bit: Optional[bool] = None
load_in_4bit: Optional[bool] = None
ov_config: Dict[str, Any] = field(default_factory=dict)

# compilation options
half: bool = False
compile: bool = False
reshape: bool = False

# openvino config
ov_config: Dict[str, Any] = field(default_factory=dict)
reshape_kwargs: Dict[str, int] = field(default_factory=dict)

def __post_init__(self):
super().__post_init__()
Expand Down
2 changes: 1 addition & 1 deletion optimum_benchmark/backends/py_txi/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def load_model_from_pretrained(self) -> None:
else:
raise NotImplementedError(f"TXI does not support task {self.config.task}")

def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.config.task in TEXT_GENERATION_TASKS:
inputs = {"prompt": self.pretrained_processor.batch_decode(inputs["input_ids"].tolist())}
elif self.config.task in TEXT_EMBEDDING_TASKS:
Expand Down
2 changes: 1 addition & 1 deletion optimum_benchmark/backends/py_txi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class PyTXIConfig(BackendConfig):
num_shard: Optional[int] = None
speculate: Optional[int] = None
cuda_graphs: Optional[int] = None
disable_custom_kernels: Optional[bool] = None
trust_remote_code: Optional[bool] = None
disable_custom_kernels: Optional[bool] = None

# TEI specific
pooling: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def split_between_processes(self) -> bool:
and not self.config.deepspeed_inference
)

def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs
Expand Down
8 changes: 6 additions & 2 deletions optimum_benchmark/backends/tensorrt_llm/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ def trtllm_kwargs(self):
return kwargs

def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.generate(inputs=inputs.get("input_ids"), **kwargs)
return self.pretrained_model.generate(
inputs=inputs.get("input_ids"), attention_mask=inputs.get("attention_mask"), **kwargs
)

def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.generate(inputs=inputs.get("input_ids"), **kwargs)
return self.pretrained_model.generate(
inputs=inputs.get("input_ids"), attention_mask=inputs.get("attention_mask"), **kwargs
)
4 changes: 3 additions & 1 deletion optimum_benchmark/backends/torch_ort/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ class TorchORTConfig(BackendConfig):
# load options
no_weights: bool = False
torch_dtype: Optional[str] = None
attn_implementation: Optional[str] = None
attn_implementation: Optional[str] = (
"eager" # we pin eager because sdpa became default of many architectures, which fails with torch-ort
)

# peft options
peft_type: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion optimum_benchmark/backends/vllm/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def vllm_kwargs(self):
**self.config.engine_args,
}

def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.config.task in TEXT_GENERATION_TASKS:
inputs = {"prompts": self.pretrained_processor.batch_decode(inputs["input_ids"])}
else:
Expand Down
7 changes: 2 additions & 5 deletions optimum_benchmark/scenarios/inference/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,10 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport:
input_shapes=self.config.input_shapes,
)()

self.logger.info(f"\t+ Preparing inputs for backend {self.backend.config.name} before model loading.")
self.inputs = self.backend.prepare_inputs_before_load(inputs=self.inputs)

self.run_model_loading_tracking()

self.logger.info(f"\t+ Preparing inputs for backend {self.backend.config.name} after model loading.")
self.inputs = self.backend.prepare_inputs_after_load(inputs=self.inputs)
self.logger.info(f"\t+ Preparing inputs for backend {self.backend.config.name}")
self.inputs = self.backend.prepare_inputs(inputs=self.inputs)

if self.config.warmup_runs > 0:
if self.backend.config.task in TEXT_GENERATION_TASKS:
Expand Down

0 comments on commit b4b8df7

Please sign in to comment.