Update app.py
Browse files
app.py
CHANGED
@@ -391,7 +391,6 @@
|
|
391 |
|
392 |
# if __name__ == "__main__":
|
393 |
# main()
|
394 |
-
|
395 |
import streamlit as st
|
396 |
import matplotlib.pyplot as plt
|
397 |
import torch
|
@@ -412,6 +411,15 @@ import plotly.graph_objects as go
|
|
412 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
413 |
fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
|
414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8=False):
|
416 |
raw_datasets = load_dataset(dataset_name)
|
417 |
raw_datasets = raw_datasets.shuffle(seed=42)
|
@@ -444,7 +452,7 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8
|
|
444 |
train_datasets.append(train_dataset)
|
445 |
test_datasets.append(test_dataset)
|
446 |
|
447 |
-
data_collator =
|
448 |
|
449 |
return train_datasets, test_datasets, data_collator, raw_datasets
|
450 |
|
@@ -701,7 +709,7 @@ def main():
|
|
701 |
st.write(f"### Round {round_num + 1} ✅")
|
702 |
|
703 |
logs = read_log_file2()
|
704 |
-
filtered_log_list = [line for line in logs.splitlines
|
705 |
filtered_logs = "\n".join(filtered_log_list)
|
706 |
|
707 |
st.markdown(filtered_logs)
|
|
|
391 |
|
392 |
# if __name__ == "__main__":
|
393 |
# main()
|
|
|
394 |
import streamlit as st
|
395 |
import matplotlib.pyplot as plt
|
396 |
import torch
|
|
|
411 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
412 |
fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
|
413 |
|
414 |
+
class CustomDataCollator(DataCollatorWithPadding):
|
415 |
+
def __call__(self, features):
|
416 |
+
if 'input_ids' in features[0] and isinstance(features[0]['input_ids'][0], int):
|
417 |
+
# Handle byte encoding case
|
418 |
+
max_length = max(len(f["input_ids"]) for f in features)
|
419 |
+
for f in features:
|
420 |
+
f['input_ids'] += [0] * (max_length - len(f['input_ids']))
|
421 |
+
return super().__call__(features)
|
422 |
+
|
423 |
def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8=False):
|
424 |
raw_datasets = load_dataset(dataset_name)
|
425 |
raw_datasets = raw_datasets.shuffle(seed=42)
|
|
|
452 |
train_datasets.append(train_dataset)
|
453 |
test_datasets.append(test_dataset)
|
454 |
|
455 |
+
data_collator = CustomDataCollator(tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased"))
|
456 |
|
457 |
return train_datasets, test_datasets, data_collator, raw_datasets
|
458 |
|
|
|
709 |
st.write(f"### Round {round_num + 1} ✅")
|
710 |
|
711 |
logs = read_log_file2()
|
712 |
+
filtered_log_list = [line for line in logs.splitlines if pattern.search(line)]
|
713 |
filtered_logs = "\n".join(filtered_log_list)
|
714 |
|
715 |
st.markdown(filtered_logs)
|