alisrbdni commited on
Commit
3c26e3a
·
verified ·
1 Parent(s): 41d8e7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -7
app.py CHANGED
@@ -1,13 +1,13 @@
1
  # %%writefile app.py
2
- # %%writefile app.py
3
 
4
  import streamlit as st
5
  import matplotlib.pyplot as plt
6
  import torch
7
  from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
8
- from datasets import load_dataset
9
  from evaluate import load as load_metric
10
  from torch.utils.data import DataLoader
 
11
  import random
12
  import warnings
13
  from collections import OrderedDict
@@ -40,10 +40,7 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
40
 
41
  data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
42
 
43
- trainloaders = [DataLoader(ds, shuffle=True, batch_size=32, collate_fn=data_collator) for ds in train_datasets]
44
- testloaders = [DataLoader(ds, batch_size=32, collate_fn=data_collator) for ds in test_datasets]
45
-
46
- return trainloaders, testloaders
47
 
48
  def train(net, trainloader, epochs):
49
  optimizer = AdamW(net.parameters(), lr=5e-5)
@@ -107,7 +104,30 @@ def main():
107
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
108
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
109
 
110
- trainloaders, testloaders = load_data(dataset_name, num_clients=NUM_CLIENTS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  if st.button("Start Training"):
113
  round_losses = []
@@ -148,6 +168,7 @@ if __name__ == "__main__":
148
  main()
149
 
150
 
 
151
  ##ORIGINAL###
152
 
153
 
 
1
  # %%writefile app.py
 
2
 
3
  import streamlit as st
4
  import matplotlib.pyplot as plt
5
  import torch
6
  from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
7
+ from datasets import load_dataset, Dataset
8
  from evaluate import load as load_metric
9
  from torch.utils.data import DataLoader
10
+ import pandas as pd
11
  import random
12
  import warnings
13
  from collections import OrderedDict
 
40
 
41
  data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
42
 
43
+ return train_datasets, test_datasets, data_collator
 
 
 
44
 
45
  def train(net, trainloader, epochs):
46
  optimizer = AdamW(net.parameters(), lr=5e-5)
 
104
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
105
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
106
 
107
+ train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS)
108
+
109
+ trainloaders = []
110
+ testloaders = []
111
+
112
+ for i in range(NUM_CLIENTS):
113
+ st.write(f"### Client {i+1} Datasets")
114
+
115
+ train_df = pd.DataFrame(train_datasets[i])
116
+ test_df = pd.DataFrame(test_datasets[i])
117
+
118
+ st.write("#### Train Dataset")
119
+ edited_train_df = st.experimental_data_editor(train_df, key=f"train_{i}")
120
+ st.write("#### Test Dataset")
121
+ edited_test_df = st.experimental_data_editor(test_df, key=f"test_{i}")
122
+
123
+ edited_train_dataset = Dataset.from_pandas(edited_train_df)
124
+ edited_test_dataset = Dataset.from_pandas(edited_test_df)
125
+
126
+ trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
127
+ testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
128
+
129
+ trainloaders.append(trainloader)
130
+ testloaders.append(testloader)
131
 
132
  if st.button("Start Training"):
133
  round_losses = []
 
168
  main()
169
 
170
 
171
+
172
  ##ORIGINAL###
173
 
174