alisrbdni commited on
Commit
7d516a5
·
verified ·
1 Parent(s): 269a3e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -391,7 +391,6 @@
391
 
392
  # if __name__ == "__main__":
393
  # main()
394
-
395
  import streamlit as st
396
  import matplotlib.pyplot as plt
397
  import torch
@@ -433,7 +432,8 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8
433
  tokenizer = T5Tokenizer.from_pretrained(model_name)
434
 
435
  def utf8_encode_function(examples):
436
- examples["input_ids"] = [tokenizer(text.encode('utf-8'), return_tensors="pt")["input_ids"].squeeze().tolist() for text in examples["text"]]
 
437
  return examples
438
 
439
  tokenized_datasets = raw_datasets.map(utf8_encode_function, batched=True)
@@ -611,7 +611,6 @@ def plot_metrics(rounds, clients, memory_usage):
611
  for msg in debug_messages:
612
  st.write(msg)
613
 
614
- # Placeholder for actual loss and accuracy values, assuming they're included in the messages
615
  losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg]
616
  accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg]
617
 
 
391
 
392
  # if __name__ == "__main__":
393
  # main()
 
394
  import streamlit as st
395
  import matplotlib.pyplot as plt
396
  import torch
 
432
  tokenizer = T5Tokenizer.from_pretrained(model_name)
433
 
434
  def utf8_encode_function(examples):
435
+ encoded_texts = [text.encode('utf-8') for text in examples["text"]]
436
+ examples["input_ids"] = [tokenizer(list(encoded_text), return_tensors="pt", padding='max_length', truncation=True, max_length=512)["input_ids"].squeeze().tolist() for encoded_text in encoded_texts]
437
  return examples
438
 
439
  tokenized_datasets = raw_datasets.map(utf8_encode_function, batched=True)
 
611
  for msg in debug_messages:
612
  st.write(msg)
613
 
 
614
  losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg]
615
  accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg]
616