Mahavaury2 commited on
Commit
b1e539f
·
verified ·
1 Parent(s): eb04a36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -39
app.py CHANGED
@@ -1,34 +1,48 @@
 
 
1
  import os
2
  from collections.abc import Iterator
3
  from threading import Thread
4
 
5
  import gradio as gr
 
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
 
 
 
 
9
  CUSTOM_CSS = """
10
  .gradio-container {
11
  background: linear-gradient(to right, #FFDEE9, #B5FFFC);
12
- color: black;
13
  }
14
  """
15
 
16
- DESCRIPTION = """# Bonjour Dans le chat du consentement
17
- Mistral-7B Instruct Demo
 
 
 
 
18
  """
19
 
20
- MAX_INPUT_TOKEN_LENGTH = 4096 # just a default
 
 
 
21
 
22
- # Define model/tokenizer at the top so they're visible in all scopes
23
- tokenizer = None
24
- model = None
25
 
26
- # Try to load the model only if GPU is available
 
 
27
  if torch.cuda.is_available():
28
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
29
  tokenizer = AutoTokenizer.from_pretrained(
30
- model_id,
31
- trust_remote_code=True
32
  )
33
  model = AutoModelForCausalLM.from_pretrained(
34
  model_id,
@@ -36,57 +50,110 @@ if torch.cuda.is_available():
36
  device_map="auto",
37
  trust_remote_code=True
38
  )
39
- else:
40
- # Show a warning in the description
41
- DESCRIPTION += "\n**Running on CPU** — This model is too large for CPU inference!"
42
-
43
- def generate(message: str, chat_history: list[dict]) -> Iterator[str]:
44
- # If there's no GPU (thus no tokenizer/model), return an error
45
- if tokenizer is None or model is None:
46
- yield "Error: No GPU available. Unable to load Mistral-7B-Instruct."
47
- return
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  conversation = [*chat_history, {"role": "user", "content": message}]
50
 
51
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
52
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
53
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
54
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
55
-
 
56
  input_ids = input_ids.to(model.device)
57
 
58
  streamer = TextIteratorStreamer(
59
- tokenizer,
60
- timeout=20.0,
61
- skip_prompt=True,
62
  skip_special_tokens=True
63
  )
64
-
65
- generate_kwargs = {
66
- "input_ids": input_ids,
67
- "streamer": streamer,
68
- "max_new_tokens": 512,
69
- "do_sample": True,
70
- "temperature": 0.7,
71
- "top_p": 0.9,
72
- "repetition_penalty": 1.1,
73
- }
74
-
75
  t = Thread(target=model.generate, kwargs=generate_kwargs)
76
  t.start()
77
 
78
  outputs = []
79
  for text in streamer:
80
  outputs.append(text)
 
81
  yield "".join(outputs)
82
 
 
 
 
83
  demo = gr.ChatInterface(
84
  fn=generate,
85
  description=DESCRIPTION,
86
- css=CUSTOM_CSS,
87
- examples=None,
88
- type="messages"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  )
90
 
91
  if __name__ == "__main__":
92
- demo.queue().launch()
 
1
+ #!/usr/bin/env python
2
+
3
  import os
4
  from collections.abc import Iterator
5
  from threading import Thread
6
 
7
  import gradio as gr
8
+ import spaces
9
  import torch
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
12
+ #
13
+ # 1) Custom Pastel Gradient CSS
14
+ #
15
  CUSTOM_CSS = """
16
  .gradio-container {
17
  background: linear-gradient(to right, #FFDEE9, #B5FFFC);
 
18
  }
19
  """
20
 
21
+ #
22
+ # 2) Description: Add French greeting, plus any info
23
+ #
24
+ DESCRIPTION = """# Bonjour Dans le chat du consentement
25
+
26
+ Mistral-7B Instruct Demo
27
  """
28
 
29
+ if not torch.cuda.is_available():
30
+ DESCRIPTION += (
31
+ "\n<p style='color:red;'>Running on CPU - This is likely too large to run effectively.</p>"
32
+ )
33
 
34
+ MAX_MAX_NEW_TOKENS = 2048
35
+ DEFAULT_MAX_NEW_TOKENS = 1024
36
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
37
 
38
+ #
39
+ # 3) Load Mistral-7B Instruct (requires gating, GPU recommended)
40
+ #
41
  if torch.cuda.is_available():
42
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
43
  tokenizer = AutoTokenizer.from_pretrained(
44
+ model_id,
45
+ trust_remote_code=True # Might be needed for custom code
46
  )
47
  model = AutoModelForCausalLM.from_pretrained(
48
  model_id,
 
50
  device_map="auto",
51
  trust_remote_code=True
52
  )
 
 
 
 
 
 
 
 
 
53
 
54
+ def generate(
55
+ message: str,
56
+ chat_history: list[dict],
57
+ max_new_tokens: int = 1024,
58
+ temperature: float = 0.6,
59
+ top_p: float = 0.9,
60
+ top_k: int = 50,
61
+ repetition_penalty: float = 1.2,
62
+ ) -> Iterator[str]:
63
+ """
64
+ This function handles streaming chat text as the model generates it.
65
+ Uses Mistral's 'apply_chat_template' to handle chat-style prompting.
66
+ """
67
  conversation = [*chat_history, {"role": "user", "content": message}]
68
 
69
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
70
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
71
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
72
+ gr.Warning(
73
+ f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
74
+ )
75
  input_ids = input_ids.to(model.device)
76
 
77
  streamer = TextIteratorStreamer(
78
+ tokenizer,
79
+ timeout=20.0,
80
+ skip_prompt=True,
81
  skip_special_tokens=True
82
  )
83
+ generate_kwargs = dict(
84
+ {"input_ids": input_ids},
85
+ streamer=streamer,
86
+ max_new_tokens=max_new_tokens,
87
+ do_sample=True,
88
+ top_p=top_p,
89
+ top_k=top_k,
90
+ temperature=temperature,
91
+ num_beams=1,
92
+ repetition_penalty=repetition_penalty,
93
+ )
94
  t = Thread(target=model.generate, kwargs=generate_kwargs)
95
  t.start()
96
 
97
  outputs = []
98
  for text in streamer:
99
  outputs.append(text)
100
+ # Stream partial output as it's generated
101
  yield "".join(outputs)
102
 
103
+ #
104
+ # 4) Build the Chat Interface with extra sliders
105
+ #
106
  demo = gr.ChatInterface(
107
  fn=generate,
108
  description=DESCRIPTION,
109
+ css=CUSTOM_CSS, # Use our pastel gradient
110
+ additional_inputs=[
111
+ gr.Slider(
112
+ label="Max new tokens",
113
+ minimum=1,
114
+ maximum=MAX_MAX_NEW_TOKENS,
115
+ step=1,
116
+ value=DEFAULT_MAX_NEW_TOKENS,
117
+ ),
118
+ gr.Slider(
119
+ label="Temperature",
120
+ minimum=0.1,
121
+ maximum=4.0,
122
+ step=0.1,
123
+ value=0.6,
124
+ ),
125
+ gr.Slider(
126
+ label="Top-p (nucleus sampling)",
127
+ minimum=0.05,
128
+ maximum=1.0,
129
+ step=0.05,
130
+ value=0.9,
131
+ ),
132
+ gr.Slider(
133
+ label="Top-k",
134
+ minimum=1,
135
+ maximum=1000,
136
+ step=1,
137
+ value=50,
138
+ ),
139
+ gr.Slider(
140
+ label="Repetition penalty",
141
+ minimum=1.0,
142
+ maximum=2.0,
143
+ step=0.05,
144
+ value=1.2,
145
+ ),
146
+ ],
147
+ stop_btn=None,
148
+ examples=[
149
+ ["Hello there! How are you doing?"],
150
+ ["Can you explain briefly what the Python programming language is?"],
151
+ ["Explain the plot of Cinderella in a sentence."],
152
+ ["How many hours does it take a man to eat a Helicopter?"],
153
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
154
+ ],
155
+ type="messages",
156
  )
157
 
158
  if __name__ == "__main__":
159
+ demo.queue(max_size=20).launch()