Mahavaury2 commited on
Commit
b507d58
·
verified ·
1 Parent(s): 0c40459

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -49
app.py CHANGED
@@ -9,8 +9,10 @@ import spaces
9
  import torch
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
 
12
  DESCRIPTION = "# Mistral-7B v0.3"
13
 
 
14
  if not torch.cuda.is_available():
15
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
16
 
@@ -20,20 +22,27 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
20
 
21
  if torch.cuda.is_available():
22
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
23
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
24
  tokenizer = AutoTokenizer.from_pretrained(model_id)
25
 
26
-
27
  @spaces.GPU
28
  def generate(
29
  message: str,
30
  chat_history: list[dict],
 
31
  max_new_tokens: int = 1024,
32
  temperature: float = 0.6,
33
  top_p: float = 0.9,
34
  top_k: int = 50,
35
  repetition_penalty: float = 1.2,
36
  ) -> Iterator[str]:
 
 
 
37
  conversation = [*chat_history, {"role": "user", "content": message}]
38
 
39
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
@@ -42,7 +51,12 @@ def generate(
42
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
43
  input_ids = input_ids.to(model.device)
44
 
45
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
46
  generate_kwargs = dict(
47
  {"input_ids": input_ids},
48
  streamer=streamer,
@@ -62,57 +76,33 @@ def generate(
62
  outputs.append(text)
63
  yield "".join(outputs)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
 
66
  demo = gr.ChatInterface(
67
  fn=generate,
68
- additional_inputs=[
69
- gr.Slider(
70
- label="Max new tokens",
71
- minimum=1,
72
- maximum=MAX_MAX_NEW_TOKENS,
73
- step=1,
74
- value=DEFAULT_MAX_NEW_TOKENS,
75
- ),
76
- gr.Slider(
77
- label="Temperature",
78
- minimum=0.1,
79
- maximum=4.0,
80
- step=0.1,
81
- value=0.6,
82
- ),
83
- gr.Slider(
84
- label="Top-p (nucleus sampling)",
85
- minimum=0.05,
86
- maximum=1.0,
87
- step=0.05,
88
- value=0.9,
89
- ),
90
- gr.Slider(
91
- label="Top-k",
92
- minimum=1,
93
- maximum=1000,
94
- step=1,
95
- value=50,
96
- ),
97
- gr.Slider(
98
- label="Repetition penalty",
99
- minimum=1.0,
100
- maximum=2.0,
101
- step=0.05,
102
- value=1.2,
103
- ),
104
- ],
105
- stop_btn=None,
106
- examples=[
107
- ["Hello there! How are you doing?"],
108
- ["Can you explain briefly to me what is the Python programming language?"],
109
- ["Explain the plot of Cinderella in a sentence."],
110
- ["How many hours does it take a man to eat a Helicopter?"],
111
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
112
- ],
113
  type="messages",
114
  description=DESCRIPTION,
115
- css_paths="style.css",
 
116
  )
117
 
118
  if __name__ == "__main__":
 
9
  import torch
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
12
+ # Description de la démo
13
  DESCRIPTION = "# Mistral-7B v0.3"
14
 
15
+ # Si pas de GPU détecté, afficher un message
16
  if not torch.cuda.is_available():
17
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
18
 
 
22
 
23
  if torch.cuda.is_available():
24
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_id,
27
+ torch_dtype=torch.float16,
28
+ device_map="auto"
29
+ )
30
  tokenizer = AutoTokenizer.from_pretrained(model_id)
31
 
 
32
  @spaces.GPU
33
  def generate(
34
  message: str,
35
  chat_history: list[dict],
36
+ # Valeurs par défaut pour la génération
37
  max_new_tokens: int = 1024,
38
  temperature: float = 0.6,
39
  top_p: float = 0.9,
40
  top_k: int = 50,
41
  repetition_penalty: float = 1.2,
42
  ) -> Iterator[str]:
43
+ """
44
+ Génération de texte à partir de l'historique de conversation (chat_history) et du message utilisateur.
45
+ """
46
  conversation = [*chat_history, {"role": "user", "content": message}]
47
 
48
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
51
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
52
  input_ids = input_ids.to(model.device)
53
 
54
+ streamer = TextIteratorStreamer(
55
+ tokenizer,
56
+ timeout=20.0,
57
+ skip_prompt=True,
58
+ skip_special_tokens=True
59
+ )
60
  generate_kwargs = dict(
61
  {"input_ids": input_ids},
62
  streamer=streamer,
 
76
  outputs.append(text)
77
  yield "".join(outputs)
78
 
79
+ # CSS pastel (dégradé sur toute la page)
80
+ custom_css = """
81
+ .gradio-container {
82
+ background: linear-gradient(135deg, #FDE2E2, #E2ECFD);
83
+ }
84
+ """
85
+
86
+ # Liste des questions prédéfinies
87
+ predefined_examples = [
88
+ ["I - C’est quoi le consentement ? Comment savoir si ma copine a envie de moi ?"],
89
+ ["II - C’est quoi une agression sexuelle ?"],
90
+ ["III - C’est quoi un viol ?"],
91
+ ["IV - C’est quoi un attouchement ?"],
92
+ ["V - C’est quoi un harcèlement sexuel ?"],
93
+ ["VI - Est ce illégal de visionner du porno ?"],
94
+ ["VII - Mon copain me demande un nude, dois-je le faire ?"],
95
+ ["VIII - Mon ancien copain me menace de poster des photos de moi nue sur internet, que faire ?"],
96
+ ["IX - Que puis-je faire si un membre de ma famille me touche d’une manière bizarre, mais que j’ai peur de parler ou de ne pas être cru ?"],
97
+ ]
98
 
99
+ # Création de l'interface de chat
100
  demo = gr.ChatInterface(
101
  fn=generate,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  type="messages",
103
  description=DESCRIPTION,
104
+ css=custom_css,
105
+ examples=predefined_examples,
106
  )
107
 
108
  if __name__ == "__main__":