abdullahalzubaer commited on
Commit
c0a64ff
·
verified ·
1 Parent(s): e5135c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -1
app.py CHANGED
@@ -1,3 +1,104 @@
 
 
 
1
  import gradio as gr
2
 
3
- gr.Interface.load("mistralai/Mistral-7B-Instruct-v0.2").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref: https://huggingface.co/spaces/Skier8402/mistral-super-fast/blob/main/app.py
2
+
3
+ from huggingface_hub import InferenceClient
4
  import gradio as gr
5
 
6
+ client = InferenceClient(
7
+ "mistralai/Mistral-7B-Instruct-v0.2"
8
+ )
9
+
10
+
11
+ def format_prompt(message, history):
12
+ prompt = "<s>"
13
+ for user_prompt, bot_response in history:
14
+ prompt += f"[INST] {user_prompt} [/INST]"
15
+ prompt += f" {bot_response}</s> "
16
+ prompt += f"[INST] {message} [/INST]"
17
+ return prompt
18
+
19
+ def generate(
20
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
21
+ ):
22
+ temperature = float(temperature)
23
+ if temperature < 1e-2:
24
+ temperature = 1e-2
25
+ top_p = float(top_p)
26
+
27
+ generate_kwargs = dict(
28
+ temperature=temperature,
29
+ max_new_tokens=max_new_tokens,
30
+ top_p=top_p,
31
+ repetition_penalty=repetition_penalty,
32
+ do_sample=True,
33
+ seed=42,
34
+ )
35
+
36
+ formatted_prompt = format_prompt(prompt, history)
37
+
38
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
39
+ output = ""
40
+
41
+ for response in stream:
42
+ output += response.token.text
43
+ yield output
44
+ return output
45
+
46
+
47
+ additional_inputs=[
48
+ gr.Slider(
49
+ label="Temperature",
50
+ value=0.9,
51
+ minimum=0.0,
52
+ maximum=1.0,
53
+ step=0.05,
54
+ interactive=True,
55
+ info="Higher values produce more diverse outputs",
56
+ ),
57
+ gr.Slider(
58
+ label="Max new tokens",
59
+ value=256,
60
+ minimum=0,
61
+ maximum=1048,
62
+ step=64,
63
+ interactive=True,
64
+ info="The maximum numbers of new tokens",
65
+ ),
66
+ gr.Slider(
67
+ label="Top-p (nucleus sampling)",
68
+ value=0.90,
69
+ minimum=0.0,
70
+ maximum=1,
71
+ step=0.05,
72
+ interactive=True,
73
+ info="Higher values sample more low-probability tokens",
74
+ ),
75
+ gr.Slider(
76
+ label="Repetition penalty",
77
+ value=1.2,
78
+ minimum=1.0,
79
+ maximum=2.0,
80
+ step=0.05,
81
+ interactive=True,
82
+ info="Penalize repeated tokens",
83
+ )
84
+ ]
85
+
86
+ css = """
87
+ #mkd {
88
+ height: 500px;
89
+ overflow: auto;
90
+ border: 1px solid #ccc;
91
+ }
92
+ """
93
+
94
+ with gr.Blocks(css=css) as demo:
95
+ gr.HTML("<h1><center>Mistral 7B Instruct<h1><center>")
96
+ gr.HTML("<h3><center>In this demo, you can chat with <a href='https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2'>Mistral-7B-Instruct</a> model. 💬<h3><center>")
97
+ gr.HTML("<h3><center>Learn more about the model <a href='https://huggingface.co/docs/transformers/main/model_doc/mistral'>here</a>. 📚<h3><center>")
98
+ gr.ChatInterface(
99
+ generate,
100
+ additional_inputs=additional_inputs,
101
+ examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."]]
102
+ )
103
+
104
+ demo.queue().launch(debug=True)