alisrbdni commited on
Commit
f217473
·
verified ·
1 Parent(s): afd37eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -9
app.py CHANGED
@@ -71,10 +71,13 @@ def test(net, testloader):
71
  return loss, accuracy
72
 
73
  class CustomClient(fl.client.NumPyClient):
74
- def __init__(self, net, trainloader, testloader):
75
  self.net = net
76
  self.trainloader = trainloader
77
  self.testloader = testloader
 
 
 
78
 
79
  def get_parameters(self, config):
80
  return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
@@ -87,6 +90,9 @@ class CustomClient(fl.client.NumPyClient):
87
  def fit(self, parameters, config):
88
  self.set_parameters(parameters)
89
  train(self.net, self.trainloader, epochs=1)
 
 
 
90
  return self.get_parameters(config={}), len(self.trainloader.dataset), {}
91
 
92
  def evaluate(self, parameters, config):
@@ -94,13 +100,25 @@ class CustomClient(fl.client.NumPyClient):
94
  loss, accuracy = test(self.net, self.testloader)
95
  return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def main():
98
- st.write("## Federated Learning with Flower and Dynamic Models and Datasets for Mobile Devices")
99
  dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
100
  model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"])
101
 
102
- net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
103
-
104
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
105
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
106
 
@@ -108,6 +126,7 @@ def main():
108
 
109
  trainloaders = []
110
  testloaders = []
 
111
 
112
  for i in range(NUM_CLIENTS):
113
  st.write(f"### Client {i+1} Datasets")
@@ -129,12 +148,11 @@ def main():
129
  trainloaders.append(trainloader)
130
  testloaders.append(testloader)
131
 
132
- if st.button("Start Training"):
133
- round_losses = []
134
- round_accuracies = []
135
-
136
- clients = [CustomClient(net, trainloaders[i], testloaders[i]) for i in range(NUM_CLIENTS)]
137
 
 
138
  def client_fn(cid):
139
  return clients[int(cid)]
140
 
@@ -161,6 +179,10 @@ def main():
161
 
162
  st.success(f"Training completed successfully!")
163
 
 
 
 
 
164
  else:
165
  st.write("Click the 'Start Training' button to start the training process.")
166
 
@@ -169,6 +191,177 @@ if __name__ == "__main__":
169
 
170
 
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  ##ORIGINAL###
173
 
174
 
 
71
  return loss, accuracy
72
 
73
  class CustomClient(fl.client.NumPyClient):
74
+ def __init__(self, net, trainloader, testloader, client_id):
75
  self.net = net
76
  self.trainloader = trainloader
77
  self.testloader = testloader
78
+ self.client_id = client_id
79
+ self.losses = []
80
+ self.accuracies = []
81
 
82
  def get_parameters(self, config):
83
  return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
 
90
  def fit(self, parameters, config):
91
  self.set_parameters(parameters)
92
  train(self.net, self.trainloader, epochs=1)
93
+ loss, accuracy = test(self.net, self.testloader)
94
+ self.losses.append(loss)
95
+ self.accuracies.append(accuracy)
96
  return self.get_parameters(config={}), len(self.trainloader.dataset), {}
97
 
98
  def evaluate(self, parameters, config):
 
100
  loss, accuracy = test(self.net, self.testloader)
101
  return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
102
 
103
+ def plot_metrics(self):
104
+ fig, ax1 = plt.subplots()
105
+
106
+ ax2 = ax1.twinx()
107
+ ax1.plot(range(1, len(self.losses) + 1), self.losses, 'g-')
108
+ ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, 'b-')
109
+
110
+ ax1.set_xlabel('Round')
111
+ ax1.set_ylabel('Loss', color='g')
112
+ ax2.set_ylabel('Accuracy', color='b')
113
+
114
+ plt.title(f'Client {self.client_id} Metrics')
115
+ st.pyplot(fig)
116
+
117
  def main():
118
+ st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
119
  dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
120
  model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"])
121
 
 
 
122
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
123
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
124
 
 
126
 
127
  trainloaders = []
128
  testloaders = []
129
+ clients = []
130
 
131
  for i in range(NUM_CLIENTS):
132
  st.write(f"### Client {i+1} Datasets")
 
148
  trainloaders.append(trainloader)
149
  testloaders.append(testloader)
150
 
151
+ net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
152
+ client = CustomClient(net, trainloader, testloader, client_id=i+1)
153
+ clients.append(client)
 
 
154
 
155
+ if st.button("Start Training"):
156
  def client_fn(cid):
157
  return clients[int(cid)]
158
 
 
179
 
180
  st.success(f"Training completed successfully!")
181
 
182
+ for client in clients:
183
+ st.write(f"### Client {client.client_id} Model Metrics")
184
+ client.plot_metrics()
185
+
186
  else:
187
  st.write("Click the 'Start Training' button to start the training process.")
188
 
 
191
 
192
 
193
 
194
+ # 05/2024 # %%writefile app.py
195
+
196
+ # import streamlit as st
197
+ # import matplotlib.pyplot as plt
198
+ # import torch
199
+ # from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
200
+ # from datasets import load_dataset, Dataset
201
+ # from evaluate import load as load_metric
202
+ # from torch.utils.data import DataLoader
203
+ # import pandas as pd
204
+ # import random
205
+ # import warnings
206
+ # from collections import OrderedDict
207
+ # import flwr as fl
208
+
209
+ # DEVICE = torch.device("cpu")
210
+
211
+ # def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
212
+ # raw_datasets = load_dataset(dataset_name)
213
+ # raw_datasets = raw_datasets.shuffle(seed=42)
214
+ # del raw_datasets["unsupervised"]
215
+
216
+ # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
217
+
218
+ # def tokenize_function(examples):
219
+ # return tokenizer(examples["text"], truncation=True)
220
+
221
+ # tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
222
+ # tokenized_datasets = tokenized_datasets.remove_columns("text")
223
+ # tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
224
+
225
+ # train_datasets = []
226
+ # test_datasets = []
227
+
228
+ # for _ in range(num_clients):
229
+ # train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
230
+ # test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
231
+ # train_datasets.append(train_dataset)
232
+ # test_datasets.append(test_dataset)
233
+
234
+ # data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
235
+
236
+ # return train_datasets, test_datasets, data_collator
237
+
238
+ # def train(net, trainloader, epochs):
239
+ # optimizer = AdamW(net.parameters(), lr=5e-5)
240
+ # net.train()
241
+ # for _ in range(epochs):
242
+ # for batch in trainloader:
243
+ # batch = {k: v.to(DEVICE) for k, v in batch.items()}
244
+ # outputs = net(**batch)
245
+ # loss = outputs.loss
246
+ # loss.backward()
247
+ # optimizer.step()
248
+ # optimizer.zero_grad()
249
+
250
+ # def test(net, testloader):
251
+ # metric = load_metric("accuracy")
252
+ # net.eval()
253
+ # loss = 0
254
+ # for batch in testloader:
255
+ # batch = {k: v.to(DEVICE) for k, v in batch.items()}
256
+ # with torch.no_grad():
257
+ # outputs = net(**batch)
258
+ # logits = outputs.logits
259
+ # loss += outputs.loss.item()
260
+ # predictions = torch.argmax(logits, dim=-1)
261
+ # metric.add_batch(predictions=predictions, references=batch["labels"])
262
+ # loss /= len(testloader)
263
+ # accuracy = metric.compute()["accuracy"]
264
+ # return loss, accuracy
265
+
266
+ # class CustomClient(fl.client.NumPyClient):
267
+ # def __init__(self, net, trainloader, testloader):
268
+ # self.net = net
269
+ # self.trainloader = trainloader
270
+ # self.testloader = testloader
271
+
272
+ # def get_parameters(self, config):
273
+ # return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
274
+
275
+ # def set_parameters(self, parameters):
276
+ # params_dict = zip(self.net.state_dict().keys(), parameters)
277
+ # state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
278
+ # self.net.load_state_dict(state_dict, strict=True)
279
+
280
+ # def fit(self, parameters, config):
281
+ # self.set_parameters(parameters)
282
+ # train(self.net, self.trainloader, epochs=1)
283
+ # return self.get_parameters(config={}), len(self.trainloader.dataset), {}
284
+
285
+ # def evaluate(self, parameters, config):
286
+ # self.set_parameters(parameters)
287
+ # loss, accuracy = test(self.net, self.testloader)
288
+ # return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
289
+
290
+ # def main():
291
+ # st.write("## Federated Learning with Flower and Dynamic Models and Datasets for Mobile Devices")
292
+ # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
293
+ # model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"])
294
+
295
+ # net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
296
+
297
+ # NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
298
+ # NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
299
+
300
+ # train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS)
301
+
302
+ # trainloaders = []
303
+ # testloaders = []
304
+
305
+ # for i in range(NUM_CLIENTS):
306
+ # st.write(f"### Client {i+1} Datasets")
307
+
308
+ # train_df = pd.DataFrame(train_datasets[i])
309
+ # test_df = pd.DataFrame(test_datasets[i])
310
+
311
+ # st.write("#### Train Dataset")
312
+ # edited_train_df = st.experimental_data_editor(train_df, key=f"train_{i}")
313
+ # st.write("#### Test Dataset")
314
+ # edited_test_df = st.experimental_data_editor(test_df, key=f"test_{i}")
315
+
316
+ # edited_train_dataset = Dataset.from_pandas(edited_train_df)
317
+ # edited_test_dataset = Dataset.from_pandas(edited_test_df)
318
+
319
+ # trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
320
+ # testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
321
+
322
+ # trainloaders.append(trainloader)
323
+ # testloaders.append(testloader)
324
+
325
+ # if st.button("Start Training"):
326
+ # round_losses = []
327
+ # round_accuracies = []
328
+
329
+ # clients = [CustomClient(net, trainloaders[i], testloaders[i]) for i in range(NUM_CLIENTS)]
330
+
331
+ # def client_fn(cid):
332
+ # return clients[int(cid)]
333
+
334
+ # def weighted_average(metrics):
335
+ # accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
336
+ # losses = [num_examples * m["loss"] for num_examples, m in metrics]
337
+ # examples = [num_examples for num_examples, _ in metrics]
338
+ # return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
339
+
340
+ # strategy = fl.server.strategy.FedAvg(
341
+ # fraction_fit=1.0,
342
+ # fraction_evaluate=1.0,
343
+ # evaluate_metrics_aggregation_fn=weighted_average,
344
+ # )
345
+
346
+ # fl.simulation.start_simulation(
347
+ # client_fn=client_fn,
348
+ # num_clients=NUM_CLIENTS,
349
+ # config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
350
+ # strategy=strategy,
351
+ # client_resources={"num_cpus": 1, "num_gpus": 0},
352
+ # ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
353
+ # )
354
+
355
+ # st.success(f"Training completed successfully!")
356
+
357
+ # else:
358
+ # st.write("Click the 'Start Training' button to start the training process.")
359
+
360
+ # if __name__ == "__main__":
361
+ # main()
362
+
363
+
364
+
365
  ##ORIGINAL###
366
 
367