rodrisouza commited on
Commit
1b6eb0b
·
verified ·
1 Parent(s): 4ef26a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -26
app.py CHANGED
@@ -4,7 +4,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import pandas as pd
5
  from datetime import datetime, timedelta, timezone
6
  import torch
7
- from config import hugging_face_token, init_google_sheets_client, models, default_model_name, user_names, google_sheets_name
8
  import spaces
9
 
10
  # Hack for ZeroGPU
@@ -13,8 +13,8 @@ torch.jit.script = lambda f: f
13
  # Initialize Google Sheets client
14
  client = init_google_sheets_client()
15
  sheet = client.open(google_sheets_name)
16
- stories_sheet = sheet.worksheet("Stories") # Assuming stories are in the second sheet (index 1)
17
- system_prompts_sheet = sheet.worksheet("System Prompts") # Assuming system prompts are in a separate sheet
18
 
19
  # Load stories from Google Sheets
20
  def load_stories():
@@ -23,14 +23,14 @@ def load_stories():
23
  return stories
24
 
25
  # Load system prompts from Google Sheets
26
- def load_system_prompts():
27
- system_prompts_data = system_prompts_sheet.get_all_values()
28
- system_prompts = [prompt[0] for prompt in system_prompts_data[1:]] # Skip header row
29
- return system_prompts
30
 
31
- # Load available stories and system prompts
32
  stories = load_stories()
33
- system_prompts = load_system_prompts()
34
 
35
  # Initialize the selected model
36
  selected_model = default_model_name
@@ -65,17 +65,28 @@ def load_model(model_name):
65
  # Ensure the initial model is loaded
66
  tokenizer, model = load_model(selected_model)
67
 
68
- # Chat history
69
  chat_history = []
 
70
 
71
  # Function to handle interaction with model
72
  @spaces.GPU
73
  def interact(user_input, history):
74
- global tokenizer, model
75
  try:
76
  if tokenizer is None or model is None:
77
  raise ValueError("Tokenizer or model is not initialized.")
78
 
 
 
 
 
 
 
 
 
 
 
79
  messages = history + [{"role": "user", "content": user_input}]
80
 
81
  # Ensure roles alternate correctly
@@ -97,17 +108,16 @@ def interact(user_input, history):
97
  formatted_history = [(entry["content"], None) if entry["role"] == "user" else (None, entry["content"]) for entry in history if entry["role"] in ["user", "assistant"]]
98
  return "", formatted_history, history
99
  except Exception as e:
100
- if torch.cuda.is_available():
101
  torch.cuda.empty_cache()
102
  print(f"Error during interaction: {e}")
103
  raise gr.Error(f"An error occurred during interaction: {str(e)}")
104
 
105
  # Function to send selected story and initial message
106
  def send_selected_story(title, model_name, system_prompt):
107
- global chat_history
108
- global selected_story
109
- global data # Ensure data is reset
110
  data = [] # Reset data for new story
 
111
  tokenizer, model = load_model(model_name)
112
  selected_story = title
113
  for story in stories:
@@ -128,7 +138,7 @@ Here is the story:
128
  question_prompt = "Please ask a simple question about the story to encourage interaction."
129
  _, formatted_history, chat_history = interact(question_prompt, chat_history)
130
 
131
- return formatted_history, chat_history, gr.update(value=[]), story["story"] # Reset the data table and return the story
132
  else:
133
  print("Combined message is empty.")
134
  else:
@@ -168,15 +178,11 @@ def save_comment_score(chat_responses, score, comment, story_name, user_name, sy
168
  ])
169
 
170
  # Append data to Google Sheets
171
- try:
172
- user_sheet = client.open(google_sheets_name).worksheet(user_name)
173
- except gspread.exceptions.WorksheetNotFound:
174
- user_sheet = client.open(google_sheets_name).add_worksheet(title=user_name, rows="100", cols="20")
175
-
176
- user_sheet.append_row([timestamp_str, user_name, model_name, system_prompt, story_name, last_user_message, last_assistant_message, score, comment])
177
 
178
  df = pd.DataFrame(data, columns=["Timestamp", "User Name", "Model Name", "System Prompt", "Story Name", "User Input", "Chat Response", "Score", "Comment"])
179
- return df[["User Input", "Chat Response", "Score", "Comment"]], gr.update(value="") # Show only the required columns and clear the comment input box
180
 
181
  # Create the chat interface using Gradio Blocks
182
  with gr.Blocks() as demo:
@@ -186,11 +192,11 @@ with gr.Blocks() as demo:
186
  user_dropdown = gr.Dropdown(choices=user_names, label="Select User Name")
187
  initial_story = stories[0]["title"] if stories else None
188
  story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story)
189
-
190
- system_prompt_dropdown = gr.Dropdown(choices=system_prompts, label="Select System Prompt")
191
 
192
  send_story_button = gr.Button("Send Story")
193
- selected_story_textbox = gr.Textbox(label="Selected Story", lines=10, interactive=False)
 
194
 
195
  with gr.Row():
196
  with gr.Column(scale=1):
 
4
  import pandas as pd
5
  from datetime import datetime, timedelta, timezone
6
  import torch
7
+ from config import hugging_face_token, init_google_sheets_client, models, default_model_name, user_names, google_sheets_name, MAX_INTERACTIONS
8
  import spaces
9
 
10
  # Hack for ZeroGPU
 
13
  # Initialize Google Sheets client
14
  client = init_google_sheets_client()
15
  sheet = client.open(google_sheets_name)
16
+ stories_sheet = sheet.worksheet("Stories") # Assuming stories are in a separate sheet
17
+ prompts_sheet = sheet.worksheet("System Prompts") # Assuming system prompts are in a separate sheet
18
 
19
  # Load stories from Google Sheets
20
  def load_stories():
 
23
  return stories
24
 
25
  # Load system prompts from Google Sheets
26
+ def load_prompts():
27
+ prompts_data = prompts_sheet.get_all_values()
28
+ prompts = [prompt[0] for prompt in prompts_data if prompt[0] != "System Prompts"] # Skip header row
29
+ return prompts
30
 
31
+ # Load available stories and prompts
32
  stories = load_stories()
33
+ prompts = load_prompts()
34
 
35
  # Initialize the selected model
36
  selected_model = default_model_name
 
65
  # Ensure the initial model is loaded
66
  tokenizer, model = load_model(selected_model)
67
 
68
+ # Chat history and interaction count
69
  chat_history = []
70
+ interaction_count = 0
71
 
72
  # Function to handle interaction with model
73
  @spaces.GPU
74
  def interact(user_input, history):
75
+ global tokenizer, model, interaction_count
76
  try:
77
  if tokenizer is None or model is None:
78
  raise ValueError("Tokenizer or model is not initialized.")
79
 
80
+ # Increment interaction count
81
+ interaction_count += 1
82
+
83
+ # Check if the maximum number of interactions has been reached
84
+ if interaction_count > MAX_INTERACTIONS:
85
+ farewell_message = "Thank you for the conversation! Have a great day!"
86
+ history.append({"role": "assistant", "content": farewell_message})
87
+ formatted_history = [(entry["content"], None) if entry["role"] == "user" else (None, entry["content"]) for entry in history if entry["role"] in ["user", "assistant"]]
88
+ return "", formatted_history, history
89
+
90
  messages = history + [{"role": "user", "content": user_input}]
91
 
92
  # Ensure roles alternate correctly
 
108
  formatted_history = [(entry["content"], None) if entry["role"] == "user" else (None, entry["content"]) for entry in history if entry["role"] in ["user", "assistant"]]
109
  return "", formatted_history, history
110
  except Exception as e:
111
+ if torch.cuda.is available():
112
  torch.cuda.empty_cache()
113
  print(f"Error during interaction: {e}")
114
  raise gr.Error(f"An error occurred during interaction: {str(e)}")
115
 
116
  # Function to send selected story and initial message
117
  def send_selected_story(title, model_name, system_prompt):
118
+ global chat_history, selected_story, data, interaction_count
 
 
119
  data = [] # Reset data for new story
120
+ interaction_count = 0 # Reset interaction count
121
  tokenizer, model = load_model(model_name)
122
  selected_story = title
123
  for story in stories:
 
138
  question_prompt = "Please ask a simple question about the story to encourage interaction."
139
  _, formatted_history, chat_history = interact(question_prompt, chat_history)
140
 
141
+ return formatted_history, chat_history, gr.update(value=[]), gr.update(value=selected_story) # Reset the data table and update the selected story textbox
142
  else:
143
  print("Combined message is empty.")
144
  else:
 
178
  ])
179
 
180
  # Append data to Google Sheets
181
+ sheet = client.open(google_sheets_name).sheet1 # Assuming results are saved in sheet1
182
+ sheet.append_row([timestamp_str, user_name, model_name, system_prompt, story_name, last_user_message, last_assistant_message, score, comment])
 
 
 
 
183
 
184
  df = pd.DataFrame(data, columns=["Timestamp", "User Name", "Model Name", "System Prompt", "Story Name", "User Input", "Chat Response", "Score", "Comment"])
185
+ return df, gr.update(value="") # Clear the comment input box
186
 
187
  # Create the chat interface using Gradio Blocks
188
  with gr.Blocks() as demo:
 
192
  user_dropdown = gr.Dropdown(choices=user_names, label="Select User Name")
193
  initial_story = stories[0]["title"] if stories else None
194
  story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story)
195
+ system_prompt_dropdown = gr.Dropdown(choices=prompts, label="Select System Prompt")
 
196
 
197
  send_story_button = gr.Button("Send Story")
198
+
199
+ selected_story_textbox = gr.Textbox(label="Selected Story", interactive=False)
200
 
201
  with gr.Row():
202
  with gr.Column(scale=1):