demo-chatbot-v3 / app.py
rodrisouza's picture
Update app.py
c8779e3 verified
raw
history blame
9.81 kB
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from datetime import datetime, timedelta, timezone
import torch
from config import hugging_face_token, init_google_sheets_client, models, default_model_name, user_names, google_sheets_name, MAX_INTERACTIONS
import spaces
# Hack for ZeroGPU
torch.jit.script = lambda f: f
# Initialize Google Sheets client
client = init_google_sheets_client()
sheet = client.open(google_sheets_name)
stories_sheet = sheet.worksheet("Stories") # Assuming stories are in the second sheet (index 1)
system_prompts_sheet = sheet.worksheet("System Prompts") # Assuming system prompts are in a separate sheet
# Load stories from Google Sheets
def load_stories():
stories_data = stories_sheet.get_all_values()
stories = [{"title": story[0], "story": story[1]} for story in stories_data if story[0] != "Title"] # Skip header row
return stories
# Load system prompts from Google Sheets
def load_system_prompts():
system_prompts_data = system_prompts_sheet.get_all_values()
system_prompts = [prompt[0] for prompt in system_prompts_data[1:]] # Skip header row
return system_prompts
# Load available stories and system prompts
stories = load_stories()
system_prompts = load_system_prompts()
# Initialize the selected model
selected_model = default_model_name
tokenizer, model = None, None
# Initialize the data list
data = []
# Load the model and tokenizer once at the beginning
def load_model(model_name):
global tokenizer, model, selected_model
try:
# Release the memory of the previous model if exists
if model is not None:
del model
torch.cuda.empty_cache()
tokenizer = AutoTokenizer.from_pretrained(models[model_name], padding_side='left', token=hugging_face_token, trust_remote_code=True)
# Ensure the padding token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
model = AutoModelForCausalLM.from_pretrained(models[model_name], token=hugging_face_token, trust_remote_code=True).to("cuda")
selected_model = model_name
except Exception as e:
print(f"Error loading model {model_name}: {e}")
raise e
return tokenizer, model
# Ensure the initial model is loaded
tokenizer, model = load_model(selected_model)
# Chat history and interaction count
chat_history = []
interaction_count = 0
# Function to handle interaction with model
@spaces.GPU
def interact(user_input, history, interaction_count):
global tokenizer, model
try:
if tokenizer is None or model is None:
raise ValueError("Tokenizer or model is not initialized.")
messages = history + [{"role": "user", "content": user_input}]
# Ensure roles alternate correctly
for i in range(1, len(messages)):
if messages[i-1].get("role") == messages[i].get("role"):
raise ValueError("Conversation roles must alternate user/assistant/user/assistant/...")
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Generate response using selected model
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
# Update chat history with generated response
history.append({"role": "user", "content": user_input})
# Check if it's the last interaction
interaction_count += 1
if interaction_count >= MAX_INTERACTIONS:
response += ". Thank you for the questions. That's all for now. Goodbye!"
history.append({"role": "assistant", "content": response})
formatted_history = [(entry["content"], None) if entry["role"] == "user" else (None, entry["content"]) for entry in history if entry["role"] in ["user", "assistant"]]
return "", formatted_history, history, interaction_count
except Exception as e:
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"Error during interaction: {e}")
raise gr.Error(f"An error occurred during interaction: {str(e)}")
# Function to send selected story and initial message
def send_selected_story(title, model_name, system_prompt):
global chat_history, interaction_count
global selected_story
global data # Ensure data is reset
data = [] # Reset data for new story
interaction_count = 0 # Reset interaction count
tokenizer, model = load_model(model_name)
selected_story = title
for story in stories:
if story["title"] == title:
system_prompt = f"""
{system_prompt}
Here is the story:
---
{story['story']}
---
"""
combined_message = system_prompt.strip()
if combined_message:
chat_history = [] # Reset chat history
chat_history.append({"role": "system", "content": combined_message})
# Generate the first question based on the story
question_prompt = "Please ask a simple question about the story to encourage interaction."
_, formatted_history, chat_history, interaction_count = interact(question_prompt, chat_history, interaction_count)
return formatted_history, chat_history, gr.update(value=[]), story["story"] # Reset the data table and return the story
else:
print("Combined message is empty.")
else:
print("Story title does not match.")
# Function to save comment and score
def save_comment_score(chat_responses, score, comment, story_name, user_name, system_prompt):
last_user_message = ""
last_assistant_message = ""
# Find the last user and assistant messages
for message in reversed(chat_responses):
if isinstance(message, list) and len(message) == 2:
if message[0] and not last_user_message:
last_user_message = message[0]
elif message[1] and not last_assistant_message:
last_assistant_message = message[1]
if last_user_message and last_assistant_message:
break
timestamp = datetime.now(timezone.utc) - timedelta(hours=3) # Adjust to GMT-3
timestamp_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
model_name = selected_model
# Append data to local data storage
data.append([
timestamp_str,
user_name,
model_name,
system_prompt,
story_name,
last_user_message,
last_assistant_message,
score,
comment
])
# Append data to Google Sheets
try:
user_sheet = client.open(google_sheets_name).worksheet(user_name)
except gspread.exceptions.WorksheetNotFound:
user_sheet = client.open(google_sheets_name).add_worksheet(title=user_name, rows="100", cols="20")
user_sheet.append_row([timestamp_str, user_name, model_name, system_prompt, story_name, last_user_message, last_assistant_message, score, comment])
df = pd.DataFrame(data, columns=["Timestamp", "User Name", "Model Name", "System Prompt", "Story Name", "User Input", "Chat Response", "Score", "Comment"])
return df[["User Input", "Chat Response", "Score", "Comment"]], gr.update(value="") # Show only the required columns and clear the comment input box
# Create the chat interface using Gradio Blocks
with gr.Blocks() as demo:
gr.Markdown("# Chat with Model")
model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model", value=selected_model)
user_dropdown = gr.Dropdown(choices=user_names, label="Select User Name")
initial_story = stories[0]["title"] if stories else None
story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story)
system_prompt_dropdown = gr.Dropdown(choices=system_prompts, label="Select System Prompt", value=system_prompts[0])
send_story_button = gr.Button("Send Story")
selected_story_textbox = gr.Textbox(label="Selected Story", lines=10, interactive=False)
with gr.Row():
with gr.Column(scale=1):
chatbot_input = gr.Textbox(placeholder="Type your message here...", label="User Input")
send_message_button = gr.Button("Send")
with gr.Column(scale=2):
chatbot_output = gr.Chatbot(label="Chat History")
with gr.Row():
with gr.Column(scale=1):
score_input = gr.Slider(minimum=0, maximum=5, step=1, label="Score")
comment_input = gr.Textbox(placeholder="Add a comment...", label="Comment")
save_button = gr.Button("Save Score and Comment")
data_table = gr.DataFrame(headers=["User Input", "Chat Response", "Score", "Comment"])
chat_history_json = gr.JSON(value=[], visible=False)
send_story_button.click(fn=send_selected_story, inputs=[story_dropdown, model_dropdown, system_prompt_dropdown], outputs=[chatbot_output, chat_history_json, data_table, selected_story_textbox])
send_message_button.click(fn=interact, inputs=[chatbot_input, chat_history_json, gr.State(interaction_count)], outputs=[chatbot_input, chatbot_output, chat_history_json, gr.State(interaction_count)])
save_button.click(fn=save_comment_score, inputs=[chatbot_output, score_input, comment_input, story_dropdown, user_dropdown, system_prompt_dropdown], outputs=[data_table, comment_input])
demo.launch()