File size: 3,554 Bytes
4bedd8f
d7db62d
d212956
3df4d21
d212956
d7db62d
 
 
 
142d801
 
 
 
 
 
 
 
d212956
 
 
d7db62d
d212956
 
 
142d801
 
3df4d21
d5c8018
d212956
2ec390b
a295415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5c8018
 
 
a295415
d212956
a295415
d7db62d
 
 
 
 
 
 
 
 
 
 
 
 
a295415
d7db62d
 
 
 
d212956
 
 
d7db62d
d212956
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bedd8f
 
d212956
 
4bedd8f
d212956
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
import torch
import spaces

import os
from threading import Thread
from typing import Iterator

# Define quantization configuration
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,  # Specify 4-bit quantization
    bnb_4bit_use_double_quant=True,  # Use double quantization for better efficiency
    bnb_4bit_quant_type="nf4",  # Set the quantization type to NF4
    bnb_4bit_compute_dtype=torch.float16  # Use float16 for computations
)

# Load the tokenizer and quantized model from Hugging Face
model_name = "llSourcell/medllama2_7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(model_name, 
                                             quantization_config=quantization_config, 
                                             device_map="auto")
model.eval()
max_token_length = 4096

@spaces.GPU(duration=15)
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    for user, assistant in chat_history:
        conversation.extend(
            [
                {"role": "user", "content": user},
                {"role": "assistant", "content": assistant},
            ]
        )
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > max_token_length:
        input_ids = input_ids[:, -max_token_length:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {max_token_length} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

# Define the Gradio ChatInterface
chatbot = gr.ChatInterface(
                fn=generate,
                chatbot=gr.Chatbot(
                        height="64vh"
                    ),
                additional_inputs=[
                    gr.Textbox(
                        "Behave as if you are a medical doctor providing answers for patients' clinical questions.",
                        label="System Prompt"
                    )
                ],
                title="Medical QA Chat",
                description="Feel free to ask any question to Medllama2 Chatbot.",
                theme="soft",
                submit_btn="Send",
                retry_btn="Regenerate Response",
                undo_btn="Delete Previous",
                clear_btn="Clear Chat"
)

# Following line is important to queue the messages
chatbot.queue()

# Enable share = True if you want to create a public link for people to use your application
chatbot.launch()