|
| 1 | +from pathlib import Path |
| 2 | +import shutil |
| 3 | +from huggingface_hub import snapshot_download |
| 4 | +import torch |
| 5 | +from PIL import Image |
| 6 | + |
| 7 | + |
| 8 | +def download_original_model(model_id, model_local_dir): |
| 9 | + if not model_local_dir.exists(): |
| 10 | + snapshot_download(repo_id=model_id, local_dir=model_local_dir) |
| 11 | + |
| 12 | + modeling_file = model_local_dir / "modeling_llava_qwen2.py" |
| 13 | + orig_modeling_file = model_local_dir / f"orig_{modeling_file.name}" |
| 14 | + |
| 15 | + # model code depends from flash_attn package that may be problematic to load. Patch model code for avoiding import of this package |
| 16 | + if not orig_modeling_file.exists(): |
| 17 | + modeling_file.rename(orig_modeling_file) |
| 18 | + with orig_modeling_file.open("r") as f: |
| 19 | + content = f.read() |
| 20 | + replacement_lines = [ |
| 21 | + ("from flash_attn import flash_attn_func, flash_attn_varlen_func", ""), |
| 22 | + ("from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input", ""), |
| 23 | + (' _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)', "pass"), |
| 24 | + ] |
| 25 | + |
| 26 | + for replace_pair in replacement_lines: |
| 27 | + content = content.replace(*replace_pair) |
| 28 | + |
| 29 | + with modeling_file.open("w") as f: |
| 30 | + f.write(content) |
| 31 | + |
| 32 | + |
| 33 | +def converted_model_exists(model_dir): |
| 34 | + for file_name in ["openvino_language_model.xml", "openvino_text_embeddings_model.xml", "openvino_vision_embeddings_model.xml"]: |
| 35 | + if not (Path(model_dir) / file_name).exists() or not (Path(model_dir) / file_name.replace(".bin")).exists(): |
| 36 | + return False |
| 37 | + |
| 38 | + return True |
| 39 | + |
| 40 | + |
| 41 | +def copy_model_files(src_dir, dst_dir, ignore_llm=True, ignore_vision_encoder=True): |
| 42 | + ignore_files = [] |
| 43 | + if ignore_llm: |
| 44 | + ignore_files.extend(["openvino_language_model.xml", "openvino_language_model.bin"]) |
| 45 | + if ignore_vision_encoder: |
| 46 | + ignore_files.extend(["openvino_vision_embeddings_model.xml", "openvino_vision_embeddings_model.bin"]) |
| 47 | + |
| 48 | + for file_name in src_dir.glob("*"): |
| 49 | + if file_name.name in ignore_files: |
| 50 | + continue |
| 51 | + shutil.copy(file_name, dst_dir) |
| 52 | + |
| 53 | + |
| 54 | +def expand2square(pil_img, background_color): |
| 55 | + width, height = pil_img.size |
| 56 | + if width == height: |
| 57 | + return pil_img |
| 58 | + elif width > height: |
| 59 | + result = Image.new(pil_img.mode, (width, width), background_color) |
| 60 | + result.paste(pil_img, (0, (width - height) // 2)) |
| 61 | + return result |
| 62 | + else: |
| 63 | + result = Image.new(pil_img.mode, (height, height), background_color) |
| 64 | + result.paste(pil_img, ((height - width) // 2, 0)) |
| 65 | + return result |
| 66 | + |
| 67 | + |
| 68 | +def process_images(images, model_cfg, processor): |
| 69 | + image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) |
| 70 | + new_images = [] |
| 71 | + if image_aspect_ratio == "pad": |
| 72 | + for image in images: |
| 73 | + image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) |
| 74 | + image = processor.preprocess(images=image, return_tensors="pt")["pixel_values"][0] |
| 75 | + new_images.append(image) |
| 76 | + else: |
| 77 | + return processor(images=images, return_tensors="pt")["pixel_values"] |
| 78 | + if all(x.shape == new_images[0].shape for x in new_images): |
| 79 | + new_images = torch.stack(new_images, dim=0) |
| 80 | + return new_images |
| 81 | + |
| 82 | + |
| 83 | +def process_text_input(text, tokenizer): |
| 84 | + text_chunks = [tokenizer(chunk).input_ids for chunk in text.split("<image>")] |
| 85 | + input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0) |
| 86 | + attention_mask = torch.ones_like(input_ids, dtype=torch.int64) |
| 87 | + return input_ids, attention_mask |
0 commit comments