alisrbdni commited on
Commit
be7b4da
·
verified ·
1 Parent(s): 6e158d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -3
app.py CHANGED
@@ -348,7 +348,94 @@ class CustomClient(fl.client.NumPyClient):
348
 
349
  fig.tight_layout()
350
  plot_placeholder.pyplot(fig)
351
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  def read_log_file():
353
  with open("./log.txt", "r") as file:
354
  return file.read()
@@ -423,7 +510,7 @@ def main():
423
  config=fl.server.ServerConfig(num_rounds=1),
424
  strategy=strategy,
425
  client_resources={"num_cpus": 1, "num_gpus": 0},
426
- ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
427
  )
428
 
429
  for i, client in enumerate(clients):
@@ -447,7 +534,13 @@ def main():
447
 
448
  # Display log.txt content
449
  st.write("## Training Log")
450
- st.text(read_log_file())
 
 
 
 
 
 
451
 
452
  else:
453
  st.write("Click the 'Start Training' button to start the training process.")
 
348
 
349
  fig.tight_layout()
350
  plot_placeholder.pyplot(fig)
351
+ import matplotlib.pyplot as plt
352
+ import re
353
+
354
+ def read_log_file(log_path='log.txt'):
355
+ with open(log_path, 'r') as file:
356
+ log_lines = file.readlines()
357
+ return log_lines
358
+
359
+ def parse_log(log_lines):
360
+ rounds = []
361
+ clients = {}
362
+ memory_usage = []
363
+
364
+ round_pattern = re.compile(r'\[ROUND (\d+)\]')
365
+ client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
366
+ memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
367
+
368
+ current_round = None
369
+
370
+ for line in log_lines:
371
+ round_match = round_pattern.search(line)
372
+ client_match = client_pattern.search(line)
373
+ memory_match = memory_pattern.search(line)
374
+
375
+ if round_match:
376
+ current_round = int(round_match.group(1))
377
+ rounds.append(current_round)
378
+ elif client_match:
379
+ client_id = int(client_match.group(1))
380
+ log_level = client_match.group(2)
381
+ message = client_match.group(3)
382
+
383
+ if client_id not in clients:
384
+ clients[client_id] = {'rounds': [], 'messages': []}
385
+
386
+ clients[client_id]['rounds'].append(current_round)
387
+ clients[client_id]['messages'].append((log_level, message))
388
+ elif memory_match:
389
+ memory_usage.append(float(memory_match.group(1)))
390
+
391
+ return rounds, clients, memory_usage
392
+
393
+ def plot_metrics(rounds, clients, memory_usage):
394
+ st.write("## Metrics Overview")
395
+
396
+ st.write("### Memory Usage")
397
+ plt.figure()
398
+ plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)')
399
+ plt.xlabel('Step')
400
+ plt.ylabel('Memory Usage (GB)')
401
+ plt.legend()
402
+ st.pyplot(plt)
403
+
404
+ for client_id, data in clients.items():
405
+ st.write(f"### Client {client_id} Metrics")
406
+
407
+ info_messages = [msg for level, msg in data['messages'] if level == 'INFO']
408
+ debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG']
409
+
410
+ st.write("#### INFO Messages")
411
+ for msg in info_messages:
412
+ st.write(msg)
413
+
414
+ st.write("#### DEBUG Messages")
415
+ for msg in debug_messages:
416
+ st.write(msg)
417
+
418
+ # Placeholder for actual loss and accuracy values, assuming they're included in the messages
419
+ losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg]
420
+ accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg]
421
+
422
+ if losses:
423
+ plt.figure()
424
+ plt.plot(data['rounds'], losses, label='Loss')
425
+ plt.xlabel('Round')
426
+ plt.ylabel('Loss')
427
+ plt.legend()
428
+ st.pyplot(plt)
429
+
430
+ if accuracies:
431
+ plt.figure()
432
+ plt.plot(data['rounds'], accuracies, label='Accuracy')
433
+ plt.xlabel('Round')
434
+ plt.ylabel('Accuracy')
435
+ plt.legend()
436
+ st.pyplot(plt)
437
+
438
+
439
  def read_log_file():
440
  with open("./log.txt", "r") as file:
441
  return file.read()
 
510
  config=fl.server.ServerConfig(num_rounds=1),
511
  strategy=strategy,
512
  client_resources={"num_cpus": 1, "num_gpus": 0},
513
+ ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": 0}
514
  )
515
 
516
  for i, client in enumerate(clients):
 
534
 
535
  # Display log.txt content
536
  st.write("## Training Log")
537
+ # st.text(read_log_file())
538
+ st.write("## Training Log Analysis")
539
+
540
+ log_lines = read_log_file()
541
+ rounds, clients, memory_usage = parse_log(log_lines)
542
+
543
+ plot_metrics(rounds, clients, memory_usage)
544
 
545
  else:
546
  st.write("Click the 'Start Training' button to start the training process.")