Spaces:
Configuration error
Configuration error
import os | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import pandas as pd | |
from datetime import datetime, timedelta, timezone | |
import torch | |
from config import hugging_face_token, init_google_sheets_client, models, default_model_name, user_names, google_sheets_name, MAX_INTERACTIONS | |
import spaces | |
# Hack for ZeroGPU | |
torch.jit.script = lambda f: f | |
# Initialize Google Sheets client | |
client = init_google_sheets_client() | |
sheet = client.open(google_sheets_name) | |
stories_sheet = sheet.worksheet("Stories") # Assuming stories are in a separate sheet | |
prompts_sheet = sheet.worksheet("System Prompts") # Assuming system prompts are in a separate sheet | |
# Load stories from Google Sheets | |
def load_stories(): | |
stories_data = stories_sheet.get_all_values() | |
stories = [{"title": story[0], "story": story[1]} for story in stories_data if story[0] != "Title"] # Skip header row | |
return stories | |
# Load system prompts from Google Sheets | |
def load_prompts(): | |
prompts_data = prompts_sheet.get_all_values() | |
prompts = [prompt[0] for prompt in prompts_data if prompt[0] != "System Prompts"] # Skip header row | |
return prompts | |
# Load available stories and prompts | |
stories = load_stories() | |
prompts = load_prompts() | |
# Initialize the selected model | |
selected_model = default_model_name | |
tokenizer, model = None, None | |
# Initialize the data list | |
data = [] | |
# Load the model and tokenizer once at the beginning | |
def load_model(model_name): | |
global tokenizer, model, selected_model | |
try: | |
# Release the memory of the previous model if exists | |
if model is not None: | |
del model | |
torch.cuda.empty_cache() | |
tokenizer = AutoTokenizer.from_pretrained(models[model_name], padding_side='left', token=hugging_face_token, trust_remote_code=True) | |
# Ensure the padding token is set | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token}) | |
model = AutoModelForCausalLM.from_pretrained(models[model_name], token=hugging_face_token, trust_remote_code=True).to("cuda") | |
selected_model = model_name | |
except Exception as e: | |
print(f"Error loading model {model_name}: {e}") | |
raise e | |
return tokenizer, model | |
# Ensure the initial model is loaded | |
tokenizer, model = load_model(selected_model) | |
# Chat history and interaction count | |
chat_history = [] | |
interaction_count = 0 | |
# Function to handle interaction with model | |
def interact(user_input, history): | |
global tokenizer, model, interaction_count | |
try: | |
if tokenizer is None or model is None: | |
raise ValueError("Tokenizer or model is not initialized.") | |
# Increment interaction count | |
interaction_count += 1 | |
# Check if the maximum number of interactions has been reached | |
if interaction_count > MAX_INTERACTIONS: | |
farewell_message = "Thank you for the conversation! Have a great day!" | |
history.append({"role": "assistant", "content": farewell_message}) | |
formatted_history = [(entry["content"], None) if entry["role"] == "user" else (None, entry["content"]) for entry in history if entry["role"] in ["user", "assistant"]] | |
return "", formatted_history, history | |
messages = history + [{"role": "user", "content": user_input}] | |
# Ensure roles alternate correctly | |
for i in range(1, len(messages)): | |
if messages[i-1].get("role") == messages[i].get("role"): | |
raise ValueError("Conversation roles must alternate user/assistant/user/assistant/...") | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
# Generate response using selected model | |
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda") | |
chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id) # Increase max_new_tokens | |
response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) | |
# Update chat history with generated response | |
history.append({"role": "user", "content": user_input}) | |
history.append({"role": "assistant", "content": response}) | |
formatted_history = [(entry["content"], None) if entry["role"] == "user" else (None, entry["content"]) for entry in history if entry["role"] in ["user", "assistant"]] | |
return "", formatted_history, history | |
except Exception as e: | |
if torch.cuda.is available(): | |
torch.cuda.empty_cache() | |
print(f"Error during interaction: {e}") | |
raise gr.Error(f"An error occurred during interaction: {str(e)}") | |
# Function to send selected story and initial message | |
def send_selected_story(title, model_name, system_prompt): | |
global chat_history, selected_story, data, interaction_count | |
data = [] # Reset data for new story | |
interaction_count = 0 # Reset interaction count | |
tokenizer, model = load_model(model_name) | |
selected_story = title | |
for story in stories: | |
if story["title"] == title: | |
system_prompt = f""" | |
{system_prompt} | |
Here is the story: | |
--- | |
{story['story']} | |
--- | |
""" | |
combined_message = system_prompt.strip() | |
if combined_message: | |
chat_history = [] # Reset chat history | |
chat_history.append({"role": "system", "content": combined_message}) | |
# Generate the first question based on the story | |
question_prompt = "Please ask a simple question about the story to encourage interaction." | |
_, formatted_history, chat_history = interact(question_prompt, chat_history) | |
return formatted_history, chat_history, gr.update(value=[]), gr.update(value=selected_story) # Reset the data table and update the selected story textbox | |
else: | |
print("Combined message is empty.") | |
else: | |
print("Story title does not match.") | |
# Function to save comment and score | |
def save_comment_score(chat_responses, score, comment, story_name, user_name, system_prompt): | |
last_user_message = "" | |
last_assistant_message = "" | |
# Find the last user and assistant messages | |
for message in reversed(chat_responses): | |
if isinstance(message, list) and len(message) == 2: | |
if message[0] and not last_user_message: | |
last_user_message = message[0] | |
elif message[1] and not last_assistant_message: | |
last_assistant_message = message[1] | |
if last_user_message and last_assistant_message: | |
break | |
timestamp = datetime.now(timezone.utc) - timedelta(hours=3) # Adjust to GMT-3 | |
timestamp_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") | |
model_name = selected_model | |
# Append data to local data storage | |
data.append([ | |
timestamp_str, | |
user_name, | |
model_name, | |
system_prompt, | |
story_name, | |
last_user_message, | |
last_assistant_message, | |
score, | |
comment | |
]) | |
# Append data to Google Sheets | |
sheet = client.open(google_sheets_name).sheet1 # Assuming results are saved in sheet1 | |
sheet.append_row([timestamp_str, user_name, model_name, system_prompt, story_name, last_user_message, last_assistant_message, score, comment]) | |
df = pd.DataFrame(data, columns=["Timestamp", "User Name", "Model Name", "System Prompt", "Story Name", "User Input", "Chat Response", "Score", "Comment"]) | |
return df, gr.update(value="") # Clear the comment input box | |
# Create the chat interface using Gradio Blocks | |
with gr.Blocks() as demo: | |
gr.Markdown("# Chat with Model") | |
model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model", value=selected_model) | |
user_dropdown = gr.Dropdown(choices=user_names, label="Select User Name") | |
initial_story = stories[0]["title"] if stories else None | |
story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story) | |
system_prompt_dropdown = gr.Dropdown(choices=prompts, label="Select System Prompt") | |
send_story_button = gr.Button("Send Story") | |
selected_story_textbox = gr.Textbox(label="Selected Story", interactive=False) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
chatbot_input = gr.Textbox(placeholder="Type your message here...", label="User Input") | |
send_message_button = gr.Button("Send") | |
with gr.Column(scale=2): | |
chatbot_output = gr.Chatbot(label="Chat History") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
score_input = gr.Slider(minimum=0, maximum=5, step=1, label="Score") | |
comment_input = gr.Textbox(placeholder="Add a comment...", label="Comment") | |
save_button = gr.Button("Save Score and Comment") | |
data_table = gr.DataFrame(headers=["User Input", "Chat Response", "Score", "Comment"]) | |
chat_history_json = gr.JSON(value=[], visible=False) | |
send_story_button.click(fn=send_selected_story, inputs=[story_dropdown, model_dropdown, system_prompt_dropdown], outputs=[chatbot_output, chat_history_json, data_table, selected_story_textbox]) | |
send_message_button.click(fn=interact, inputs=[chatbot_input, chat_history_json], outputs=[chatbot_input, chatbot_output, chat_history_json]) | |
save_button.click(fn=save_comment_score, inputs=[chatbot_output, score_input, comment_input, story_dropdown, user_dropdown, system_prompt_dropdown], outputs=[data_table, comment_input]) | |
demo.launch() | |