import gc import os import gradio as gr from llama_cpp import Llama ALPACA_SYSTEM_PROMPT = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request' ALPACA_SYSTEM_PROMPT_NO_INPUT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' DEFAULT_MODEL = 'Med-Alpaca-2-7b-chat.Q4_K_M' model_paths = { 'Med-Alpaca-2-7b-chat.Q2_K': { 'repo_id': 'minhnguyent546/Med-Alpaca-2-7b-chat-GGUF', 'filename': 'Med-Alpaca-2-7B-chat.Q2_K.gguf', }, 'Med-Alpaca-2-7b-chat.Q4_K_M': { 'repo_id': 'minhnguyent546/Med-Alpaca-2-7b-chat-GGUF', 'filename': 'Med-Alpaca-2-7B-chat.Q4_K_M.gguf', }, 'Med-Alpaca-2-7b-chat.Q6_K': { 'repo_id': 'minhnguyent546/Med-Alpaca-2-7b-chat-GGUF', 'filename': 'Med-Alpaca-2-7B-chat.Q6_K.gguf', }, 'Med-Alpaca-2-7b-chat.Q8_0': { 'repo_id': 'minhnguyent546/Med-Alpaca-2-7b-chat-GGUF', 'filename': 'Med-Alpaca-2-7B-chat.Q8_0.gguf', }, 'Med-Alpaca-2-7b-chat.F16': { 'repo_id': 'minhnguyent546/Med-Alpaca-2-7b-chat-GGUF', 'filename': 'Med-Alpaca-2-7B-chat.F16.gguf', }, } model = None def generate_alpaca_prompt( instruction: str, input: str | None = None, response: str = '', ) -> str: prompt = '' if input is not None and input and input.strip() != '': prompt = ( f'{ALPACA_SYSTEM_PROMPT}\n\n' f'### Instruction:\n' f'{instruction}\n\n' f'### Input:\n' f'{input}\n\n' f'### Response: ' f'{response}' ) else: prompt = ( f'{ALPACA_SYSTEM_PROMPT_NO_INPUT}\n\n' f'### Instruction:\n' f'{instruction}\n\n' f'### Response: ' f'{response}' ) return prompt.strip() def chat_completion( message, history, seed: int, max_new_tokens: int, temperature: float, repeatition_penalty: float, top_k: int, top_p: float, ): if model is None: reload_model(DEFAULT_MODEL) prompt = generate_alpaca_prompt(instruction=message) response_iterator = model( prompt, stream=True, seed=seed, max_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repeat_penalty=repeatition_penalty, ) partial_response = '' for token in response_iterator: partial_response += token['choices'][0]['text'] yield partial_response def reload_model(model_name: str): global model if 'model' in globals(): del model gc.collect() model = Llama.from_pretrained( **model_paths[model_name], n_ctx=4096, n_threads=4, cache_dir='./.hf_cache' ) app_title_mark = gr.Markdown(f"""
{model_name}
""") chatbot = gr.Chatbot( type='messages', height=500, placeholder='Hi doctor, I have a headache, what should I do?', label=model_name, avatar_images=[None, 'https://raw.githubusercontent.com/minhnguyent546/medical-llama2/demo/assets/med_alpaca.png'], # pyright: ignore[reportArgumentType] ) return app_title_mark, chatbot def main() -> None: with gr.Blocks(theme=gr.themes.Ocean()) as demo: app_title_mark = gr.Markdown(f"""
{DEFAULT_MODEL}
""") model_options = list(model_paths.keys()) with gr.Row(): with gr.Column(scale=2): with gr.Row(): model_radio = gr.Radio(choices=model_options, label='Model', value=DEFAULT_MODEL) with gr.Row(): seed = gr.Number(value=998244353, label='Seed') max_new_tokens = gr.Number(value=512, minimum=64, maximum=2048, label='Max new tokens') with gr.Row(): temperature = gr.Slider(0, 2, step=0.01, label='Temperature', value=0.6) repeatition_penalty = gr.Slider(0.01, 5, step=0.05, label='Repetition penalty', value=1.1) with gr.Row(): top_k = gr.Slider(1, 100, step=1, label='Top k', value=40) top_p = gr.Slider(0, 1, step=0.01, label='Top p', value=0.9) with gr.Column(scale=5): chatbot = gr.Chatbot( type='messages', height=500, placeholder='Hi doctor, I have a headache, what should I do?', label=DEFAULT_MODEL, avatar_images=[None, 'https://raw.githubusercontent.com/minhnguyent546/medical-llama2/demo/assets/med_alpaca.png'], # pyright: ignore[reportArgumentType] ) textbox = gr.Textbox( placeholder='Hi doctor, I have a headache, what should I do?', container=False, submit_btn=True, stop_btn=True, ) chat_interface = gr.ChatInterface( chat_completion, type='messages', chatbot=chatbot, textbox=textbox, additional_inputs=[ seed, max_new_tokens, temperature, repeatition_penalty, top_k, top_p, ], ) model_radio.change(reload_model, inputs=[model_radio], outputs=[app_title_mark, chatbot]) demo.queue(api_open=False, default_concurrency_limit=20) demo.launch(max_threads=5, share=os.environ.get('GRADIO_SHARE', False)) if __name__ == '__main__': main()