alisrbdni commited on
Commit
2c9aff2
·
verified ·
1 Parent(s): 3ca7e98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -80
app.py CHANGED
@@ -1,5 +1,3 @@
1
- # %%writefile app.py
2
-
3
  import streamlit as st
4
  import matplotlib.pyplot as plt
5
  import torch
@@ -9,8 +7,6 @@ from evaluate import load as load_metric
9
  from torch.utils.data import DataLoader
10
  import pandas as pd
11
  import random
12
- import warnings
13
- from collections import OrderedDict
14
  import flwr as fl
15
 
16
  DEVICE = torch.device("cpu")
@@ -21,7 +17,6 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
21
  del raw_datasets["unsupervised"]
22
 
23
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
24
-
25
  def tokenize_function(examples):
26
  return tokenizer(examples["text"], truncation=True)
27
 
@@ -102,109 +97,273 @@ class CustomClient(fl.client.NumPyClient):
102
 
103
  def plot_metrics(self, round_num):
104
  if self.losses and self.accuracies:
105
- st.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
106
- st.write(f"Loss: {self.losses[-1]:.4f}")
107
- st.write(f"Accuracy: {self.accuracies[-1]:.4f}")
108
-
109
  fig, ax1 = plt.subplots()
110
 
111
- ax2 = ax1.twinx()
112
- ax1.plot(range(1, len(self.losses) + 1), self.losses, 'g-')
113
- ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, 'b-')
114
-
115
  ax1.set_xlabel('Round')
116
- ax1.set_ylabel('Loss', color='g')
117
- ax2.set_ylabel('Accuracy', color='b')
 
 
 
 
 
 
 
118
 
119
- plt.title(f'Client {self.client_id} Metrics')
120
  st.pyplot(fig)
 
121
 
122
  def main():
123
- st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
124
- dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
125
- model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"])
126
 
127
- NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
128
- NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
129
 
130
  train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS)
131
 
132
- trainloaders = []
133
- testloaders = []
134
- clients = []
 
135
 
136
- for i in range(NUM_CLIENTS):
137
- st.write(f"### Client {i+1} Datasets")
 
138
 
139
- train_df = pd.DataFrame(train_datasets[i])
140
- test_df = pd.DataFrame(test_datasets[i])
141
 
142
- st.write("#### Train Dataset")
143
- edited_train_df = st.experimental_data_editor(train_df, key=f"train_{i}")
144
- st.write("#### Test Dataset")
145
- edited_test_df = st.experimental_data_editor(test_df, key=f"test_{i}")
146
 
147
- edited_train_dataset = Dataset.from_pandas(edited_train_df)
148
- edited_test_dataset = Dataset.from_pandas(edited_test_df)
 
149
 
150
- trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
151
- testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
 
 
 
152
 
153
- trainloaders.append(trainloader)
154
- testloaders.append(testloader)
155
 
156
- net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
157
- client = CustomClient(net, trainloader, testloader, client_id=i+1)
158
- clients.append(client)
159
 
160
- if st.button("Start Training"):
161
- def client_fn(cid):
162
- return clients[int(cid)]
163
 
164
- def weighted_average(metrics):
165
- accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
166
- losses = [num_examples * m["loss"] for num_examples, m in metrics]
167
- examples = [num_examples for num_examples, _ in metrics]
168
- return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
169
 
170
- strategy = fl.server.strategy.FedAvg(
171
- fraction_fit=1.0,
172
- fraction_evaluate=1.0,
173
- evaluate_metrics_aggregation_fn=weighted_average,
174
- )
175
 
176
- for round_num in range(NUM_ROUNDS):
177
- st.write(f"### Round {round_num + 1}")
 
 
 
 
 
 
 
 
 
 
178
 
179
- fl.simulation.start_simulation(
180
- client_fn=client_fn,
181
- num_clients=NUM_CLIENTS,
182
- config=fl.server.ServerConfig(num_rounds=1),
183
- strategy=strategy,
184
- client_resources={"num_cpus": 1, "num_gpus": 0},
185
- ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
186
- )
187
 
188
- for client in clients:
189
- client.plot_metrics(round_num + 1)
190
- st.write(" ")
 
191
 
192
- st.success(f"Training completed successfully!")
193
 
194
- # Display final metrics
195
- st.write("## Final Client Metrics")
196
- for client in clients:
197
- st.write(f"### Client {client.client_id}")
198
- st.write(f"Final Loss: {client.losses[-1]:.4f}")
199
- st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
200
- client.plot_metrics(NUM_ROUNDS)
201
- st.write(" ")
202
 
203
- else:
204
- st.write("Click the 'Start Training' button to start the training process.")
 
205
 
206
- if __name__ == "__main__":
207
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
 
210
 
 
 
 
1
  import streamlit as st
2
  import matplotlib.pyplot as plt
3
  import torch
 
7
  from torch.utils.data import DataLoader
8
  import pandas as pd
9
  import random
 
 
10
  import flwr as fl
11
 
12
  DEVICE = torch.device("cpu")
 
17
  del raw_datasets["unsupervised"]
18
 
19
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
20
  def tokenize_function(examples):
21
  return tokenizer(examples["text"], truncation=True)
22
 
 
97
 
98
  def plot_metrics(self, round_num):
99
  if self.losses and self.accuracies:
 
 
 
 
100
  fig, ax1 = plt.subplots()
101
 
102
+ color = 'tab:red'
 
 
 
103
  ax1.set_xlabel('Round')
104
+ ax1.set_ylabel('Loss', color=color)
105
+ ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color)
106
+ ax1.tick_params(axis='y', labelcolor=color)
107
+
108
+ ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
109
+ color = 'tab:blue'
110
+ ax2.set_ylabel('Accuracy', color=color) # we already handled the x-label with ax1
111
+ ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color)
112
+ ax2.tick_params(axis='y', labelcolor=color)
113
 
114
+ fig.tight_layout() # otherwise the right y-label is slightly clipped
115
  st.pyplot(fig)
116
+ st.write(f"Round {round_num} - Loss: {self.losses[-1]:.4f}, Accuracy: {self.accuracies[-1]:.4f}")
117
 
118
  def main():
119
+ st.title("Federated Learning with Dynamic Models and Datasets")
120
+ dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"], index=0)
121
+ model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"], index=0)
122
 
123
+ NUM_CLIENTS = st.slider("Number of Clients", 1, 10, 3)
124
+ NUM_ROUNDS = st.slider("Number of Rounds", 1, 10, 5)
125
 
126
  train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS)
127
 
128
+ if st.button("Initialize Clients"):
129
+ trainloaders = []
130
+ testloaders = []
131
+ clients = []
132
 
133
+ for i in range(NUM_CLIENTS):
134
+ train_df = pd.DataFrame(train_datasets[i])
135
+ test_df = pd.DataFrame(test_datasets[i])
136
 
137
+ edited_train_dataset = Dataset.from_pandas(train_df)
138
+ edited_test_dataset = Dataset.from_pandas(test_df)
139
 
140
+ trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=4, collate_fn=data_collator)
141
+ testloader = DataLoader(edited_test_dataset, batch_size=4, collate_fn=data_collator)
 
 
142
 
143
+ net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
144
+ client = CustomClient(net, trainloader, testloader, client_id=i+1)
145
+ clients.append(client)
146
 
147
+ for round_num in range(1, NUM_ROUNDS + 1):
148
+ st.write(f"### Round {round_num}")
149
+ for client in clients:
150
+ _, _, _ = client.fit({}, {})
151
+ client.plot_metrics(round_num)
152
 
153
+ st.success("Training completed successfully!")
 
154
 
155
+ if __name__ == "__main__":
156
+ main()
 
157
 
 
 
 
158
 
 
 
 
 
 
159
 
160
+ # # %%writefile app.py
 
 
 
 
161
 
162
+ # import streamlit as st
163
+ # import matplotlib.pyplot as plt
164
+ # import torch
165
+ # from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
166
+ # from datasets import load_dataset, Dataset
167
+ # from evaluate import load as load_metric
168
+ # from torch.utils.data import DataLoader
169
+ # import pandas as pd
170
+ # import random
171
+ # import warnings
172
+ # from collections import OrderedDict
173
+ # import flwr as fl
174
 
175
+ # DEVICE = torch.device("cpu")
 
 
 
 
 
 
 
176
 
177
+ # def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
178
+ # raw_datasets = load_dataset(dataset_name)
179
+ # raw_datasets = raw_datasets.shuffle(seed=42)
180
+ # del raw_datasets["unsupervised"]
181
 
182
+ # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
183
 
184
+ # def tokenize_function(examples):
185
+ # return tokenizer(examples["text"], truncation=True)
 
 
 
 
 
 
186
 
187
+ # tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
188
+ # tokenized_datasets = tokenized_datasets.remove_columns("text")
189
+ # tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
190
 
191
+ # train_datasets = []
192
+ # test_datasets = []
193
+
194
+ # for _ in range(num_clients):
195
+ # train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
196
+ # test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
197
+ # train_datasets.append(train_dataset)
198
+ # test_datasets.append(test_dataset)
199
+
200
+ # data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
201
+
202
+ # return train_datasets, test_datasets, data_collator
203
+
204
+ # def train(net, trainloader, epochs):
205
+ # optimizer = AdamW(net.parameters(), lr=5e-5)
206
+ # net.train()
207
+ # for _ in range(epochs):
208
+ # for batch in trainloader:
209
+ # batch = {k: v.to(DEVICE) for k, v in batch.items()}
210
+ # outputs = net(**batch)
211
+ # loss = outputs.loss
212
+ # loss.backward()
213
+ # optimizer.step()
214
+ # optimizer.zero_grad()
215
+
216
+ # def test(net, testloader):
217
+ # metric = load_metric("accuracy")
218
+ # net.eval()
219
+ # loss = 0
220
+ # for batch in testloader:
221
+ # batch = {k: v.to(DEVICE) for k, v in batch.items()}
222
+ # with torch.no_grad():
223
+ # outputs = net(**batch)
224
+ # logits = outputs.logits
225
+ # loss += outputs.loss.item()
226
+ # predictions = torch.argmax(logits, dim=-1)
227
+ # metric.add_batch(predictions=predictions, references=batch["labels"])
228
+ # loss /= len(testloader)
229
+ # accuracy = metric.compute()["accuracy"]
230
+ # return loss, accuracy
231
+
232
+ # class CustomClient(fl.client.NumPyClient):
233
+ # def __init__(self, net, trainloader, testloader, client_id):
234
+ # self.net = net
235
+ # self.trainloader = trainloader
236
+ # self.testloader = testloader
237
+ # self.client_id = client_id
238
+ # self.losses = []
239
+ # self.accuracies = []
240
+
241
+ # def get_parameters(self, config):
242
+ # return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
243
+
244
+ # def set_parameters(self, parameters):
245
+ # params_dict = zip(self.net.state_dict().keys(), parameters)
246
+ # state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
247
+ # self.net.load_state_dict(state_dict, strict=True)
248
+
249
+ # def fit(self, parameters, config):
250
+ # self.set_parameters(parameters)
251
+ # train(self.net, self.trainloader, epochs=1)
252
+ # loss, accuracy = test(self.net, self.testloader)
253
+ # self.losses.append(loss)
254
+ # self.accuracies.append(accuracy)
255
+ # return self.get_parameters(config={}), len(self.trainloader.dataset), {}
256
+
257
+ # def evaluate(self, parameters, config):
258
+ # self.set_parameters(parameters)
259
+ # loss, accuracy = test(self.net, self.testloader)
260
+ # return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
261
+
262
+ # def plot_metrics(self, round_num):
263
+ # if self.losses and self.accuracies:
264
+ # st.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
265
+ # st.write(f"Loss: {self.losses[-1]:.4f}")
266
+ # st.write(f"Accuracy: {self.accuracies[-1]:.4f}")
267
+
268
+ # fig, ax1 = plt.subplots()
269
+
270
+ # ax2 = ax1.twinx()
271
+ # ax1.plot(range(1, len(self.losses) + 1), self.losses, 'g-')
272
+ # ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, 'b-')
273
+
274
+ # ax1.set_xlabel('Round')
275
+ # ax1.set_ylabel('Loss', color='g')
276
+ # ax2.set_ylabel('Accuracy', color='b')
277
+
278
+ # plt.title(f'Client {self.client_id} Metrics')
279
+ # st.pyplot(fig)
280
+
281
+ # def main():
282
+ # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
283
+ # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
284
+ # model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"])
285
+
286
+ # NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
287
+ # NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
288
+
289
+ # train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS)
290
+
291
+ # trainloaders = []
292
+ # testloaders = []
293
+ # clients = []
294
+
295
+ # for i in range(NUM_CLIENTS):
296
+ # st.write(f"### Client {i+1} Datasets")
297
+
298
+ # train_df = pd.DataFrame(train_datasets[i])
299
+ # test_df = pd.DataFrame(test_datasets[i])
300
+
301
+ # st.write("#### Train Dataset")
302
+ # edited_train_df = st.experimental_data_editor(train_df, key=f"train_{i}")
303
+ # st.write("#### Test Dataset")
304
+ # edited_test_df = st.experimental_data_editor(test_df, key=f"test_{i}")
305
+
306
+ # edited_train_dataset = Dataset.from_pandas(edited_train_df)
307
+ # edited_test_dataset = Dataset.from_pandas(edited_test_df)
308
+
309
+ # trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
310
+ # testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
311
+
312
+ # trainloaders.append(trainloader)
313
+ # testloaders.append(testloader)
314
+
315
+ # net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
316
+ # client = CustomClient(net, trainloader, testloader, client_id=i+1)
317
+ # clients.append(client)
318
+
319
+ # if st.button("Start Training"):
320
+ # def client_fn(cid):
321
+ # return clients[int(cid)]
322
+
323
+ # def weighted_average(metrics):
324
+ # accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
325
+ # losses = [num_examples * m["loss"] for num_examples, m in metrics]
326
+ # examples = [num_examples for num_examples, _ in metrics]
327
+ # return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
328
+
329
+ # strategy = fl.server.strategy.FedAvg(
330
+ # fraction_fit=1.0,
331
+ # fraction_evaluate=1.0,
332
+ # evaluate_metrics_aggregation_fn=weighted_average,
333
+ # )
334
+
335
+ # for round_num in range(NUM_ROUNDS):
336
+ # st.write(f"### Round {round_num + 1}")
337
+
338
+ # fl.simulation.start_simulation(
339
+ # client_fn=client_fn,
340
+ # num_clients=NUM_CLIENTS,
341
+ # config=fl.server.ServerConfig(num_rounds=1),
342
+ # strategy=strategy,
343
+ # client_resources={"num_cpus": 1, "num_gpus": 0},
344
+ # ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
345
+ # )
346
+
347
+ # for client in clients:
348
+ # client.plot_metrics(round_num + 1)
349
+ # st.write(" ")
350
+
351
+ # st.success(f"Training completed successfully!")
352
+
353
+ # # Display final metrics
354
+ # st.write("## Final Client Metrics")
355
+ # for client in clients:
356
+ # st.write(f"### Client {client.client_id}")
357
+ # st.write(f"Final Loss: {client.losses[-1]:.4f}")
358
+ # st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
359
+ # client.plot_metrics(NUM_ROUNDS)
360
+ # st.write(" ")
361
+
362
+ # else:
363
+ # st.write("Click the 'Start Training' button to start the training process.")
364
+
365
+ # if __name__ == "__main__":
366
+ # main()
367
 
368
 
369