|
| 1 | +from pathlib import Path |
| 2 | +import requests |
| 3 | +import gradio as gr |
| 4 | +from PIL import Image |
| 5 | +from threading import Thread |
| 6 | +from transformers import TextIteratorStreamer |
| 7 | + |
| 8 | +chat_template = """ |
| 9 | +{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\n\n\" }}\n {%- else %}\n {{- \"[INST]\" }}\n {%- endif %}\n {%- if message[\"content\"] is not string %}\n {%- for chunk in message[\"content\"] %}\n {%- if chunk[\"type\"] == \"text\" %}\n {{- chunk[\"content\"] }}\n {%- elif chunk[\"type\"] == \"image\" %}\n {{- \"[IMG]\" }}\n {%- else %}\n {{- raise_exception(\"Unrecognized content type!\") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message[\"content\"] }}\n {%- endif %}\n {{- \"[/INST]\" }}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %} |
| 10 | +""" |
| 11 | + |
| 12 | + |
| 13 | +def resize_with_aspect_ratio(image: Image, dst_height=512, dst_width=512): |
| 14 | + width, height = image.size |
| 15 | + if width > dst_width or height > dst_height: |
| 16 | + im_scale = min(dst_height / height, dst_width / width) |
| 17 | + resize_size = (int(width * im_scale), int(height * im_scale)) |
| 18 | + return image.resize(resize_size) |
| 19 | + return image |
| 20 | + |
| 21 | + |
| 22 | +def make_demo(model, processor): |
| 23 | + model_name = Path(model.config._name_or_path).parent.name |
| 24 | + |
| 25 | + example_image_urls = [ |
| 26 | + ("https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/dd5105d6-6a64-4935-8a34-3058a82c8d5d", "small.png"), |
| 27 | + ("https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/1221e2a8-a6da-413a-9af6-f04d56af3754", "chart.png"), |
| 28 | + ] |
| 29 | + |
| 30 | + for url, file_name in example_image_urls: |
| 31 | + if not Path(file_name).exists(): |
| 32 | + Image.open(requests.get(url, stream=True).raw).save(file_name) |
| 33 | + if processor.chat_template is None: |
| 34 | + processor.set_chat_template(chat_template) |
| 35 | + |
| 36 | + def bot_streaming(message, history): |
| 37 | + print(f"message is - {message}") |
| 38 | + print(f"history is - {history}") |
| 39 | + files = message["files"] if isinstance(message, dict) else message.files |
| 40 | + message_text = message["text"] if isinstance(message, dict) else message.text |
| 41 | + if files: |
| 42 | + # message["files"][-1] is a Dict or just a string |
| 43 | + if isinstance(files[-1], dict): |
| 44 | + image = files[-1]["path"] |
| 45 | + else: |
| 46 | + image = files[-1] if isinstance(files[-1], (list, tuple)) else files[-1].path |
| 47 | + else: |
| 48 | + # if there's no image uploaded for this turn, look for images in the past turns |
| 49 | + # kept inside tuples, take the last one |
| 50 | + for hist in history: |
| 51 | + if type(hist[0]) == tuple: |
| 52 | + image = hist[0][0] |
| 53 | + try: |
| 54 | + if image is None: |
| 55 | + # Handle the case where image is None |
| 56 | + raise gr.Error("You need to upload an image for Llama-3.2-Vision to work. Close the error and try again with an Image.") |
| 57 | + except NameError: |
| 58 | + # Handle the case where 'image' is not defined at all |
| 59 | + raise gr.Error("You need to upload an image for Llama-3.2-Vision to work. Close the error and try again with an Image.") |
| 60 | + |
| 61 | + conversation = [] |
| 62 | + flag = False |
| 63 | + for user, assistant in history: |
| 64 | + if assistant is None: |
| 65 | + # pass |
| 66 | + flag = True |
| 67 | + conversation.extend([{"role": "user", "content": []}]) |
| 68 | + continue |
| 69 | + if flag == True: |
| 70 | + conversation[0]["content"] = [{"type": "text", "content": f"{user}"}] |
| 71 | + conversation.append({"role": "assistant", "content": assistant}) |
| 72 | + flag = False |
| 73 | + continue |
| 74 | + conversation.extend([{"role": "user", "content": [{"type": "text", "content": user}]}, {"role": "assistant", "content": assistant}]) |
| 75 | + |
| 76 | + conversation.append({"role": "user", "content": [{"type": "text", "content": f"{message_text}"}, {"type": "image"}]}) |
| 77 | + prompt = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) |
| 78 | + print(f"prompt is -\n{prompt}") |
| 79 | + image = Image.open(image) |
| 80 | + image = resize_with_aspect_ratio(image) |
| 81 | + inputs = processor(prompt, image, return_tensors="pt") |
| 82 | + |
| 83 | + streamer = TextIteratorStreamer( |
| 84 | + processor, |
| 85 | + **{ |
| 86 | + "skip_special_tokens": True, |
| 87 | + "skip_prompt": True, |
| 88 | + "clean_up_tokenization_spaces": False, |
| 89 | + }, |
| 90 | + ) |
| 91 | + generation_kwargs = dict( |
| 92 | + inputs, |
| 93 | + streamer=streamer, |
| 94 | + max_new_tokens=1024, |
| 95 | + do_sample=False, |
| 96 | + temperature=0.0, |
| 97 | + eos_token_id=processor.tokenizer.eos_token_id, |
| 98 | + ) |
| 99 | + |
| 100 | + thread = Thread(target=model.generate, kwargs=generation_kwargs) |
| 101 | + thread.start() |
| 102 | + |
| 103 | + buffer = "" |
| 104 | + for new_text in streamer: |
| 105 | + buffer += new_text |
| 106 | + yield buffer |
| 107 | + |
| 108 | + demo = gr.ChatInterface( |
| 109 | + fn=bot_streaming, |
| 110 | + title=f"{model_name} with OpenVINO", |
| 111 | + examples=[ |
| 112 | + {"text": "What is the text saying?", "files": ["./small.png"]}, |
| 113 | + {"text": "What does the chart display?", "files": ["./chart.png"]}, |
| 114 | + ], |
| 115 | + description=f"{model_name} with OpenVINO. Upload an image and start chatting about it, or simply try one of the examples below. If you won't upload an image, you will receive an error.", |
| 116 | + stop_btn=None, |
| 117 | + retry_btn=None, |
| 118 | + undo_btn=None, |
| 119 | + multimodal=True, |
| 120 | + ) |
| 121 | + |
| 122 | + return demo |
0 commit comments