Mahavaury2 commited on
Commit
30169f7
·
verified ·
1 Parent(s): 7c4c2ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -83
app.py CHANGED
@@ -5,34 +5,30 @@ from collections.abc import Iterator
5
  from threading import Thread
6
 
7
  import gradio as gr
8
- import spaces
9
  import torch
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
12
  #
13
- # 1) Custom Pastel Gradient CSS
14
  #
15
  CUSTOM_CSS = """
16
  .gradio-container {
17
  background: linear-gradient(to right, #FFDEE9, #B5FFFC);
 
18
  }
19
  """
20
 
21
  #
22
- # 2) Description: Add French greeting, plus any info
 
23
  #
24
- DESCRIPTION = """# Bonjour Dans le chat du consentement
25
-
26
- Mistral-7B Instruct Demo
27
  """
28
 
29
  if not torch.cuda.is_available():
30
- DESCRIPTION += (
31
- "\n<p style='color:red;'>Running on CPU - This is likely too large to run effectively.</p>"
32
- )
33
 
34
- MAX_MAX_NEW_TOKENS = 2048
35
- DEFAULT_MAX_NEW_TOKENS = 1024
36
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
37
 
38
  #
@@ -41,8 +37,8 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
41
  if torch.cuda.is_available():
42
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
43
  tokenizer = AutoTokenizer.from_pretrained(
44
- model_id,
45
- trust_remote_code=True # Might be needed for custom code
46
  )
47
  model = AutoModelForCausalLM.from_pretrained(
48
  model_id,
@@ -54,106 +50,61 @@ if torch.cuda.is_available():
54
  def generate(
55
  message: str,
56
  chat_history: list[dict],
57
- max_new_tokens: int = 1024,
58
- temperature: float = 0.6,
59
- top_p: float = 0.9,
60
- top_k: int = 50,
61
- repetition_penalty: float = 1.2,
62
  ) -> Iterator[str]:
63
  """
64
- This function handles streaming chat text as the model generates it.
65
- Uses Mistral's 'apply_chat_template' to handle chat-style prompting.
66
  """
67
  conversation = [*chat_history, {"role": "user", "content": message}]
68
 
 
69
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
70
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
71
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
72
- gr.Warning(
73
- f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
74
- )
75
  input_ids = input_ids.to(model.device)
76
 
 
77
  streamer = TextIteratorStreamer(
78
- tokenizer,
79
- timeout=20.0,
80
- skip_prompt=True,
81
  skip_special_tokens=True
82
  )
 
 
83
  generate_kwargs = dict(
84
- {"input_ids": input_ids},
85
  streamer=streamer,
86
- max_new_tokens=max_new_tokens,
87
  do_sample=True,
88
- top_p=top_p,
89
- top_k=top_k,
90
- temperature=temperature,
91
- num_beams=1,
92
- repetition_penalty=repetition_penalty,
93
  )
 
 
94
  t = Thread(target=model.generate, kwargs=generate_kwargs)
95
  t.start()
96
 
97
  outputs = []
98
  for text in streamer:
99
  outputs.append(text)
100
- # Stream partial output as it's generated
101
  yield "".join(outputs)
102
 
103
  #
104
- # 4) Build the Chat Interface with extra sliders
 
 
105
  #
106
  demo = gr.ChatInterface(
107
  fn=generate,
108
  description=DESCRIPTION,
109
- css=CUSTOM_CSS, # Use our pastel gradient
110
- additional_inputs=[
111
- gr.Slider(
112
- label="Max new tokens",
113
- minimum=1,
114
- maximum=MAX_MAX_NEW_TOKENS,
115
- step=1,
116
- value=DEFAULT_MAX_NEW_TOKENS,
117
- ),
118
- gr.Slider(
119
- label="Temperature",
120
- minimum=0.1,
121
- maximum=4.0,
122
- step=0.1,
123
- value=0.6,
124
- ),
125
- gr.Slider(
126
- label="Top-p (nucleus sampling)",
127
- minimum=0.05,
128
- maximum=1.0,
129
- step=0.05,
130
- value=0.9,
131
- ),
132
- gr.Slider(
133
- label="Top-k",
134
- minimum=1,
135
- maximum=1000,
136
- step=1,
137
- value=50,
138
- ),
139
- gr.Slider(
140
- label="Repetition penalty",
141
- minimum=1.0,
142
- maximum=2.0,
143
- step=0.05,
144
- value=1.2,
145
- ),
146
- ],
147
- stop_btn=None,
148
- examples=[
149
- ["Hello there! How are you doing?"],
150
- ["Can you explain briefly what the Python programming language is?"],
151
- ["Explain the plot of Cinderella in a sentence."],
152
- ["How many hours does it take a man to eat a Helicopter?"],
153
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
154
- ],
155
- type="messages",
156
  )
157
 
158
  if __name__ == "__main__":
159
- demo.queue(max_size=20).launch()
 
5
  from threading import Thread
6
 
7
  import gradio as gr
 
8
  import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
  #
12
+ # 1) Custom Pastel Gradient CSS, and force text to black
13
  #
14
  CUSTOM_CSS = """
15
  .gradio-container {
16
  background: linear-gradient(to right, #FFDEE9, #B5FFFC);
17
+ color: black; /* ensure text appears in black */
18
  }
19
  """
20
 
21
  #
22
+ # 2) Description: "Bonjour Dans le chat du consentement" in black
23
+ # Also add a CPU notice in black if no GPU is found.
24
  #
25
+ DESCRIPTION = """# Bonjour Dans le chat du consentement
26
+ Mistral-7B Instruct Demo
 
27
  """
28
 
29
  if not torch.cuda.is_available():
30
+ DESCRIPTION += "Running on CPU - This is likely too large to run effectively.\n"
 
 
31
 
 
 
32
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
33
 
34
  #
 
37
  if torch.cuda.is_available():
38
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
39
  tokenizer = AutoTokenizer.from_pretrained(
40
+ model_id,
41
+ trust_remote_code=True
42
  )
43
  model = AutoModelForCausalLM.from_pretrained(
44
  model_id,
 
50
  def generate(
51
  message: str,
52
  chat_history: list[dict],
 
 
 
 
 
53
  ) -> Iterator[str]:
54
  """
55
+ Minimal chat generation function: no sliders, no extra params.
 
56
  """
57
  conversation = [*chat_history, {"role": "user", "content": message}]
58
 
59
+ # Convert conversation to tokens
60
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
61
+ # If it exceeds max token length, trim
62
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
63
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
64
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
65
+
 
66
  input_ids = input_ids.to(model.device)
67
 
68
+ # Use a streamer to yield tokens as they are generated
69
  streamer = TextIteratorStreamer(
70
+ tokenizer,
71
+ timeout=20.0,
72
+ skip_prompt=True,
73
  skip_special_tokens=True
74
  )
75
+
76
+ # Basic generation settings (feel free to adjust if you want)
77
  generate_kwargs = dict(
78
+ input_ids=input_ids,
79
  streamer=streamer,
80
+ max_new_tokens=512, # Adjust if you want more or fewer tokens
81
  do_sample=True,
82
+ temperature=0.7,
83
+ top_p=0.9,
84
+ repetition_penalty=1.1,
 
 
85
  )
86
+
87
+ # Run generation in a background thread
88
  t = Thread(target=model.generate, kwargs=generate_kwargs)
89
  t.start()
90
 
91
  outputs = []
92
  for text in streamer:
93
  outputs.append(text)
 
94
  yield "".join(outputs)
95
 
96
  #
97
+ # 4) Build the Chat Interface
98
+ # - No additional sliders
99
+ # - No pre-filled example questions
100
  #
101
  demo = gr.ChatInterface(
102
  fn=generate,
103
  description=DESCRIPTION,
104
+ css=CUSTOM_CSS,
105
+ examples=None, # remove example prompts
106
+ type="messages"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  )
108
 
109
  if __name__ == "__main__":
110
+ demo.queue().launch()