analist commited on
Commit
97581f5
·
verified ·
1 Parent(s): 80df497

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ # Configuration du modèle
6
+ MODEL_NAME = "analist/llama3.1-8B-omnimed-rl"
7
+ DEFAULT_SYSTEM_PROMPT = """Vous êtes OmniMed, un assistant médical IA conçu pour aider les professionnels de santé dans leurs tâches quotidiennes.
8
+ Répondez de manière précise, concise et professionnelle aux questions médicales.
9
+ Basez vos réponses sur des connaissances médicales établies et indiquez clairement lorsque vous n'êtes pas certain d'une information."""
10
+
11
+ # Définition des paramètres
12
+ MAX_NEW_TOKENS = 1024
13
+ TEMPERATURE = 0.7
14
+ TOP_P = 0.9
15
+ REPETITION_PENALTY = 1.1
16
+
17
+ # Chargement du modèle et du tokenizer
18
+ @gr.on_startup
19
+ def load_model():
20
+ print("Chargement du modèle et du tokenizer...")
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ MODEL_NAME,
23
+ torch_dtype=torch.float16,
24
+ device_map="auto",
25
+ trust_remote_code=True
26
+ )
27
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
28
+ print("Modèle et tokenizer chargés avec succès!")
29
+ return model, tokenizer
30
+
31
+ # Fonction pour générer une réponse
32
+ def generate_response(message, chat_history, system_prompt, temperature=TEMPERATURE, max_tokens=MAX_NEW_TOKENS):
33
+ model, tokenizer = load_model.value
34
+
35
+ # Construction du contexte de chat
36
+ chat_context = []
37
+
38
+ # Ajout du prompt système s'il existe
39
+ if system_prompt.strip():
40
+ chat_context.append({"role": "system", "content": system_prompt})
41
+ else:
42
+ chat_context.append({"role": "system", "content": DEFAULT_SYSTEM_PROMPT})
43
+
44
+ # Ajout de l'historique des conversations
45
+ for user_msg, assistant_msg in chat_history:
46
+ chat_context.append({"role": "user", "content": user_msg})
47
+ chat_context.append({"role": "assistant", "content": assistant_msg})
48
+
49
+ # Ajout du message actuel
50
+ chat_context.append({"role": "user", "content": message})
51
+
52
+ # Tokenisation et génération
53
+ input_ids = tokenizer.apply_chat_template(
54
+ chat_context,
55
+ return_tensors="pt"
56
+ ).to(model.device)
57
+
58
+ # Génération de la réponse
59
+ with torch.no_grad():
60
+ outputs = model.generate(
61
+ input_ids=input_ids,
62
+ max_new_tokens=max_tokens,
63
+ temperature=temperature,
64
+ top_p=TOP_P,
65
+ repetition_penalty=REPETITION_PENALTY,
66
+ do_sample=temperature > 0,
67
+ )
68
+
69
+ # Décodage de la réponse
70
+ generated_text = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
71
+ return generated_text
72
+
73
+ # Interface utilisateur Gradio
74
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
75
+ gr.Markdown(
76
+ """
77
+ # OmniMed - Assistant médical IA
78
+
79
+ Ce chatbot alimenté par le modèle Llama 3.1 fine-tuné est conçu pour aider les professionnels de santé dans leurs tâches quotidiennes.
80
+ Posez des questions médicales et recevez des réponses précises et professionnelles.
81
+ """
82
+ )
83
+
84
+ with gr.Row():
85
+ with gr.Column(scale=4):
86
+ chatbot = gr.Chatbot(height=500, show_copy_button=True)
87
+
88
+ with gr.Row():
89
+ msg = gr.Textbox(
90
+ placeholder="Posez votre question médicale ici...",
91
+ show_label=False,
92
+ container=False,
93
+ scale=12
94
+ )
95
+ submit_btn = gr.Button("Envoyer", scale=1)
96
+
97
+ with gr.Row():
98
+ clear_btn = gr.Button("Effacer la conversation")
99
+
100
+ with gr.Column(scale=1):
101
+ with gr.Accordion("Paramètres avancés", open=False):
102
+ system_prompt = gr.Textbox(
103
+ label="Prompt système",
104
+ placeholder="Instructions pour l'assistant...",
105
+ value=DEFAULT_SYSTEM_PROMPT,
106
+ lines=5
107
+ )
108
+ temperature_slider = gr.Slider(
109
+ minimum=0.0,
110
+ maximum=1.0,
111
+ value=TEMPERATURE,
112
+ step=0.05,
113
+ label="Température"
114
+ )
115
+ max_tokens_slider = gr.Slider(
116
+ minimum=128,
117
+ maximum=2048,
118
+ value=MAX_NEW_TOKENS,
119
+ step=64,
120
+ label="Tokens max."
121
+ )
122
+
123
+ # Logique des événements
124
+ def respond(message, chat_history, system_prompt, temperature, max_tokens):
125
+ if not message.strip():
126
+ return "", chat_history
127
+
128
+ response = generate_response(message, chat_history, system_prompt, temperature, max_tokens)
129
+ chat_history.append((message, response))
130
+ return "", chat_history
131
+
132
+ msg.submit(
133
+ respond,
134
+ [msg, chatbot, system_prompt, temperature_slider, max_tokens_slider],
135
+ [msg, chatbot]
136
+ )
137
+
138
+ submit_btn.click(
139
+ respond,
140
+ [msg, chatbot, system_prompt, temperature_slider, max_tokens_slider],
141
+ [msg, chatbot]
142
+ )
143
+
144
+ clear_btn.click(lambda: [], None, chatbot)
145
+
146
+ gr.Markdown("""
147
+ ### Note importante:
148
+ Ce chatbot est destiné uniquement à l'assistance des professionnels de santé et ne remplace pas l'avis médical d'un spécialiste.
149
+ Toutes les informations fournies doivent être vérifiées par un professionnel qualifié.
150
+ """)
151
+
152
+ # Lancement de l'application
153
+ if __name__ == "__main__":
154
+ demo.queue().launch()