Mahavaury2 commited on
Commit
eed53c2
·
verified ·
1 Parent(s): 94c9a27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -35
app.py CHANGED
@@ -9,17 +9,13 @@ import spaces
9
  import torch
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
12
- # Description de la démo
13
  DESCRIPTION = "# Mistral-7B v0.3"
14
-
15
- # Si pas de GPU détecté, afficher un message
16
  if not torch.cuda.is_available():
17
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
18
 
19
- MAX_MAX_NEW_TOKENS = 2048
20
- DEFAULT_MAX_NEW_TOKENS = 1024
21
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
22
 
 
23
  if torch.cuda.is_available():
24
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
25
  model = AutoModelForCausalLM.from_pretrained(
@@ -33,7 +29,6 @@ if torch.cuda.is_available():
33
  def generate(
34
  message: str,
35
  chat_history: list[dict],
36
- # Valeurs par défaut pour la génération
37
  max_new_tokens: int = 1024,
38
  temperature: float = 0.6,
39
  top_p: float = 0.9,
@@ -41,22 +36,19 @@ def generate(
41
  repetition_penalty: float = 1.2,
42
  ) -> Iterator[str]:
43
  """
44
- Génération de texte à partir de l'historique de conversation (chat_history) et du message utilisateur.
45
  """
46
  conversation = [*chat_history, {"role": "user", "content": message}]
47
-
48
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
 
49
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
50
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
51
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
52
  input_ids = input_ids.to(model.device)
53
 
54
- streamer = TextIteratorStreamer(
55
- tokenizer,
56
- timeout=20.0,
57
- skip_prompt=True,
58
- skip_special_tokens=True
59
- )
60
  generate_kwargs = dict(
61
  {"input_ids": input_ids},
62
  streamer=streamer,
@@ -76,37 +68,131 @@ def generate(
76
  outputs.append(text)
77
  yield "".join(outputs)
78
 
79
- # CSS pour appliquer le dégradé pastel à TOUTE la page
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  custom_css = """
81
  html, body {
82
  height: 100%;
83
  margin: 0;
84
  padding: 0;
85
  background: linear-gradient(135deg, #FDE2E2, #E2ECFD) !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  }
87
  """
88
 
89
- # Liste des questions prédéfinies
90
- predefined_examples = [
91
- ["1 - C’est quoi le consentement ? Comment savoir si ma copine a envie de moi ?"],
92
- ["2 - C’est quoi une agression sexuelle ?"],
93
- ["3 - C’est quoi un viol ?"],
94
- ["4 - C’est quoi un attouchement ?"],
95
- ["5 - C’est quoi un harcèlement sexuel ?"],
96
- ["6 - Est ce illégal de visionner du porno ?"],
97
- ["7 - Mon copain me demande un nude, dois-je le faire ?"],
98
- ["8 - Mon ancien copain me menace de poster des photos de moi nue sur internet, que faire ?"],
99
- ["9 - Que puis-je faire si un membre de ma famille me touche d’une manière bizarre, mais que j’ai peur de parler ou de ne pas être cru ?"],
100
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- # Création de l'interface de chat
103
- demo = gr.ChatInterface(
104
- fn=generate,
105
- type="messages",
106
- description=DESCRIPTION,
107
- css=custom_css,
108
- examples=predefined_examples,
109
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  if __name__ == "__main__":
112
  demo.queue(max_size=20).launch()
 
9
  import torch
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
 
12
  DESCRIPTION = "# Mistral-7B v0.3"
 
 
13
  if not torch.cuda.is_available():
14
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
15
 
 
 
16
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
17
 
18
+ # Chargement du modèle (sur GPU si disponible)
19
  if torch.cuda.is_available():
20
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
21
  model = AutoModelForCausalLM.from_pretrained(
 
29
  def generate(
30
  message: str,
31
  chat_history: list[dict],
 
32
  max_new_tokens: int = 1024,
33
  temperature: float = 0.6,
34
  top_p: float = 0.9,
 
36
  repetition_penalty: float = 1.2,
37
  ) -> Iterator[str]:
38
  """
39
+ Génération de texte à partir de l'historique + message utilisateur.
40
  """
41
  conversation = [*chat_history, {"role": "user", "content": message}]
 
42
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
43
+
44
+ # Gestion de la limite de tokens
45
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
46
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
47
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
48
+
49
  input_ids = input_ids.to(model.device)
50
 
51
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
52
  generate_kwargs = dict(
53
  {"input_ids": input_ids},
54
  streamer=streamer,
 
68
  outputs.append(text)
69
  yield "".join(outputs)
70
 
71
+ # Liste des questions (chaque question est associée à une "bulle cliquable")
72
+ QUESTIONS = [
73
+ "I - C’est quoi le consentement ? Comment savoir si ma copine a envie de moi ?",
74
+ "II - C’est quoi une agression sexuelle ?",
75
+ "III - C’est quoi un viol ?",
76
+ "IV - C’est quoi un attouchement ?",
77
+ "V - C’est quoi un harcèlement sexuel ?",
78
+ "VI - Est ce illégal de visionner du porno ?",
79
+ "VII - Mon copain me demande un nude, dois-je le faire ?",
80
+ "VIII - Mon ancien copain me menace de poster des photos de moi nue sur internet, que faire ?",
81
+ "IX - Que puis-je faire si un membre de ma famille me touche d’une manière bizarre, mais que j’ai peur de parler ou de ne pas être cru ?",
82
+ ]
83
+
84
+ # -- CSS personnalisé --
85
  custom_css = """
86
  html, body {
87
  height: 100%;
88
  margin: 0;
89
  padding: 0;
90
  background: linear-gradient(135deg, #FDE2E2, #E2ECFD) !important;
91
+ font-family: sans-serif;
92
+ }
93
+
94
+ /* Conteneur global pour centrer la "pieuvre" */
95
+ .octopus-container {
96
+ position: relative;
97
+ width: 600px;
98
+ height: 600px;
99
+ margin: 0 auto; /* centre horizontalement */
100
+ margin-top: 30px; /* marge en haut */
101
+ }
102
+
103
+ /* L'image (point d'interrogation) au centre */
104
+ .center-image {
105
+ position: absolute;
106
+ width: 100px; /* Ajuster la taille */
107
+ top: 50%;
108
+ left: 50%;
109
+ transform: translate(-50%, -50%);
110
+ }
111
+
112
+ /* Bulles de question */
113
+ .octopus-arm {
114
+ position: absolute;
115
+ width: 140px;
116
+ background: #FFFFFFCC;
117
+ padding: 8px;
118
+ border-radius: 10px;
119
+ box-shadow: 0 2px 4px rgba(0,0,0,0.2);
120
+ text-align: center;
121
+ cursor: pointer; /* Pour montrer que c'est cliquable */
122
+ }
123
+
124
+ /* Positions "rayonnantes" */
125
+ .arm1 { top: 0%; left: 50%; transform: translate(-50%, 0); }
126
+ .arm2 { top: 10%; left: 80%; }
127
+ .arm3 { top: 35%; left: 95%; transform: translateX(-100%); }
128
+ .arm4 { top: 60%; left: 80%; }
129
+ .arm5 { top: 85%; left: 50%; transform: translate(-50%, -100%); }
130
+ .arm6 { top: 60%; left: 10%; }
131
+ .arm7 { top: 35%; left: 0%; }
132
+ .arm8 { top: 10%; left: 10%; }
133
+ .arm9 { top: 0%; left: 25%; }
134
+
135
+ /* Texte dans les bulles */
136
+ .octopus-arm p {
137
+ margin: 0;
138
+ font-size: 0.85rem;
139
+ line-height: 1.2;
140
+ color: #000;
141
  }
142
  """
143
 
144
+ #
145
+ # Construction de l'application avec Blocks
146
+ #
147
+ with gr.Blocks(css=custom_css) as demo:
148
+ # Titre / Description
149
+ gr.Markdown(DESCRIPTION)
150
+
151
+ # Section HTML : la "pieuvre" + le script JS
152
+ # Notez la fonction setChatText(...) qu'on appelle en cliquant sur la bulle.
153
+ with gr.Box():
154
+ gr.HTML(
155
+ """
156
+ <div class="octopus-container">
157
+ <!-- Image point d'interrogation au centre -->
158
+ <img src="https://cdn-icons-png.flaticon.com/512/4926/4926733.png" class="center-image" />
159
+
160
+ <!-- Bulles de questions (cliquables) -->
161
+ """ +
162
+ "".join(
163
+ f"""
164
+ <div class="octopus-arm arm{i+1}" onclick="setChatText(`{QUESTIONS[i]}`)">
165
+ <p>{QUESTIONS[i]}</p>
166
+ </div>
167
+ """
168
+ for i in range(len(QUESTIONS))
169
+ ) +
170
+ """
171
+ </div>
172
 
173
+ <script>
174
+ // Au clic sur une bulle, insère la question dans la zone de saisie du ChatInterface
175
+ function setChatText(questionText) {
176
+ // Le ChatInterface de Gradio possède un <textarea> dont
177
+ // l'attribut placeholder est souvent "Type a message..."
178
+ // On cherche ce textarea et on y met la question.
179
+ let textBox = document.querySelector('textarea[placeholder="Type a message..."]');
180
+ if (textBox) {
181
+ textBox.value = questionText;
182
+ } else {
183
+ console.warn("Impossible de trouver la zone de texte du chatbot.");
184
+ }
185
+ }
186
+ </script>
187
+ """
188
+ )
189
+
190
+ # Le chatbot Gradio, en dessous
191
+ chat = gr.ChatInterface(
192
+ fn=generate,
193
+ type="messages",
194
+ description="Cliquez sur une question ci-dessus ou saisissez votre texte.",
195
+ )
196
 
197
  if __name__ == "__main__":
198
  demo.queue(max_size=20).launch()