rodrisouza commited on
Commit
cc8e923
·
verified ·
1 Parent(s): 33b0fe4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -38
app.py CHANGED
@@ -4,7 +4,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import pandas as pd
5
  from datetime import datetime, timedelta, timezone
6
  import torch
7
- from config import hugging_face_token, init_google_sheets_client, models, default_model_name, user_names, google_sheets_name, MAX_INTERACTIONS
8
  import spaces
9
 
10
  # Hack for ZeroGPU
@@ -13,8 +13,8 @@ torch.jit.script = lambda f: f
13
  # Initialize Google Sheets client
14
  client = init_google_sheets_client()
15
  sheet = client.open(google_sheets_name)
16
- stories_sheet = sheet.worksheet("Stories") # Assuming stories are in a separate sheet
17
- prompts_sheet = sheet.worksheet("System Prompts") # Assuming system prompts are in a separate sheet
18
 
19
  # Load stories from Google Sheets
20
  def load_stories():
@@ -23,14 +23,14 @@ def load_stories():
23
  return stories
24
 
25
  # Load system prompts from Google Sheets
26
- def load_prompts():
27
- prompts_data = prompts_sheet.get_all_values()
28
- prompts = [prompt[0] for prompt in prompts_data if prompt[0] != "System Prompts"] # Skip header row
29
- return prompts
30
 
31
- # Load available stories and prompts
32
  stories = load_stories()
33
- prompts = load_prompts()
34
 
35
  # Initialize the selected model
36
  selected_model = default_model_name
@@ -65,21 +65,18 @@ def load_model(model_name):
65
  # Ensure the initial model is loaded
66
  tokenizer, model = load_model(selected_model)
67
 
68
- # Chat history and interaction count
69
  chat_history = []
 
70
 
71
  # Function to handle interaction with model
72
  @spaces.GPU
73
- def interact(user_input, history, interaction_count):
74
- global tokenizer, model
75
  try:
76
  if tokenizer is None or model is None:
77
  raise ValueError("Tokenizer or model is not initialized.")
78
 
79
- # Concatenate a final message if max interactions are reached
80
- if interaction_count >= MAX_INTERACTIONS - 1:
81
- user_input += ". Thank you for the questions. That's all for now. Goodbye!"
82
-
83
  messages = history + [{"role": "user", "content": user_input}]
84
 
85
  # Ensure roles alternate correctly
@@ -89,10 +86,6 @@ def interact(user_input, history, interaction_count):
89
 
90
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
91
 
92
- # Check if the maximum number of interactions has been reached
93
- interaction_count += 1
94
- print(f"Interaction count: {interaction_count}") # Print the interaction count
95
-
96
  # Generate response using selected model
97
  input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
98
  chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id) # Increase max_new_tokens
@@ -101,9 +94,16 @@ def interact(user_input, history, interaction_count):
101
  # Update chat history with generated response
102
  history.append({"role": "user", "content": user_input})
103
  history.append({"role": "assistant", "content": response})
104
-
 
 
 
 
 
 
 
105
  formatted_history = [(entry["content"], None) if entry["role"] == "user" else (None, entry["content"]) for entry in history if entry["role"] in ["user", "assistant"]]
106
- return "", formatted_history, history, interaction_count
107
  except Exception as e:
108
  if torch.cuda.is_available():
109
  torch.cuda.empty_cache()
@@ -112,16 +112,13 @@ def interact(user_input, history, interaction_count):
112
 
113
  # Function to send selected story and initial message
114
  def send_selected_story(title, model_name, system_prompt):
115
- global chat_history
116
- global selected_story
117
- global data # Ensure data is reset
118
  data = [] # Reset data for new story
 
119
  tokenizer, model = load_model(model_name)
120
  selected_story = title
121
- story_text = ""
122
  for story in stories:
123
  if story["title"] == title:
124
- story_text = story["story"]
125
  system_prompt = f"""
126
  {system_prompt}
127
  Here is the story:
@@ -136,9 +133,9 @@ Here is the story:
136
 
137
  # Generate the first question based on the story
138
  question_prompt = "Please ask a simple question about the story to encourage interaction."
139
- _, formatted_history, chat_history, interaction_count = interact(question_prompt, chat_history, 0)
140
 
141
- return formatted_history, chat_history, gr.update(value=[]), gr.update(value=story_text), interaction_count # Reset the data table and update the selected story textbox
142
  else:
143
  print("Combined message is empty.")
144
  else:
@@ -178,11 +175,15 @@ def save_comment_score(chat_responses, score, comment, story_name, user_name, sy
178
  ])
179
 
180
  # Append data to Google Sheets
181
- sheet = client.open(google_sheets_name).sheet1 # Assuming results are saved in sheet1
182
- sheet.append_row([timestamp_str, user_name, model_name, system_prompt, story_name, last_user_message, last_assistant_message, score, comment])
 
 
 
 
183
 
184
  df = pd.DataFrame(data, columns=["Timestamp", "User Name", "Model Name", "System Prompt", "Story Name", "User Input", "Chat Response", "Score", "Comment"])
185
- return df, gr.update(value="") # Clear the comment input box
186
 
187
  # Create the chat interface using Gradio Blocks
188
  with gr.Blocks() as demo:
@@ -192,11 +193,11 @@ with gr.Blocks() as demo:
192
  user_dropdown = gr.Dropdown(choices=user_names, label="Select User Name")
193
  initial_story = stories[0]["title"] if stories else None
194
  story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story)
195
- system_prompt_dropdown = gr.Dropdown(choices=prompts, label="Select System Prompt")
196
 
197
- send_story_button = gr.Button("Send Story")
198
 
199
- selected_story_textbox = gr.Textbox(label="Selected Story", interactive=False)
 
200
 
201
  with gr.Row():
202
  with gr.Column(scale=1):
@@ -215,10 +216,9 @@ with gr.Blocks() as demo:
215
  data_table = gr.DataFrame(headers=["User Input", "Chat Response", "Score", "Comment"])
216
 
217
  chat_history_json = gr.JSON(value=[], visible=False)
218
- interaction_count_state = gr.State(0)
219
 
220
- send_story_button.click(fn=send_selected_story, inputs=[story_dropdown, model_dropdown, system_prompt_dropdown], outputs=[chatbot_output, chat_history_json, data_table, selected_story_textbox, interaction_count_state])
221
- send_message_button.click(fn=interact, inputs=[chatbot_input, chat_history_json, interaction_count_state], outputs=[chatbot_input, chatbot_output, chat_history_json, interaction_count_state])
222
  save_button.click(fn=save_comment_score, inputs=[chatbot_output, score_input, comment_input, story_dropdown, user_dropdown, system_prompt_dropdown], outputs=[data_table, comment_input])
223
 
224
  demo.launch()
 
4
  import pandas as pd
5
  from datetime import datetime, timedelta, timezone
6
  import torch
7
+ from config import hugging_face_token, init_google_sheets_client, models, default_model_name, user_names, google_sheets_name, max_interactions
8
  import spaces
9
 
10
  # Hack for ZeroGPU
 
13
  # Initialize Google Sheets client
14
  client = init_google_sheets_client()
15
  sheet = client.open(google_sheets_name)
16
+ stories_sheet = sheet.worksheet("Stories") # Assuming stories are in the second sheet (index 1)
17
+ system_prompts_sheet = sheet.worksheet("System Prompts") # Assuming system prompts are in a separate sheet
18
 
19
  # Load stories from Google Sheets
20
  def load_stories():
 
23
  return stories
24
 
25
  # Load system prompts from Google Sheets
26
+ def load_system_prompts():
27
+ system_prompts_data = system_prompts_sheet.get_all_values()
28
+ system_prompts = [prompt[0] for prompt in system_prompts_data[1:]] # Skip header row
29
+ return system_prompts
30
 
31
+ # Load available stories and system prompts
32
  stories = load_stories()
33
+ system_prompts = load_system_prompts()
34
 
35
  # Initialize the selected model
36
  selected_model = default_model_name
 
65
  # Ensure the initial model is loaded
66
  tokenizer, model = load_model(selected_model)
67
 
68
+ # Chat history and interaction counter
69
  chat_history = []
70
+ interaction_count = 0
71
 
72
  # Function to handle interaction with model
73
  @spaces.GPU
74
+ def interact(user_input, history):
75
+ global tokenizer, model, interaction_count
76
  try:
77
  if tokenizer is None or model is None:
78
  raise ValueError("Tokenizer or model is not initialized.")
79
 
 
 
 
 
80
  messages = history + [{"role": "user", "content": user_input}]
81
 
82
  # Ensure roles alternate correctly
 
86
 
87
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
88
 
 
 
 
 
89
  # Generate response using selected model
90
  input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
91
  chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id) # Increase max_new_tokens
 
94
  # Update chat history with generated response
95
  history.append({"role": "user", "content": user_input})
96
  history.append({"role": "assistant", "content": response})
97
+
98
+ interaction_count += 1
99
+ print(f"Interaction count: {interaction_count}")
100
+
101
+ if interaction_count >= max_interactions:
102
+ response += ". Thank you for the questions. That's all for now. Goodbye!"
103
+ history[-1]["content"] = response
104
+
105
  formatted_history = [(entry["content"], None) if entry["role"] == "user" else (None, entry["content"]) for entry in history if entry["role"] in ["user", "assistant"]]
106
+ return "", formatted_history, history
107
  except Exception as e:
108
  if torch.cuda.is_available():
109
  torch.cuda.empty_cache()
 
112
 
113
  # Function to send selected story and initial message
114
  def send_selected_story(title, model_name, system_prompt):
115
+ global chat_history, selected_story, data, interaction_count
 
 
116
  data = [] # Reset data for new story
117
+ interaction_count = 0 # Reset interaction counter
118
  tokenizer, model = load_model(model_name)
119
  selected_story = title
 
120
  for story in stories:
121
  if story["title"] == title:
 
122
  system_prompt = f"""
123
  {system_prompt}
124
  Here is the story:
 
133
 
134
  # Generate the first question based on the story
135
  question_prompt = "Please ask a simple question about the story to encourage interaction."
136
+ _, formatted_history, chat_history = interact(question_prompt, chat_history)
137
 
138
+ return formatted_history, chat_history, gr.update(value=[]), story["story"] # Reset the data table and return the story
139
  else:
140
  print("Combined message is empty.")
141
  else:
 
175
  ])
176
 
177
  # Append data to Google Sheets
178
+ try:
179
+ user_sheet = client.open(google_sheets_name).worksheet(user_name)
180
+ except gspread.exceptions.WorksheetNotFound:
181
+ user_sheet = client.open(google_sheets_name).add_worksheet(title=user_name, rows="100", cols="20")
182
+
183
+ user_sheet.append_row([timestamp_str, user_name, model_name, system_prompt, story_name, last_user_message, last_assistant_message, score, comment])
184
 
185
  df = pd.DataFrame(data, columns=["Timestamp", "User Name", "Model Name", "System Prompt", "Story Name", "User Input", "Chat Response", "Score", "Comment"])
186
+ return df[["User Input", "Chat Response", "Score", "Comment"]], gr.update(value="") # Show only the required columns and clear the comment input box
187
 
188
  # Create the chat interface using Gradio Blocks
189
  with gr.Blocks() as demo:
 
193
  user_dropdown = gr.Dropdown(choices=user_names, label="Select User Name")
194
  initial_story = stories[0]["title"] if stories else None
195
  story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story)
 
196
 
197
+ system_prompt_dropdown = gr.Dropdown(choices=system_prompts, label="Select System Prompt")
198
 
199
+ send_story_button = gr.Button("Send Story")
200
+ selected_story_textbox = gr.Textbox(label="Selected Story", lines=10, interactive=False)
201
 
202
  with gr.Row():
203
  with gr.Column(scale=1):
 
216
  data_table = gr.DataFrame(headers=["User Input", "Chat Response", "Score", "Comment"])
217
 
218
  chat_history_json = gr.JSON(value=[], visible=False)
 
219
 
220
+ send_story_button.click(fn=send_selected_story, inputs=[story_dropdown, model_dropdown, system_prompt_dropdown], outputs=[chatbot_output, chat_history_json, data_table, selected_story_textbox])
221
+ send_message_button.click(fn=interact, inputs=[chatbot_input, chat_history_json], outputs=[chatbot_input, chatbot_output, chat_history_json])
222
  save_button.click(fn=save_comment_score, inputs=[chatbot_output, score_input, comment_input, story_dropdown, user_dropdown, system_prompt_dropdown], outputs=[data_table, comment_input])
223
 
224
  demo.launch()