Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
import torch | |
from peft import PeftModel, PeftConfig | |
# Configuration du modèle | |
ADAPTER_MODEL_NAME = "analist/llama3.1-8B-omnimed-rl" | |
BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B" # Modèle de base pour Llama 3.1 8B | |
DEFAULT_SYSTEM_PROMPT = """Vous êtes OmniMed, un assistant médical IA conçu pour aider les professionnels de santé dans leurs tâches quotidiennes. | |
Répondez de manière précise, concise et professionnelle aux questions médicales. | |
Basez vos réponses sur des connaissances médicales établies et indiquez clairement lorsque vous n'êtes pas certain d'une information.""" | |
# Définition des paramètres | |
MAX_NEW_TOKENS = 1024 | |
TEMPERATURE = 0.7 | |
TOP_P = 0.9 | |
REPETITION_PENALTY = 1.1 | |
# Configuration pour la quantification 4-bit | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
) | |
# Chargement du modèle et du tokenizer | |
print("Chargement du modèle de base et du tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_MODEL_NAME) | |
print("Chargement du modèle de base quantifié...") | |
base_model = AutoModelForCausalLM.from_pretrained( | |
ADAPTER_MODEL_NAME, | |
quantization_config=bnb_config, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
print("Application des adaptateurs...") | |
model = PeftModel.from_pretrained( | |
base_model, | |
device_map="auto", | |
) | |
print("Modèle et tokenizer chargés avec succès!") | |
# Fonction pour générer une réponse | |
def generate_response(message, chat_history, system_prompt, temperature=TEMPERATURE, max_tokens=MAX_NEW_TOKENS): | |
# Construction du contexte de chat | |
chat_context = [] | |
# Ajout du prompt système s'il existe | |
if system_prompt.strip(): | |
chat_context.append({"role": "system", "content": system_prompt}) | |
else: | |
chat_context.append({"role": "system", "content": DEFAULT_SYSTEM_PROMPT}) | |
# Ajout de l'historique des conversations | |
for user_msg, assistant_msg in chat_history: | |
chat_context.append({"role": "user", "content": user_msg}) | |
chat_context.append({"role": "assistant", "content": assistant_msg}) | |
# Ajout du message actuel | |
chat_context.append({"role": "user", "content": message}) | |
# Préparation du texte d'entrée avec le template de chat | |
input_text = tokenizer.apply_chat_template( | |
chat_context, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Tokenisation de l'entrée | |
model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device) | |
# Mise à jour des paramètres de génération | |
model.generation_config.temperature = temperature | |
model.generation_config.max_new_tokens = max_tokens | |
# Génération de la réponse | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
**model_inputs, | |
use_cache=True, | |
) | |
# Extraction uniquement de la nouvelle partie générée | |
new_tokens = generated_ids[0][model_inputs.input_ids.shape[1]:] | |
response = tokenizer.decode(new_tokens, skip_special_tokens=True) | |
return response | |
# Construction du contexte de chat | |
chat_context = [] | |
# Ajout du prompt système s'il existe | |
if system_prompt.strip(): | |
chat_context.append({"role": "system", "content": system_prompt}) | |
else: | |
chat_context.append({"role": "system", "content": DEFAULT_SYSTEM_PROMPT}) | |
# Ajout de l'historique des conversations | |
for user_msg, assistant_msg in chat_history: | |
chat_context.append({"role": "user", "content": user_msg}) | |
chat_context.append({"role": "assistant", "content": assistant_msg}) | |
# Ajout du message actuel | |
chat_context.append({"role": "user", "content": message}) | |
# Tokenisation et génération | |
input_ids = tokenizer.apply_chat_template( | |
chat_context, | |
return_tensors="pt" | |
).to(model.device) | |
# Génération de la réponse | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=TOP_P, | |
repetition_penalty=REPETITION_PENALTY, | |
do_sample=temperature > 0, | |
) | |
# Décodage de la réponse | |
generated_text = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True) | |
return generated_text | |
# Interface utilisateur Gradio | |
with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
gr.Markdown( | |
""" | |
# OmniMed - Assistant médical IA | |
Ce chatbot alimenté par le modèle Llama 3.1 fine-tuné est conçu pour aider les professionnels de santé dans leurs tâches quotidiennes. | |
Posez des questions médicales et recevez des réponses précises et professionnelles. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot(height=500, show_copy_button=True) | |
with gr.Row(): | |
msg = gr.Textbox( | |
placeholder="Posez votre question médicale ici...", | |
show_label=False, | |
container=False, | |
scale=12 | |
) | |
submit_btn = gr.Button("Envoyer", scale=1) | |
with gr.Row(): | |
clear_btn = gr.Button("Effacer la conversation") | |
with gr.Column(scale=1): | |
with gr.Accordion("Paramètres avancés", open=False): | |
system_prompt = gr.Textbox( | |
label="Prompt système", | |
placeholder="Instructions pour l'assistant...", | |
value=DEFAULT_SYSTEM_PROMPT, | |
lines=5 | |
) | |
temperature_slider = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=TEMPERATURE, | |
step=0.05, | |
label="Température" | |
) | |
max_tokens_slider = gr.Slider( | |
minimum=128, | |
maximum=2048, | |
value=MAX_NEW_TOKENS, | |
step=64, | |
label="Tokens max." | |
) | |
# Logique des événements | |
def respond(message, chat_history, system_prompt, temperature, max_tokens): | |
if not message.strip(): | |
return "", chat_history | |
response = generate_response(message, chat_history, system_prompt, temperature, max_tokens) | |
chat_history.append((message, response)) | |
return "", chat_history | |
msg.submit( | |
respond, | |
[msg, chatbot, system_prompt, temperature_slider, max_tokens_slider], | |
[msg, chatbot] | |
) | |
submit_btn.click( | |
respond, | |
[msg, chatbot, system_prompt, temperature_slider, max_tokens_slider], | |
[msg, chatbot] | |
) | |
clear_btn.click(lambda: [], None, chatbot) | |
gr.Markdown(""" | |
### Note importante: | |
Ce chatbot est destiné uniquement à l'assistance des professionnels de santé et ne remplace pas l'avis médical d'un spécialiste. | |
Toutes les informations fournies doivent être vérifiées par un professionnel qualifié. | |
""") | |
# Lancement de l'application | |
if __name__ == "__main__": | |
demo.queue().launch() |