miracFence's picture
Update app.py
2ec390b verified
raw
history blame
3.55 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
import torch
import spaces
import os
from threading import Thread
from typing import Iterator
# Define quantization configuration
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # Specify 4-bit quantization
bnb_4bit_use_double_quant=True, # Use double quantization for better efficiency
bnb_4bit_quant_type="nf4", # Set the quantization type to NF4
bnb_4bit_compute_dtype=torch.float16 # Use float16 for computations
)
# Load the tokenizer and quantized model from Hugging Face
model_name = "llSourcell/medllama2_7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(model_name,
quantization_config=quantization_config,
device_map="auto")
model.eval()
max_token_length = 4096
@spaces.GPU(duration=15)
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
for user, assistant in chat_history:
conversation.extend(
[
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
)
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > max_token_length:
input_ids = input_ids[:, -max_token_length:]
gr.Warning(f"Trimmed input from conversation as it was longer than {max_token_length} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Define the Gradio ChatInterface
chatbot = gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(
height="64vh"
),
additional_inputs=[
gr.Textbox(
"Behave as if you are a medical doctor providing answers for patients' clinical questions.",
label="System Prompt"
)
],
title="Medical QA Chat",
description="Feel free to ask any question to Medllama2 Chatbot.",
theme="soft",
submit_btn="Send",
retry_btn="Regenerate Response",
undo_btn="Delete Previous",
clear_btn="Clear Chat"
)
# Following line is important to queue the messages
chatbot.queue()
# Enable share = True if you want to create a public link for people to use your application
chatbot.launch()