lobrien001 commited on
Commit
b36ae86
·
verified ·
1 Parent(s): 8459b1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -201
app.py CHANGED
@@ -9,209 +9,194 @@ import random
9
  from transformers import pipeline
10
  from sklearn.metrics import precision_score, recall_score, f1_score
11
  import json
12
- import requests
13
-
14
- # Set up logging
15
- logging.basicConfig(level=logging.DEBUG)
16
-
17
- # Initialize monitoring metrics
18
- REQUEST_COUNT = Counter('request_count', 'Number of requests')
19
- REQUEST_LATENCY = Histogram('request_latency_seconds', 'Request latency')
20
- RESPONSE_SIZE = Histogram('response_size_bytes', 'Response size in bytes')
21
- ERROR_COUNT = Counter('error_count', 'Number of errors')
22
- CPU_USAGE = Gauge('cpu_usage_percent', 'CPU usage percentage')
23
- MEM_USAGE = Gauge('mem_usage_percent', 'Memory usage percentage')
24
- QUEUE_LENGTH = Gauge('queue_length', 'Queue length')
25
-
26
- # Initialize NER pipeline (assuming a pre-trained model is used)
27
- ner_pipeline = pipeline("ner")
28
-
29
- # Initialize queue
30
- chat_queue = Queue()
31
-
32
- def chat_function(message, user_ner_tags, ground_truth, history=[]):
33
- logging.debug("Starting chat_function")
34
- with REQUEST_LATENCY.time():
35
- REQUEST_COUNT.inc()
36
- try:
37
- chat_queue.put(message)
38
- logging.info(f"Received message from user: {message}")
39
-
40
- ner_results = ner_pipeline(message)
41
-
42
- table_data = []
43
- model_predicted_labels = []
44
- user_predicted_labels = [] # List to store user-provided labels
45
-
46
- try:
47
- user_ner_results = json.loads(user_ner_tags) # Load user's NER results
48
- if not isinstance(user_ner_results, list):
49
- raise ValueError("Invalid format for user NER tags. Please provide a JSON list of dictionaries.")
50
- except json.JSONDecodeError:
51
- user_ner_results = [] # If invalid JSON, set user results to empty list
52
-
53
- for i, result in enumerate(ner_results):
54
- token = result['word']
55
- entity = result['entity']
56
- model_score = round(result['score'], 4)
57
- start = result['start']
58
- end = result['end']
59
- label_id = int(entity.split('_')[-1])
60
- model_predicted_labels.append(label_id)
61
-
62
- # Try to get the user's label for this token
63
- user_score = 0.0 # Default score if user didn't tag the token
64
- user_entity = "-"
65
- if i < len(user_ner_results) and user_ner_results[i]['word'] == token:
66
- user_entity = user_ner_results[i]['entity']
67
- user_label_id = int(user_entity.split('_')[-1])
68
- user_predicted_labels.append(user_label_id)
69
-
70
- # Here, you would typically have a user-provided confidence score, but for this example, let's just use 1.0
71
- user_score = 1.0
72
-
73
- table_data.append([token, entity, model_score, user_entity, user_score])
74
-
75
- response_size = len(str(table_data).encode('utf-8'))
76
- RESPONSE_SIZE.observe(response_size)
77
-
78
- time.sleep(random.uniform(0.5, 2.5))
79
-
80
- # --- Compute Metrics (Model & User) ---
81
- metrics_response = ""
82
- if ground_truth:
83
- try:
84
- ground_truth_labels = json.loads(ground_truth)
85
- except json.JSONDecodeError:
86
- return history + [[message, (table_data, "Invalid JSON format for ground truth labels.")]]
87
-
88
- model_precision = precision_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
89
- model_recall = recall_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
90
- model_f1 = f1_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
91
-
92
- metrics_response += "\nModel Metrics:\n"
93
- metrics_response += (f"Precision: {model_precision:.4f}\n"
94
- f"Recall: {model_recall:.4f}\n"
95
- f"F1 Score: {model_f1:.4f}")
96
-
97
- if user_ner_results: # Only calculate user metrics if user provided tags
98
- user_precision = precision_score(ground_truth_labels, user_predicted_labels, average='weighted', zero_division=0)
99
- user_recall = recall_score(ground_truth_labels, user_predicted_labels, average='weighted', zero_division=0)
100
- user_f1 = f1_score(ground_truth_labels, user_predicted_labels, average='weighted', zero_division=0)
101
-
102
- metrics_response += "\nUser Metrics:\n"
103
- metrics_response += (f"Precision: {user_precision:.4f}\n"
104
- f"Recall: {user_recall:.4f}\n"
105
- f"F1 Score: {user_f1:.4f}")
106
- else:
107
- metrics_response = "Ground truth labels not provided."
108
-
109
- chat_queue.get()
110
- logging.debug("Finished processing message")
111
- return history + [[message, (table_data, metrics_response)]]
112
-
113
- except Exception as e:
114
- ERROR_COUNT.inc()
115
- logging.error(f"Error in chat processing: {e}")
116
- return history + [[message, f"An error occurred. Please try again. Error: {e}"]]
117
-
118
- # --- Gradio Interface ---
119
  with gr.Blocks(css="""
120
  body {
121
- background-image: url("stag.jpeg");
122
- background-size: cover;
123
- background-repeat: no-repeat;
124
  }
125
- """, title="PLOD Filtered with Monitoring") as demo: # Load CSS for background image
126
- with gr.Tab("Chat"):
127
- gr.Markdown("## Chat with the Bot")
128
- message_input = gr.Textbox(label="Enter your sentence:", lines=2)
129
- user_ner_tags_input = gr.Textbox(label="Enter your NER tags (JSON format):", lines=5)
130
- ground_truth_input = gr.Textbox(label="Enter ground truth labels (JSON format):", lines=2)
131
-
132
- chat_output = gr.Chatbot()
133
- chat_interface = gr.Interface(fn=chat_function,
134
- inputs=[message_input, user_ner_tags_input, ground_truth_input],
135
- outputs=chat_output)
136
- chat_interface.render()
137
-
138
- with gr.Tab("Model Parameters"):
139
- model_params_display = gr.Textbox(label="Model Parameters", lines=20, interactive=False) # Display model parameters
140
-
141
- with gr.Tab("Performance Metrics"):
142
- request_count_display = gr.Number(label="Request Count", value=0)
143
- avg_latency_display = gr.Number(label="Avg. Response Time (s)", value=0)
144
-
145
- with gr.Tab("Infrastructure"):
146
- cpu_usage_display = gr.Number(label="CPU Usage (%)", value=0)
147
- mem_usage_display = gr.Number(label="Memory Usage (%)", value=0)
148
-
149
- with gr.Tab("Logs"):
150
- logs_display = gr.Textbox(label="Logs", lines=10) # Increased lines for better visibility
151
-
152
- with gr.Tab("Stress Testing"):
153
- num_requests_input = gr.Number(label="Number of Requests", value=10)
154
- stress_message_input = gr.Textbox(label="Message", value="Hello bot!")
155
- delay_input = gr.Number(label="Delay Between Requests (seconds)", value=0.1)
156
- stress_test_button = gr.Button("Start Stress Test")
157
- stress_test_status = gr.Textbox(label="Stress Test Status", lines=5, interactive=False)
158
-
159
- def run_stress_test(num_requests, message, delay):
160
- stress_test_status.value = "Stress test started..."
161
- try:
162
- stress_test(num_requests, message, delay)
163
- stress_test_status.value = "Stress test completed."
164
- except Exception as e:
165
- stress_test_status.value = f"Stress test failed: {e}"
166
-
167
- stress_test_button.click(run_stress_test, [num_requests_input, stress_message_input, delay_input], stress_test_status)
168
-
169
- # --- Update Functions ---
170
- def update_metrics(request_count_display, avg_latency_display):
171
- while True:
172
- request_count = REQUEST_COUNT._value.get()
173
- latency_samples = REQUEST_LATENCY.collect()[0].samples
174
- avg_latency = sum(s.value for s in latency_samples) / len(latency_samples) if latency_samples else 0
175
-
176
- request_count_display.value = request_count
177
- avg_latency_display.value = round(avg_latency, 2)
178
-
179
- time.sleep(5) # Update every 5 seconds
180
-
181
- def update_usage(cpu_usage_display, mem_usage_display):
182
- while True:
183
- cpu_usage_display.value = psutil.cpu_percent()
184
- mem_usage_display.value = psutil.virtual_memory().percent
185
- CPU_USAGE.set(psutil.cpu_percent())
186
- MEM_USAGE.set(psutil.virtual_memory().percent)
187
- time.sleep(5)
188
-
189
- def update_logs(logs_display):
190
- while True:
191
- with open("chat_log.txt", "r") as log_file:
192
- logs = log_file.readlines()
193
- logs_display.value = "".join(logs[-10:]) # Display last 10 lines
194
- time.sleep(1) # Update every 1 second
195
-
196
- def display_model_params(model_params_display):
197
- while True:
198
- model_params = ner_pipeline.model.config.to_dict()
199
- model_params_str = "\n".join(f"{key}: {value}" for key, value in model_params.items())
200
- model_params_display.value = model_params_str
201
- time.sleep(10) # Update every 10 seconds
202
-
203
- def update_queue_length():
204
- while True:
205
- QUEUE_LENGTH.set(chat_queue.qsize())
206
- time.sleep(1) # Update every second
207
-
208
- # --- Start Threads ---
209
- threading.Thread(target=start_http_server, args=(8000,), daemon=True).start()
210
- threading.Thread(target=update_metrics, args=(request_count_display, avg_latency_display), daemon=True).start()
211
- threading.Thread(target=update_usage, args=(cpu_usage_display, mem_usage_display), daemon=True).start()
212
- threading.Thread(target=update_logs, args=(logs_display,), daemon=True).start()
213
- threading.Thread(target=display_model_params, args=(model_params_display,), daemon=True).start()
214
- threading.Thread(target=update_queue_length, daemon=True).start()
215
 
216
  # Launch the app
217
- demo.launch(share=True)
 
9
  from transformers import pipeline
10
  from sklearn.metrics import precision_score, recall_score, f1_score
11
  import json
12
+
13
+ # Load the model
14
+ ner_pipeline = pipeline("ner", model="Sevixdd/roberta-base-finetuned-ner")
15
+
16
+ # --- Prometheus Metrics Setup ---
17
+ REQUEST_COUNT = Counter('gradio_request_count', 'Total number of requests')
18
+ REQUEST_LATENCY = Histogram('gradio_request_latency_seconds', 'Request latency in seconds')
19
+ ERROR_COUNT = Counter('gradio_error_count', 'Total number of errors')
20
+ RESPONSE_SIZE = Histogram('gradio_response_size_bytes', 'Size of responses in bytes')
21
+ CPU_USAGE = Gauge('system_cpu_usage_percent', 'System CPU usage in percent')
22
+ MEM_USAGE = Gauge('system_memory_usage_percent', 'System memory usage in percent')
23
+ QUEUE_LENGTH = Gauge('chat_queue_length', 'Length of the chat queue')
24
+
25
+ # --- Logging Setup ---
26
+ logging.basicConfig(filename="chat_log.txt", level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
27
+
28
+ # --- Queue and Metrics ---
29
+ chat_queue = Queue() # Define chat_queue globally
30
+
31
+ # --- Chat Function with Monitoring ---
32
+ def chat_function(message, ground_truth):
33
+   logging.debug("Starting chat_function")
34
+   with REQUEST_LATENCY.time():
35
+     REQUEST_COUNT.inc()
36
+     try:
37
+       start_time = time.time()
38
+       chat_queue.put(message)
39
+       logging.info(f"Received message from user: {message}")
40
+
41
+       ner_results = ner_pipeline(message)
42
+       logging.debug(f"NER results: {ner_results}")
43
+
44
+       detailed_response = []
45
+       predicted_labels = []
46
+       for result in ner_results:
47
+         token = result['word']
48
+         score = result['score']
49
+         entity = result['entity']
50
+         start = result['start']
51
+         end = result['end']
52
+         label_id = int(entity.split('_')[-1]) # Extract numeric label from entity
53
+         predicted_labels.append(label_id)
54
+         detailed_response.append(f"Token: {token}, Entity: {entity}, Score: {score:.4f}, Start: {start}, End: {end}")
55
+
56
+       response = "\n".join(detailed_response)
57
+       logging.info(f"Generated response: {response}")
58
+
59
+       response_size = len(response.encode('utf-8'))
60
+       RESPONSE_SIZE.observe(response_size)
61
+
62
+       time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time
63
+
64
+       # Compute metrics
65
+       try:
66
+         ground_truth_labels = json.loads(ground_truth) # Assuming ground_truth is input as a JSON string
67
+       except json.JSONDecodeError:
68
+         return "Invalid JSON format for ground truth labels. Please provide a valid JSON array."
69
+
70
+       precision = precision_score(ground_truth_labels, predicted_labels, average='weighted')
71
+       recall = recall_score(ground_truth_labels, predicted_labels, average='weighted')
72
+       f1 = f1_score(ground_truth_labels, predicted_labels, average='weighted')
73
+
74
+       metrics_response = (f"Precision: {precision:.4f}\n"
75
+                 f"Recall: {recall:.4f}\n"
76
+                 f"F1 Score: {f1:.4f}")
77
+
78
+       full_response = f"{response}\n\nMetrics:\n{metrics_response}"
79
+
80
+       chat_queue.get()
81
+       logging.debug("Finished processing message")
82
+       return full_response
83
+     except Exception as e:
84
+       ERROR_COUNT.inc()
85
+       logging.error(f"Error in chat processing: {e}")
86
+       return f"An error occurred. Please try again. Error: {e}"
87
+
88
+ # Function to simulate stress test
89
+ def stress_test(num_requests, message, delay):
90
+   def send_chat_message():
91
+     response = requests.post("http://127.0.0.1:7860/api/predict/", json={
92
+       "data": [message],
93
+       "fn_index": 0 # This might need to be updated based on your Gradio app's function index
94
+     })
95
+     logging.debug(response.json())
96
+
97
+   threads = []
98
+   for _ in range(num_requests):
99
+     t = threading.Thread(target=send_chat_message)
100
+     t.start()
101
+     threads.append(t)
102
+     time.sleep(delay) # Delay between requests
103
+
104
+   for t in threads:
105
+     t.join()
106
+
107
+ # --- Gradio Interface with Background Image and Three Windows ---
 
 
 
 
 
 
 
 
 
 
 
108
  with gr.Blocks(css="""
109
  body {
110
+   background-image: url("stag.jpeg"); 
111
+   background-size: cover; 
112
+   background-repeat: no-repeat;
113
  }
114
+ """, title="PLOD Filtered with Monitoring") as demo: # Load CSS for background image
115
+   with gr.Tab("Chat"):
116
+     gr.Markdown("## Chat with the Bot")
117
+     message_input = gr.Textbox(label="Enter your sentence:", lines=2)
118
+     ground_truth_input = gr.Textbox(label="Enter ground truth labels (JSON format):", lines=2)
119
+     output = gr.Textbox(label="Response", lines=10)
120
+     chat_interface = gr.Interface(fn=chat_function, inputs=[message_input, ground_truth_input], outputs=output)
121
+     chat_interface.render()
122
+
123
+   with gr.Tab("Model Parameters"):
124
+     model_params_display = gr.Textbox(label="Model Parameters", lines=20, interactive=False) # Display model parameters
125
+
126
+   with gr.Tab("Performance Metrics"):
127
+     request_count_display = gr.Number(label="Request Count", value=0)
128
+     avg_latency_display = gr.Number(label="Avg. Response Time (s)", value=0)
129
+
130
+   with gr.Tab("Infrastructure"):
131
+     cpu_usage_display = gr.Number(label="CPU Usage (%)", value=0)
132
+     mem_usage_display = gr.Number(label="Memory Usage (%)", value=0)
133
+
134
+   with gr.Tab("Logs"):
135
+     logs_display = gr.Textbox(label="Logs", lines=10) # Increased lines for better visibility
136
+
137
+   with gr.Tab("Stress Testing"):
138
+     num_requests_input = gr.Number(label="Number of Requests", value=10)
139
+     message_input_stress = gr.Textbox(label="Message", value="Hello bot!")
140
+     delay_input = gr.Number(label="Delay Between Requests (seconds)", value=0.1)
141
+     stress_test_button = gr.Button("Start Stress Test")
142
+     stress_test_status = gr.Textbox(label="Stress Test Status", lines=5, interactive=False)
143
+
144
+     def run_stress_test(num_requests, message, delay):
145
+       stress_test_status.value = "Stress test started..."
146
+       try:
147
+         stress_test(num_requests, message, delay)
148
+         stress_test_status.value = "Stress test completed."
149
+       except Exception as e:
150
+         stress_test_status.value = f"Stress test failed: {e}"
151
+
152
+     stress_test_button.click(run_stress_test, [num_requests_input, message_input_stress, delay_input], stress_test_status)
153
+
154
+   # --- Update Functions ---
155
+   def update_metrics(request_count_display, avg_latency_display):
156
+     while True:
157
+       request_count = REQUEST_COUNT._value.get()
158
+       latency_samples = REQUEST_LATENCY.collect()[0].samples
159
+       avg_latency = sum(s.value for s in latency_samples) / len(latency_samples if latency_samples else [1]) # Avoid division by zero
160
+
161
+       request_count_display.value = request_count
162
+       avg_latency_display.value = round(avg_latency, 2)
163
+
164
+       time.sleep(5) # Update every 5 seconds
165
+
166
+   def update_usage(cpu_usage_display, mem_usage_display):
167
+     while True:
168
+       cpu_usage_display.value = psutil.cpu_percent()
169
+       mem_usage_display.value = psutil.virtual_memory().percent
170
+       CPU_USAGE.set(psutil.cpu_percent())
171
+       MEM_USAGE.set(psutil.virtual_memory().percent)
172
+       time.sleep(5)
173
+
174
+   def update_logs(logs_display):
175
+     while True:
176
+       with open("chat_log.txt", "r") as log_file:
177
+         logs = log_file.readlines()
178
+         logs_display.value = "".join(logs[-10:]) # Display last 10 lines
179
+       time.sleep(1) # Update every 1 second
180
+
181
+   def display_model_params(model_params_display):
182
+     while True:
183
+       model_params = ner_pipeline.model.config.to_dict()
184
+       model_params_str = "\n".join(f"{key}: {value}" for key, value in model_params.items())
185
+       model_params_display.value = model_params_str
186
+       time.sleep(10) # Update every 10 seconds
187
+
188
+   def update_queue_length():
189
+     while True:
190
+       QUEUE_LENGTH.set(chat_queue.qsize())
191
+       time.sleep(1) # Update every second
192
+
193
+   # --- Start Threads ---
194
+   threading.Thread(target=start_http_server, args=(8000,), daemon=True).start()
195
+   threading.Thread(target=update_metrics, args=(request_count_display, avg_latency_display), daemon=True).start()
196
+   threading.Thread(target=update_usage, args=(cpu_usage_display, mem_usage_display), daemon=True).start()
197
+   threading.Thread(target=update_logs, args=(logs_display,), daemon=True).start()
198
+   threading.Thread(target=display_model_params, args=(model_params_display,), daemon=True).start()
199
+   threading.Thread(target=update_queue_length, daemon=True).start()
 
 
 
 
200
 
201
  # Launch the app
202
+ demo.launch(share=True).