simnJS commited on
Commit
6d581b7
·
verified ·
1 Parent(s): aa415cb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from peft import PeftModel
5
+
6
+ # Remplacez par le modèle de base et l'adaptateur LoRA que vous utilisez
7
+ BASE_MODEL = "bigcode/starcoder2-3b"
8
+ ADAPTER_REPO = "simnJS/autotrain-fxp6j-p5s8i"
9
+
10
+ # 1. Charger le tokenizer
11
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
12
+
13
+ # 2. Charger le modèle de base en FP16 sur le GPU (si vous avez la VRAM nécessaire)
14
+ base_model = AutoModelForCausalLM.from_pretrained(
15
+ BASE_MODEL,
16
+ torch_dtype=torch.float16,
17
+ device_map="auto" # Permet de placer le modèle sur le GPU automatiquement
18
+ )
19
+
20
+ # 3. Charger l'adaptateur LoRA
21
+ model = PeftModel.from_pretrained(
22
+ base_model,
23
+ ADAPTER_REPO,
24
+ torch_dtype=torch.float16
25
+ )
26
+
27
+ # 4. Fonction pour générer une réponse
28
+ def generate_answer(user_message, history):
29
+ """
30
+ user_message: le dernier message de l'utilisateur
31
+ history: liste de tuples (message_utilisateur, réponse_modèle)
32
+ """
33
+ # Construire le prompt en tenant compte de l'historique si besoin
34
+ # Ici, on fait simple et on utilise juste le dernier message
35
+ prompt = user_message
36
+
37
+ # Encoder le prompt
38
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
39
+ # Générer la réponse
40
+ outputs = model.generate(
41
+ **inputs,
42
+ max_new_tokens=100,
43
+ temperature=0.7,
44
+ do_sample=True,
45
+ top_p=0.9
46
+ )
47
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
+ # Ajouter dans l'historique
49
+ history.append((user_message, answer))
50
+ return history, history
51
+
52
+ # 5. Interface Gradio de type chatbot
53
+ with gr.Blocks() as demo:
54
+ gr.Markdown("# Chat avec mon modèle LoRA Verse")
55
+ chatbot = gr.Chatbot()
56
+ msg = gr.Textbox(label="Tapez votre message ici...")
57
+ state = gr.State([]) # pour stocker l'historique des messages
58
+
59
+ def submit_message(user_message, history):
60
+ return generate_answer(user_message, history)
61
+
62
+ msg.submit(submit_message, inputs=[msg, state], outputs=[chatbot, state])
63
+ # ou un bouton si vous préférez
64
+ # send_btn = gr.Button("Envoyer")
65
+ # send_btn.click(fn=submit_message, inputs=[msg, state], outputs=[chatbot, state])
66
+
67
+ demo.launch()