Spaces:
Runtime error
Runtime error
File size: 10,917 Bytes
0712e23 496ed17 0712e23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 |
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from datetime import datetime, timedelta, timezone
import torch
from config import hugging_face_token, init_google_sheets_client, models, quantized_models, default_model_name, user_names, google_sheets_name, MAX_INTERACTIONS
import spaces
# Hack for ZeroGPU
torch.jit.script = lambda f: f
# Initialize Google Sheets client
client = init_google_sheets_client()
sheet = client.open(google_sheets_name)
stories_sheet = sheet.worksheet("Stories")
system_prompts_sheet = sheet.worksheet("System Prompts")
# Load stories from Google Sheets
def load_stories():
stories_data = stories_sheet.get_all_values()
stories = [{"title": story[0], "story": story[1]} for story in stories_data if story[0] != "Title"] # Skip header row
return stories
# Load system prompts from Google Sheets
def load_system_prompts():
system_prompts_data = system_prompts_sheet.get_all_values()
system_prompts = [prompt[0] for prompt in system_prompts_data[1:]] # Skip header row
return system_prompts
# Load available stories and system prompts
stories = load_stories()
system_prompts = load_system_prompts()
# Initialize the selected model
selected_model = default_model_name
tokenizer, model = None, None
# Initialize the data list
data = []
# Load the model and tokenizer once at the beginning
def load_model(model_name):
global tokenizer, model, selected_model
try:
# Release the memory of the previous model if exists
if model is not None:
del model
torch.cuda.empty_cache()
# Check if the model is in models or quantized_models and load accordingly
if model_name in models:
model_path = models[model_name]
elif model_name in quantized_models:
model_path = quantized_models[model_name]
else:
raise ValueError(f"Model {model_name} not found in either models or quantized_models.")
tokenizer = AutoTokenizer.from_pretrained(
model_path,
padding_side='left',
token=hugging_face_token,
trust_remote_code=True
)
# Ensure the padding token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
model = AutoModelForCausalLM.from_pretrained(
model_path,
token=hugging_face_token,
trust_remote_code=True
)
# Only move to CUDA if it's not a quantized model
if model_name not in quantized_models:
model = model.to("cuda")
selected_model = model_name
except Exception as e:
print(f"Error loading model {model_name}: {e}")
raise e
return tokenizer, model
# Ensure the initial model is loaded
tokenizer, model = load_model(selected_model)
# Chat history
chat_history = []
# Function to handle interaction with model
@spaces.GPU
def interact(user_input, history, interaction_count, model_name):
global tokenizer, model
try:
if tokenizer is None or model is None:
raise ValueError("Tokenizer or model is not initialized.")
# Determine the device to use (either CUDA if available, or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Only move the model to the device if it's not a quantized model
if model_name not in quantized_models:
model = model.to(device)
if interaction_count >= MAX_INTERACTIONS:
user_input += ". Thank you for your questions. Our session is now over. Goodbye!"
messages = history + [{"role": "user", "content": user_input}]
# Ensure roles alternate correctly
for i in range(1, len(messages)):
if messages[i-1].get("role") == messages[i].get("role"):
raise ValueError("Conversation roles must alternate user/assistant/user/assistant/...")
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Move input tensor to the correct device
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, temperature=0.1)
response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
# Update chat history with generated response
history.append({"role": "user", "content": user_input})
history.append({"role": "assistant", "content": response})
interaction_count += 1
formatted_history = [(entry["content"], None) if entry["role"] == "user" else (None, entry["content"]) for entry in history if entry["role"] in ["user", "assistant"]]
return "", formatted_history, history, interaction_count
except Exception as e:
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"Error during interaction: {e}")
raise gr.Error(f"An error occurred during interaction: {str(e)}")
# Function to send selected story and initial message
def send_selected_story(title, model_name, system_prompt):
global chat_history
global selected_story
global data # Ensure data is reset
data = [] # Reset data for new story
interaction_count = 1 # Reset interaction count for new story
tokenizer, model = load_model(model_name) # Load the appropriate model
selected_story = title
for story in stories:
if story["title"] == title:
system_prompt = f"""
{system_prompt}
Here is the story:
---
{story['story']}
---
"""
combined_message = system_prompt.strip()
if combined_message:
chat_history = [] # Reset chat history
chat_history.append({"role": "system", "content": combined_message})
question_prompt = "Please ask a simple question about the story to encourage interaction."
_, formatted_history, chat_history, interaction_count = interact(question_prompt, chat_history, interaction_count, model_name)
return formatted_history, chat_history, gr.update(value=[]), story["story"]
else:
print("Combined message is empty.")
else:
print("Story title does not match.")
# Function to save comment and score
def save_comment_score(chat_responses, score, comment, story_name, user_name, system_prompt):
full_chat_history = ""
# Create formatted chat history with roles
for message in chat_responses:
if message[0]: # User message
full_chat_history += f"User: {message[0]}\n"
if message[1]: # Assistant message
full_chat_history += f"Assistant: {message[1]}\n"
timestamp = datetime.now(timezone.utc) - timedelta(hours=3) # Adjust to GMT-3
timestamp_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
model_name = selected_model
# Append data to local data storage
data.append([
timestamp_str,
user_name,
model_name,
system_prompt,
story_name,
full_chat_history,
score,
comment
])
# Append data to Google Sheets
try:
user_sheet = client.open(google_sheets_name).worksheet(user_name)
except gspread.exceptions.WorksheetNotFound:
user_sheet = client.open(google_sheets_name).add_worksheet(title=user_name, rows="100", cols="20")
user_sheet.append_row([timestamp_str, user_name, model_name, system_prompt, story_name, full_chat_history, score, comment])
df = pd.DataFrame(data, columns=["Timestamp", "User Name", "Model Name", "System Prompt", "Story Name", "Chat History", "Score", "Comment"])
return df[["Chat History", "Score", "Comment"]], gr.update(value="") # Show only the required columns and clear the comment input box
# Function to load user guide from a file
def load_user_guide():
with open('user_guide.txt', 'r') as file:
return file.read()
# Combine both model dictionaries
all_models = {**models, **quantized_models}
# Create the chat interface using Gradio Blocks
with gr.Blocks() as demo:
with gr.Tabs():
with gr.TabItem("Chat"):
gr.Markdown("# Demo Chatbot V3")
gr.Markdown("## Context")
with gr.Group():
model_dropdown = gr.Dropdown(choices=list(all_models.keys()), label="Select Model", value=default_model_name)
user_dropdown = gr.Dropdown(choices=user_names, label="Select User Name")
initial_story = stories[0]["title"] if stories else None
story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story)
system_prompt_dropdown = gr.Dropdown(choices=system_prompts, label="Select System Prompt", value=system_prompts[0])
send_story_button = gr.Button("Send Story")
gr.Markdown("## Chat")
with gr.Group():
selected_story_textbox = gr.Textbox(label="Selected Story", lines=10, interactive=False)
chatbot_output = gr.Chatbot(label="Chat History")
chatbot_input = gr.Textbox(placeholder="Type your message here...", label="User Input")
send_message_button = gr.Button("Send")
gr.Markdown("## Evaluation")
with gr.Group():
score_input = gr.Slider(minimum=0, maximum=5, step=1, label="Score")
comment_input = gr.Textbox(placeholder="Add a comment...", label="Comment")
save_button = gr.Button("Save Score and Comment")
data_table = gr.DataFrame(headers=["Chat History", "Score", "Comment"])
with gr.TabItem("User Guide"):
gr.Textbox(label="User Guide", value=load_user_guide(), lines=20)
chat_history_json = gr.JSON(value=[], visible=False)
interaction_count = gr.Number(value=0, visible=False)
send_story_button.click(fn=send_selected_story, inputs=[story_dropdown, model_dropdown, system_prompt_dropdown], outputs=[chatbot_output, chat_history_json, data_table, selected_story_textbox])
send_message_button.click(fn=interact, inputs=[chatbot_input, chat_history_json, interaction_count, model_dropdown], outputs=[chatbot_input, chatbot_output, chat_history_json, interaction_count])
save_button.click(fn=save_comment_score, inputs=[chatbot_output, score_input, comment_input, story_dropdown, user_dropdown, system_prompt_dropdown], outputs=[data_table, comment_input])
demo.launch() |