OmniMed_SIA / app.py
analist's picture
Update app.py
12a623d verified
raw
history blame
7.49 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from peft import PeftModel, PeftConfig
# Configuration du modèle
ADAPTER_MODEL_NAME = "analist/llama3.1-8B-omnimed-rl"
BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B" # Modèle de base pour Llama 3.1 8B
DEFAULT_SYSTEM_PROMPT = """Vous êtes OmniMed, un assistant médical IA conçu pour aider les professionnels de santé dans leurs tâches quotidiennes.
Répondez de manière précise, concise et professionnelle aux questions médicales.
Basez vos réponses sur des connaissances médicales établies et indiquez clairement lorsque vous n'êtes pas certain d'une information."""
# Définition des paramètres
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.7
TOP_P = 0.9
REPETITION_PENALTY = 1.1
# Configuration pour la quantification 4-bit
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
# Chargement du modèle et du tokenizer
print("Chargement du modèle de base et du tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_MODEL_NAME)
print("Chargement du modèle de base quantifié...")
base_model = AutoModelForCausalLM.from_pretrained(
ADAPTER_MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
print("Application des adaptateurs...")
model = PeftModel.from_pretrained(
base_model,
device_map="auto",
)
print("Modèle et tokenizer chargés avec succès!")
# Fonction pour générer une réponse
def generate_response(message, chat_history, system_prompt, temperature=TEMPERATURE, max_tokens=MAX_NEW_TOKENS):
# Construction du contexte de chat
chat_context = []
# Ajout du prompt système s'il existe
if system_prompt.strip():
chat_context.append({"role": "system", "content": system_prompt})
else:
chat_context.append({"role": "system", "content": DEFAULT_SYSTEM_PROMPT})
# Ajout de l'historique des conversations
for user_msg, assistant_msg in chat_history:
chat_context.append({"role": "user", "content": user_msg})
chat_context.append({"role": "assistant", "content": assistant_msg})
# Ajout du message actuel
chat_context.append({"role": "user", "content": message})
# Préparation du texte d'entrée avec le template de chat
input_text = tokenizer.apply_chat_template(
chat_context,
tokenize=False,
add_generation_prompt=True
)
# Tokenisation de l'entrée
model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
# Mise à jour des paramètres de génération
model.generation_config.temperature = temperature
model.generation_config.max_new_tokens = max_tokens
# Génération de la réponse
with torch.no_grad():
generated_ids = model.generate(
**model_inputs,
use_cache=True,
)
# Extraction uniquement de la nouvelle partie générée
new_tokens = generated_ids[0][model_inputs.input_ids.shape[1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
return response
# Construction du contexte de chat
chat_context = []
# Ajout du prompt système s'il existe
if system_prompt.strip():
chat_context.append({"role": "system", "content": system_prompt})
else:
chat_context.append({"role": "system", "content": DEFAULT_SYSTEM_PROMPT})
# Ajout de l'historique des conversations
for user_msg, assistant_msg in chat_history:
chat_context.append({"role": "user", "content": user_msg})
chat_context.append({"role": "assistant", "content": assistant_msg})
# Ajout du message actuel
chat_context.append({"role": "user", "content": message})
# Tokenisation et génération
input_ids = tokenizer.apply_chat_template(
chat_context,
return_tensors="pt"
).to(model.device)
# Génération de la réponse
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=TOP_P,
repetition_penalty=REPETITION_PENALTY,
do_sample=temperature > 0,
)
# Décodage de la réponse
generated_text = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
return generated_text
# Interface utilisateur Gradio
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown(
"""
# OmniMed - Assistant médical IA
Ce chatbot alimenté par le modèle Llama 3.1 fine-tuné est conçu pour aider les professionnels de santé dans leurs tâches quotidiennes.
Posez des questions médicales et recevez des réponses précises et professionnelles.
"""
)
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(height=500, show_copy_button=True)
with gr.Row():
msg = gr.Textbox(
placeholder="Posez votre question médicale ici...",
show_label=False,
container=False,
scale=12
)
submit_btn = gr.Button("Envoyer", scale=1)
with gr.Row():
clear_btn = gr.Button("Effacer la conversation")
with gr.Column(scale=1):
with gr.Accordion("Paramètres avancés", open=False):
system_prompt = gr.Textbox(
label="Prompt système",
placeholder="Instructions pour l'assistant...",
value=DEFAULT_SYSTEM_PROMPT,
lines=5
)
temperature_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=TEMPERATURE,
step=0.05,
label="Température"
)
max_tokens_slider = gr.Slider(
minimum=128,
maximum=2048,
value=MAX_NEW_TOKENS,
step=64,
label="Tokens max."
)
# Logique des événements
def respond(message, chat_history, system_prompt, temperature, max_tokens):
if not message.strip():
return "", chat_history
response = generate_response(message, chat_history, system_prompt, temperature, max_tokens)
chat_history.append((message, response))
return "", chat_history
msg.submit(
respond,
[msg, chatbot, system_prompt, temperature_slider, max_tokens_slider],
[msg, chatbot]
)
submit_btn.click(
respond,
[msg, chatbot, system_prompt, temperature_slider, max_tokens_slider],
[msg, chatbot]
)
clear_btn.click(lambda: [], None, chatbot)
gr.Markdown("""
### Note importante:
Ce chatbot est destiné uniquement à l'assistance des professionnels de santé et ne remplace pas l'avis médical d'un spécialiste.
Toutes les informations fournies doivent être vérifiées par un professionnel qualifié.
""")
# Lancement de l'application
if __name__ == "__main__":
demo.queue().launch()