File size: 3,920 Bytes
91b7019
 
d0be98e
 
161002f
 
 
 
 
 
 
 
 
d0be98e
 
 
 
91b7019
d0be98e
 
91b7019
d0be98e
 
 
 
 
 
 
161002f
d0be98e
91b7019
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0be98e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91b7019
d0be98e
91b7019
 
 
d0be98e
91b7019
 
d0be98e
91b7019
d0be98e
 
 
f5fe03f
91b7019
d0be98e
91b7019
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import spaces  
import os
IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
IS_SPACE = os.environ.get("SPACE_ID", None) is not None

device = "cuda" if torch.cuda.is_available() else "cpu"
LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1"
print(f"Using device: {device}")
print(f"low memory: {LOW_MEMORY}")
# Define BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(load_in_4bit=True,
                                bnb_4bit_quant_type="nf4",
                                bnb_4bit_compute_dtype=torch.float16)

# Model name
model_name = "ruslanmv/Medical-Llama3-v2"

# Load tokenizer and model with BitsAndBytesConfig
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, bnb_config=bnb_config)
model = AutoModelForCausalLM.from_pretrained(model_name, config=bnb_config)

# Ensure model is on the correct device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
@spaces.GPU
# Define the respond function
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    # Format the conversation as a single string for the model
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=1000)
    
    # Move inputs to device
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    
    # Generate the response
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=max_tokens,
            temperature=temperature,
            top_p=top_p,
            use_cache=True
        )
    
    # Extract the response
    response_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    
    # Remove the prompt and system message from the response
    response_text = response_text.replace(system_message, '').strip()
    response_text = response_text.replace(f"Human: {message}\n\nAssistant: ", '').strip()
    
    return response_text

# Create the Gradio interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a Medical AI Assistant. Please be thorough and provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.", label="System message", lines=3),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
    title="Medical AI Assistant",
    description="Give me your symptoms and ask me a health problem. The AI will provide informative answers. If the AI doesn't know the answer, it will advise seeking professional help.",
   
    examples=[["I'm a 35-year-old male and for the past few months, I've been experiencing fatigue, increased sensitivity to cold, and dry, itchy skin. Could these symptoms be related to hypothyroidism? If so, what steps should I take to get a proper diagnosis and discuss treatment options?"], ["What are the symptoms of diabetes?"], ["How can I improve my sleep?"]],

)

if __name__ == "__main__":
    demo.launch()