rodrisouza commited on
Commit
eabbb32
·
verified ·
1 Parent(s): 020432b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -30
app.py CHANGED
@@ -4,7 +4,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import pandas as pd
5
  from datetime import datetime, timedelta, timezone
6
  import torch
7
- from config import hugging_face_token, init_google_sheets_client, models, quantized_models, default_model_name, user_names, google_sheets_name, MAX_INTERACTIONS
8
  import spaces
9
 
10
  # Hack for ZeroGPU
@@ -48,28 +48,14 @@ def load_model(model_name):
48
  del model
49
  torch.cuda.empty_cache()
50
 
51
- tokenizer = AutoTokenizer.from_pretrained(
52
- models[model_name],
53
- padding_side='left',
54
- token=hugging_face_token,
55
- trust_remote_code=True
56
- )
57
 
58
  # Ensure the padding token is set
59
  if tokenizer.pad_token is None:
60
  tokenizer.pad_token = tokenizer.eos_token
61
  tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
62
 
63
- model = AutoModelForCausalLM.from_pretrained(
64
- models[model_name],
65
- token=hugging_face_token,
66
- trust_remote_code=True
67
- )
68
-
69
- # Only move to CUDA if it's not a quantized model
70
- if model_name not in quantized_models:
71
- model = model.to("cuda")
72
-
73
  selected_model = model_name
74
  except Exception as e:
75
  print(f"Error loading model {model_name}: {e}")
@@ -90,12 +76,6 @@ def interact(user_input, history, interaction_count):
90
  if tokenizer is None or model is None:
91
  raise ValueError("Tokenizer or model is not initialized.")
92
 
93
- # Determine the device to use (either CUDA if available, or CPU)
94
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95
-
96
- # Ensure the model is on the correct device
97
- model.to(device)
98
-
99
  if interaction_count >= MAX_INTERACTIONS:
100
  user_input += ". Thank you for your questions. Our session is now over. Goodbye!"
101
 
@@ -108,8 +88,8 @@ def interact(user_input, history, interaction_count):
108
 
109
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
110
 
111
- # Move input tensor to the same device as the model
112
- input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
113
  chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, temperature=0.1)
114
  response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
115
 
@@ -202,9 +182,6 @@ def load_user_guide():
202
  with open('user_guide.txt', 'r') as file:
203
  return file.read()
204
 
205
- # Combine both model dictionaries
206
- all_models = {**models, **quantized_models}
207
-
208
  # Create the chat interface using Gradio Blocks
209
  with gr.Blocks() as demo:
210
  with gr.Tabs():
@@ -213,7 +190,7 @@ with gr.Blocks() as demo:
213
 
214
  gr.Markdown("## Context")
215
  with gr.Group():
216
- model_dropdown = gr.Dropdown(choices=list(all_models.keys()), label="Select Model", value=default_model_name)
217
  user_dropdown = gr.Dropdown(choices=user_names, label="Select User Name")
218
  initial_story = stories[0]["title"] if stories else None
219
  story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story)
@@ -244,4 +221,4 @@ with gr.Blocks() as demo:
244
  send_message_button.click(fn=interact, inputs=[chatbot_input, chat_history_json, interaction_count], outputs=[chatbot_input, chatbot_output, chat_history_json, interaction_count])
245
  save_button.click(fn=save_comment_score, inputs=[chatbot_output, score_input, comment_input, story_dropdown, user_dropdown, system_prompt_dropdown], outputs=[data_table, comment_input])
246
 
247
- demo.launch()
 
4
  import pandas as pd
5
  from datetime import datetime, timedelta, timezone
6
  import torch
7
+ from config import hugging_face_token, init_google_sheets_client, models, default_model_name, user_names, google_sheets_name, MAX_INTERACTIONS
8
  import spaces
9
 
10
  # Hack for ZeroGPU
 
48
  del model
49
  torch.cuda.empty_cache()
50
 
51
+ tokenizer = AutoTokenizer.from_pretrained(models[model_name], padding_side='left', token=hugging_face_token, trust_remote_code=True)
 
 
 
 
 
52
 
53
  # Ensure the padding token is set
54
  if tokenizer.pad_token is None:
55
  tokenizer.pad_token = tokenizer.eos_token
56
  tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
57
 
58
+ model = AutoModelForCausalLM.from_pretrained(models[model_name], token=hugging_face_token, trust_remote_code=True).to("cuda")
 
 
 
 
 
 
 
 
 
59
  selected_model = model_name
60
  except Exception as e:
61
  print(f"Error loading model {model_name}: {e}")
 
76
  if tokenizer is None or model is None:
77
  raise ValueError("Tokenizer or model is not initialized.")
78
 
 
 
 
 
 
 
79
  if interaction_count >= MAX_INTERACTIONS:
80
  user_input += ". Thank you for your questions. Our session is now over. Goodbye!"
81
 
 
88
 
89
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
90
 
91
+ # Generate response using selected model
92
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
93
  chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, temperature=0.1)
94
  response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
95
 
 
182
  with open('user_guide.txt', 'r') as file:
183
  return file.read()
184
 
 
 
 
185
  # Create the chat interface using Gradio Blocks
186
  with gr.Blocks() as demo:
187
  with gr.Tabs():
 
190
 
191
  gr.Markdown("## Context")
192
  with gr.Group():
193
+ model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model", value=selected_model)
194
  user_dropdown = gr.Dropdown(choices=user_names, label="Select User Name")
195
  initial_story = stories[0]["title"] if stories else None
196
  story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story)
 
221
  send_message_button.click(fn=interact, inputs=[chatbot_input, chat_history_json, interaction_count], outputs=[chatbot_input, chatbot_output, chat_history_json, interaction_count])
222
  save_button.click(fn=save_comment_score, inputs=[chatbot_output, score_input, comment_input, story_dropdown, user_dropdown, system_prompt_dropdown], outputs=[data_table, comment_input])
223
 
224
+ demo.launch()