Update app.py
Browse files
app.py
CHANGED
@@ -348,7 +348,94 @@ class CustomClient(fl.client.NumPyClient):
|
|
348 |
|
349 |
fig.tight_layout()
|
350 |
plot_placeholder.pyplot(fig)
|
351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
def read_log_file():
|
353 |
with open("./log.txt", "r") as file:
|
354 |
return file.read()
|
@@ -423,7 +510,7 @@ def main():
|
|
423 |
config=fl.server.ServerConfig(num_rounds=1),
|
424 |
strategy=strategy,
|
425 |
client_resources={"num_cpus": 1, "num_gpus": 0},
|
426 |
-
ray_init_args={"log_to_driver":
|
427 |
)
|
428 |
|
429 |
for i, client in enumerate(clients):
|
@@ -447,7 +534,13 @@ def main():
|
|
447 |
|
448 |
# Display log.txt content
|
449 |
st.write("## Training Log")
|
450 |
-
st.text(read_log_file())
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
|
452 |
else:
|
453 |
st.write("Click the 'Start Training' button to start the training process.")
|
|
|
348 |
|
349 |
fig.tight_layout()
|
350 |
plot_placeholder.pyplot(fig)
|
351 |
+
import matplotlib.pyplot as plt
|
352 |
+
import re
|
353 |
+
|
354 |
+
def read_log_file(log_path='log.txt'):
|
355 |
+
with open(log_path, 'r') as file:
|
356 |
+
log_lines = file.readlines()
|
357 |
+
return log_lines
|
358 |
+
|
359 |
+
def parse_log(log_lines):
|
360 |
+
rounds = []
|
361 |
+
clients = {}
|
362 |
+
memory_usage = []
|
363 |
+
|
364 |
+
round_pattern = re.compile(r'\[ROUND (\d+)\]')
|
365 |
+
client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
|
366 |
+
memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
|
367 |
+
|
368 |
+
current_round = None
|
369 |
+
|
370 |
+
for line in log_lines:
|
371 |
+
round_match = round_pattern.search(line)
|
372 |
+
client_match = client_pattern.search(line)
|
373 |
+
memory_match = memory_pattern.search(line)
|
374 |
+
|
375 |
+
if round_match:
|
376 |
+
current_round = int(round_match.group(1))
|
377 |
+
rounds.append(current_round)
|
378 |
+
elif client_match:
|
379 |
+
client_id = int(client_match.group(1))
|
380 |
+
log_level = client_match.group(2)
|
381 |
+
message = client_match.group(3)
|
382 |
+
|
383 |
+
if client_id not in clients:
|
384 |
+
clients[client_id] = {'rounds': [], 'messages': []}
|
385 |
+
|
386 |
+
clients[client_id]['rounds'].append(current_round)
|
387 |
+
clients[client_id]['messages'].append((log_level, message))
|
388 |
+
elif memory_match:
|
389 |
+
memory_usage.append(float(memory_match.group(1)))
|
390 |
+
|
391 |
+
return rounds, clients, memory_usage
|
392 |
+
|
393 |
+
def plot_metrics(rounds, clients, memory_usage):
|
394 |
+
st.write("## Metrics Overview")
|
395 |
+
|
396 |
+
st.write("### Memory Usage")
|
397 |
+
plt.figure()
|
398 |
+
plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)')
|
399 |
+
plt.xlabel('Step')
|
400 |
+
plt.ylabel('Memory Usage (GB)')
|
401 |
+
plt.legend()
|
402 |
+
st.pyplot(plt)
|
403 |
+
|
404 |
+
for client_id, data in clients.items():
|
405 |
+
st.write(f"### Client {client_id} Metrics")
|
406 |
+
|
407 |
+
info_messages = [msg for level, msg in data['messages'] if level == 'INFO']
|
408 |
+
debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG']
|
409 |
+
|
410 |
+
st.write("#### INFO Messages")
|
411 |
+
for msg in info_messages:
|
412 |
+
st.write(msg)
|
413 |
+
|
414 |
+
st.write("#### DEBUG Messages")
|
415 |
+
for msg in debug_messages:
|
416 |
+
st.write(msg)
|
417 |
+
|
418 |
+
# Placeholder for actual loss and accuracy values, assuming they're included in the messages
|
419 |
+
losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg]
|
420 |
+
accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg]
|
421 |
+
|
422 |
+
if losses:
|
423 |
+
plt.figure()
|
424 |
+
plt.plot(data['rounds'], losses, label='Loss')
|
425 |
+
plt.xlabel('Round')
|
426 |
+
plt.ylabel('Loss')
|
427 |
+
plt.legend()
|
428 |
+
st.pyplot(plt)
|
429 |
+
|
430 |
+
if accuracies:
|
431 |
+
plt.figure()
|
432 |
+
plt.plot(data['rounds'], accuracies, label='Accuracy')
|
433 |
+
plt.xlabel('Round')
|
434 |
+
plt.ylabel('Accuracy')
|
435 |
+
plt.legend()
|
436 |
+
st.pyplot(plt)
|
437 |
+
|
438 |
+
|
439 |
def read_log_file():
|
440 |
with open("./log.txt", "r") as file:
|
441 |
return file.read()
|
|
|
510 |
config=fl.server.ServerConfig(num_rounds=1),
|
511 |
strategy=strategy,
|
512 |
client_resources={"num_cpus": 1, "num_gpus": 0},
|
513 |
+
ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": 0}
|
514 |
)
|
515 |
|
516 |
for i, client in enumerate(clients):
|
|
|
534 |
|
535 |
# Display log.txt content
|
536 |
st.write("## Training Log")
|
537 |
+
# st.text(read_log_file())
|
538 |
+
st.write("## Training Log Analysis")
|
539 |
+
|
540 |
+
log_lines = read_log_file()
|
541 |
+
rounds, clients, memory_usage = parse_log(log_lines)
|
542 |
+
|
543 |
+
plot_metrics(rounds, clients, memory_usage)
|
544 |
|
545 |
else:
|
546 |
st.write("Click the 'Start Training' button to start the training process.")
|