Update app.py
Browse files
app.py
CHANGED
@@ -391,11 +391,12 @@
|
|
391 |
|
392 |
# if __name__ == "__main__":
|
393 |
# main()
|
|
|
394 |
import streamlit as st
|
395 |
import matplotlib.pyplot as plt
|
396 |
import torch
|
397 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
|
398 |
-
from transformers import
|
399 |
from datasets import load_dataset, Dataset
|
400 |
from evaluate import load as load_metric
|
401 |
from torch.utils.data import DataLoader
|
@@ -429,11 +430,11 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8
|
|
429 |
del raw_datasets["unsupervised"]
|
430 |
|
431 |
if model_name == "google/byt5-small":
|
432 |
-
tokenizer =
|
433 |
|
434 |
def utf8_encode_function(examples):
|
435 |
-
encoded_texts = [text.encode('utf-8') for text in examples["text"]]
|
436 |
-
examples["input_ids"] = [tokenizer(
|
437 |
return examples
|
438 |
|
439 |
tokenized_datasets = raw_datasets.map(utf8_encode_function, batched=True)
|
@@ -683,7 +684,11 @@ def main():
|
|
683 |
trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
|
684 |
testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
|
685 |
|
686 |
-
|
|
|
|
|
|
|
|
|
687 |
client = CustomClient(net, trainloader, testloader, client_id=i+1)
|
688 |
clients.append(client)
|
689 |
|
@@ -786,3 +791,4 @@ def main():
|
|
786 |
if __name__ == "__main__":
|
787 |
main()
|
788 |
|
|
|
|
391 |
|
392 |
# if __name__ == "__main__":
|
393 |
# main()
|
394 |
+
|
395 |
import streamlit as st
|
396 |
import matplotlib.pyplot as plt
|
397 |
import torch
|
398 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
|
399 |
+
from transformers import ByT5Tokenizer, ByT5ForConditionalGeneration
|
400 |
from datasets import load_dataset, Dataset
|
401 |
from evaluate import load as load_metric
|
402 |
from torch.utils.data import DataLoader
|
|
|
430 |
del raw_datasets["unsupervised"]
|
431 |
|
432 |
if model_name == "google/byt5-small":
|
433 |
+
tokenizer = ByT5Tokenizer.from_pretrained(model_name)
|
434 |
|
435 |
def utf8_encode_function(examples):
|
436 |
+
encoded_texts = [list(text.encode('utf-8')) for text in examples["text"]]
|
437 |
+
examples["input_ids"] = [tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=512)["input_ids"].squeeze().tolist() for text in encoded_texts]
|
438 |
return examples
|
439 |
|
440 |
tokenized_datasets = raw_datasets.map(utf8_encode_function, batched=True)
|
|
|
684 |
trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
|
685 |
testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
|
686 |
|
687 |
+
if model_name == "google/byt5-small":
|
688 |
+
net = ByT5ForConditionalGeneration.from_pretrained(model_name).to(DEVICE)
|
689 |
+
else:
|
690 |
+
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
691 |
+
|
692 |
client = CustomClient(net, trainloader, testloader, client_id=i+1)
|
693 |
clients.append(client)
|
694 |
|
|
|
791 |
if __name__ == "__main__":
|
792 |
main()
|
793 |
|
794 |
+
|