Mahavaury2 commited on
Commit
c566ded
·
verified ·
1 Parent(s): 0387457

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -48
app.py CHANGED
@@ -1,61 +1,159 @@
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
 
5
- # 1) Define pastel gradient CSS
6
- css = """
 
 
7
  .gradio-container {
8
  background: linear-gradient(to right, #FFDEE9, #B5FFFC);
9
  }
10
  """
11
 
12
- title = "Bonjour Dans le chat du consentement"
13
-
14
- # 2) Load the Mistral model & tokenizer from HF Hub
15
- model_id = "mistralai/Mistral-7B-Instruct-v0.3"
16
-
17
- # If you're on a GPU Space, you can do:
18
- # device_map = "auto"
19
- # torch_dtype = torch.bfloat16
20
- # If you're on a CPU-only Space, remove those arguments or set device_map="cpu"
21
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
22
- model = AutoModelForCausalLM.from_pretrained(
23
- model_id,
24
- device_map="auto", # "auto" if you have GPU
25
- torch_dtype=torch.bfloat16, # for GPU. Remove or use float32 on CPU
26
- trust_remote_code=True
27
- )
28
 
29
- # 3) Create a text-generation pipeline
30
- generate_text = pipeline(
31
- "text-generation",
32
- model=model,
33
- tokenizer=tokenizer,
34
- max_length=512, # adjust as needed
35
- temperature=0.7, # adjust as needed
36
- do_sample=True
37
- )
 
 
38
 
39
- def mistral_inference(prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  """
41
- Passes user prompt to the pipeline and returns the generated text.
42
- We'll strip any special tokens and limit the output.
43
  """
44
- # The pipeline returns a list of dicts [{"generated_text": "..."}]
45
- outputs = generate_text(prompt)
46
- text_out = outputs[0]["generated_text"]
47
- return text_out
48
-
49
- # 4) Build the Gradio interface with a pastel background & greeting
50
- with gr.Blocks(css=css) as demo:
51
- gr.Markdown(f"<h1 style='text-align:center;'>{title}</h1>")
52
- user_input = gr.Textbox(label="Entrez votre message ici:", lines=3)
53
- output = gr.Textbox(label="Réponse du Modèle:", lines=5)
54
- send_button = gr.Button("Envoyer")
55
-
56
- # Link the button to the inference function
57
- send_button.click(fn=mistral_inference, inputs=user_input, outputs=output)
58
-
59
- # 5) Launch the app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if __name__ == "__main__":
61
- demo.launch()
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ from collections.abc import Iterator
5
+ from threading import Thread
6
+
7
  import gradio as gr
8
+ import spaces
9
  import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
12
+ #
13
+ # 1) Custom Pastel Gradient CSS
14
+ #
15
+ CUSTOM_CSS = """
16
  .gradio-container {
17
  background: linear-gradient(to right, #FFDEE9, #B5FFFC);
18
  }
19
  """
20
 
21
+ #
22
+ # 2) Description: Add French greeting, plus any info
23
+ #
24
+ DESCRIPTION = """# Bonjour Dans le chat du consentement
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ Mistral-7B Instruct Demo
27
+ """
28
+
29
+ if not torch.cuda.is_available():
30
+ DESCRIPTION += (
31
+ "\n<p style='color:red;'>Running on CPU - This is likely too large to run effectively.</p>"
32
+ )
33
+
34
+ MAX_MAX_NEW_TOKENS = 2048
35
+ DEFAULT_MAX_NEW_TOKENS = 1024
36
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
37
 
38
+ #
39
+ # 3) Load Mistral-7B Instruct (requires gating, GPU recommended)
40
+ #
41
+ if torch.cuda.is_available():
42
+ model_id = "mistralai/Mistral-7B-Instruct-v0.3"
43
+ tokenizer = AutoTokenizer.from_pretrained(
44
+ model_id,
45
+ trust_remote_code=True # Might be needed for custom code
46
+ )
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ model_id,
49
+ torch_dtype=torch.float16,
50
+ device_map="auto",
51
+ trust_remote_code=True
52
+ )
53
+
54
+ def generate(
55
+ message: str,
56
+ chat_history: list[dict],
57
+ max_new_tokens: int = 1024,
58
+ temperature: float = 0.6,
59
+ top_p: float = 0.9,
60
+ top_k: int = 50,
61
+ repetition_penalty: float = 1.2,
62
+ ) -> Iterator[str]:
63
  """
64
+ This function handles streaming chat text as the model generates it.
65
+ Uses Mistral's 'apply_chat_template' to handle chat-style prompting.
66
  """
67
+ conversation = [*chat_history, {"role": "user", "content": message}]
68
+
69
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
70
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
71
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
72
+ gr.Warning(
73
+ f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
74
+ )
75
+ input_ids = input_ids.to(model.device)
76
+
77
+ streamer = TextIteratorStreamer(
78
+ tokenizer,
79
+ timeout=20.0,
80
+ skip_prompt=True,
81
+ skip_special_tokens=True
82
+ )
83
+ generate_kwargs = dict(
84
+ {"input_ids": input_ids},
85
+ streamer=streamer,
86
+ max_new_tokens=max_new_tokens,
87
+ do_sample=True,
88
+ top_p=top_p,
89
+ top_k=top_k,
90
+ temperature=temperature,
91
+ num_beams=1,
92
+ repetition_penalty=repetition_penalty,
93
+ )
94
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
95
+ t.start()
96
+
97
+ outputs = []
98
+ for text in streamer:
99
+ outputs.append(text)
100
+ # Stream partial output as it's generated
101
+ yield "".join(outputs)
102
+
103
+ #
104
+ # 4) Build the Chat Interface with extra sliders
105
+ #
106
+ demo = gr.ChatInterface(
107
+ fn=generate,
108
+ description=DESCRIPTION,
109
+ css=CUSTOM_CSS, # Use our pastel gradient
110
+ additional_inputs=[
111
+ gr.Slider(
112
+ label="Max new tokens",
113
+ minimum=1,
114
+ maximum=MAX_MAX_NEW_TOKENS,
115
+ step=1,
116
+ value=DEFAULT_MAX_NEW_TOKENS,
117
+ ),
118
+ gr.Slider(
119
+ label="Temperature",
120
+ minimum=0.1,
121
+ maximum=4.0,
122
+ step=0.1,
123
+ value=0.6,
124
+ ),
125
+ gr.Slider(
126
+ label="Top-p (nucleus sampling)",
127
+ minimum=0.05,
128
+ maximum=1.0,
129
+ step=0.05,
130
+ value=0.9,
131
+ ),
132
+ gr.Slider(
133
+ label="Top-k",
134
+ minimum=1,
135
+ maximum=1000,
136
+ step=1,
137
+ value=50,
138
+ ),
139
+ gr.Slider(
140
+ label="Repetition penalty",
141
+ minimum=1.0,
142
+ maximum=2.0,
143
+ step=0.05,
144
+ value=1.2,
145
+ ),
146
+ ],
147
+ stop_btn=None,
148
+ examples=[
149
+ ["Hello there! How are you doing?"],
150
+ ["Can you explain briefly what the Python programming language is?"],
151
+ ["Explain the plot of Cinderella in a sentence."],
152
+ ["How many hours does it take a man to eat a Helicopter?"],
153
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
154
+ ],
155
+ type="messages",
156
+ )
157
+
158
  if __name__ == "__main__":
159
+ demo.queue(max_size=20).launch()