rodrisouza commited on
Commit
58cb3d4
·
verified ·
1 Parent(s): 88720d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -62
app.py CHANGED
@@ -1,63 +1,207 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
-
61
-
62
- if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import pandas as pd
5
+ from datetime import datetime, timedelta, timezone
6
+ import torch
7
+ from config import hugging_face_token, init_google_sheets_client, models, default_model_name, user_names, google_sheets_name
8
+ import spaces
9
+
10
+ # Hack for ZeroGPU
11
+ torch.jit.script = lambda f: f
12
+
13
+ # Initialize Google Sheets client
14
+ client = init_google_sheets_client()
15
+ sheet = client.open(google_sheets_name)
16
+ stories_sheet = sheet.get_worksheet(1) # Assuming stories are in the second sheet (index 1)
17
+
18
+ # Load stories from Google Sheets
19
+ def load_stories():
20
+ stories_data = stories_sheet.get_all_values()
21
+ stories = [{"title": story[0], "story": story[1]} for story in stories_data if story[0] != "Title"] # Skip header row
22
+ return stories
23
+
24
+ # Load available stories
25
+ stories = load_stories()
26
+
27
+ # Initialize the selected model
28
+ selected_model = default_model_name
29
+ tokenizer, model = None, None
30
+
31
+ # Initialize the data list
32
+ data = []
33
+
34
+ # Load the model and tokenizer once at the beginning
35
+ def load_model(model_name):
36
+ global tokenizer, model, selected_model
37
+ try:
38
+ # Release the memory of the previous model if exists
39
+ if model is not None:
40
+ del model
41
+ torch.cuda.empty_cache()
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(models[model_name], padding_side='left', token=hugging_face_token, trust_remote_code=True)
44
+
45
+ # Ensure the padding token is set
46
+ if tokenizer.pad_token is None:
47
+ tokenizer.pad_token = tokenizer.eos_token
48
+ tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
49
+
50
+ model = AutoModelForCausalLM.from_pretrained(models[model_name], token=hugging_face_token, trust_remote_code=True).to("cuda")
51
+ selected_model = model_name
52
+ except Exception as e:
53
+ print(f"Error loading model {model_name}: {e}")
54
+ raise e
55
+ return tokenizer, model
56
+
57
+ # Ensure the initial model is loaded
58
+ tokenizer, model = load_model(selected_model)
59
+
60
+ # Chat history
61
+ chat_history = []
62
+
63
+ # Function to handle interaction with model
64
+ @spaces.GPU
65
+ def interact(user_input, history):
66
+ global tokenizer, model
67
+ try:
68
+ if tokenizer is None or model is None:
69
+ raise ValueError("Tokenizer or model is not initialized.")
70
+
71
+ messages = history + [{"role": "user", "content": user_input}]
72
+
73
+ # Ensure roles alternate correctly
74
+ for i in range(1, len(messages)):
75
+ if messages[i-1].get("role") == messages[i].get("role"):
76
+ raise ValueError("Conversation roles must alternate user/assistant/user/assistant/...")
77
+
78
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
79
+
80
+ # Generate response using selected model
81
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
82
+ chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id) # Increase max_new_tokens
83
+ response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
84
+
85
+ # Update chat history with generated response
86
+ history.append({"role": "user", "content": user_input})
87
+ history.append({"role": "assistant", "content": response})
88
+
89
+ formatted_history = [(entry["content"], None) if entry["role"] == "user" else (None, entry["content"]) for entry in history if entry["role"] in ["user", "assistant"]]
90
+ return "", formatted_history, history
91
+ except Exception as e:
92
+ if torch.cuda.is_available():
93
+ torch.cuda.empty_cache()
94
+ print(f"Error during interaction: {e}")
95
+ raise gr.Error(f"An error occurred during interaction: {str(e)}")
96
+
97
+ # Function to send selected story and initial message
98
+ def send_selected_story(title, model_name, system_prompt):
99
+ global chat_history
100
+ global selected_story
101
+ global data # Ensure data is reset
102
+ data = [] # Reset data for new story
103
+ tokenizer, model = load_model(model_name)
104
+ selected_story = title
105
+ for story in stories:
106
+ if story["title"] == title:
107
+ system_prompt = f"""
108
+ {system_prompt}
109
+ Here is the story:
110
+ ---
111
+ {story['story']}
112
+ ---
113
+ """
114
+ combined_message = system_prompt.strip()
115
+ if combined_message:
116
+ chat_history = [] # Reset chat history
117
+ chat_history.append({"role": "system", "content": combined_message})
118
+
119
+ # Generate the first question based on the story
120
+ question_prompt = "Please ask a simple question about the story to encourage interaction."
121
+ _, formatted_history, chat_history = interact(question_prompt, chat_history)
122
+
123
+ return formatted_history, chat_history, gr.update(value=[]) # Reset the data table
124
+ else:
125
+ print("Combined message is empty.")
126
+ else:
127
+ print("Story title does not match.")
128
+
129
+ # Function to save comment and score
130
+ def save_comment_score(chat_responses, score, comment, story_name, user_name):
131
+ last_user_message = ""
132
+ last_assistant_message = ""
133
+
134
+ # Find the last user and assistant messages
135
+ for message in reversed(chat_responses):
136
+ if isinstance(message, list) and len(message) == 2:
137
+ if message[0] and not last_user_message:
138
+ last_user_message = message[0]
139
+ elif message[1] and not last_assistant_message:
140
+ last_assistant_message = message[1]
141
+
142
+ if last_user_message and last_assistant_message:
143
+ break
144
+
145
+ timestamp = datetime.now(timezone.utc) - timedelta(hours=3) # Adjust to GMT-3
146
+ timestamp_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
147
+ model_name = selected_model
148
+
149
+ # Append data to local data storage
150
+ data.append([
151
+ timestamp_str,
152
+ user_name,
153
+ model_name,
154
+ story_name,
155
+ last_user_message,
156
+ last_assistant_message,
157
+ score,
158
+ comment
159
+ ])
160
+
161
+ # Append data to Google Sheets
162
+ sheet = client.open(google_sheets_name).sheet1 # Assuming results are saved in sheet1
163
+ sheet.append_row([timestamp_str, user_name, model_name, story_name, last_user_message, last_assistant_message, score, comment])
164
+
165
+ df = pd.DataFrame(data, columns=["Timestamp", "User Name", "Model Name", "Story Name", "User Input", "Chat Response", "Score", "Comment"])
166
+ return df, gr.update(value="") # Clear the comment input box
167
+
168
+ # Create the chat interface using Gradio Blocks
169
+ with gr.Blocks() as demo:
170
+ gr.Markdown("# Chat with Model")
171
+
172
+ model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Select Model", value=selected_model)
173
+ user_dropdown = gr.Dropdown(choices=user_names, label="Select User Name")
174
+ initial_story = stories[0]["title"] if stories else None
175
+ story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story)
176
+
177
+ default_system_prompt = ("You are friendly chatbot and you will interact with a child who speaks Spanish and is learning English as a foreign language. "
178
+ "Everything you write should be in English. I will provide you with a short children's story in English. "
179
+ "After reading the story, please ask the child a series of five simple questions about it, one at a time, to encourage ongoing interaction. "
180
+ "Wait for the child's response to each question before asking the next one.")
181
+ system_prompt_input = gr.Textbox(lines=5, value=default_system_prompt, label="System Prompt")
182
+
183
+ send_story_button = gr.Button("Send Story")
184
+
185
+ with gr.Row():
186
+ with gr.Column(scale=1):
187
+ chatbot_input = gr.Textbox(placeholder="Type your message here...", label="User Input")
188
+ send_message_button = gr.Button("Send")
189
+
190
+ with gr.Column(scale=2):
191
+ chatbot_output = gr.Chatbot(label="Chat History")
192
+
193
+ with gr.Row():
194
+ with gr.Column(scale=1):
195
+ score_input = gr.Slider(minimum=0, maximum=5, step=1, label="Score")
196
+ comment_input = gr.Textbox(placeholder="Add a comment...", label="Comment")
197
+ save_button = gr.Button("Save Score and Comment")
198
+
199
+ data_table = gr.DataFrame(headers=["Timestamp", "User Name", "Model Name", "Story Name", "User Input", "Chat Response", "Score", "Comment"])
200
+
201
+ chat_history_json = gr.JSON(value=[], visible=False)
202
+
203
+ send_story_button.click(fn=send_selected_story, inputs=[story_dropdown, model_dropdown, system_prompt_input], outputs=[chatbot_output, chat_history_json, data_table])
204
+ send_message_button.click(fn=interact, inputs=[chatbot_input, chat_history_json], outputs=[chatbot_input, chatbot_output, chat_history_json])
205
+ save_button.click(fn=save_comment_score, inputs=[chatbot_output, score_input, comment_input, story_dropdown, user_dropdown], outputs=[data_table, comment_input])
206
+
207
+ demo.launch()