Shriharsh commited on
Commit
67cfd82
·
verified ·
1 Parent(s): b484597

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -34
app.py CHANGED
@@ -1,9 +1,23 @@
 
 
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
 
3
 
4
- client = InferenceClient(
5
- "mistralai/Mistral-7B-Instruct-v0.3"
6
- )
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def format_prompt(message, history):
9
  prompt = "<s>"
@@ -16,32 +30,47 @@ def format_prompt(message, history):
16
  def generate(
17
  prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
18
  ):
19
- temperature = float(temperature)
20
- if temperature < 1e-2:
21
- temperature = 1e-2
22
- top_p = float(top_p)
 
23
 
24
- generate_kwargs = dict(
25
- temperature=temperature,
26
- max_new_tokens=max_new_tokens,
27
- top_p=top_p,
28
- repetition_penalty=repetition_penalty,
29
- do_sample=True,
30
- seed=42,
31
- )
32
 
33
- formatted_prompt = format_prompt(prompt, history)
 
34
 
35
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
36
- output = ""
 
 
 
 
 
 
37
 
38
- for response in stream:
39
- output += response.token.text
40
- yield output
41
- return output
42
 
 
 
 
 
 
 
43
 
44
- additional_inputs=[
 
45
  gr.Slider(
46
  label="Temperature",
47
  value=0.9,
@@ -58,7 +87,7 @@ additional_inputs=[
58
  maximum=1048,
59
  step=64,
60
  interactive=True,
61
- info="The maximum numbers of new tokens",
62
  ),
63
  gr.Slider(
64
  label="Top-p (nucleus sampling)",
@@ -77,21 +106,26 @@ additional_inputs=[
77
  step=0.05,
78
  interactive=True,
79
  info="Penalize repeated tokens",
80
- )
81
  ]
82
 
83
- # Create a Chatbot object with the desired height
84
- chatbot = gr.Chatbot(height=450,
85
- layout="bubble")
86
 
 
87
  with gr.Blocks() as demo:
88
- gr.HTML("<h1><center>🤖 Mistral-7B-Chat 💬<h1><center>")
89
  gr.ChatInterface(
90
- generate,
91
- chatbot=chatbot, # Use the created Chatbot object
92
  additional_inputs=additional_inputs,
93
- examples=[["Give me the code for Binary Search in C++"], ["Explain the chapter of The Grand Inquistor from The Brothers Karmazov."], ["Explain Newton's second law."]],
94
-
 
 
 
95
  )
96
 
97
- demo.queue().launch(debug=True)
 
 
 
1
+ import os
2
+ import logging
3
  from huggingface_hub import InferenceClient
4
  import gradio as gr
5
+ from requests.exceptions import ConnectionError
6
 
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Initialize the Hugging Face Inference Client
12
+ try:
13
+ client = InferenceClient(
14
+ model="mistralai/Mistral-7B-Instruct-v0.3",
15
+ token=os.getenv("HF_TOKEN"), # Ensure HF_TOKEN is set in your environment
16
+ timeout=30,
17
+ )
18
+ except Exception as e:
19
+ logger.error(f"Failed to initialize InferenceClient: {e}")
20
+ raise
21
 
22
  def format_prompt(message, history):
23
  prompt = "<s>"
 
30
  def generate(
31
  prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
32
  ):
33
+ try:
34
+ temperature = float(temperature)
35
+ if temperature < 1e-2:
36
+ temperature = 1e-2
37
+ top_p = float(top_p)
38
 
39
+ generate_kwargs = dict(
40
+ temperature=temperature,
41
+ max_new_tokens=max_new_tokens,
42
+ top_p=top_p,
43
+ repetition_penalty=repetition_penalty,
44
+ do_sample=True,
45
+ seed=42,
46
+ )
47
 
48
+ formatted_prompt = format_prompt(prompt, history)
49
+ logger.info("Sending request to Hugging Face API")
50
 
51
+ stream = client.text_generation(
52
+ formatted_prompt,
53
+ **generate_kwargs,
54
+ stream=True,
55
+ details=True,
56
+ return_full_text=False,
57
+ )
58
+ output = ""
59
 
60
+ for response in stream:
61
+ output += response.token.text
62
+ yield output
63
+ return output
64
 
65
+ except ConnectionError as e:
66
+ logger.error(f"Network error: {e}")
67
+ yield "Error: Unable to connect to the Hugging Face API. Please check your internet connection and try again."
68
+ except Exception as e:
69
+ logger.error(f"Error during text generation: {e}")
70
+ yield f"Error: {str(e)}"
71
 
72
+ # Define additional inputs for Gradio interface
73
+ additional_inputs = [
74
  gr.Slider(
75
  label="Temperature",
76
  value=0.9,
 
87
  maximum=1048,
88
  step=64,
89
  interactive=True,
90
+ info="The maximum number of new tokens",
91
  ),
92
  gr.Slider(
93
  label="Top-p (nucleus sampling)",
 
106
  step=0.05,
107
  interactive=True,
108
  info="Penalize repeated tokens",
109
+ ),
110
  ]
111
 
112
+ # Create a Chatbot object
113
+ chatbot = gr.Chatbot(height=450, layout="bubble")
 
114
 
115
+ # Build the Gradio interface
116
  with gr.Blocks() as demo:
117
+ gr.HTML("<h1><center>🤖 Mistral-7B-Chat 💬</center></h1>")
118
  gr.ChatInterface(
119
+ fn=generate,
120
+ chatbot=chatbot,
121
  additional_inputs=additional_inputs,
122
+ examples=[
123
+ ["Give me the code for Binary Search in C++"],
124
+ ["Explain the chapter of The Grand Inquisitor from The Brothers Karamazov."],
125
+ ["Explain Newton's second law."],
126
+ ],
127
  )
128
 
129
+ if __name__ == "__main__":
130
+ logger.info("Starting Gradio application")
131
+ demo.launch()