Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 16, 2024
1 parent ea32802 commit abec2b6
Showing 1 changed file with 3 additions and 16 deletions.
19 changes: 3 additions & 16 deletions optimum_benchmark/backends/py_txi/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, List

import torch
from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from py_txi import TEI, TGI, TEIConfig, TGIConfig
from safetensors.torch import save_file

Expand Down Expand Up @@ -36,26 +36,13 @@ def load(self) -> None:

self.tmpdir.cleanup()

@property
def volume(self) -> str:
return list(self.config.volumes.keys())[0]

def download_pretrained_model(self) -> None:
# directly downloads pretrained model in volume (/data) to change generation config before loading model
with init_empty_weights(include_buffers=True):
self.automodel_loader.from_pretrained(self.config.model, **self.config.model_kwargs, cache_dir=self.volume)
model_snapshot_folder = snapshot_download(self.config.model, self.config.model_kwargs)

if self.config.task in TEXT_GENERATION_TASKS:
self.logger.info("\t+ Preparing generation config")
self.generation_config.eos_token_id = None
self.generation_config.pad_token_id = None
model_cache_folder = f"models/{self.config.model}".replace("/", "--")
model_cache_path = f"{self.volume}/{model_cache_folder}"
snapshot_file = f"{model_cache_path}/refs/{self.config.model_kwargs.get('revision', 'main')}"
snapshot_ref = open(snapshot_file, "r").read().strip()
model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}"
self.logger.info("\t+ Saving pretrained generation config")
self.generation_config.save_pretrained(save_directory=model_snapshot_path)
self.generation_config.save_pretrained(save_directory=model_snapshot_folder)

def create_no_weights_model(self) -> None:
self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model")
Expand Down

0 comments on commit abec2b6

Please sign in to comment.