analist commited on
Commit
1fc8fc9
·
verified ·
1 Parent(s): 9f45c43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -7
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
4
 
5
  # Configuration du modèle
6
- MODEL_NAME = "analist/llama3.1-8B-omnimed-rl"
 
7
  DEFAULT_SYSTEM_PROMPT = """Vous êtes OmniMed, un assistant médical IA conçu pour aider les professionnels de santé dans leurs tâches quotidiennes.
8
  Répondez de manière précise, concise et professionnelle aux questions médicales.
9
  Basez vos réponses sur des connaissances médicales établies et indiquez clairement lorsque vous n'êtes pas certain d'une information."""
@@ -14,15 +16,33 @@ TEMPERATURE = 0.7
14
  TOP_P = 0.9
15
  REPETITION_PENALTY = 1.1
16
 
 
 
 
 
 
 
 
 
17
  # Chargement du modèle et du tokenizer
18
- print("Chargement du modèle et du tokenizer...")
19
- model = AutoModelForCausalLM.from_pretrained(
20
- MODEL_NAME,
21
- torch_dtype=torch.float16,
 
 
 
22
  device_map="auto",
23
  trust_remote_code=True
24
  )
25
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
26
  print("Modèle et tokenizer chargés avec succès!")
27
 
28
  # Fonction pour générer une réponse
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  import torch
4
+ from peft import PeftModel, PeftConfig
5
 
6
  # Configuration du modèle
7
+ ADAPTER_MODEL_NAME = "analist/llama3.1-8B-omnimed-rl"
8
+ BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B" # Modèle de base pour Llama 3.1 8B
9
  DEFAULT_SYSTEM_PROMPT = """Vous êtes OmniMed, un assistant médical IA conçu pour aider les professionnels de santé dans leurs tâches quotidiennes.
10
  Répondez de manière précise, concise et professionnelle aux questions médicales.
11
  Basez vos réponses sur des connaissances médicales établies et indiquez clairement lorsque vous n'êtes pas certain d'une information."""
 
16
  TOP_P = 0.9
17
  REPETITION_PENALTY = 1.1
18
 
19
+ # Configuration pour la quantification 4-bit
20
+ bnb_config = BitsAndBytesConfig(
21
+ load_in_4bit=True,
22
+ bnb_4bit_quant_type="nf4",
23
+ bnb_4bit_compute_dtype=torch.float16,
24
+ bnb_4bit_use_double_quant=True,
25
+ )
26
+
27
  # Chargement du modèle et du tokenizer
28
+ print("Chargement du modèle de base et du tokenizer...")
29
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
30
+
31
+ print("Chargement du modèle de base quantifié...")
32
+ base_model = AutoModelForCausalLM.from_pretrained(
33
+ BASE_MODEL_NAME,
34
+ quantization_config=bnb_config,
35
  device_map="auto",
36
  trust_remote_code=True
37
  )
38
+
39
+ print("Application des adaptateurs...")
40
+ model = PeftModel.from_pretrained(
41
+ base_model,
42
+ ADAPTER_MODEL_NAME,
43
+ device_map="auto",
44
+ )
45
+
46
  print("Modèle et tokenizer chargés avec succès!")
47
 
48
  # Fonction pour générer une réponse