Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
import torch | |
import spaces | |
# Define quantization configuration | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, # Specify 4-bit quantization | |
bnb_4bit_use_double_quant=True, # Use double quantization for better efficiency | |
bnb_4bit_quant_type="nf4", # Set the quantization type to NF4 | |
bnb_4bit_compute_dtype=torch.float16 # Use float16 for computations | |
) | |
# Load the tokenizer and quantized model from Hugging Face | |
model_name = "llSourcell/medllama2_7b" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load model with quantization | |
model = AutoModelForCausalLM.from_pretrained(model_name, | |
quantization_config=quantization_config, | |
device_map="auto") | |
model.eval() | |
def format_history(msg: str, history: list[list[str, str]], system_prompt: str): | |
chat_history = system_prompt | |
for query, response in history: | |
chat_history += f"\nUser: {query}\nAssistant: {response}" | |
chat_history += f"\nUser: {msg}\nAssistant:" | |
return chat_history | |
def generate_response(msg: str, history: list[list[str, str]], system_prompt: str): | |
chat_history = format_history(msg, history, system_prompt) | |
# Tokenize the input prompt | |
inputs = tokenizer(chat_history, return_tensors="pt").to("cuda") | |
# Generate a response using the model | |
outputs = model.generate(inputs["input_ids"], max_length=1024, pad_token_id=tokenizer.eos_token_id) | |
# Decode the response back to a string | |
response = tokenizer.decode(outputs[:, inputs["input_ids"].shape[-1]:][0], skip_special_tokens=True) | |
# Yield the generated response | |
yield response | |
# Define the Gradio ChatInterface | |
chatbot = gr.ChatInterface( | |
generate_response, | |
chatbot=gr.Chatbot( | |
height="64vh" | |
), | |
additional_inputs=[ | |
gr.Textbox( | |
"Behave as if you are a medical doctor providing answers for patients' clinical questions.", | |
label="System Prompt" | |
) | |
], | |
title="Medical QA Chat", | |
description="Feel free to ask any question to Medllama2 Chatbot.", | |
theme="soft", | |
submit_btn="Send", | |
retry_btn="Regenerate Response", | |
undo_btn="Delete Previous", | |
clear_btn="Clear Chat" | |
) | |
# Following line is important to queue the messages | |
chatbot.queue() | |
# Enable share = True if you want to create a public link for people to use your application | |
chatbot.launch() | |