@@ -3051,22 +3051,22 @@ def patched_forward(self, fn):
3051
3051
3052
3052
def __enter__ (self ):
3053
3053
if is_torch_version (">=" , "2.1.0" ):
3054
- if self ._model .config .model_type == "qwen2" and self ._model .config ._attn_implementation != "sdpa" :
3055
- if is_transformers_version ("<" , "4.48" ):
3054
+ if (
3055
+ self ._model .config .model_type in ["qwen2" , "llama" ]
3056
+ and self ._model .config ._attn_implementation != "sdpa"
3057
+ ):
3058
+ self ._model .config ._orig_attn_implementation = self ._model .config ._attn_implementation
3059
+ self ._model .config ._attn_implementation = "sdpa"
3060
+ if self ._model .config .model_type == "qwen2" and is_transformers_version ("<" , "4.48" ):
3056
3061
from transformers .models .qwen2 .modeling_qwen2 import QWEN2_ATTENTION_CLASSES
3057
3062
3058
3063
sdpa_attn = QWEN2_ATTENTION_CLASSES ["sdpa" ]
3059
- self ._model .config ._orig_attn_implementation = self ._model .config ._attn_implementation
3060
- self ._model .config ._attn_implementation = "sdpa"
3061
3064
3062
3065
for layer in self ._model .model .layers :
3063
3066
layer .self_attn ._orig_forward = layer .self_attn .forward
3064
3067
layer .self_attn .forward = types .MethodType (sdpa_attn .forward , layer .self_attn )
3065
3068
3066
- if self ._model .config .model_type == "llama" and self ._model .config ._attn_implementation != "sdpa" :
3067
- self ._model .config ._orig_attn_implementation = self ._model .config ._attn_implementation
3068
- self ._model .config ._attn_implementation = "sdpa"
3069
- if is_transformers_version ("<" , "4.47" ):
3069
+ if self ._model .config .model_type == "llama" and is_transformers_version ("<" , "4.47" ):
3070
3070
from transformers .models .llama .modeling_llama import LLAMA_ATTENTION_CLASSES
3071
3071
3072
3072
sdpa_attn = LLAMA_ATTENTION_CLASSES ["sdpa" ]
0 commit comments