|
121 | 121 | " \"CACHE_DIR\": os.path.join(save_name, \"model_cache\"), # OpenVINO will use this directory as cache\n",
|
122 | 122 | " },\n",
|
123 | 123 | " \"compile\": False,\n",
|
124 |
| - " \"quantization_config\": quantization_config,\n", |
| 124 | + " \"quantization_config\": quantization_config\n", |
125 | 125 | "}\n",
|
126 | 126 | "\n",
|
127 | 127 | "# Check whether the model was already exported\n",
|
|
143 | 143 | "\n",
|
144 | 144 | "# TODO Optional: export to huggingface/hub\n",
|
145 | 145 | "\n",
|
146 |
| - "model_size = os.stat(os.path.join(save_name, \"openvino_model.bin\")).st_size / 1024**3\n", |
147 |
| - "print(f\"Model size in FP32: ~5.4GB, current model size in 4bit: {model_size:.2f}GB\")" |
| 146 | + "model_size = os.stat(os.path.join(save_name, \"openvino_model.bin\")).st_size / 1024 ** 3\n", |
| 147 | + "print(f'Model size in FP32: ~5.4GB, current model size in 4bit: {model_size:.2f}GB')" |
148 | 148 | ]
|
149 | 149 | },
|
150 | 150 | {
|
|
212 | 212 | "from transformers import TextStreamer\n",
|
213 | 213 | "\n",
|
214 | 214 | "# Tokenize the sample\n",
|
215 |
| - "inputs = tokenizer([sample], return_tensors=\"pt\")\n", |
| 215 | + "inputs = tokenizer([sample], return_tensors='pt')\n", |
216 | 216 | "\n",
|
217 | 217 | "# Call generate on the inputs\n",
|
218 | 218 | "out = model.generate(\n",
|
|
294 | 294 | "\n",
|
295 | 295 | "\n",
|
296 | 296 | "# Tokenize the sample\n",
|
297 |
| - "inputs = tokenizer([sample], return_tensors=\"pt\")\n", |
| 297 | + "inputs = tokenizer([sample], return_tensors='pt') \n", |
298 | 298 | "\n",
|
299 | 299 | "out = stateless_model.generate(\n",
|
300 | 300 | " **inputs,\n",
|
301 | 301 | " max_new_tokens=128,\n",
|
302 | 302 | " streamer=TextStreamer(tokenizer=tokenizer, skip_special_tokens=True),\n",
|
303 | 303 | " pad_token_id=tokenizer.eos_token_id,\n",
|
304 | 304 | " prompt_lookup_num_tokens=3,\n",
|
305 |
| - ")" |
| 305 | + ") " |
306 | 306 | ]
|
307 | 307 | },
|
308 | 308 | {
|
|
358 | 358 | " \"CACHE_DIR\": os.path.join(save_name, \"model_cache\"), # OpenVINO will use this directory as cache\n",
|
359 | 359 | " },\n",
|
360 | 360 | " \"compile\": False,\n",
|
361 |
| - " \"quantization_config\": quantization_config,\n", |
| 361 | + " \"quantization_config\": quantization_config\n", |
362 | 362 | "}\n",
|
363 | 363 | "\n",
|
364 | 364 | "# Check whether the model was already exported\n",
|
|
458 | 458 | " if len(self.seq_lens) > 0 or len(self.win_sizes) > 0:\n",
|
459 | 459 | " raise RuntimeError(\"Always use a new instance, don't reuse!\")\n",
|
460 | 460 | " self.model_forward = self.model.forward\n",
|
461 |
| - "\n", |
| 461 | + " \n", |
462 | 462 | " @wraps(self.model_forward)\n",
|
463 | 463 | " def forward_wrapper(**kwargs):\n",
|
464 | 464 | " self.seq_lens[-1].append(kwargs.get(\"attention_mask\").shape[-1])\n",
|
465 | 465 | " self.win_sizes[-1].append(kwargs.get(\"input_ids\").shape[-1] - 1)\n",
|
466 | 466 | " return self.model_forward(**kwargs)\n",
|
467 |
| - "\n", |
| 467 | + " \n", |
468 | 468 | " self.model.forward = forward_wrapper\n",
|
469 |
| - "\n", |
| 469 | + " \n", |
470 | 470 | " # wrap generate method\n",
|
471 | 471 | " self.model_generate = self.model.generate\n",
|
472 | 472 | "\n",
|
|
479 | 479 | " out = self.model_generate(*args, **kwargs)\n",
|
480 | 480 | " self.seq_lens[-1].append(out.shape[-1])\n",
|
481 | 481 | " return out\n",
|
482 |
| - "\n", |
483 | 482 | " self.model.generate = generate_wrapper\n",
|
484 | 483 | " return self\n",
|
485 | 484 | "\n",
|
486 |
| - " def __exit__(self, type, value, traceback):\n", |
| 485 | + " def __exit__(self, type, value, traceback):\n", |
487 | 486 | " self.model.forward = self.model_forward\n",
|
488 | 487 | " self.model.generate = self.model_generate\n",
|
489 | 488 | " self.model_forward = None\n",
|
|
495 | 494 | " self.seq_lens = [sl[1:] for sl in self.seq_lens]\n",
|
496 | 495 | " # Add window size for output to ease calculation later\n",
|
497 | 496 | " for ws, sl in zip(self.win_sizes, self.seq_lens):\n",
|
498 |
| - " ws.append(0)\n", |
| 497 | + " ws.append(0) \n", |
499 | 498 | "\n",
|
500 | 499 | " def acceptance_rate(self, return_mean=True, normalize=False):\n",
|
501 | 500 | " # ar_per_win = ((cur_seq_len - cur_win_size) - (prev_seq_len - prev_win_size) - 1) / prev_win_size\n",
|
|
504 | 503 | " sl = np.array(sl, dtype=np.float64)\n",
|
505 | 504 | " ws = np.array(ws, dtype=np.float64)\n",
|
506 | 505 | " out_lens = sl - ws\n",
|
507 |
| - " accepted = out_lens[1:] - out_lens[:-1] - 1\n", |
508 |
| - " ar_per_win.append(np.divide(accepted, ws[:-1], out=np.zeros_like(accepted), where=ws[:-1] != 0))\n", |
| 506 | + " accepted = (out_lens[1:] - out_lens[:-1] - 1)\n", |
| 507 | + " ar_per_win.append(np.divide(accepted, ws[:-1],\n", |
| 508 | + " out=np.zeros_like(accepted),where=ws[:-1] != 0))\n", |
509 | 509 | " ar_per_win = np.hstack(ar_per_win)\n",
|
510 | 510 | " # Normalized AR doesn't take into account windows with size 0\n",
|
511 | 511 | " if normalize:\n",
|
|
544 | 544 | "samples_number = 30\n",
|
545 | 545 | "with AcceptanceRateRecorder(stateless_model) as ar_recorder:\n",
|
546 | 546 | " for text in tqdm(dataset[:samples_number]):\n",
|
547 |
| - " tokenized_prompt = tokenizer([prompt_template.format(text=text)], return_tensors=\"pt\")\n", |
| 547 | + " tokenized_prompt = tokenizer([prompt_template.format(text=text)], return_tensors='pt')\n", |
548 | 548 | " stateless_model.generate(\n",
|
549 | 549 | " **tokenized_prompt,\n",
|
550 | 550 | " max_new_tokens=128,\n",
|
|
623 | 623 | " return False\n",
|
624 | 624 | "\n",
|
625 | 625 | "\n",
|
| 626 | + "\n", |
626 | 627 | "# Set the chat template to the tokenizer. The chat template implements the simple template of\n",
|
627 | 628 | "# User: content\n",
|
628 | 629 | "# Assistant: content\n",
|
|
650 | 651 | " if model_msg:\n",
|
651 | 652 | " messages.append({\"role\": \"Assistant\", \"content\": model_msg})\n",
|
652 | 653 | " input_token = tokenizer.apply_chat_template(\n",
|
653 |
| - " messages, add_generation_prompt=True, tokenize=True, return_tensors=\"pt\", return_dict=True\n", |
| 654 | + " messages,\n", |
| 655 | + " add_generation_prompt=True,\n", |
| 656 | + " tokenize=True,\n", |
| 657 | + " return_tensors=\"pt\",\n", |
| 658 | + " return_dict=True\n", |
654 | 659 | " )\n",
|
655 | 660 | " return input_token\n",
|
656 | 661 | "\n",
|
|
674 | 679 | " # Construct the input message string for the model by concatenating the current system message and conversation history\n",
|
675 | 680 | " # Tokenize the messages string\n",
|
676 | 681 | " inputs = prepare_history_for_model(history)\n",
|
677 |
| - " input_length = inputs[\"input_ids\"].shape[1]\n", |
| 682 | + " input_length = inputs['input_ids'].shape[1]\n", |
678 | 683 | " # truncate input in case it is too long.\n",
|
679 | 684 | " # TODO improve this\n",
|
680 | 685 | " if input_length > 2000:\n",
|
681 | 686 | " history = [history[-1]]\n",
|
682 | 687 | " inputs = prepare_history_for_model(history)\n",
|
683 |
| - " input_length = inputs[\"input_ids\"].shape[1]\n", |
| 688 | + " input_length = inputs['input_ids'].shape[1]\n", |
684 | 689 | "\n",
|
685 | 690 | " prompt_char = \"▌\"\n",
|
686 | 691 | " history[-1][1] = prompt_char\n",
|
687 | 692 | " yield history, \"Status: Generating...\", *([gr.update(interactive=False)] * 4)\n",
|
688 |
| - "\n", |
| 693 | + " \n", |
689 | 694 | " streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
|
690 | 695 | "\n",
|
691 | 696 | " # Create a stopping criteria to prevent the model from playing the role of the user aswell.\n",
|
|
701 | 706 | " eos_token_id=[tokenizer.eos_token_id],\n",
|
702 | 707 | " pad_token_id=tokenizer.eos_token_id,\n",
|
703 | 708 | " )\n",
|
704 |
| - " generate_kwargs = (\n", |
705 |
| - " dict(\n", |
706 |
| - " streamer=streamer,\n", |
707 |
| - " generation_config=generation_config,\n", |
708 |
| - " stopping_criteria=stopping_criteria,\n", |
709 |
| - " )\n", |
710 |
| - " | inputs\n", |
711 |
| - " )\n", |
| 709 | + " generate_kwargs = dict(\n", |
| 710 | + " streamer=streamer,\n", |
| 711 | + " generation_config=generation_config,\n", |
| 712 | + " stopping_criteria=stopping_criteria,\n", |
| 713 | + " ) | inputs\n", |
712 | 714 | "\n",
|
713 | 715 | " if assisted:\n",
|
714 | 716 | " target_generate = stateless_model.generate\n",
|
|
735 | 737 | " yield history, \"Status: Generating...\", *([gr.update(interactive=False)] * 4)\n",
|
736 | 738 | " history[-1][1] = partial_text\n",
|
737 | 739 | " generation_time = time.perf_counter() - start\n",
|
738 |
| - " yield history, f\"Generation time: {generation_time:.2f} sec\", *([gr.update(interactive=True)] * 4)" |
| 740 | + " yield history, f'Generation time: {generation_time:.2f} sec', *([gr.update(interactive=True)] * 4)" |
739 | 741 | ]
|
740 | 742 | },
|
741 | 743 | {
|
|
779 | 781 | " [\"Can you explain to me briefly what is Python programming language?\"],\n",
|
780 | 782 | " [\"Explain the plot of Cinderella in a sentence.\"],\n",
|
781 | 783 | " [\"Write a Python function to perform binary search over a sorted list. Use markdown to write code\"],\n",
|
782 |
| - " [\n", |
783 |
| - " \"Lily has a rubber ball that she drops from the top of a wall. The wall is 2 meters tall. How long will it take for the ball to reach the ground?\"\n", |
784 |
| - " ],\n", |
| 784 | + " [\"Lily has a rubber ball that she drops from the top of a wall. The wall is 2 meters tall. How long will it take for the ball to reach the ground?\"],\n", |
785 | 785 | "]\n",
|
786 | 786 | "\n",
|
787 | 787 | "\n",
|
|
797 | 797 | " \"\"\"\n",
|
798 | 798 | " # Append current user message to history with a blank assistant message which will be generated by the model\n",
|
799 | 799 | " history.append([message, None])\n",
|
800 |
| - " return (\"\", history)\n", |
| 800 | + " return ('', history)\n", |
801 | 801 | "\n",
|
802 | 802 | "\n",
|
803 | 803 | "def prepare_for_regenerate(history):\n",
|
|
808 | 808 | " history: conversation history\n",
|
809 | 809 | " Returns:\n",
|
810 | 810 | " updated history\n",
|
811 |
| - " \"\"\"\n", |
| 811 | + " \"\"\" \n", |
812 | 812 | " history[-1][1] = None\n",
|
813 | 813 | " return history\n",
|
814 | 814 | "\n",
|
|
821 | 821 | " msg = gr.Textbox(placeholder=\"Enter message here...\", show_label=False, autofocus=True, scale=75)\n",
|
822 | 822 | " status = gr.Textbox(\"Status: Idle\", show_label=False, max_lines=1, scale=15)\n",
|
823 | 823 | " with gr.Row():\n",
|
824 |
| - " submit = gr.Button(\"Submit\", variant=\"primary\")\n", |
| 824 | + " submit = gr.Button(\"Submit\", variant='primary')\n", |
825 | 825 | " regenerate = gr.Button(\"Regenerate\")\n",
|
826 | 826 | " clear = gr.Button(\"Clear\")\n",
|
827 | 827 | " with gr.Accordion(\"Advanced Options:\", open=False):\n",
|
|
860 | 860 | " step=0.1,\n",
|
861 | 861 | " interactive=True,\n",
|
862 | 862 | " )\n",
|
863 |
| - " gr.Examples(EXAMPLES, inputs=msg, label=\"Click on any example and press the 'Submit' button\")\n", |
| 863 | + " gr.Examples(\n", |
| 864 | + " EXAMPLES, inputs=msg, label=\"Click on any example and press the 'Submit' button\"\n", |
| 865 | + " )\n", |
864 | 866 | "\n",
|
865 | 867 | " # Sets generate function to be triggered when the user submit a new message\n",
|
866 | 868 | " gr.on(\n",
|
|
874 | 876 | " inputs=[chatbot, temperature, max_new_tokens, top_p, repetition_penalty, assisted],\n",
|
875 | 877 | " outputs=[chatbot, status, msg, submit, regenerate, clear],\n",
|
876 | 878 | " concurrency_limit=1,\n",
|
877 |
| - " queue=True,\n", |
| 879 | + " queue=True\n", |
878 | 880 | " )\n",
|
879 |
| - " regenerate.click(fn=prepare_for_regenerate, inputs=chatbot, outputs=chatbot, queue=True, concurrency_limit=1).then(\n", |
| 881 | + " regenerate.click(\n", |
| 882 | + " fn=prepare_for_regenerate,\n", |
| 883 | + " inputs=chatbot,\n", |
| 884 | + " outputs=chatbot,\n", |
| 885 | + " queue=True,\n", |
| 886 | + " concurrency_limit=1\n", |
| 887 | + " ).then(\n", |
880 | 888 | " fn=generate,\n",
|
881 | 889 | " inputs=[chatbot, temperature, max_new_tokens, top_p, repetition_penalty, assisted],\n",
|
882 | 890 | " outputs=[chatbot, status, msg, submit, regenerate, clear],\n",
|
883 | 891 | " concurrency_limit=1,\n",
|
884 |
| - " queue=True,\n", |
| 892 | + " queue=True\n", |
885 | 893 | " )\n",
|
886 | 894 | " clear.click(fn=lambda: (None, \"Status: Idle\"), inputs=None, outputs=[chatbot, status], queue=False)"
|
887 | 895 | ]
|
|
0 commit comments