alisrbdni commited on
Commit
8fe0c4a
·
verified ·
1 Parent(s): 0640d06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +233 -18
app.py CHANGED
@@ -9,7 +9,6 @@ from evaluate import load as load_metric
9
  from torch.utils.data import DataLoader
10
  import pandas as pd
11
  import random
12
- import warnings
13
  from collections import OrderedDict
14
  import flwr as fl
15
 
@@ -87,12 +86,13 @@ class CustomClient(fl.client.NumPyClient):
87
  state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
88
  self.net.load_state_dict(state_dict, strict=True)
89
 
90
- def fit(self, parameters, config):
91
  self.set_parameters(parameters)
92
  train(self.net, self.trainloader, epochs=1)
93
  loss, accuracy = test(self.net, self.testloader)
94
  self.losses.append(loss)
95
  self.accuracies.append(accuracy)
 
96
  return self.get_parameters(config={}), len(self.trainloader.dataset), {}
97
 
98
  def evaluate(self, parameters, config):
@@ -100,24 +100,28 @@ class CustomClient(fl.client.NumPyClient):
100
  loss, accuracy = test(self.net, self.testloader)
101
  return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
102
 
103
- def plot_metrics(self, round_num):
104
  if self.losses and self.accuracies:
105
- st.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
106
- st.write(f"Loss: {self.losses[-1]:.4f}")
107
- st.write(f"Accuracy: {self.accuracies[-1]:.4f}")
108
 
109
  fig, ax1 = plt.subplots()
110
 
111
- ax2 = ax1.twinx()
112
- ax1.plot(range(1, len(self.losses) + 1), self.losses, 'g-')
113
- ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, 'b-')
114
-
115
  ax1.set_xlabel('Round')
116
- ax1.set_ylabel('Loss', color='g')
117
- ax2.set_ylabel('Accuracy', color='b')
 
 
 
 
 
 
 
118
 
119
- plt.title(f'Client {self.client_id} Metrics')
120
- st.pyplot(fig)
121
 
122
  def main():
123
  st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
@@ -175,6 +179,7 @@ def main():
175
 
176
  for round_num in range(NUM_ROUNDS):
177
  st.write(f"### Round {round_num + 1}")
 
178
 
179
  fl.simulation.start_simulation(
180
  client_fn=client_fn,
@@ -185,11 +190,11 @@ def main():
185
  ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
186
  )
187
 
188
- for client in clients:
189
- client.plot_metrics(round_num + 1)
190
  st.write(" ")
191
 
192
- st.success(f"Training completed successfully!")
193
 
194
  # Display final metrics
195
  st.write("## Final Client Metrics")
@@ -197,7 +202,7 @@ def main():
197
  st.write(f"### Client {client.client_id}")
198
  st.write(f"Final Loss: {client.losses[-1]:.4f}")
199
  st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
200
- client.plot_metrics(NUM_ROUNDS)
201
  st.write(" ")
202
 
203
  else:
@@ -208,6 +213,216 @@ if __name__ == "__main__":
208
 
209
 
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
 
213
 
 
9
  from torch.utils.data import DataLoader
10
  import pandas as pd
11
  import random
 
12
  from collections import OrderedDict
13
  import flwr as fl
14
 
 
86
  state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
87
  self.net.load_state_dict(state_dict, strict=True)
88
 
89
+ def fit(self, parameters, config, round_num, plot_placeholder):
90
  self.set_parameters(parameters)
91
  train(self.net, self.trainloader, epochs=1)
92
  loss, accuracy = test(self.net, self.testloader)
93
  self.losses.append(loss)
94
  self.accuracies.append(accuracy)
95
+ self.plot_metrics(round_num, plot_placeholder)
96
  return self.get_parameters(config={}), len(self.trainloader.dataset), {}
97
 
98
  def evaluate(self, parameters, config):
 
100
  loss, accuracy = test(self.net, self.testloader)
101
  return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
102
 
103
+ def plot_metrics(self, round_num, plot_placeholder):
104
  if self.losses and self.accuracies:
105
+ plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
106
+ plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}")
107
+ plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}")
108
 
109
  fig, ax1 = plt.subplots()
110
 
111
+ color = 'tab:red'
 
 
 
112
  ax1.set_xlabel('Round')
113
+ ax1.set_ylabel('Loss', color=color)
114
+ ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color)
115
+ ax1.tick_params(axis='y', labelcolor=color)
116
+
117
+ ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
118
+ color = 'tab:blue'
119
+ ax2.set_ylabel('Accuracy', color=color)
120
+ ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color)
121
+ ax2.tick_params(axis='y', labelcolor=color)
122
 
123
+ fig.tight_layout()
124
+ plot_placeholder.pyplot(fig)
125
 
126
  def main():
127
  st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
 
179
 
180
  for round_num in range(NUM_ROUNDS):
181
  st.write(f"### Round {round_num + 1}")
182
+ plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
183
 
184
  fl.simulation.start_simulation(
185
  client_fn=client_fn,
 
190
  ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
191
  )
192
 
193
+ for i, client in enumerate(clients):
194
+ client.plot_metrics(round_num + 1, plot_placeholders[i])
195
  st.write(" ")
196
 
197
+ st.success("Training completed successfully!")
198
 
199
  # Display final metrics
200
  st.write("## Final Client Metrics")
 
202
  st.write(f"### Client {client.client_id}")
203
  st.write(f"Final Loss: {client.losses[-1]:.4f}")
204
  st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
205
+ client.plot_metrics(NUM_ROUNDS, st.empty())
206
  st.write(" ")
207
 
208
  else:
 
213
 
214
 
215
 
216
+ # # %%writefile app.py
217
+
218
+ # import streamlit as st
219
+ # import matplotlib.pyplot as plt
220
+ # import torch
221
+ # from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
222
+ # from datasets import load_dataset, Dataset
223
+ # from evaluate import load as load_metric
224
+ # from torch.utils.data import DataLoader
225
+ # import pandas as pd
226
+ # import random
227
+ # import warnings
228
+ # from collections import OrderedDict
229
+ # import flwr as fl
230
+
231
+ # DEVICE = torch.device("cpu")
232
+
233
+ # def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
234
+ # raw_datasets = load_dataset(dataset_name)
235
+ # raw_datasets = raw_datasets.shuffle(seed=42)
236
+ # del raw_datasets["unsupervised"]
237
+
238
+ # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
239
+
240
+ # def tokenize_function(examples):
241
+ # return tokenizer(examples["text"], truncation=True)
242
+
243
+ # tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
244
+ # tokenized_datasets = tokenized_datasets.remove_columns("text")
245
+ # tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
246
+
247
+ # train_datasets = []
248
+ # test_datasets = []
249
+
250
+ # for _ in range(num_clients):
251
+ # train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
252
+ # test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
253
+ # train_datasets.append(train_dataset)
254
+ # test_datasets.append(test_dataset)
255
+
256
+ # data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
257
+
258
+ # return train_datasets, test_datasets, data_collator
259
+
260
+ # def train(net, trainloader, epochs):
261
+ # optimizer = AdamW(net.parameters(), lr=5e-5)
262
+ # net.train()
263
+ # for _ in range(epochs):
264
+ # for batch in trainloader:
265
+ # batch = {k: v.to(DEVICE) for k, v in batch.items()}
266
+ # outputs = net(**batch)
267
+ # loss = outputs.loss
268
+ # loss.backward()
269
+ # optimizer.step()
270
+ # optimizer.zero_grad()
271
+
272
+ # def test(net, testloader):
273
+ # metric = load_metric("accuracy")
274
+ # net.eval()
275
+ # loss = 0
276
+ # for batch in testloader:
277
+ # batch = {k: v.to(DEVICE) for k, v in batch.items()}
278
+ # with torch.no_grad():
279
+ # outputs = net(**batch)
280
+ # logits = outputs.logits
281
+ # loss += outputs.loss.item()
282
+ # predictions = torch.argmax(logits, dim=-1)
283
+ # metric.add_batch(predictions=predictions, references=batch["labels"])
284
+ # loss /= len(testloader)
285
+ # accuracy = metric.compute()["accuracy"]
286
+ # return loss, accuracy
287
+
288
+ # class CustomClient(fl.client.NumPyClient):
289
+ # def __init__(self, net, trainloader, testloader, client_id):
290
+ # self.net = net
291
+ # self.trainloader = trainloader
292
+ # self.testloader = testloader
293
+ # self.client_id = client_id
294
+ # self.losses = []
295
+ # self.accuracies = []
296
+
297
+ # def get_parameters(self, config):
298
+ # return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
299
+
300
+ # def set_parameters(self, parameters):
301
+ # params_dict = zip(self.net.state_dict().keys(), parameters)
302
+ # state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
303
+ # self.net.load_state_dict(state_dict, strict=True)
304
+
305
+ # def fit(self, parameters, config):
306
+ # self.set_parameters(parameters)
307
+ # train(self.net, self.trainloader, epochs=1)
308
+ # loss, accuracy = test(self.net, self.testloader)
309
+ # self.losses.append(loss)
310
+ # self.accuracies.append(accuracy)
311
+ # return self.get_parameters(config={}), len(self.trainloader.dataset), {}
312
+
313
+ # def evaluate(self, parameters, config):
314
+ # self.set_parameters(parameters)
315
+ # loss, accuracy = test(self.net, self.testloader)
316
+ # return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
317
+
318
+ # def plot_metrics(self, round_num):
319
+ # if self.losses and self.accuracies:
320
+ # st.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
321
+ # st.write(f"Loss: {self.losses[-1]:.4f}")
322
+ # st.write(f"Accuracy: {self.accuracies[-1]:.4f}")
323
+
324
+ # fig, ax1 = plt.subplots()
325
+
326
+ # ax2 = ax1.twinx()
327
+ # ax1.plot(range(1, len(self.losses) + 1), self.losses, 'g-')
328
+ # ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, 'b-')
329
+
330
+ # ax1.set_xlabel('Round')
331
+ # ax1.set_ylabel('Loss', color='g')
332
+ # ax2.set_ylabel('Accuracy', color='b')
333
+
334
+ # plt.title(f'Client {self.client_id} Metrics')
335
+ # st.pyplot(fig)
336
+
337
+ # def main():
338
+ # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
339
+ # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
340
+ # model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"])
341
+
342
+ # NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
343
+ # NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
344
+
345
+ # train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS)
346
+
347
+ # trainloaders = []
348
+ # testloaders = []
349
+ # clients = []
350
+
351
+ # for i in range(NUM_CLIENTS):
352
+ # st.write(f"### Client {i+1} Datasets")
353
+
354
+ # train_df = pd.DataFrame(train_datasets[i])
355
+ # test_df = pd.DataFrame(test_datasets[i])
356
+
357
+ # st.write("#### Train Dataset")
358
+ # edited_train_df = st.experimental_data_editor(train_df, key=f"train_{i}")
359
+ # st.write("#### Test Dataset")
360
+ # edited_test_df = st.experimental_data_editor(test_df, key=f"test_{i}")
361
+
362
+ # edited_train_dataset = Dataset.from_pandas(edited_train_df)
363
+ # edited_test_dataset = Dataset.from_pandas(edited_test_df)
364
+
365
+ # trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
366
+ # testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
367
+
368
+ # trainloaders.append(trainloader)
369
+ # testloaders.append(testloader)
370
+
371
+ # net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
372
+ # client = CustomClient(net, trainloader, testloader, client_id=i+1)
373
+ # clients.append(client)
374
+
375
+ # if st.button("Start Training"):
376
+ # def client_fn(cid):
377
+ # return clients[int(cid)]
378
+
379
+ # def weighted_average(metrics):
380
+ # accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
381
+ # losses = [num_examples * m["loss"] for num_examples, m in metrics]
382
+ # examples = [num_examples for num_examples, _ in metrics]
383
+ # return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
384
+
385
+ # strategy = fl.server.strategy.FedAvg(
386
+ # fraction_fit=1.0,
387
+ # fraction_evaluate=1.0,
388
+ # evaluate_metrics_aggregation_fn=weighted_average,
389
+ # )
390
+
391
+ # for round_num in range(NUM_ROUNDS):
392
+ # st.write(f"### Round {round_num + 1}")
393
+
394
+ # fl.simulation.start_simulation(
395
+ # client_fn=client_fn,
396
+ # num_clients=NUM_CLIENTS,
397
+ # config=fl.server.ServerConfig(num_rounds=1),
398
+ # strategy=strategy,
399
+ # client_resources={"num_cpus": 1, "num_gpus": 0},
400
+ # ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
401
+ # )
402
+
403
+ # for client in clients:
404
+ # client.plot_metrics(round_num + 1)
405
+ # st.write(" ")
406
+
407
+ # st.success(f"Training completed successfully!")
408
+
409
+ # # Display final metrics
410
+ # st.write("## Final Client Metrics")
411
+ # for client in clients:
412
+ # st.write(f"### Client {client.client_id}")
413
+ # st.write(f"Final Loss: {client.losses[-1]:.4f}")
414
+ # st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
415
+ # client.plot_metrics(NUM_ROUNDS)
416
+ # st.write(" ")
417
+
418
+ # else:
419
+ # st.write("Click the 'Start Training' button to start the training process.")
420
+
421
+ # if __name__ == "__main__":
422
+ # main()
423
+
424
+
425
+
426
 
427
 
428