|
1 | 1 | import logging
|
2 | 2 | import json
|
| 3 | +import torch |
3 | 4 |
|
4 | 5 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, AutoModelForVision2Seq, AutoTokenizer
|
5 | 6 | from diffusers import DiffusionPipeline, AutoPipelineForImage2Image, AutoPipelineForInpainting
|
6 | 7 |
|
| 8 | +from .utils import mock_torch_cuda_is_available |
| 9 | + |
7 | 10 |
|
8 | 11 | logging.basicConfig(level=logging.INFO)
|
9 | 12 | logger = logging.getLogger(__name__)
|
@@ -82,18 +85,39 @@ def load_text_llamacpp_pipeline(model_dir):
|
82 | 85 | return model
|
83 | 86 |
|
84 | 87 |
|
| 88 | +def load_text_hf_pipeline(model_id, device): |
| 89 | + model_kwargs = {} |
| 90 | + |
| 91 | + if not torch.cuda.is_available or device.lower() == "cpu": |
| 92 | + config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) |
| 93 | + is_gptq = False |
| 94 | + is_awq = False |
| 95 | + if getattr(config, "quantization_config", None): |
| 96 | + is_gptq = config.quantization_config["quant_method"] == "gptq" |
| 97 | + is_awq = config.quantization_config["quant_method"] == "awq" |
| 98 | + if is_gptq or is_awq: |
| 99 | + # infer in FP32 |
| 100 | + model_kwargs["torch_dtype"] = torch.float32 |
| 101 | + with mock_torch_cuda_is_available(is_gptq or is_awq): |
| 102 | + model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, device_map="cpu", **model_kwargs) |
| 103 | + if is_awq: |
| 104 | + model.is_awq = is_awq |
| 105 | + else: |
| 106 | + model = AutoModelForCausalLM.from_pretrained( |
| 107 | + model_id, trust_remote_code=True, device_map=device.lower(), **model_kwargs |
| 108 | + ) |
| 109 | + model.eval() |
| 110 | + return model |
| 111 | + |
| 112 | + |
85 | 113 | def load_text_model(
|
86 | 114 | model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False, use_llamacpp=False, **kwargs,
|
87 | 115 | ):
|
88 | 116 | if use_hf:
|
89 | 117 | logger.info("Using HF Transformers API")
|
90 |
| - model = AutoModelForCausalLM.from_pretrained( |
91 |
| - model_id, trust_remote_code=True, device_map=device.lower() |
92 |
| - ) |
93 |
| - model.eval() |
| 118 | + model = load_text_hf_pipeline(model_id, device) |
94 | 119 | elif use_genai:
|
95 | 120 | model = load_text_genai_pipeline(model_id, device, ov_config, **kwargs)
|
96 |
| - |
97 | 121 | elif use_llamacpp:
|
98 | 122 | logger.info("Using llama.cpp API")
|
99 | 123 | model = load_text_llamacpp_pipeline(model_id)
|
|
0 commit comments