alisrbdni commited on
Commit
ad31de9
·
verified ·
1 Parent(s): 6282b8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -3
app.py CHANGED
@@ -214,8 +214,6 @@
214
  # if __name__ == "__main__":
215
  # main()
216
 
217
-
218
-
219
  # %%writefile app.py
220
 
221
  import streamlit as st
@@ -273,5 +271,182 @@ def train(net, trainloader, epochs):
273
  optimizer.step()
274
  optimizer.zero_grad()
275
 
276
- def test(net, testloader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
 
214
  # if __name__ == "__main__":
215
  # main()
216
 
 
 
217
  # %%writefile app.py
218
 
219
  import streamlit as st
 
271
  optimizer.step()
272
  optimizer.zero_grad()
273
 
274
+ def test(net, testloader):
275
+ metric = load_metric("accuracy")
276
+ net.eval()
277
+ loss = 0
278
+ for batch in testloader:
279
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
280
+ with torch.no_grad():
281
+ outputs = net(**batch)
282
+ logits = outputs.logits
283
+ loss += outputs.loss.item()
284
+ predictions = torch.argmax(logits, dim=-1)
285
+ metric.add_batch(predictions=predictions, references=batch["labels"])
286
+ loss /= len(testloader)
287
+ accuracy = metric.compute()["accuracy"]
288
+ return loss, accuracy
289
+
290
+ class CustomClient(fl.client.NumPyClient):
291
+ def __init__(self, net, trainloader, testloader, client_id):
292
+ self.net = net
293
+ self.trainloader = trainloader
294
+ self.testloader = testloader
295
+ self.client_id = client_id
296
+ self.losses = []
297
+ self.accuracies = []
298
+
299
+ def get_parameters(self, config):
300
+ return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
301
+
302
+ def set_parameters(self, parameters):
303
+ params_dict = zip(self.net.state_dict().keys(), parameters)
304
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
305
+ self.net.load_state_dict(state_dict, strict=True)
306
+
307
+ def fit(self, parameters, config):
308
+ log(INFO, f"Client {self.client_id} is starting fit()")
309
+ self.set_parameters(parameters)
310
+ train(self.net, self.trainloader, epochs=1)
311
+ loss, accuracy = test(self.net, self.testloader)
312
+ self.losses.append(loss)
313
+ self.accuracies.append(accuracy)
314
+ log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
315
+ return self.get_parameters(config={}), len(self.trainloader.dataset), {"loss": loss, "accuracy": accuracy}
316
+
317
+ def evaluate(self, parameters, config):
318
+ log(INFO, f"Client {self.client_id} is starting evaluate()")
319
+ self.set_parameters(parameters)
320
+ loss, accuracy = test(self.net, self.testloader)
321
+ log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
322
+ return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)}
323
+
324
+ def plot_metrics(self, round_num, plot_placeholder):
325
+ if self.losses and self.accuracies:
326
+ plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
327
+ plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}")
328
+ plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}")
329
+
330
+ fig, ax1 = plt.subplots()
331
+
332
+ color = 'tab:red'
333
+ ax1.set_xlabel('Round')
334
+ ax1.set_ylabel('Loss', color=color)
335
+ ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color)
336
+ ax1.tick_params(axis='y', labelcolor=color)
337
+
338
+ ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
339
+ color = 'tab:blue'
340
+ ax2.set_ylabel('Accuracy', color=color)
341
+ ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color)
342
+ ax2.tick_params(axis='y', labelcolor=color)
343
+
344
+ fig.tight_layout()
345
+ plot_placeholder.pyplot(fig)
346
+
347
+ def read_log_file():
348
+ with open("log.txt", "r") as file:
349
+ return file.read()
350
+
351
+ def main():
352
+ st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
353
+ dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
354
+ model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
355
+
356
+ NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
357
+ NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
358
+
359
+ train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS)
360
+
361
+ trainloaders = []
362
+ testloaders = []
363
+ clients = []
364
+
365
+ for i in range(NUM_CLIENTS):
366
+ st.write(f"### Client {i+1} Datasets")
367
+
368
+ train_df = pd.DataFrame(train_datasets[i])
369
+ test_df = pd.DataFrame(test_datasets[i])
370
+
371
+ st.write("#### Train Dataset (Words)")
372
+ st.dataframe(raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 20)))
373
+ st.write("#### Train Dataset (Tokens)")
374
+ edited_train_df = st.data_editor(train_df, key=f"train_{i}")
375
+
376
+ st.write("#### Test Dataset (Words)")
377
+ st.dataframe(raw_datasets["test"].select(random.sample(range(len(raw_datasets["test"])), 20)))
378
+ st.write("#### Test Dataset (Tokens)")
379
+ edited_test_df = st.data_editor(test_df, key=f"test_{i}")
380
+
381
+ edited_train_dataset = Dataset.from_pandas(edited_train_df)
382
+ edited_test_dataset = Dataset.from_pandas(edited_test_df)
383
+
384
+ trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
385
+ testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
386
+
387
+ trainloaders.append(trainloader)
388
+ testloaders.append(testloader)
389
+
390
+ net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
391
+ client = CustomClient(net, trainloader, testloader, client_id=i+1)
392
+ clients.append(client)
393
+
394
+ if st.button("Start Training"):
395
+ def client_fn(cid):
396
+ return clients[int(cid)]
397
+
398
+ def weighted_average(metrics):
399
+ accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
400
+ losses = [num_examples * m["loss"] for num_examples, m in metrics]
401
+ examples = [num_examples for num_examples, _ in metrics]
402
+ return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
403
+
404
+ strategy = fl.server.strategy.FedAvg(
405
+ fraction_fit=1.0,
406
+ fraction_evaluate=1.0,
407
+ evaluate_metrics_aggregation_fn=weighted_average,
408
+ )
409
+
410
+ for round_num in range(NUM_ROUNDS):
411
+ st.write(f"### Round {round_num + 1}")
412
+ plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
413
+
414
+ fl.simulation.start_simulation(
415
+ client_fn=client_fn,
416
+ num_clients=NUM_CLIENTS,
417
+ config=fl.server.ServerConfig(num_rounds=1),
418
+ strategy=strategy,
419
+ client_resources={"num_cpus": 1, "num_gpus": 0},
420
+ ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
421
+ )
422
+
423
+ for i, client in enumerate(clients):
424
+ client.plot_metrics(round_num + 1, plot_placeholders[i])
425
+ st.write(" ")
426
+
427
+ st.success("Training completed successfully!")
428
+
429
+ # Display final metrics
430
+ st.write("## Final Client Metrics")
431
+ for client in clients:
432
+ st.write(f"### Client {client.client_id}")
433
+ if client.losses and client.accuracies:
434
+ st.write(f"Final Loss: {client.losses[-1]:.4f}")
435
+ st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
436
+ client.plot_metrics(NUM_ROUNDS, st.empty())
437
+ else:
438
+ st.write("No metrics available.")
439
+
440
+ st.write(" ")
441
+
442
+ # Display log.txt content
443
+ st.write("## Training Log")
444
+ st.text(read_log_file())
445
+
446
+ else:
447
+ st.write("Click the 'Start Training' button to start the training process.")
448
+
449
+ if __name__ == "__main__":
450
+ main()
451
+
452