Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer | |
import torch | |
import spaces | |
import os | |
from threading import Thread | |
from typing import Iterator | |
# Define quantization configuration | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, # Specify 4-bit quantization | |
bnb_4bit_use_double_quant=True, # Use double quantization for better efficiency | |
bnb_4bit_quant_type="nf4", # Set the quantization type to NF4 | |
bnb_4bit_compute_dtype=torch.float16 # Use float16 for computations | |
) | |
# Load the tokenizer and quantized model from Hugging Face | |
model_name = "llSourcell/medllama2_7b" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Load model with quantization | |
model = AutoModelForCausalLM.from_pretrained(model_name, | |
quantization_config=quantization_config, | |
device_map="auto") | |
model.eval() | |
max_token_length = 4096 | |
def generate( | |
message: str, | |
chat_history: list[tuple[str, str]], | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
conversation = [] | |
for user, assistant in chat_history: | |
conversation.extend( | |
[ | |
{"role": "user", "content": user}, | |
{"role": "assistant", "content": assistant}, | |
] | |
) | |
conversation.append({"role": "user", "content": message}) | |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") | |
if input_ids.shape[1] > max_token_length: | |
input_ids = input_ids[:, -max_token_length:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {max_token_length} tokens.") | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
{"input_ids": input_ids}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=repetition_penalty, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
# Define the Gradio ChatInterface | |
chatbot = gr.ChatInterface( | |
fn=generate, | |
chatbot=gr.Chatbot( | |
height="64vh" | |
), | |
additional_inputs=[ | |
gr.Textbox( | |
"Behave as if you are a medical doctor providing answers for patients' clinical questions.", | |
label="System Prompt" | |
) | |
], | |
title="Medical QA Chat", | |
description="Feel free to ask any question to Medllama2 Chatbot.", | |
theme="soft", | |
submit_btn="Send", | |
retry_btn="Regenerate Response", | |
undo_btn="Delete Previous", | |
clear_btn="Clear Chat" | |
) | |
# Following line is important to queue the messages | |
chatbot.queue() | |
# Enable share = True if you want to create a public link for people to use your application | |
chatbot.launch() | |