alisrbdni commited on
Commit
11a8e77
·
verified ·
1 Parent(s): 0466f74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -391,7 +391,6 @@
391
 
392
  # if __name__ == "__main__":
393
  # main()
394
-
395
  import streamlit as st
396
  import matplotlib.pyplot as plt
397
  import torch
@@ -412,6 +411,15 @@ import plotly.graph_objects as go
412
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
413
  fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
414
 
 
 
 
 
 
 
 
 
 
415
  def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8=False):
416
  raw_datasets = load_dataset(dataset_name)
417
  raw_datasets = raw_datasets.shuffle(seed=42)
@@ -444,7 +452,7 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8
444
  train_datasets.append(train_dataset)
445
  test_datasets.append(test_dataset)
446
 
447
- data_collator = DataCollatorWithPadding(tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased"))
448
 
449
  return train_datasets, test_datasets, data_collator, raw_datasets
450
 
@@ -701,7 +709,7 @@ def main():
701
  st.write(f"### Round {round_num + 1} ✅")
702
 
703
  logs = read_log_file2()
704
- filtered_log_list = [line for line in logs.splitlines() if pattern.search(line)]
705
  filtered_logs = "\n".join(filtered_log_list)
706
 
707
  st.markdown(filtered_logs)
 
391
 
392
  # if __name__ == "__main__":
393
  # main()
 
394
  import streamlit as st
395
  import matplotlib.pyplot as plt
396
  import torch
 
411
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
412
  fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
413
 
414
+ class CustomDataCollator(DataCollatorWithPadding):
415
+ def __call__(self, features):
416
+ if 'input_ids' in features[0] and isinstance(features[0]['input_ids'][0], int):
417
+ # Handle byte encoding case
418
+ max_length = max(len(f["input_ids"]) for f in features)
419
+ for f in features:
420
+ f['input_ids'] += [0] * (max_length - len(f['input_ids']))
421
+ return super().__call__(features)
422
+
423
  def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8=False):
424
  raw_datasets = load_dataset(dataset_name)
425
  raw_datasets = raw_datasets.shuffle(seed=42)
 
452
  train_datasets.append(train_dataset)
453
  test_datasets.append(test_dataset)
454
 
455
+ data_collator = CustomDataCollator(tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased"))
456
 
457
  return train_datasets, test_datasets, data_collator, raw_datasets
458
 
 
709
  st.write(f"### Round {round_num + 1} ✅")
710
 
711
  logs = read_log_file2()
712
+ filtered_log_list = [line for line in logs.splitlines if pattern.search(line)]
713
  filtered_logs = "\n".join(filtered_log_list)
714
 
715
  st.markdown(filtered_logs)