analist commited on
Commit
aa17b8b
·
verified ·
1 Parent(s): 12a623d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -188
app.py CHANGED
@@ -1,212 +1,98 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  import torch
4
- from peft import PeftModel, PeftConfig
5
 
6
- # Configuration du modèle
7
- ADAPTER_MODEL_NAME = "analist/llama3.1-8B-omnimed-rl"
8
- BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B" # Modèle de base pour Llama 3.1 8B
9
- DEFAULT_SYSTEM_PROMPT = """Vous êtes OmniMed, un assistant médical IA conçu pour aider les professionnels de santé dans leurs tâches quotidiennes.
10
- Répondez de manière précise, concise et professionnelle aux questions médicales.
11
- Basez vos réponses sur des connaissances médicales établies et indiquez clairement lorsque vous n'êtes pas certain d'une information."""
12
 
13
- # Définition des paramètres
14
- MAX_NEW_TOKENS = 1024
15
- TEMPERATURE = 0.7
16
- TOP_P = 0.9
17
- REPETITION_PENALTY = 1.1
18
-
19
- # Configuration pour la quantification 4-bit
20
- bnb_config = BitsAndBytesConfig(
21
- load_in_4bit=True,
22
- bnb_4bit_quant_type="nf4",
23
- bnb_4bit_compute_dtype=torch.float16,
24
- bnb_4bit_use_double_quant=True,
25
- )
26
 
27
- # Chargement du modèle et du tokenizer
28
- print("Chargement du modèle de base et du tokenizer...")
29
- tokenizer = AutoTokenizer.from_pretrained(ADAPTER_MODEL_NAME)
30
 
31
- print("Chargement du modèle de base quantifié...")
32
- base_model = AutoModelForCausalLM.from_pretrained(
33
- ADAPTER_MODEL_NAME,
34
- quantization_config=bnb_config,
35
- device_map="auto",
36
- trust_remote_code=True
 
 
 
37
  )
38
 
39
- print("Application des adaptateurs...")
40
- model = PeftModel.from_pretrained(
41
- base_model,
42
- device_map="auto",
43
- )
44
-
45
- print("Modèle et tokenizer chargés avec succès!")
46
 
47
  # Fonction pour générer une réponse
48
- def generate_response(message, chat_history, system_prompt, temperature=TEMPERATURE, max_tokens=MAX_NEW_TOKENS):
49
- # Construction du contexte de chat
50
- chat_context = []
51
-
52
- # Ajout du prompt système s'il existe
53
- if system_prompt.strip():
54
- chat_context.append({"role": "system", "content": system_prompt})
55
- else:
56
- chat_context.append({"role": "system", "content": DEFAULT_SYSTEM_PROMPT})
57
-
58
- # Ajout de l'historique des conversations
59
- for user_msg, assistant_msg in chat_history:
60
- chat_context.append({"role": "user", "content": user_msg})
61
- chat_context.append({"role": "assistant", "content": assistant_msg})
62
-
63
- # Ajout du message actuel
64
- chat_context.append({"role": "user", "content": message})
65
-
66
- # Préparation du texte d'entrée avec le template de chat
67
- input_text = tokenizer.apply_chat_template(
68
- chat_context,
69
- tokenize=False,
70
- add_generation_prompt=True
71
- )
72
-
73
- # Tokenisation de l'entrée
74
- model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
75
-
76
- # Mise à jour des paramètres de génération
77
- model.generation_config.temperature = temperature
78
- model.generation_config.max_new_tokens = max_tokens
79
-
80
- # Génération de la réponse
81
- with torch.no_grad():
82
- generated_ids = model.generate(
83
- **model_inputs,
84
- use_cache=True,
85
- )
86
 
87
- # Extraction uniquement de la nouvelle partie générée
88
- new_tokens = generated_ids[0][model_inputs.input_ids.shape[1]:]
89
- response = tokenizer.decode(new_tokens, skip_special_tokens=True)
90
 
91
- return response
 
92
 
93
- # Construction du contexte de chat
94
- chat_context = []
95
 
96
- # Ajout du prompt système s'il existe
97
- if system_prompt.strip():
98
- chat_context.append({"role": "system", "content": system_prompt})
99
- else:
100
- chat_context.append({"role": "system", "content": DEFAULT_SYSTEM_PROMPT})
101
 
102
- # Ajout de l'historique des conversations
103
- for user_msg, assistant_msg in chat_history:
104
- chat_context.append({"role": "user", "content": user_msg})
105
- chat_context.append({"role": "assistant", "content": assistant_msg})
106
 
107
- # Ajout du message actuel
108
- chat_context.append({"role": "user", "content": message})
109
-
110
- # Tokenisation et génération
111
- input_ids = tokenizer.apply_chat_template(
112
- chat_context,
113
- return_tensors="pt"
114
- ).to(model.device)
115
-
116
- # Génération de la réponse
117
- with torch.no_grad():
118
- outputs = model.generate(
119
- input_ids=input_ids,
120
- max_new_tokens=max_tokens,
121
- temperature=temperature,
122
- top_p=TOP_P,
123
- repetition_penalty=REPETITION_PENALTY,
124
- do_sample=temperature > 0,
125
- )
126
-
127
- # Décodage de la réponse
128
- generated_text = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
129
- return generated_text
130
 
131
- # Interface utilisateur Gradio
132
- with gr.Blocks(css="footer {visibility: hidden}") as demo:
133
- gr.Markdown(
134
- """
135
- # OmniMed - Assistant médical IA
136
-
137
- Ce chatbot alimenté par le modèle Llama 3.1 fine-tuné est conçu pour aider les professionnels de santé dans leurs tâches quotidiennes.
138
- Posez des questions médicales et recevez des réponses précises et professionnelles.
139
- """
140
- )
141
-
142
- with gr.Row():
143
- with gr.Column(scale=4):
144
- chatbot = gr.Chatbot(height=500, show_copy_button=True)
145
-
146
- with gr.Row():
147
- msg = gr.Textbox(
148
- placeholder="Posez votre question médicale ici...",
149
- show_label=False,
150
- container=False,
151
- scale=12
152
- )
153
- submit_btn = gr.Button("Envoyer", scale=1)
154
-
155
- with gr.Row():
156
- clear_btn = gr.Button("Effacer la conversation")
157
-
158
- with gr.Column(scale=1):
159
- with gr.Accordion("Paramètres avancés", open=False):
160
- system_prompt = gr.Textbox(
161
- label="Prompt système",
162
- placeholder="Instructions pour l'assistant...",
163
- value=DEFAULT_SYSTEM_PROMPT,
164
- lines=5
165
- )
166
- temperature_slider = gr.Slider(
167
- minimum=0.0,
168
- maximum=1.0,
169
- value=TEMPERATURE,
170
- step=0.05,
171
- label="Température"
172
- )
173
- max_tokens_slider = gr.Slider(
174
- minimum=128,
175
- maximum=2048,
176
- value=MAX_NEW_TOKENS,
177
- step=64,
178
- label="Tokens max."
179
- )
180
-
181
- # Logique des événements
182
- def respond(message, chat_history, system_prompt, temperature, max_tokens):
183
- if not message.strip():
184
- return "", chat_history
185
-
186
- response = generate_response(message, chat_history, system_prompt, temperature, max_tokens)
187
- chat_history.append((message, response))
188
- return "", chat_history
189
 
190
- msg.submit(
191
- respond,
192
- [msg, chatbot, system_prompt, temperature_slider, max_tokens_slider],
193
- [msg, chatbot]
194
- )
195
 
196
- submit_btn.click(
197
- respond,
198
- [msg, chatbot, system_prompt, temperature_slider, max_tokens_slider],
199
- [msg, chatbot]
200
- )
201
 
202
- clear_btn.click(lambda: [], None, chatbot)
203
-
204
  gr.Markdown("""
205
- ### Note importante:
206
- Ce chatbot est destiné uniquement à l'assistance des professionnels de santé et ne remplace pas l'avis médical d'un spécialiste.
207
- Toutes les informations fournies doivent être vérifiées par un professionnel qualifié.
 
 
208
  """)
209
 
210
- # Lancement de l'application
211
  if __name__ == "__main__":
212
- demo.queue().launch()
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
 
5
+ # Définir le modèle et le tokenizer
6
+ # Utilisation d'un modèle français pour le domaine médical
7
+ MODEL_NAME = "mistralai/Mistral-7B-v0.1" # Vous pouvez utiliser un modèle plus adapté au français comme "camembert" ou un modèle médical spécifique
 
 
 
8
 
9
+ # Fonction pour charger le modèle et le tokenizer
10
+ def load_model():
11
+ print("Chargement du modèle et du tokenizer...")
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ MODEL_NAME,
15
+ torch_dtype=torch.float16,
16
+ device_map="auto",
17
+ load_in_8bit=True # Quantification pour réduire l'utilisation de la mémoire
18
+ )
19
+ return model, tokenizer
 
 
20
 
21
+ # Charger le modèle et le tokenizer
22
+ model, tokenizer = load_model()
 
23
 
24
+ # Créer un pipeline de génération de texte
25
+ generator = pipeline(
26
+ "text-generation",
27
+ model=model,
28
+ tokenizer=tokenizer,
29
+ max_length=512,
30
+ do_sample=True,
31
+ temperature=0.7,
32
+ top_p=0.9,
33
  )
34
 
35
+ # Message initial pour orienter le modèle vers le domaine médical
36
+ SYSTEM_PROMPT = """Tu es un assistant médical IA nommé MediBot. Tu fournis des informations médicales générales et des conseils de santé basés sur des données scientifiques.
37
+ Important: Tu n'es pas un médecin et tu dois toujours recommander à l'utilisateur de consulter un professionnel de santé pour un diagnostic ou un traitement spécifique.
38
+ Réponds aux questions médicales de manière précise et claire, en te basant sur des informations médicales vérifiées."""
 
 
 
39
 
40
  # Fonction pour générer une réponse
41
+ def generate_response(message, chat_history):
42
+ # Construire le prompt avec l'historique de conversation et le nouveau message
43
+ prompt = SYSTEM_PROMPT + "\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ # Ajouter l'historique du chat
46
+ for user_msg, bot_msg in chat_history:
47
+ prompt += f"Utilisateur: {user_msg}\nMediBot: {bot_msg}\n\n"
48
 
49
+ # Ajouter le nouveau message
50
+ prompt += f"Utilisateur: {message}\nMediBot:"
51
 
52
+ # Générer la réponse
53
+ response = generator(prompt, max_new_tokens=256)[0]["generated_text"]
54
 
55
+ # Extraire seulement la partie réponse du modèle
56
+ response_only = response.split("MediBot:")[-1].strip()
 
 
 
57
 
58
+ # Ajouter un rappel de consulter un professionnel
59
+ if len(response_only) > 0 and not "consulter un professionnel" in response_only.lower():
60
+ response_only += "\n\n(N'oubliez pas que ces informations sont générales. Pour des conseils personnalisés, consultez un professionnel de santé.)"
 
61
 
62
+ return response_only
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ # Fonction pour gérer l'historique du chat Gradio
65
+ def chatbot_response(message, history):
66
+ bot_message = generate_response(message, history)
67
+ history.append((message, bot_message))
68
+ return "", history
69
+
70
+ # Interface Gradio
71
+ with gr.Blocks(title="MediBot - Assistant Médical IA") as demo:
72
+ gr.Markdown("# MediBot - Votre Assistant Médical basé sur l'IA")
73
+ gr.Markdown("""
74
+ ### ⚠️ IMPORTANT - AVERTISSEMENT MÉDICAL ⚠️
75
+ Ce chatbot utilise l'intelligence artificielle pour fournir des informations médicales générales.
76
+ Il ne remplace en aucun cas l'avis d'un professionnel de santé qualifié.
77
+ Pour toute question médicale sérieuse, veuillez consulter un médecin.
78
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ chatbot = gr.Chatbot(height=500)
81
+ msg = gr.Textbox(placeholder="Posez votre question médicale ici...", label="Votre question")
82
+ clear = gr.Button("Effacer la conversation")
 
 
83
 
84
+ # Configurer les événements
85
+ msg.submit(chatbot_response, [msg, chatbot], [msg, chatbot])
86
+ clear.click(lambda: None, None, chatbot, queue=False)
 
 
87
 
 
 
88
  gr.Markdown("""
89
+ ### Exemples de questions que vous pouvez poser:
90
+ - Quels sont les symptômes courants de la grippe?
91
+ - Comment prévenir l'hypertension artérielle?
92
+ - Quels aliments sont recommandés pour les diabétiques?
93
+ - Quelles sont les causes fréquentes des migraines?
94
  """)
95
 
96
+ # Lancer l'application
97
  if __name__ == "__main__":
98
+ demo.launch(share=True) # share=True permet de générer un lien public temporaire