|
| 1 | +import argparse |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +import app |
| 5 | +import convert_and_optimize_asr as asr |
| 6 | +import convert_and_optimize_chat as chat |
| 7 | + |
| 8 | + |
| 9 | +def main(args): |
| 10 | + asr_model_dir = asr.convert_asr_model(args.asr_model_type, args.asr_precision, Path(args.model_dir)) |
| 11 | + chat_model_dir = chat.convert_chat_model(args.chat_model_type, args.chat_precision, Path(args.model_dir)) |
| 12 | + |
| 13 | + app.run(asr_model_dir, chat_model_dir, args.public) |
| 14 | + |
| 15 | + |
| 16 | +if __name__ == '__main__': |
| 17 | + parser = argparse.ArgumentParser() |
| 18 | + |
| 19 | + parser.add_argument("--asr_model_type", type=str, choices=["distil-whisper-large-v3", "belle-distilwhisper-large-v2-zh"], |
| 20 | + default="distil-whisper-large-v3", help="Speech recognition model to be converted") |
| 21 | + parser.add_argument("--asr_precision", type=str, default="fp16", choices=["fp16", "int8"], help="ASR model precision") |
| 22 | + parser.add_argument("--chat_model_type", type=str, choices=["llama3.1-8B", "llama3-8B", "qwen2-7B", "llama3.2-3B"], |
| 23 | + default="llama3.2-3B", help="Chat model to be converted") |
| 24 | + parser.add_argument("--chat_precision", type=str, default="int4", choices=["fp16", "int8", "int4"], help="Chat model precision") |
| 25 | + parser.add_argument("--model_dir", type=str, default="model", help="Directory to place the model in") |
| 26 | + parser.add_argument('--public', default=False, action="store_true", help="Whether interface should be available publicly") |
| 27 | + |
| 28 | + main(parser.parse_args()) |
0 commit comments