IAMJB commited on
Commit
6e2127b
·
1 Parent(s): c1f3a99

hyperbolic

Browse files
Files changed (2) hide show
  1. df/PaperCentral.py +1 -1
  2. paper_chat_tab.py +215 -116
df/PaperCentral.py CHANGED
@@ -483,7 +483,7 @@ class PaperCentral:
483
  neurips_id = re.search(r'id=([^&]+)', row["proceedings"])
484
  if neurips_id:
485
  neurips_id = neurips_id.group(1)
486
- return f'<a href="/?tab=tab-chat-with-paper&paper_id={neurips_id}" id="custom_button" target="_blank">✨ Chat with paper</a>'
487
  else:
488
  return ""
489
 
 
483
  neurips_id = re.search(r'id=([^&]+)', row["proceedings"])
484
  if neurips_id:
485
  neurips_id = neurips_id.group(1)
486
+ return f'<a href="/?tab=tab-chat-with-paper&paper_id={neurips_id}" id="custom_button" target="_self">✨ Chat with paper</a>'
487
  else:
488
  return ""
489
 
paper_chat_tab.py CHANGED
@@ -1,10 +1,12 @@
1
  import gradio as gr
2
  from PyPDF2 import PdfReader
3
  from bs4 import BeautifulSoup
4
-
 
5
  import requests
6
  from io import BytesIO
7
  from transformers import AutoTokenizer
 
8
 
9
  import os
10
  from openai import OpenAI
@@ -12,13 +14,41 @@ from openai import OpenAI
12
  # Cache for tokenizers to avoid reloading
13
  tokenizer_cache = {}
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Function to fetch paper information from OpenReview
17
  def fetch_paper_info_neurips(paper_id):
18
  url = f"https://openreview.net/forum?id={paper_id}"
19
  response = requests.get(url)
20
  if response.status_code != 200:
21
- return None, None
22
 
23
  html_content = response.content
24
  soup = BeautifulSoup(html_content, 'html.parser')
@@ -44,7 +74,6 @@ def fetch_paper_info_neurips(paper_id):
44
  abstract = 'Abstract not found'
45
 
46
  # Construct preamble in Markdown
47
- # preamble = f"**[{title}](https://openreview.net/forum?id={paper_id})**\n\n{author_list}\n\n**Abstract:**\n{abstract}"
48
  preamble = f"**[{title}](https://openreview.net/forum?id={paper_id})**\n\n{author_list}\n\n"
49
 
50
  return preamble
@@ -75,110 +104,33 @@ def fetch_paper_content(paper_id):
75
  return None
76
 
77
 
78
- def paper_chat_tab(paper_id):
79
- with gr.Blocks() as demo:
80
- with gr.Column():
81
- # Textbox to display the paper title and authors
82
- content = gr.Markdown(value="")
83
-
84
- # Preamble message to hint the user
85
- gr.Markdown("**Note:** Providing your own sambanova token can help you avoid rate limits.")
86
-
87
- # Input for Hugging Face token
88
- hf_token_input = gr.Textbox(
89
- label="Enter your sambanova token (optional)",
90
- type="password",
91
- placeholder="Enter your sambanova token to avoid rate limits"
92
- )
93
-
94
- models = [
95
- # "Meta-Llama-3.1-8B-Instruct",
96
- "Meta-Llama-3.1-70B-Instruct",
97
- # "Meta-Llama-3.1-405B-Instruct",
98
- ]
99
-
100
- default_model = models[0]
101
-
102
- # Dropdown for selecting the model
103
- model_dropdown = gr.Dropdown(
104
- label="Select Model",
105
- choices=models,
106
- value=default_model
107
- )
108
-
109
- # State to store the paper content
110
- paper_content = gr.State()
111
-
112
- # Create a column for each model, only visible if it's the default model
113
- columns = []
114
- for model_name in models:
115
- column = gr.Column(visible=(model_name == default_model))
116
- with column:
117
- chatbot = create_chat_interface(model_name, paper_content, hf_token_input)
118
- columns.append(column)
119
- gr.HTML(
120
- '<img src="https://venturebeat.com/wp-content/uploads/2020/02/SambaNovaLogo_H_F.jpg" width="100px" />')
121
- gr.Markdown("**Note:** This model is supported by SambaNova.")
122
-
123
- # Update visibility of columns based on the selected model
124
- def update_columns(selected_model):
125
- visibility = []
126
- for model_name in models:
127
- is_visible = model_name == selected_model
128
- visibility.append(gr.update(visible=is_visible))
129
- return visibility
130
-
131
- model_dropdown.change(
132
- fn=update_columns,
133
- inputs=model_dropdown,
134
- outputs=columns,
135
- api_name=False,
136
- queue=False,
137
- )
138
-
139
- # Function to update the content Markdown and paper_content when paper ID or model changes
140
- def update_paper_info(paper_id, selected_model):
141
- preamble = fetch_paper_info_neurips(paper_id)
142
- text = fetch_paper_content(paper_id)
143
- if text is None:
144
- return preamble, None
145
-
146
- return preamble, text
147
-
148
- # Update paper content when paper ID or model changes
149
- paper_id.change(
150
- fn=update_paper_info,
151
- inputs=[paper_id, model_dropdown],
152
- outputs=[content, paper_content]
153
- )
154
-
155
- model_dropdown.change(
156
- fn=update_paper_info,
157
- inputs=[paper_id, model_dropdown],
158
- outputs=[content, paper_content],
159
- queue=False,
160
- )
161
- return demo
162
-
163
-
164
- def create_chat_interface(model_name, paper_content, hf_token_input):
165
- # Load tokenizer and cache it
166
- if model_name not in tokenizer_cache:
167
- # Load the tokenizer from Hugging Face
168
- # tokenizer = AutoTokenizer.from_pretrained(model_name)
169
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
170
- token=os.environ.get("HF_TOKEN"))
171
- tokenizer_cache[model_name] = tokenizer
172
- else:
173
- tokenizer = tokenizer_cache[model_name]
174
-
175
- max_total_tokens = 50000 # Maximum tokens allowed
176
-
177
  # Define the function to handle the chat
178
- def get_fn(message, history, paper_content_value, hf_token_value):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  # Include the paper content as context
180
  if paper_content_value:
181
- context = f"The following is the content of the paper:\n{paper_content_value}\n\n"
182
  else:
183
  context = ""
184
 
@@ -237,24 +189,25 @@ def create_chat_interface(model_name, paper_content, hf_token_input):
237
  # Rebuild the final messages list including the (possibly truncated) context
238
  final_messages = []
239
  if context:
240
- final_messages.append({"role": "system", "content": context})
 
241
  final_messages.extend(messages)
242
 
243
- # Use the Hugging Face token if provided
244
- api_key = hf_token_value or os.environ.get("SAMBANOVA_API_KEY")
245
  if not api_key:
246
  raise ValueError("API token is not provided.")
247
 
248
- # Initialize the OpenAI client
249
  client = OpenAI(
250
- base_url="https://api.sambanova.ai/v1/",
251
  api_key=api_key,
252
  )
253
 
254
  try:
255
  # Create the chat completion
256
  completion = client.chat.completions.create(
257
- model=model_name,
258
  messages=final_messages,
259
  stream=True,
260
  )
@@ -263,9 +216,20 @@ def create_chat_interface(model_name, paper_content, hf_token_input):
263
  delta = chunk.choices[0].delta.content or ""
264
  response_text += delta
265
  yield response_text
266
- except Exception as e:
267
- error_message = f"Error: {str(e)}"
268
- yield error_message
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  # Create the ChatInterface
271
  chat_interface = gr.ChatInterface(
@@ -274,9 +238,144 @@ def create_chat_interface(model_name, paper_content, hf_token_input):
274
  label="Chatbot",
275
  scale=1,
276
  height=400,
277
- autoscroll=True
278
  ),
279
- additional_inputs=[paper_content, hf_token_input],
280
- # examples=["What are the main findings of this paper?", "Explain the methodology used in this research."]
281
  )
282
  return chat_interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from PyPDF2 import PdfReader
3
  from bs4 import BeautifulSoup
4
+ import openai
5
+ import traceback
6
  import requests
7
  from io import BytesIO
8
  from transformers import AutoTokenizer
9
+ import json
10
 
11
  import os
12
  from openai import OpenAI
 
14
  # Cache for tokenizers to avoid reloading
15
  tokenizer_cache = {}
16
 
17
+ # Global variables for providers
18
+ PROVIDERS = {
19
+ "Hyperbolic": {
20
+ "name": "hyperbolic",
21
+ "logo": "https://www.nftgators.com/wp-content/uploads/2024/07/Hyperbolic.jpg",
22
+ "endpoint": "https://api.hyperbolic.xyz/v1",
23
+ "api_key_env_var": "HYPERBOLIC_API_KEY",
24
+ "models": [
25
+ "meta-llama/Meta-Llama-3.1-405B-Instruct",
26
+ ],
27
+ "type": "tuples",
28
+ "max_total_tokens": "50000",
29
+ },
30
+ "SambaNova": {
31
+ "name": "SambaNova",
32
+ "logo": "https://venturebeat.com/wp-content/uploads/2020/02/SambaNovaLogo_H_F.jpg",
33
+ "endpoint": "https://api.sambanova.ai/v1/",
34
+ "api_key_env_var": "SAMBANOVA_API_KEY",
35
+ "models": [
36
+ "Meta-Llama-3.1-70B-Instruct",
37
+ # Add more models if needed
38
+ ],
39
+ "type": "tuples",
40
+ "max_total_tokens": "50000",
41
+ },
42
+
43
+ }
44
+
45
 
46
  # Function to fetch paper information from OpenReview
47
  def fetch_paper_info_neurips(paper_id):
48
  url = f"https://openreview.net/forum?id={paper_id}"
49
  response = requests.get(url)
50
  if response.status_code != 200:
51
+ return None
52
 
53
  html_content = response.content
54
  soup = BeautifulSoup(html_content, 'html.parser')
 
74
  abstract = 'Abstract not found'
75
 
76
  # Construct preamble in Markdown
 
77
  preamble = f"**[{title}](https://openreview.net/forum?id={paper_id})**\n\n{author_list}\n\n"
78
 
79
  return preamble
 
104
  return None
105
 
106
 
107
+ def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_token_input, default_type,
108
+ provider_max_total_tokens):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  # Define the function to handle the chat
110
+ print("the type is", default_type.value)
111
+
112
+ def get_fn(message, history, paper_content_value, hf_token_value, provider_name_value, model_name_value,
113
+ max_total_tokens):
114
+ provider_info = PROVIDERS[provider_name_value]
115
+ endpoint = provider_info['endpoint']
116
+ api_key_env_var = provider_info['api_key_env_var']
117
+ models = provider_info['models']
118
+ max_total_tokens = int(max_total_tokens)
119
+
120
+ # Load tokenizer and cache it
121
+ tokenizer_key = f"{provider_name_value}_{model_name_value}"
122
+ if tokenizer_key not in tokenizer_cache:
123
+ # Load the tokenizer; adjust the model path based on the provider and model
124
+ # This is a placeholder; you need to provide the correct tokenizer path
125
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
126
+ token=os.environ.get("HF_TOKEN"))
127
+ tokenizer_cache[tokenizer_key] = tokenizer
128
+ else:
129
+ tokenizer = tokenizer_cache[tokenizer_key]
130
+
131
  # Include the paper content as context
132
  if paper_content_value:
133
+ context = f"The discussion is about the following paper:\n{paper_content_value}\n\n"
134
  else:
135
  context = ""
136
 
 
189
  # Rebuild the final messages list including the (possibly truncated) context
190
  final_messages = []
191
  if context:
192
+ final_messages.append(
193
+ {"role": "system", "content": f"{context}"})
194
  final_messages.extend(messages)
195
 
196
+ # Use the provider's API key
197
+ api_key = hf_token_value or os.environ.get(api_key_env_var)
198
  if not api_key:
199
  raise ValueError("API token is not provided.")
200
 
201
+ # Initialize the OpenAI client with the provider's endpoint
202
  client = OpenAI(
203
+ base_url=endpoint,
204
  api_key=api_key,
205
  )
206
 
207
  try:
208
  # Create the chat completion
209
  completion = client.chat.completions.create(
210
+ model=model_name_value,
211
  messages=final_messages,
212
  stream=True,
213
  )
 
216
  delta = chunk.choices[0].delta.content or ""
217
  response_text += delta
218
  yield response_text
219
+ except json.JSONDecodeError as e:
220
+ print("Failed to decode JSON during the completion creation process.")
221
+ print(f"Error Message: {e.msg}")
222
+ print(f"Error Position: Line {e.lineno}, Column {e.colno} (Character {e.pos})")
223
+ print(f"Problematic JSON Data: {e.doc}")
224
+ yield f"{e.doc}"
225
+ except openai.OpenAIError as openai_err:
226
+ # Handle other OpenAI-related errors
227
+ print(f"An OpenAI error occurred: {openai_err}")
228
+ yield f"{openai_err}"
229
+ except Exception as ex:
230
+ # Handle any other exceptions
231
+ print(f"An unexpected error occurred: {ex}")
232
+ yield f"{ex}"
233
 
234
  # Create the ChatInterface
235
  chat_interface = gr.ChatInterface(
 
238
  label="Chatbot",
239
  scale=1,
240
  height=400,
241
+ autoscroll=True,
242
  ),
243
+ additional_inputs=[paper_content, hf_token_input, provider_dropdown, model_dropdown, provider_max_total_tokens],
244
+ type="tuples",
245
  )
246
  return chat_interface
247
+
248
+
249
+ def paper_chat_tab(paper_id):
250
+ with gr.Column():
251
+ # Textbox to display the paper title and authors
252
+ content = gr.Markdown(value="")
253
+
254
+ # Preamble message to hint the user
255
+ gr.Markdown("**Note:** Providing your own API token can help you avoid rate limits.")
256
+
257
+ # Input for API token
258
+ provider_names = list(PROVIDERS.keys())
259
+ default_provider = provider_names[0]
260
+
261
+ default_type = gr.State(value=PROVIDERS[default_provider]["type"])
262
+ default_max_total_tokens = gr.State(value=PROVIDERS[default_provider]["max_total_tokens"])
263
+
264
+ provider_dropdown = gr.Dropdown(
265
+ label="Select Provider",
266
+ choices=provider_names,
267
+ value=default_provider
268
+ )
269
+
270
+ hf_token_input = gr.Textbox(
271
+ label=f"Enter your {default_provider} API token (optional)",
272
+ type="password",
273
+ placeholder=f"Enter your {default_provider} API token to avoid rate limits"
274
+ )
275
+
276
+ # Dropdown for selecting the model
277
+ model_dropdown = gr.Dropdown(
278
+ label="Select Model",
279
+ choices=PROVIDERS[default_provider]['models'],
280
+ value=PROVIDERS[default_provider]['models'][0]
281
+ )
282
+
283
+ # Placeholder for the provider logo
284
+ logo_html = gr.HTML(
285
+ value=f'<img src="{PROVIDERS[default_provider]["logo"]}" width="100px" />'
286
+ )
287
+
288
+ # Note about the provider
289
+ note_markdown = gr.Markdown(f"**Note:** This model is supported by {default_provider}.")
290
+
291
+ # State to store the paper content
292
+ paper_content = gr.State()
293
+
294
+ # Function to update models and logo when provider changes
295
+ def update_provider(selected_provider):
296
+ provider_info = PROVIDERS[selected_provider]
297
+ models = provider_info['models']
298
+ logo_url = provider_info['logo']
299
+ chatbot_message_type = provider_info['type']
300
+ max_total_tokens = provider_info['max_total_tokens']
301
+
302
+ # Update the models dropdown
303
+ model_dropdown_choices = gr.update(choices=models, value=models[0])
304
+
305
+ # Update the logo image
306
+ logo_html_content = f'<img src="{logo_url}" width="100px" />'
307
+ logo_html_update = gr.update(value=logo_html_content)
308
+
309
+ # Update the note markdown
310
+ note_markdown_update = gr.update(value=f"**Note:** This model is supported by {selected_provider}.")
311
+
312
+ # Update the hf_token_input label and placeholder
313
+ hf_token_input_update = gr.update(
314
+ label=f"Enter your {selected_provider} API token (optional)",
315
+ placeholder=f"Enter your {selected_provider} API token to avoid rate limits"
316
+ )
317
+
318
+ return model_dropdown_choices, logo_html_update, note_markdown_update, hf_token_input_update, chatbot_message_type, max_total_tokens
319
+
320
+ provider_dropdown.change(
321
+ fn=update_provider,
322
+ inputs=provider_dropdown,
323
+ outputs=[model_dropdown, logo_html, note_markdown, hf_token_input, default_type, default_max_total_tokens],
324
+ queue=False
325
+ )
326
+
327
+ # Function to update the paper info
328
+ def update_paper_info(paper_id_value, selected_model):
329
+ preamble = fetch_paper_info_neurips(paper_id_value)
330
+ text = fetch_paper_content(paper_id_value)
331
+ if preamble is None:
332
+ preamble = "Paper not found or could not retrieve paper information."
333
+ if text is None:
334
+ return preamble, None
335
+ return preamble, text
336
+
337
+ # Update paper content when paper ID or model changes
338
+ paper_id.change(
339
+ fn=update_paper_info,
340
+ inputs=[paper_id, model_dropdown],
341
+ outputs=[content, paper_content]
342
+ )
343
+
344
+ model_dropdown.change(
345
+ fn=update_paper_info,
346
+ inputs=[paper_id, model_dropdown],
347
+ outputs=[content, paper_content],
348
+ queue=False,
349
+ )
350
+
351
+ # Create the chat interface
352
+ chat_interface = create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_token_input,
353
+ default_type, default_max_total_tokens)
354
+
355
+
356
+ def main():
357
+ """
358
+ Launches the Gradio app.
359
+ """
360
+ with gr.Blocks(css_paths="style.css") as demo:
361
+ x = gr.State(value="") # Initialize with an empty state
362
+
363
+ def update_state():
364
+ """
365
+ Function to update the state.
366
+ """
367
+ return "5G7ve8E1Lu"
368
+
369
+ with gr.Row():
370
+ update_button = gr.Button("Update State") # Button to update the state
371
+
372
+ # Update the state and reflect the change in the display
373
+ update_button.click(update_state, inputs=[], outputs=[x])
374
+ paper_chat_tab(x)
375
+
376
+ demo.launch(ssr_mode=False)
377
+
378
+
379
+ # Run the main function when the script is executed
380
+ if __name__ == "__main__":
381
+ main()