ruslanmv's picture
Update app.py
f5fe03f verified
raw
history blame
3.92 kB
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import spaces
import os
IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
IS_SPACE = os.environ.get("SPACE_ID", None) is not None
device = "cuda" if torch.cuda.is_available() else "cpu"
LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1"
print(f"Using device: {device}")
print(f"low memory: {LOW_MEMORY}")
# Define BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16)
# Model name
model_name = "ruslanmv/Medical-Llama3-v2"
# Load tokenizer and model with BitsAndBytesConfig
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, bnb_config=bnb_config)
model = AutoModelForCausalLM.from_pretrained(model_name, config=bnb_config)
# Ensure model is on the correct device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
@spaces.GPU
# Define the respond function
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
# Format the conversation as a single string for the model
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=1000)
# Move inputs to device
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
# Generate the response
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
use_cache=True
)
# Extract the response
response_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Remove the prompt and system message from the response
response_text = response_text.replace(system_message, '').strip()
response_text = response_text.replace(f"Human: {message}\n\nAssistant: ", '').strip()
return response_text
# Create the Gradio interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a Medical AI Assistant. Please be thorough and provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.", label="System message", lines=3),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
title="Medical AI Assistant",
description="Give me your symptoms and ask me a health problem. The AI will provide informative answers. If the AI doesn't know the answer, it will advise seeking professional help.",
examples=[["I'm a 35-year-old male and for the past few months, I've been experiencing fatigue, increased sensitivity to cold, and dry, itchy skin. Could these symptoms be related to hypothyroidism? If so, what steps should I take to get a proper diagnosis and discuss treatment options?"], ["What are the symptoms of diabetes?"], ["How can I improve my sleep?"]],
)
if __name__ == "__main__":
demo.launch()