alisrbdni commited on
Commit
40a7c41
·
verified ·
1 Parent(s): 7d516a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -391,11 +391,12 @@
391
 
392
  # if __name__ == "__main__":
393
  # main()
 
394
  import streamlit as st
395
  import matplotlib.pyplot as plt
396
  import torch
397
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
398
- from transformers import T5Tokenizer, T5ForConditionalGeneration
399
  from datasets import load_dataset, Dataset
400
  from evaluate import load as load_metric
401
  from torch.utils.data import DataLoader
@@ -429,11 +430,11 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8
429
  del raw_datasets["unsupervised"]
430
 
431
  if model_name == "google/byt5-small":
432
- tokenizer = T5Tokenizer.from_pretrained(model_name)
433
 
434
  def utf8_encode_function(examples):
435
- encoded_texts = [text.encode('utf-8') for text in examples["text"]]
436
- examples["input_ids"] = [tokenizer(list(encoded_text), return_tensors="pt", padding='max_length', truncation=True, max_length=512)["input_ids"].squeeze().tolist() for encoded_text in encoded_texts]
437
  return examples
438
 
439
  tokenized_datasets = raw_datasets.map(utf8_encode_function, batched=True)
@@ -683,7 +684,11 @@ def main():
683
  trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
684
  testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
685
 
686
- net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
 
 
 
 
687
  client = CustomClient(net, trainloader, testloader, client_id=i+1)
688
  clients.append(client)
689
 
@@ -786,3 +791,4 @@ def main():
786
  if __name__ == "__main__":
787
  main()
788
 
 
 
391
 
392
  # if __name__ == "__main__":
393
  # main()
394
+
395
  import streamlit as st
396
  import matplotlib.pyplot as plt
397
  import torch
398
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
399
+ from transformers import ByT5Tokenizer, ByT5ForConditionalGeneration
400
  from datasets import load_dataset, Dataset
401
  from evaluate import load as load_metric
402
  from torch.utils.data import DataLoader
 
430
  del raw_datasets["unsupervised"]
431
 
432
  if model_name == "google/byt5-small":
433
+ tokenizer = ByT5Tokenizer.from_pretrained(model_name)
434
 
435
  def utf8_encode_function(examples):
436
+ encoded_texts = [list(text.encode('utf-8')) for text in examples["text"]]
437
+ examples["input_ids"] = [tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=512)["input_ids"].squeeze().tolist() for text in encoded_texts]
438
  return examples
439
 
440
  tokenized_datasets = raw_datasets.map(utf8_encode_function, batched=True)
 
684
  trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
685
  testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
686
 
687
+ if model_name == "google/byt5-small":
688
+ net = ByT5ForConditionalGeneration.from_pretrained(model_name).to(DEVICE)
689
+ else:
690
+ net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
691
+
692
  client = CustomClient(net, trainloader, testloader, client_id=i+1)
693
  clients.append(client)
694
 
 
791
  if __name__ == "__main__":
792
  main()
793
 
794
+