rodrisouza commited on
Commit
c8779e3
·
verified ·
1 Parent(s): 7b73598

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -16
app.py CHANGED
@@ -6,7 +6,6 @@ from datetime import datetime, timedelta, timezone
6
  import torch
7
  from config import hugging_face_token, init_google_sheets_client, models, default_model_name, user_names, google_sheets_name, MAX_INTERACTIONS
8
  import spaces
9
- import gspread
10
 
11
  # Hack for ZeroGPU
12
  torch.jit.script = lambda f: f
@@ -66,7 +65,7 @@ def load_model(model_name):
66
  # Ensure the initial model is loaded
67
  tokenizer, model = load_model(selected_model)
68
 
69
- # Chat history and interaction counter
70
  chat_history = []
71
  interaction_count = 0
72
 
@@ -78,12 +77,6 @@ def interact(user_input, history, interaction_count):
78
  if tokenizer is None or model is None:
79
  raise ValueError("Tokenizer or model is not initialized.")
80
 
81
- interaction_count += 1
82
- print(f"Interaction count: {interaction_count}")
83
-
84
- if interaction_count >= MAX_INTERACTIONS:
85
- user_input += ". Thank you for the questions. That's all for now. Goodbye!"
86
-
87
  messages = history + [{"role": "user", "content": user_input}]
88
 
89
  # Ensure roles alternate correctly
@@ -95,11 +88,16 @@ def interact(user_input, history, interaction_count):
95
 
96
  # Generate response using selected model
97
  input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
98
- chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id) # Increase max_new_tokens
99
  response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
100
 
101
  # Update chat history with generated response
102
  history.append({"role": "user", "content": user_input})
 
 
 
 
 
103
  history.append({"role": "assistant", "content": response})
104
 
105
  formatted_history = [(entry["content"], None) if entry["role"] == "user" else (None, entry["content"]) for entry in history if entry["role"] in ["user", "assistant"]]
@@ -112,8 +110,11 @@ def interact(user_input, history, interaction_count):
112
 
113
  # Function to send selected story and initial message
114
  def send_selected_story(title, model_name, system_prompt):
115
- global chat_history, selected_story, data
 
 
116
  data = [] # Reset data for new story
 
117
  tokenizer, model = load_model(model_name)
118
  selected_story = title
119
  for story in stories:
@@ -131,9 +132,10 @@ Here is the story:
131
  chat_history.append({"role": "system", "content": combined_message})
132
 
133
  # Generate the first question based on the story
134
- _, formatted_history, chat_history, interaction_count = interact("Please ask a simple question about the story to encourage interaction.", chat_history, 0)
 
135
 
136
- return formatted_history, chat_history, gr.update(value=[]), story["story"], interaction_count # Reset the data table and return the story
137
  else:
138
  print("Combined message is empty.")
139
  else:
@@ -192,7 +194,7 @@ with gr.Blocks() as demo:
192
  initial_story = stories[0]["title"] if stories else None
193
  story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story)
194
 
195
- system_prompt_dropdown = gr.Dropdown(choices=system_prompts, label="Select System Prompt")
196
 
197
  send_story_button = gr.Button("Send Story")
198
  selected_story_textbox = gr.Textbox(label="Selected Story", lines=10, interactive=False)
@@ -214,10 +216,9 @@ with gr.Blocks() as demo:
214
  data_table = gr.DataFrame(headers=["User Input", "Chat Response", "Score", "Comment"])
215
 
216
  chat_history_json = gr.JSON(value=[], visible=False)
217
- interaction_count_state = gr.State(0)
218
 
219
- send_story_button.click(fn=send_selected_story, inputs=[story_dropdown, model_dropdown, system_prompt_dropdown], outputs=[chatbot_output, chat_history_json, data_table, selected_story_textbox, interaction_count_state])
220
- send_message_button.click(fn=interact, inputs=[chatbot_input, chat_history_json, interaction_count_state], outputs=[chatbot_input, chatbot_output, chat_history_json, interaction_count_state])
221
  save_button.click(fn=save_comment_score, inputs=[chatbot_output, score_input, comment_input, story_dropdown, user_dropdown, system_prompt_dropdown], outputs=[data_table, comment_input])
222
 
223
  demo.launch()
 
6
  import torch
7
  from config import hugging_face_token, init_google_sheets_client, models, default_model_name, user_names, google_sheets_name, MAX_INTERACTIONS
8
  import spaces
 
9
 
10
  # Hack for ZeroGPU
11
  torch.jit.script = lambda f: f
 
65
  # Ensure the initial model is loaded
66
  tokenizer, model = load_model(selected_model)
67
 
68
+ # Chat history and interaction count
69
  chat_history = []
70
  interaction_count = 0
71
 
 
77
  if tokenizer is None or model is None:
78
  raise ValueError("Tokenizer or model is not initialized.")
79
 
 
 
 
 
 
 
80
  messages = history + [{"role": "user", "content": user_input}]
81
 
82
  # Ensure roles alternate correctly
 
88
 
89
  # Generate response using selected model
90
  input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
91
+ chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
92
  response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
93
 
94
  # Update chat history with generated response
95
  history.append({"role": "user", "content": user_input})
96
+
97
+ # Check if it's the last interaction
98
+ interaction_count += 1
99
+ if interaction_count >= MAX_INTERACTIONS:
100
+ response += ". Thank you for the questions. That's all for now. Goodbye!"
101
  history.append({"role": "assistant", "content": response})
102
 
103
  formatted_history = [(entry["content"], None) if entry["role"] == "user" else (None, entry["content"]) for entry in history if entry["role"] in ["user", "assistant"]]
 
110
 
111
  # Function to send selected story and initial message
112
  def send_selected_story(title, model_name, system_prompt):
113
+ global chat_history, interaction_count
114
+ global selected_story
115
+ global data # Ensure data is reset
116
  data = [] # Reset data for new story
117
+ interaction_count = 0 # Reset interaction count
118
  tokenizer, model = load_model(model_name)
119
  selected_story = title
120
  for story in stories:
 
132
  chat_history.append({"role": "system", "content": combined_message})
133
 
134
  # Generate the first question based on the story
135
+ question_prompt = "Please ask a simple question about the story to encourage interaction."
136
+ _, formatted_history, chat_history, interaction_count = interact(question_prompt, chat_history, interaction_count)
137
 
138
+ return formatted_history, chat_history, gr.update(value=[]), story["story"] # Reset the data table and return the story
139
  else:
140
  print("Combined message is empty.")
141
  else:
 
194
  initial_story = stories[0]["title"] if stories else None
195
  story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story)
196
 
197
+ system_prompt_dropdown = gr.Dropdown(choices=system_prompts, label="Select System Prompt", value=system_prompts[0])
198
 
199
  send_story_button = gr.Button("Send Story")
200
  selected_story_textbox = gr.Textbox(label="Selected Story", lines=10, interactive=False)
 
216
  data_table = gr.DataFrame(headers=["User Input", "Chat Response", "Score", "Comment"])
217
 
218
  chat_history_json = gr.JSON(value=[], visible=False)
 
219
 
220
+ send_story_button.click(fn=send_selected_story, inputs=[story_dropdown, model_dropdown, system_prompt_dropdown], outputs=[chatbot_output, chat_history_json, data_table, selected_story_textbox])
221
+ send_message_button.click(fn=interact, inputs=[chatbot_input, chat_history_json, gr.State(interaction_count)], outputs=[chatbot_input, chatbot_output, chat_history_json, gr.State(interaction_count)])
222
  save_button.click(fn=save_comment_score, inputs=[chatbot_output, score_input, comment_input, story_dropdown, user_dropdown, system_prompt_dropdown], outputs=[data_table, comment_input])
223
 
224
  demo.launch()