Spaces:
Sleeping
Sleeping
File size: 3,554 Bytes
4bedd8f d7db62d d212956 3df4d21 d212956 d7db62d 142d801 d212956 d7db62d d212956 142d801 3df4d21 d5c8018 d212956 2ec390b a295415 d5c8018 a295415 d212956 a295415 d7db62d a295415 d7db62d d212956 d7db62d d212956 4bedd8f d212956 4bedd8f d212956 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
import torch
import spaces
import os
from threading import Thread
from typing import Iterator
# Define quantization configuration
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # Specify 4-bit quantization
bnb_4bit_use_double_quant=True, # Use double quantization for better efficiency
bnb_4bit_quant_type="nf4", # Set the quantization type to NF4
bnb_4bit_compute_dtype=torch.float16 # Use float16 for computations
)
# Load the tokenizer and quantized model from Hugging Face
model_name = "llSourcell/medllama2_7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(model_name,
quantization_config=quantization_config,
device_map="auto")
model.eval()
max_token_length = 4096
@spaces.GPU(duration=15)
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
for user, assistant in chat_history:
conversation.extend(
[
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
)
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > max_token_length:
input_ids = input_ids[:, -max_token_length:]
gr.Warning(f"Trimmed input from conversation as it was longer than {max_token_length} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Define the Gradio ChatInterface
chatbot = gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(
height="64vh"
),
additional_inputs=[
gr.Textbox(
"Behave as if you are a medical doctor providing answers for patients' clinical questions.",
label="System Prompt"
)
],
title="Medical QA Chat",
description="Feel free to ask any question to Medllama2 Chatbot.",
theme="soft",
submit_btn="Send",
retry_btn="Regenerate Response",
undo_btn="Delete Previous",
clear_btn="Clear Chat"
)
# Following line is important to queue the messages
chatbot.queue()
# Enable share = True if you want to create a public link for people to use your application
chatbot.launch()
|