Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit transformers version for bettertransformer support #2198

Merged
merged 4 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_bettertransformer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
run: |
pip install .[tests]
pip install --no-cache-dir --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install accelerate
pip install accelerate transformers==4.48.*

- name: Test with stable pytorch
working-directory: tests
Expand Down
11 changes: 11 additions & 0 deletions optimum/bettertransformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from ..utils.import_utils import _transformers_version, is_transformers_version


if is_transformers_version(">=", "4.49"):
raise RuntimeError(
f"BetterTransformer requires transformers<4.49 but found {_transformers_version}. "
"`optimum.bettertransformer` is deprecated and will be removed in optimum v2.0."
)

from .models import BetterTransformerManager
from .transformation import BetterTransformer
2 changes: 1 addition & 1 deletion optimum/bettertransformer/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def transform(
"""

logger.warning(
"The class `optimum.bettertransformers.transformation.BetterTransformer` is deprecated and will be removed in a future release."
"The class `optimum.bettertransformers.transformation.BetterTransformer` is deprecated and will be removed in optimum v2.0."
)

hf_config = model.config
Expand Down
3 changes: 2 additions & 1 deletion optimum/pipelines/pipelines_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from transformers.pipelines import SUPPORTED_TASKS as TRANSFORMERS_SUPPORTED_TASKS
from transformers.pipelines import infer_framework_load_model

from ..bettertransformer import BetterTransformer
from ..utils import is_onnxruntime_available, is_transformers_version


Expand Down Expand Up @@ -185,6 +184,8 @@ def load_bettertransformer(
hub_kwargs: Optional[Dict] = None,
**kwargs,
):
from ..bettertransformer import BetterTransformer

if model_kwargs is None:
# the argument was first introduced in 4.36.0 but most models didn't have an sdpa implementation then
# see https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/modeling_utils.py#L1258
Expand Down
2 changes: 1 addition & 1 deletion tests/bettertransformer/Dockerfile_bettertransformer_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ RUN apt-get autoremove -y
RUN python3 -m pip install -U pip

RUN pip install torch torchvision torchaudio
RUN pip install transformers accelerate datasets
RUN pip install transformers==4.48.* accelerate datasets

# Install Optimum
COPY . /workspace/optimum
Expand Down
Loading