alisrbdni commited on
Commit
0640d06
·
verified ·
1 Parent(s): 2c9aff2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -239
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import streamlit as st
2
  import matplotlib.pyplot as plt
3
  import torch
@@ -7,6 +9,8 @@ from evaluate import load as load_metric
7
  from torch.utils.data import DataLoader
8
  import pandas as pd
9
  import random
 
 
10
  import flwr as fl
11
 
12
  DEVICE = torch.device("cpu")
@@ -17,6 +21,7 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
17
  del raw_datasets["unsupervised"]
18
 
19
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
20
  def tokenize_function(examples):
21
  return tokenizer(examples["text"], truncation=True)
22
 
@@ -97,273 +102,109 @@ class CustomClient(fl.client.NumPyClient):
97
 
98
  def plot_metrics(self, round_num):
99
  if self.losses and self.accuracies:
 
 
 
 
100
  fig, ax1 = plt.subplots()
101
 
102
- color = 'tab:red'
103
- ax1.set_xlabel('Round')
104
- ax1.set_ylabel('Loss', color=color)
105
- ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color)
106
- ax1.tick_params(axis='y', labelcolor=color)
107
 
108
- ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
109
- color = 'tab:blue'
110
- ax2.set_ylabel('Accuracy', color=color) # we already handled the x-label with ax1
111
- ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color)
112
- ax2.tick_params(axis='y', labelcolor=color)
113
 
114
- fig.tight_layout() # otherwise the right y-label is slightly clipped
115
  st.pyplot(fig)
116
- st.write(f"Round {round_num} - Loss: {self.losses[-1]:.4f}, Accuracy: {self.accuracies[-1]:.4f}")
117
 
118
  def main():
119
- st.title("Federated Learning with Dynamic Models and Datasets")
120
- dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"], index=0)
121
- model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"], index=0)
122
 
123
- NUM_CLIENTS = st.slider("Number of Clients", 1, 10, 3)
124
- NUM_ROUNDS = st.slider("Number of Rounds", 1, 10, 5)
125
 
126
  train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS)
127
 
128
- if st.button("Initialize Clients"):
129
- trainloaders = []
130
- testloaders = []
131
- clients = []
132
 
133
- for i in range(NUM_CLIENTS):
134
- train_df = pd.DataFrame(train_datasets[i])
135
- test_df = pd.DataFrame(test_datasets[i])
136
 
137
- edited_train_dataset = Dataset.from_pandas(train_df)
138
- edited_test_dataset = Dataset.from_pandas(test_df)
139
 
140
- trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=4, collate_fn=data_collator)
141
- testloader = DataLoader(edited_test_dataset, batch_size=4, collate_fn=data_collator)
 
 
142
 
143
- net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
144
- client = CustomClient(net, trainloader, testloader, client_id=i+1)
145
- clients.append(client)
146
 
147
- for round_num in range(1, NUM_ROUNDS + 1):
148
- st.write(f"### Round {round_num}")
149
- for client in clients:
150
- _, _, _ = client.fit({}, {})
151
- client.plot_metrics(round_num)
152
 
153
- st.success("Training completed successfully!")
 
154
 
155
- if __name__ == "__main__":
156
- main()
 
157
 
 
 
 
158
 
 
 
 
 
 
159
 
160
- # # %%writefile app.py
 
 
 
 
161
 
162
- # import streamlit as st
163
- # import matplotlib.pyplot as plt
164
- # import torch
165
- # from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
166
- # from datasets import load_dataset, Dataset
167
- # from evaluate import load as load_metric
168
- # from torch.utils.data import DataLoader
169
- # import pandas as pd
170
- # import random
171
- # import warnings
172
- # from collections import OrderedDict
173
- # import flwr as fl
174
-
175
- # DEVICE = torch.device("cpu")
176
 
177
- # def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
178
- # raw_datasets = load_dataset(dataset_name)
179
- # raw_datasets = raw_datasets.shuffle(seed=42)
180
- # del raw_datasets["unsupervised"]
181
-
182
- # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
183
-
184
- # def tokenize_function(examples):
185
- # return tokenizer(examples["text"], truncation=True)
186
-
187
- # tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
188
- # tokenized_datasets = tokenized_datasets.remove_columns("text")
189
- # tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
190
-
191
- # train_datasets = []
192
- # test_datasets = []
193
-
194
- # for _ in range(num_clients):
195
- # train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
196
- # test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
197
- # train_datasets.append(train_dataset)
198
- # test_datasets.append(test_dataset)
199
-
200
- # data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
201
-
202
- # return train_datasets, test_datasets, data_collator
203
-
204
- # def train(net, trainloader, epochs):
205
- # optimizer = AdamW(net.parameters(), lr=5e-5)
206
- # net.train()
207
- # for _ in range(epochs):
208
- # for batch in trainloader:
209
- # batch = {k: v.to(DEVICE) for k, v in batch.items()}
210
- # outputs = net(**batch)
211
- # loss = outputs.loss
212
- # loss.backward()
213
- # optimizer.step()
214
- # optimizer.zero_grad()
215
-
216
- # def test(net, testloader):
217
- # metric = load_metric("accuracy")
218
- # net.eval()
219
- # loss = 0
220
- # for batch in testloader:
221
- # batch = {k: v.to(DEVICE) for k, v in batch.items()}
222
- # with torch.no_grad():
223
- # outputs = net(**batch)
224
- # logits = outputs.logits
225
- # loss += outputs.loss.item()
226
- # predictions = torch.argmax(logits, dim=-1)
227
- # metric.add_batch(predictions=predictions, references=batch["labels"])
228
- # loss /= len(testloader)
229
- # accuracy = metric.compute()["accuracy"]
230
- # return loss, accuracy
231
-
232
- # class CustomClient(fl.client.NumPyClient):
233
- # def __init__(self, net, trainloader, testloader, client_id):
234
- # self.net = net
235
- # self.trainloader = trainloader
236
- # self.testloader = testloader
237
- # self.client_id = client_id
238
- # self.losses = []
239
- # self.accuracies = []
240
-
241
- # def get_parameters(self, config):
242
- # return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
243
 
244
- # def set_parameters(self, parameters):
245
- # params_dict = zip(self.net.state_dict().keys(), parameters)
246
- # state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
247
- # self.net.load_state_dict(state_dict, strict=True)
248
-
249
- # def fit(self, parameters, config):
250
- # self.set_parameters(parameters)
251
- # train(self.net, self.trainloader, epochs=1)
252
- # loss, accuracy = test(self.net, self.testloader)
253
- # self.losses.append(loss)
254
- # self.accuracies.append(accuracy)
255
- # return self.get_parameters(config={}), len(self.trainloader.dataset), {}
256
-
257
- # def evaluate(self, parameters, config):
258
- # self.set_parameters(parameters)
259
- # loss, accuracy = test(self.net, self.testloader)
260
- # return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
261
-
262
- # def plot_metrics(self, round_num):
263
- # if self.losses and self.accuracies:
264
- # st.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
265
- # st.write(f"Loss: {self.losses[-1]:.4f}")
266
- # st.write(f"Accuracy: {self.accuracies[-1]:.4f}")
267
-
268
- # fig, ax1 = plt.subplots()
269
-
270
- # ax2 = ax1.twinx()
271
- # ax1.plot(range(1, len(self.losses) + 1), self.losses, 'g-')
272
- # ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, 'b-')
273
-
274
- # ax1.set_xlabel('Round')
275
- # ax1.set_ylabel('Loss', color='g')
276
- # ax2.set_ylabel('Accuracy', color='b')
277
-
278
- # plt.title(f'Client {self.client_id} Metrics')
279
- # st.pyplot(fig)
280
-
281
- # def main():
282
- # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
283
- # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
284
- # model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"])
285
-
286
- # NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
287
- # NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
288
-
289
- # train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS)
290
-
291
- # trainloaders = []
292
- # testloaders = []
293
- # clients = []
294
-
295
- # for i in range(NUM_CLIENTS):
296
- # st.write(f"### Client {i+1} Datasets")
297
-
298
- # train_df = pd.DataFrame(train_datasets[i])
299
- # test_df = pd.DataFrame(test_datasets[i])
300
-
301
- # st.write("#### Train Dataset")
302
- # edited_train_df = st.experimental_data_editor(train_df, key=f"train_{i}")
303
- # st.write("#### Test Dataset")
304
- # edited_test_df = st.experimental_data_editor(test_df, key=f"test_{i}")
305
-
306
- # edited_train_dataset = Dataset.from_pandas(edited_train_df)
307
- # edited_test_dataset = Dataset.from_pandas(edited_test_df)
308
-
309
- # trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
310
- # testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
311
-
312
- # trainloaders.append(trainloader)
313
- # testloaders.append(testloader)
314
-
315
- # net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
316
- # client = CustomClient(net, trainloader, testloader, client_id=i+1)
317
- # clients.append(client)
318
-
319
- # if st.button("Start Training"):
320
- # def client_fn(cid):
321
- # return clients[int(cid)]
322
-
323
- # def weighted_average(metrics):
324
- # accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
325
- # losses = [num_examples * m["loss"] for num_examples, m in metrics]
326
- # examples = [num_examples for num_examples, _ in metrics]
327
- # return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
328
-
329
- # strategy = fl.server.strategy.FedAvg(
330
- # fraction_fit=1.0,
331
- # fraction_evaluate=1.0,
332
- # evaluate_metrics_aggregation_fn=weighted_average,
333
- # )
334
-
335
- # for round_num in range(NUM_ROUNDS):
336
- # st.write(f"### Round {round_num + 1}")
337
-
338
- # fl.simulation.start_simulation(
339
- # client_fn=client_fn,
340
- # num_clients=NUM_CLIENTS,
341
- # config=fl.server.ServerConfig(num_rounds=1),
342
- # strategy=strategy,
343
- # client_resources={"num_cpus": 1, "num_gpus": 0},
344
- # ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
345
- # )
346
 
347
- # for client in clients:
348
- # client.plot_metrics(round_num + 1)
349
- # st.write(" ")
350
 
351
- # st.success(f"Training completed successfully!")
 
 
 
 
 
 
 
352
 
353
- # # Display final metrics
354
- # st.write("## Final Client Metrics")
355
- # for client in clients:
356
- # st.write(f"### Client {client.client_id}")
357
- # st.write(f"Final Loss: {client.losses[-1]:.4f}")
358
- # st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
359
- # client.plot_metrics(NUM_ROUNDS)
360
- # st.write(" ")
361
 
362
- # else:
363
- # st.write("Click the 'Start Training' button to start the training process.")
364
-
365
- # if __name__ == "__main__":
366
- # main()
367
 
368
 
369
 
 
1
+ # %%writefile app.py
2
+
3
  import streamlit as st
4
  import matplotlib.pyplot as plt
5
  import torch
 
9
  from torch.utils.data import DataLoader
10
  import pandas as pd
11
  import random
12
+ import warnings
13
+ from collections import OrderedDict
14
  import flwr as fl
15
 
16
  DEVICE = torch.device("cpu")
 
21
  del raw_datasets["unsupervised"]
22
 
23
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
24
+
25
  def tokenize_function(examples):
26
  return tokenizer(examples["text"], truncation=True)
27
 
 
102
 
103
  def plot_metrics(self, round_num):
104
  if self.losses and self.accuracies:
105
+ st.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
106
+ st.write(f"Loss: {self.losses[-1]:.4f}")
107
+ st.write(f"Accuracy: {self.accuracies[-1]:.4f}")
108
+
109
  fig, ax1 = plt.subplots()
110
 
111
+ ax2 = ax1.twinx()
112
+ ax1.plot(range(1, len(self.losses) + 1), self.losses, 'g-')
113
+ ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, 'b-')
 
 
114
 
115
+ ax1.set_xlabel('Round')
116
+ ax1.set_ylabel('Loss', color='g')
117
+ ax2.set_ylabel('Accuracy', color='b')
 
 
118
 
119
+ plt.title(f'Client {self.client_id} Metrics')
120
  st.pyplot(fig)
 
121
 
122
  def main():
123
+ st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
124
+ dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
125
+ model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"])
126
 
127
+ NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
128
+ NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
129
 
130
  train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS)
131
 
132
+ trainloaders = []
133
+ testloaders = []
134
+ clients = []
 
135
 
136
+ for i in range(NUM_CLIENTS):
137
+ st.write(f"### Client {i+1} Datasets")
 
138
 
139
+ train_df = pd.DataFrame(train_datasets[i])
140
+ test_df = pd.DataFrame(test_datasets[i])
141
 
142
+ st.write("#### Train Dataset")
143
+ edited_train_df = st.experimental_data_editor(train_df, key=f"train_{i}")
144
+ st.write("#### Test Dataset")
145
+ edited_test_df = st.experimental_data_editor(test_df, key=f"test_{i}")
146
 
147
+ edited_train_dataset = Dataset.from_pandas(edited_train_df)
148
+ edited_test_dataset = Dataset.from_pandas(edited_test_df)
 
149
 
150
+ trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
151
+ testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
 
 
 
152
 
153
+ trainloaders.append(trainloader)
154
+ testloaders.append(testloader)
155
 
156
+ net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
157
+ client = CustomClient(net, trainloader, testloader, client_id=i+1)
158
+ clients.append(client)
159
 
160
+ if st.button("Start Training"):
161
+ def client_fn(cid):
162
+ return clients[int(cid)]
163
 
164
+ def weighted_average(metrics):
165
+ accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
166
+ losses = [num_examples * m["loss"] for num_examples, m in metrics]
167
+ examples = [num_examples for num_examples, _ in metrics]
168
+ return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
169
 
170
+ strategy = fl.server.strategy.FedAvg(
171
+ fraction_fit=1.0,
172
+ fraction_evaluate=1.0,
173
+ evaluate_metrics_aggregation_fn=weighted_average,
174
+ )
175
 
176
+ for round_num in range(NUM_ROUNDS):
177
+ st.write(f"### Round {round_num + 1}")
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ fl.simulation.start_simulation(
180
+ client_fn=client_fn,
181
+ num_clients=NUM_CLIENTS,
182
+ config=fl.server.ServerConfig(num_rounds=1),
183
+ strategy=strategy,
184
+ client_resources={"num_cpus": 1, "num_gpus": 0},
185
+ ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
186
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
+ for client in clients:
189
+ client.plot_metrics(round_num + 1)
190
+ st.write(" ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ st.success(f"Training completed successfully!")
 
 
193
 
194
+ # Display final metrics
195
+ st.write("## Final Client Metrics")
196
+ for client in clients:
197
+ st.write(f"### Client {client.client_id}")
198
+ st.write(f"Final Loss: {client.losses[-1]:.4f}")
199
+ st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
200
+ client.plot_metrics(NUM_ROUNDS)
201
+ st.write(" ")
202
 
203
+ else:
204
+ st.write("Click the 'Start Training' button to start the training process.")
 
 
 
 
 
 
205
 
206
+ if __name__ == "__main__":
207
+ main()
 
 
 
208
 
209
 
210