|
| 1 | +import gradio as gr |
| 2 | +import traceback |
| 3 | +import re |
| 4 | +from PIL import Image |
| 5 | + |
| 6 | + |
| 7 | +ERROR_MSG = "Error, please retry" |
| 8 | +model_name = "MiniCPM-V 2.0" |
| 9 | + |
| 10 | +form_radio = {"choices": ["Beam Search", "Sampling"], "value": "Sampling", "interactive": True, "label": "Decode Type"} |
| 11 | +# Beam Form |
| 12 | +num_beams_slider = {"minimum": 0, "maximum": 5, "value": 1, "step": 1, "interactive": True, "label": "Num Beams"} |
| 13 | +repetition_penalty_slider = {"minimum": 0, "maximum": 3, "value": 1.2, "step": 0.01, "interactive": True, "label": "Repetition Penalty"} |
| 14 | +repetition_penalty_slider2 = {"minimum": 0, "maximum": 3, "value": 1.05, "step": 0.01, "interactive": True, "label": "Repetition Penalty"} |
| 15 | +max_new_tokens_slider = {"minimum": 1, "maximum": 4096, "value": 1024, "step": 1, "interactive": True, "label": "Max New Tokens"} |
| 16 | + |
| 17 | +top_p_slider = {"minimum": 0, "maximum": 1, "value": 0.8, "step": 0.05, "interactive": True, "label": "Top P"} |
| 18 | +top_k_slider = {"minimum": 0, "maximum": 200, "value": 100, "step": 1, "interactive": True, "label": "Top K"} |
| 19 | +temperature_slider = {"minimum": 0, "maximum": 2, "value": 0.7, "step": 0.05, "interactive": True, "label": "Temperature"} |
| 20 | + |
| 21 | + |
| 22 | +def create_component(params, comp="Slider"): |
| 23 | + if comp == "Slider": |
| 24 | + return gr.Slider( |
| 25 | + minimum=params["minimum"], |
| 26 | + maximum=params["maximum"], |
| 27 | + value=params["value"], |
| 28 | + step=params["step"], |
| 29 | + interactive=params["interactive"], |
| 30 | + label=params["label"], |
| 31 | + ) |
| 32 | + elif comp == "Radio": |
| 33 | + return gr.Radio(choices=params["choices"], value=params["value"], interactive=params["interactive"], label=params["label"]) |
| 34 | + elif comp == "Button": |
| 35 | + return gr.Button(value=params["value"], interactive=True) |
| 36 | + |
| 37 | + |
| 38 | +def upload_img(image, _chatbot, _app_session): |
| 39 | + image = Image.fromarray(image) |
| 40 | + |
| 41 | + _app_session["sts"] = None |
| 42 | + _app_session["ctx"] = [] |
| 43 | + _app_session["img"] = image |
| 44 | + _chatbot.append(("", "Image uploaded successfully, you can talk to me now")) |
| 45 | + return _chatbot, _app_session |
| 46 | + |
| 47 | + |
| 48 | +def make_demo(model): |
| 49 | + def chat(img, msgs, ctx, params=None, vision_hidden_states=None): |
| 50 | + tokenizer = model.processor.tokenizer |
| 51 | + default_params = {"num_beams": 3, "repetition_penalty": 1.2, "max_new_tokens": 1024} |
| 52 | + if params is None: |
| 53 | + params = default_params |
| 54 | + if img is None: |
| 55 | + return -1, "Error, invalid image, please upload a new image", None, None |
| 56 | + try: |
| 57 | + image = img.convert("RGB") |
| 58 | + generation_params = {"image": image, "msgs": msgs, "context": None, "tokenizer": tokenizer, "stream": True, **params} |
| 59 | + streamer = model.chat(**generation_params) |
| 60 | + |
| 61 | + buffer = "" |
| 62 | + |
| 63 | + for res in streamer: |
| 64 | + res = re.sub(r"(<box>.*</box>)", "", res) |
| 65 | + res = res.replace("<ref>", "") |
| 66 | + res = res.replace("</ref>", "") |
| 67 | + res = res.replace("<box>", "") |
| 68 | + new_text = res.replace("</box>", "") |
| 69 | + buffer += new_text |
| 70 | + yield -1, buffer, None, None |
| 71 | + except Exception as err: |
| 72 | + print(err) |
| 73 | + traceback.print_exc() |
| 74 | + return -1, ERROR_MSG, None, None |
| 75 | + |
| 76 | + def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature): |
| 77 | + if _app_cfg.get("ctx", None) is None: |
| 78 | + _chat_bot.append((_question, "Please upload an image to start")) |
| 79 | + return "", _chat_bot, _app_cfg |
| 80 | + |
| 81 | + _context = _app_cfg["ctx"].copy() |
| 82 | + if _context: |
| 83 | + _context.append({"role": "user", "content": _question}) |
| 84 | + else: |
| 85 | + _context = [{"role": "user", "content": _question}] |
| 86 | + |
| 87 | + if params_form == "Beam Search": |
| 88 | + params = {"sampling": False, "num_beams": num_beams, "repetition_penalty": repetition_penalty, "max_new_tokens": 896} |
| 89 | + else: |
| 90 | + params = { |
| 91 | + "sampling": True, |
| 92 | + "top_p": top_p, |
| 93 | + "top_k": top_k, |
| 94 | + "temperature": temperature, |
| 95 | + "repetition_penalty": repetition_penalty_2, |
| 96 | + "max_new_tokens": 896, |
| 97 | + } |
| 98 | + |
| 99 | + _context.append({"role": "assistant", "content": ""}) |
| 100 | + _chat_bot.append([_question, ""]) |
| 101 | + for code, _answer, _, sts in chat(_app_cfg["img"], _context, None, params): |
| 102 | + _context[-1]["content"] = _answer |
| 103 | + _chat_bot[-1][-1] = _answer |
| 104 | + |
| 105 | + if code == 0: |
| 106 | + _app_cfg["ctx"] = _context |
| 107 | + _app_cfg["sts"] = sts |
| 108 | + yield "", _chat_bot, _app_cfg |
| 109 | + |
| 110 | + def regenerate_button_clicked(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature): |
| 111 | + if len(_chat_bot) <= 1: |
| 112 | + _chat_bot.append(("Regenerate", "No question for regeneration.")) |
| 113 | + return "", _chat_bot, _app_cfg |
| 114 | + elif _chat_bot[-1][0] == "Regenerate": |
| 115 | + return "", _chat_bot, _app_cfg |
| 116 | + else: |
| 117 | + _question = _chat_bot[-1][0] |
| 118 | + _chat_bot = _chat_bot[:-1] |
| 119 | + _app_cfg["ctx"] = _app_cfg["ctx"][:-2] |
| 120 | + for text, _chatbot, _app_cfg in respond( |
| 121 | + _question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature |
| 122 | + ): |
| 123 | + yield text, _chatbot, _app_cfg |
| 124 | + |
| 125 | + with gr.Blocks() as demo: |
| 126 | + with gr.Row(): |
| 127 | + with gr.Column(scale=1, min_width=300): |
| 128 | + params_form = create_component(form_radio, comp="Radio") |
| 129 | + with gr.Accordion("Beam Search") as beams_according: |
| 130 | + num_beams = create_component(num_beams_slider) |
| 131 | + repetition_penalty = create_component(repetition_penalty_slider) |
| 132 | + with gr.Accordion("Sampling") as sampling_according: |
| 133 | + top_p = create_component(top_p_slider) |
| 134 | + top_k = create_component(top_k_slider) |
| 135 | + temperature = create_component(temperature_slider) |
| 136 | + repetition_penalty_2 = create_component(repetition_penalty_slider2) |
| 137 | + regenerate = create_component({"value": "Regenerate"}, comp="Button") |
| 138 | + with gr.Column(scale=3, min_width=500): |
| 139 | + app_session = gr.State({"sts": None, "ctx": None, "img": None}) |
| 140 | + bt_pic = gr.Image(label="Upload an image to start") |
| 141 | + chat_bot = gr.Chatbot(label=f"Chat with {model_name}") |
| 142 | + txt_message = gr.Textbox(label="Input text") |
| 143 | + |
| 144 | + regenerate.click( |
| 145 | + regenerate_button_clicked, |
| 146 | + [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature], |
| 147 | + [txt_message, chat_bot, app_session], |
| 148 | + ) |
| 149 | + txt_message.submit( |
| 150 | + respond, |
| 151 | + [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature], |
| 152 | + [txt_message, chat_bot, app_session], |
| 153 | + ) |
| 154 | + bt_pic.upload(lambda: None, None, chat_bot, queue=False).then( |
| 155 | + upload_img, inputs=[bt_pic, chat_bot, app_session], outputs=[chat_bot, app_session] |
| 156 | + ) |
| 157 | + |
| 158 | + return demo |
0 commit comments