alisrbdni commited on
Commit
c75b17c
·
verified ·
1 Parent(s): f217473

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +208 -12
app.py CHANGED
@@ -93,6 +93,7 @@ class CustomClient(fl.client.NumPyClient):
93
  loss, accuracy = test(self.net, self.testloader)
94
  self.losses.append(loss)
95
  self.accuracies.append(accuracy)
 
96
  return self.get_parameters(config={}), len(self.trainloader.dataset), {}
97
 
98
  def evaluate(self, parameters, config):
@@ -168,20 +169,23 @@ def main():
168
  evaluate_metrics_aggregation_fn=weighted_average,
169
  )
170
 
171
- fl.simulation.start_simulation(
172
- client_fn=client_fn,
173
- num_clients=NUM_CLIENTS,
174
- config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
175
- strategy=strategy,
176
- client_resources={"num_cpus": 1, "num_gpus": 0},
177
- ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
178
- )
179
 
180
- st.success(f"Training completed successfully!")
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- for client in clients:
183
- st.write(f"### Client {client.client_id} Model Metrics")
184
- client.plot_metrics()
185
 
186
  else:
187
  st.write("Click the 'Start Training' button to start the training process.")
@@ -190,6 +194,198 @@ if __name__ == "__main__":
190
  main()
191
 
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  # 05/2024 # %%writefile app.py
195
 
 
93
  loss, accuracy = test(self.net, self.testloader)
94
  self.losses.append(loss)
95
  self.accuracies.append(accuracy)
96
+ self.plot_metrics()
97
  return self.get_parameters(config={}), len(self.trainloader.dataset), {}
98
 
99
  def evaluate(self, parameters, config):
 
169
  evaluate_metrics_aggregation_fn=weighted_average,
170
  )
171
 
172
+ for round_num in range(NUM_ROUNDS):
173
+ st.write(f"### Round {round_num + 1}")
 
 
 
 
 
 
174
 
175
+ fl.simulation.start_simulation(
176
+ client_fn=client_fn,
177
+ num_clients=NUM_CLIENTS,
178
+ config=fl.server.ServerConfig(num_rounds=1),
179
+ strategy=strategy,
180
+ client_resources={"num_cpus": 1, "num_gpus": 0},
181
+ ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
182
+ )
183
+
184
+ for client in clients:
185
+ st.write(f"### Client {client.client_id} Metrics for Round {round_num + 1}")
186
+ client.plot_metrics()
187
 
188
+ st.success(f"Training completed successfully!")
 
 
189
 
190
  else:
191
  st.write("Click the 'Start Training' button to start the training process.")
 
194
  main()
195
 
196
 
197
+ # # %%writefile app.py
198
+
199
+ # import streamlit as st
200
+ # import matplotlib.pyplot as plt
201
+ # import torch
202
+ # from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
203
+ # from datasets import load_dataset, Dataset
204
+ # from evaluate import load as load_metric
205
+ # from torch.utils.data import DataLoader
206
+ # import pandas as pd
207
+ # import random
208
+ # import warnings
209
+ # from collections import OrderedDict
210
+ # import flwr as fl
211
+
212
+ # DEVICE = torch.device("cpu")
213
+
214
+ # def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
215
+ # raw_datasets = load_dataset(dataset_name)
216
+ # raw_datasets = raw_datasets.shuffle(seed=42)
217
+ # del raw_datasets["unsupervised"]
218
+
219
+ # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
220
+
221
+ # def tokenize_function(examples):
222
+ # return tokenizer(examples["text"], truncation=True)
223
+
224
+ # tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
225
+ # tokenized_datasets = tokenized_datasets.remove_columns("text")
226
+ # tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
227
+
228
+ # train_datasets = []
229
+ # test_datasets = []
230
+
231
+ # for _ in range(num_clients):
232
+ # train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
233
+ # test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
234
+ # train_datasets.append(train_dataset)
235
+ # test_datasets.append(test_dataset)
236
+
237
+ # data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
238
+
239
+ # return train_datasets, test_datasets, data_collator
240
+
241
+ # def train(net, trainloader, epochs):
242
+ # optimizer = AdamW(net.parameters(), lr=5e-5)
243
+ # net.train()
244
+ # for _ in range(epochs):
245
+ # for batch in trainloader:
246
+ # batch = {k: v.to(DEVICE) for k, v in batch.items()}
247
+ # outputs = net(**batch)
248
+ # loss = outputs.loss
249
+ # loss.backward()
250
+ # optimizer.step()
251
+ # optimizer.zero_grad()
252
+
253
+ # def test(net, testloader):
254
+ # metric = load_metric("accuracy")
255
+ # net.eval()
256
+ # loss = 0
257
+ # for batch in testloader:
258
+ # batch = {k: v.to(DEVICE) for k, v in batch.items()}
259
+ # with torch.no_grad():
260
+ # outputs = net(**batch)
261
+ # logits = outputs.logits
262
+ # loss += outputs.loss.item()
263
+ # predictions = torch.argmax(logits, dim=-1)
264
+ # metric.add_batch(predictions=predictions, references=batch["labels"])
265
+ # loss /= len(testloader)
266
+ # accuracy = metric.compute()["accuracy"]
267
+ # return loss, accuracy
268
+
269
+ # class CustomClient(fl.client.NumPyClient):
270
+ # def __init__(self, net, trainloader, testloader, client_id):
271
+ # self.net = net
272
+ # self.trainloader = trainloader
273
+ # self.testloader = testloader
274
+ # self.client_id = client_id
275
+ # self.losses = []
276
+ # self.accuracies = []
277
+
278
+ # def get_parameters(self, config):
279
+ # return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
280
+
281
+ # def set_parameters(self, parameters):
282
+ # params_dict = zip(self.net.state_dict().keys(), parameters)
283
+ # state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
284
+ # self.net.load_state_dict(state_dict, strict=True)
285
+
286
+ # def fit(self, parameters, config):
287
+ # self.set_parameters(parameters)
288
+ # train(self.net, self.trainloader, epochs=1)
289
+ # loss, accuracy = test(self.net, self.testloader)
290
+ # self.losses.append(loss)
291
+ # self.accuracies.append(accuracy)
292
+ # return self.get_parameters(config={}), len(self.trainloader.dataset), {}
293
+
294
+ # def evaluate(self, parameters, config):
295
+ # self.set_parameters(parameters)
296
+ # loss, accuracy = test(self.net, self.testloader)
297
+ # return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
298
+
299
+ # def plot_metrics(self):
300
+ # fig, ax1 = plt.subplots()
301
+
302
+ # ax2 = ax1.twinx()
303
+ # ax1.plot(range(1, len(self.losses) + 1), self.losses, 'g-')
304
+ # ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, 'b-')
305
+
306
+ # ax1.set_xlabel('Round')
307
+ # ax1.set_ylabel('Loss', color='g')
308
+ # ax2.set_ylabel('Accuracy', color='b')
309
+
310
+ # plt.title(f'Client {self.client_id} Metrics')
311
+ # st.pyplot(fig)
312
+
313
+ # def main():
314
+ # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
315
+ # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
316
+ # model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"])
317
+
318
+ # NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
319
+ # NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
320
+
321
+ # train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS)
322
+
323
+ # trainloaders = []
324
+ # testloaders = []
325
+ # clients = []
326
+
327
+ # for i in range(NUM_CLIENTS):
328
+ # st.write(f"### Client {i+1} Datasets")
329
+
330
+ # train_df = pd.DataFrame(train_datasets[i])
331
+ # test_df = pd.DataFrame(test_datasets[i])
332
+
333
+ # st.write("#### Train Dataset")
334
+ # edited_train_df = st.experimental_data_editor(train_df, key=f"train_{i}")
335
+ # st.write("#### Test Dataset")
336
+ # edited_test_df = st.experimental_data_editor(test_df, key=f"test_{i}")
337
+
338
+ # edited_train_dataset = Dataset.from_pandas(edited_train_df)
339
+ # edited_test_dataset = Dataset.from_pandas(edited_test_df)
340
+
341
+ # trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
342
+ # testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
343
+
344
+ # trainloaders.append(trainloader)
345
+ # testloaders.append(testloader)
346
+
347
+ # net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
348
+ # client = CustomClient(net, trainloader, testloader, client_id=i+1)
349
+ # clients.append(client)
350
+
351
+ # if st.button("Start Training"):
352
+ # def client_fn(cid):
353
+ # return clients[int(cid)]
354
+
355
+ # def weighted_average(metrics):
356
+ # accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
357
+ # losses = [num_examples * m["loss"] for num_examples, m in metrics]
358
+ # examples = [num_examples for num_examples, _ in metrics]
359
+ # return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
360
+
361
+ # strategy = fl.server.strategy.FedAvg(
362
+ # fraction_fit=1.0,
363
+ # fraction_evaluate=1.0,
364
+ # evaluate_metrics_aggregation_fn=weighted_average,
365
+ # )
366
+
367
+ # fl.simulation.start_simulation(
368
+ # client_fn=client_fn,
369
+ # num_clients=NUM_CLIENTS,
370
+ # config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
371
+ # strategy=strategy,
372
+ # client_resources={"num_cpus": 1, "num_gpus": 0},
373
+ # ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
374
+ # )
375
+
376
+ # st.success(f"Training completed successfully!")
377
+
378
+ # for client in clients:
379
+ # st.write(f"### Client {client.client_id} Model Metrics")
380
+ # client.plot_metrics()
381
+
382
+ # else:
383
+ # st.write("Click the 'Start Training' button to start the training process.")
384
+
385
+ # if __name__ == "__main__":
386
+ # main()
387
+
388
+
389
 
390
  # 05/2024 # %%writefile app.py
391