lobrien001 commited on
Commit
f54f343
·
verified ·
1 Parent(s): b36ae86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -161
app.py CHANGED
@@ -9,6 +9,7 @@ import random
9
  from transformers import pipeline
10
  from sklearn.metrics import precision_score, recall_score, f1_score
11
  import json
 
12
 
13
  # Load the model
14
  ner_pipeline = pipeline("ner", model="Sevixdd/roberta-base-finetuned-ner")
@@ -26,177 +27,176 @@ QUEUE_LENGTH = Gauge('chat_queue_length', 'Length of the chat queue')
26
  logging.basicConfig(filename="chat_log.txt", level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
27
 
28
  # --- Queue and Metrics ---
29
- chat_queue = Queue() # Define chat_queue globally
30
 
31
  # --- Chat Function with Monitoring ---
32
  def chat_function(message, ground_truth):
33
-   logging.debug("Starting chat_function")
34
-   with REQUEST_LATENCY.time():
35
-     REQUEST_COUNT.inc()
36
-     try:
37
-       start_time = time.time()
38
-       chat_queue.put(message)
39
-       logging.info(f"Received message from user: {message}")
40
-
41
-       ner_results = ner_pipeline(message)
42
-       logging.debug(f"NER results: {ner_results}")
43
-
44
-       detailed_response = []
45
-       predicted_labels = []
46
-       for result in ner_results:
47
-         token = result['word']
48
-         score = result['score']
49
-         entity = result['entity']
50
-         start = result['start']
51
-         end = result['end']
52
-         label_id = int(entity.split('_')[-1]) # Extract numeric label from entity
53
-         predicted_labels.append(label_id)
54
-         detailed_response.append(f"Token: {token}, Entity: {entity}, Score: {score:.4f}, Start: {start}, End: {end}")
55
-
56
-       response = "\n".join(detailed_response)
57
-       logging.info(f"Generated response: {response}")
58
-
59
-       response_size = len(response.encode('utf-8'))
60
-       RESPONSE_SIZE.observe(response_size)
61
-
62
-       time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time
63
-
64
-       # Compute metrics
65
-       try:
66
-         ground_truth_labels = json.loads(ground_truth) # Assuming ground_truth is input as a JSON string
67
-       except json.JSONDecodeError:
68
-         return "Invalid JSON format for ground truth labels. Please provide a valid JSON array."
69
-
70
-       precision = precision_score(ground_truth_labels, predicted_labels, average='weighted')
71
-       recall = recall_score(ground_truth_labels, predicted_labels, average='weighted')
72
-       f1 = f1_score(ground_truth_labels, predicted_labels, average='weighted')
73
-
74
-       metrics_response = (f"Precision: {precision:.4f}\n"
75
-                 f"Recall: {recall:.4f}\n"
76
-                 f"F1 Score: {f1:.4f}")
77
-
78
-       full_response = f"{response}\n\nMetrics:\n{metrics_response}"
79
-
80
-       chat_queue.get()
81
-       logging.debug("Finished processing message")
82
-       return full_response
83
-     except Exception as e:
84
-       ERROR_COUNT.inc()
85
-       logging.error(f"Error in chat processing: {e}")
86
-       return f"An error occurred. Please try again. Error: {e}"
87
 
88
  # Function to simulate stress test
89
  def stress_test(num_requests, message, delay):
90
-   def send_chat_message():
91
-     response = requests.post("http://127.0.0.1:7860/api/predict/", json={
92
-       "data": [message],
93
-       "fn_index": 0 # This might need to be updated based on your Gradio app's function index
94
-     })
95
-     logging.debug(response.json())
96
-
97
-   threads = []
98
-   for _ in range(num_requests):
99
-     t = threading.Thread(target=send_chat_message)
100
-     t.start()
101
-     threads.append(t)
102
-     time.sleep(delay) # Delay between requests
103
-
104
-   for t in threads:
105
-     t.join()
106
 
107
  # --- Gradio Interface with Background Image and Three Windows ---
108
  with gr.Blocks(css="""
109
  body {
110
-   background-image: url("stag.jpeg"); 
111
-   background-size: cover; 
112
-   background-repeat: no-repeat;
113
  }
114
- """, title="PLOD Filtered with Monitoring") as demo: # Load CSS for background image
115
-   with gr.Tab("Chat"):
116
-     gr.Markdown("## Chat with the Bot")
117
-     message_input = gr.Textbox(label="Enter your sentence:", lines=2)
118
-     ground_truth_input = gr.Textbox(label="Enter ground truth labels (JSON format):", lines=2)
119
-     output = gr.Textbox(label="Response", lines=10)
120
-     chat_interface = gr.Interface(fn=chat_function, inputs=[message_input, ground_truth_input], outputs=output)
121
-     chat_interface.render()
122
-
123
-   with gr.Tab("Model Parameters"):
124
-     model_params_display = gr.Textbox(label="Model Parameters", lines=20, interactive=False) # Display model parameters
125
-
126
-   with gr.Tab("Performance Metrics"):
127
-     request_count_display = gr.Number(label="Request Count", value=0)
128
-     avg_latency_display = gr.Number(label="Avg. Response Time (s)", value=0)
129
-
130
-   with gr.Tab("Infrastructure"):
131
-     cpu_usage_display = gr.Number(label="CPU Usage (%)", value=0)
132
-     mem_usage_display = gr.Number(label="Memory Usage (%)", value=0)
133
-
134
-   with gr.Tab("Logs"):
135
-     logs_display = gr.Textbox(label="Logs", lines=10) # Increased lines for better visibility
136
-
137
-   with gr.Tab("Stress Testing"):
138
-     num_requests_input = gr.Number(label="Number of Requests", value=10)
139
-     message_input_stress = gr.Textbox(label="Message", value="Hello bot!")
140
-     delay_input = gr.Number(label="Delay Between Requests (seconds)", value=0.1)
141
-     stress_test_button = gr.Button("Start Stress Test")
142
-     stress_test_status = gr.Textbox(label="Stress Test Status", lines=5, interactive=False)
143
-
144
-     def run_stress_test(num_requests, message, delay):
145
-       stress_test_status.value = "Stress test started..."
146
-       try:
147
-         stress_test(num_requests, message, delay)
148
-         stress_test_status.value = "Stress test completed."
149
-       except Exception as e:
150
-         stress_test_status.value = f"Stress test failed: {e}"
151
-
152
-     stress_test_button.click(run_stress_test, [num_requests_input, message_input_stress, delay_input], stress_test_status)
153
-
154
-   # --- Update Functions ---
155
-   def update_metrics(request_count_display, avg_latency_display):
156
-     while True:
157
-       request_count = REQUEST_COUNT._value.get()
158
-       latency_samples = REQUEST_LATENCY.collect()[0].samples
159
-       avg_latency = sum(s.value for s in latency_samples) / len(latency_samples if latency_samples else [1]) # Avoid division by zero
160
-
161
-       request_count_display.value = request_count
162
-       avg_latency_display.value = round(avg_latency, 2)
163
-
164
-       time.sleep(5) # Update every 5 seconds
165
-
166
-   def update_usage(cpu_usage_display, mem_usage_display):
167
-     while True:
168
-       cpu_usage_display.value = psutil.cpu_percent()
169
-       mem_usage_display.value = psutil.virtual_memory().percent
170
-       CPU_USAGE.set(psutil.cpu_percent())
171
-       MEM_USAGE.set(psutil.virtual_memory().percent)
172
-       time.sleep(5)
173
-
174
-   def update_logs(logs_display):
175
-     while True:
176
-    ��  with open("chat_log.txt", "r") as log_file:
177
-         logs = log_file.readlines()
178
-         logs_display.value = "".join(logs[-10:]) # Display last 10 lines
179
-       time.sleep(1) # Update every 1 second
180
-
181
-   def display_model_params(model_params_display):
182
-     while True:
183
-       model_params = ner_pipeline.model.config.to_dict()
184
-       model_params_str = "\n".join(f"{key}: {value}" for key, value in model_params.items())
185
-       model_params_display.value = model_params_str
186
-       time.sleep(10) # Update every 10 seconds
187
-
188
-   def update_queue_length():
189
-     while True:
190
-       QUEUE_LENGTH.set(chat_queue.qsize())
191
-       time.sleep(1) # Update every second
192
-
193
-   # --- Start Threads ---
194
-   threading.Thread(target=start_http_server, args=(8000,), daemon=True).start()
195
-   threading.Thread(target=update_metrics, args=(request_count_display, avg_latency_display), daemon=True).start()
196
-   threading.Thread(target=update_usage, args=(cpu_usage_display, mem_usage_display), daemon=True).start()
197
-   threading.Thread(target=update_logs, args=(logs_display,), daemon=True).start()
198
-   threading.Thread(target=display_model_params, args=(model_params_display,), daemon=True).start()
199
-   threading.Thread(target=update_queue_length, daemon=True).start()
200
 
201
  # Launch the app
202
- demo.launch(share=True).
 
9
  from transformers import pipeline
10
  from sklearn.metrics import precision_score, recall_score, f1_score
11
  import json
12
+ import requests
13
 
14
  # Load the model
15
  ner_pipeline = pipeline("ner", model="Sevixdd/roberta-base-finetuned-ner")
 
27
  logging.basicConfig(filename="chat_log.txt", level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
28
 
29
  # --- Queue and Metrics ---
30
+ chat_queue = Queue() # Define chat_queue globally
31
 
32
  # --- Chat Function with Monitoring ---
33
  def chat_function(message, ground_truth):
34
+ logging.debug("Starting chat_function")
35
+ with REQUEST_LATENCY.time():
36
+ REQUEST_COUNT.inc()
37
+ try:
38
+ chat_queue.put(message)
39
+ logging.info(f"Received message from user: {message}")
40
+
41
+ ner_results = ner_pipeline(message)
42
+ logging.debug(f"NER results: {ner_results}")
43
+
44
+ detailed_response = []
45
+ predicted_labels = []
46
+ for result in ner_results:
47
+ token = result['word']
48
+ score = result['score']
49
+ entity = result['entity']
50
+ start = result['start']
51
+ end = result['end']
52
+ label_id = int(entity.split('_')[-1]) # Extract numeric label from entity
53
+ predicted_labels.append(label_id)
54
+ detailed_response.append(f"Token: {token}, Entity: {entity}, Score: {score:.4f}, Start: {start}, End: {end}")
55
+
56
+ response = "\n".join(detailed_response)
57
+ logging.info(f"Generated response: {response}")
58
+
59
+ response_size = len(response.encode('utf-8'))
60
+ RESPONSE_SIZE.observe(response_size)
61
+
62
+ time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time
63
+
64
+ # Compute metrics
65
+ try:
66
+ ground_truth_labels = json.loads(ground_truth) # Assuming ground_truth is input as a JSON string
67
+ except json.JSONDecodeError:
68
+ return "Invalid JSON format for ground truth labels. Please provide a valid JSON array."
69
+
70
+ precision = precision_score(ground_truth_labels, predicted_labels, average='weighted', zero_division=0)
71
+ recall = recall_score(ground_truth_labels, predicted_labels, average='weighted', zero_division=0)
72
+ f1 = f1_score(ground_truth_labels, predicted_labels, average='weighted', zero_division=0)
73
+
74
+ metrics_response = (f"Precision: {precision:.4f}\n"
75
+ f"Recall: {recall:.4f}\n"
76
+ f"F1 Score: {f1:.4f}")
77
+
78
+ full_response = f"{response}\n\nMetrics:\n{metrics_response}"
79
+
80
+ chat_queue.get()
81
+ logging.debug("Finished processing message")
82
+ return full_response
83
+ except Exception as e:
84
+ ERROR_COUNT.inc()
85
+ logging.error(f"Error in chat processing: {e}")
86
+ return f"An error occurred. Please try again. Error: {e}"
 
87
 
88
  # Function to simulate stress test
89
  def stress_test(num_requests, message, delay):
90
+ def send_chat_message():
91
+ response = requests.post("http://127.0.0.1:7860/api/predict/", json={
92
+ "data": [message],
93
+ "fn_index": 0 # This might need to be updated based on your Gradio app's function index
94
+ })
95
+ logging.debug(response.json())
96
+
97
+ threads = []
98
+ for _ in range(num_requests):
99
+ t = threading.Thread(target=send_chat_message)
100
+ t.start()
101
+ threads.append(t)
102
+ time.sleep(delay) # Delay between requests
103
+
104
+ for t in threads:
105
+ t.join()
106
 
107
  # --- Gradio Interface with Background Image and Three Windows ---
108
  with gr.Blocks(css="""
109
  body {
110
+ background-image: url("stag.jpeg");
111
+ background-size: cover;
112
+ background-repeat: no-repeat;
113
  }
114
+ """, title="PLOD Filtered with Monitoring") as demo: # Load CSS for background image
115
+ with gr.Tab("Chat"):
116
+ gr.Markdown("## Chat with the Bot")
117
+ message_input = gr.Textbox(label="Enter your sentence:", lines=2)
118
+ ground_truth_input = gr.Textbox(label="Enter ground truth labels (JSON format):", lines=2)
119
+ output = gr.Textbox(label="Response", lines=10)
120
+ chat_interface = gr.Interface(fn=chat_function, inputs=[message_input, ground_truth_input], outputs=output)
121
+ chat_interface.render()
122
+
123
+ with gr.Tab("Model Parameters"):
124
+ model_params_display = gr.Textbox(label="Model Parameters", lines=20, interactive=False) # Display model parameters
125
+
126
+ with gr.Tab("Performance Metrics"):
127
+ request_count_display = gr.Number(label="Request Count", value=0)
128
+ avg_latency_display = gr.Number(label="Avg. Response Time (s)", value=0)
129
+
130
+ with gr.Tab("Infrastructure"):
131
+ cpu_usage_display = gr.Number(label="CPU Usage (%)", value=0)
132
+ mem_usage_display = gr.Number(label="Memory Usage (%)", value=0)
133
+
134
+ with gr.Tab("Logs"):
135
+ logs_display = gr.Textbox(label="Logs", lines=10) # Increased lines for better visibility
136
+
137
+ with gr.Tab("Stress Testing"):
138
+ num_requests_input = gr.Number(label="Number of Requests", value=10)
139
+ message_input_stress = gr.Textbox(label="Message", value="Hello bot!")
140
+ delay_input = gr.Number(label="Delay Between Requests (seconds)", value=0.1)
141
+ stress_test_button = gr.Button("Start Stress Test")
142
+ stress_test_status = gr.Textbox(label="Stress Test Status", lines=5, interactive=False)
143
+
144
+ def run_stress_test(num_requests, message, delay):
145
+ stress_test_status.value = "Stress test started..."
146
+ try:
147
+ stress_test(num_requests, message, delay)
148
+ stress_test_status.value = "Stress test completed."
149
+ except Exception as e:
150
+ stress_test_status.value = f"Stress test failed: {e}"
151
+
152
+ stress_test_button.click(run_stress_test, [num_requests_input, message_input_stress, delay_input], stress_test_status)
153
+
154
+ # --- Update Functions ---
155
+ def update_metrics(request_count_display, avg_latency_display):
156
+ while True:
157
+ request_count = REQUEST_COUNT._value.get()
158
+ latency_samples = REQUEST_LATENCY.collect()[0].samples
159
+ avg_latency = sum(s.value for s in latency_samples) / len(latency_samples if latency_samples else [1]) # Avoid division by zero
160
+
161
+ request_count_display.value = request_count
162
+ avg_latency_display.value = round(avg_latency, 2)
163
+
164
+ time.sleep(5) # Update every 5 seconds
165
+
166
+ def update_usage(cpu_usage_display, mem_usage_display):
167
+ while True:
168
+ cpu_usage_display.value = psutil.cpu_percent()
169
+ mem_usage_display.value = psutil.virtual_memory().percent
170
+ CPU_USAGE.set(psutil.cpu_percent())
171
+ MEM_USAGE.set(psutil.virtual_memory().percent)
172
+ time.sleep(5)
173
+
174
+ def update_logs(logs_display):
175
+ while True:
176
+ with open("chat_log.txt", "r") as log_file:
177
+ logs = log_file.readlines()
178
+ logs_display.value = "".join(logs[-10:]) # Display last 10 lines
179
+ time.sleep(1) # Update every 1 second
180
+
181
+ def display_model_params(model_params_display):
182
+ while True:
183
+ model_params = ner_pipeline.model.config.to_dict()
184
+ model_params_str = "\n".join(f"{key}: {value}" for key, value in model_params.items())
185
+ model_params_display.value = model_params_str
186
+ time.sleep(10) # Update every 10 seconds
187
+
188
+ def update_queue_length():
189
+ while True:
190
+ QUEUE_LENGTH.set(chat_queue.qsize())
191
+ time.sleep(1) # Update every second
192
+
193
+ # --- Start Threads ---
194
+ threading.Thread(target=start_http_server, args=(8000,), daemon=True).start()
195
+ threading.Thread(target=update_metrics, args=(request_count_display, avg_latency_display), daemon=True).start()
196
+ threading.Thread(target=update_usage, args=(cpu_usage_display, mem_usage_display), daemon=True).start()
197
+ threading.Thread(target=update_logs, args=(logs_display,), daemon=True).start()
198
+ threading.Thread(target=display_model_params, args=(model_params_display,), daemon=True).start()
199
+ threading.Thread(target=update_queue_length, daemon=True).start()
200
 
201
  # Launch the app
202
+ demo.launch(share=True)