alisrbdni commited on
Commit
044bfc8
·
verified ·
1 Parent(s): 8a9be66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -228
app.py CHANGED
@@ -1,3 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # # %%writefile app.py
2
 
3
  # import streamlit as st
@@ -11,6 +228,8 @@
11
  # import random
12
  # from collections import OrderedDict
13
  # import flwr as fl
 
 
14
 
15
  # DEVICE = torch.device("cpu")
16
 
@@ -87,16 +306,20 @@
87
  # self.net.load_state_dict(state_dict, strict=True)
88
 
89
  # def fit(self, parameters, config):
 
90
  # self.set_parameters(parameters)
91
  # train(self.net, self.trainloader, epochs=1)
92
  # loss, accuracy = test(self.net, self.testloader)
93
  # self.losses.append(loss)
94
  # self.accuracies.append(accuracy)
 
95
  # return self.get_parameters(config={}), len(self.trainloader.dataset), {}
96
 
97
  # def evaluate(self, parameters, config):
 
98
  # self.set_parameters(parameters)
99
  # loss, accuracy = test(self.net, self.testloader)
 
100
  # return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
101
 
102
  # def plot_metrics(self, round_num, plot_placeholder):
@@ -122,10 +345,14 @@
122
  # fig.tight_layout()
123
  # plot_placeholder.pyplot(fig)
124
 
 
 
 
 
125
  # def main():
126
  # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
127
  # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
128
- # model_name = st.selectbox("Model", ["bert-base-uncased","facebook/hubert-base-ls960", "distilbert-base-uncased"])
129
 
130
  # NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
131
  # NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
@@ -204,239 +431,16 @@
204
  # client.plot_metrics(NUM_ROUNDS, st.empty())
205
  # st.write(" ")
206
 
 
 
 
 
207
  # else:
208
  # st.write("Click the 'Start Training' button to start the training process.")
209
 
210
  # if __name__ == "__main__":
211
  # main()
212
 
213
- #############
214
- # %%writefile app.py
215
-
216
- import streamlit as st
217
- import matplotlib.pyplot as plt
218
- import torch
219
- from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
220
- from datasets import load_dataset, Dataset
221
- from evaluate import load as load_metric
222
- from torch.utils.data import DataLoader
223
- import pandas as pd
224
- import random
225
- from collections import OrderedDict
226
- import flwr as fl
227
- from logging import INFO, DEBUG
228
- from flwr.common.logger import log
229
-
230
- DEVICE = torch.device("cpu")
231
-
232
- def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
233
- raw_datasets = load_dataset(dataset_name)
234
- raw_datasets = raw_datasets.shuffle(seed=42)
235
- del raw_datasets["unsupervised"]
236
-
237
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
238
-
239
- def tokenize_function(examples):
240
- return tokenizer(examples["text"], truncation=True)
241
-
242
- tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
243
- tokenized_datasets = tokenized_datasets.remove_columns("text")
244
- tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
245
-
246
- train_datasets = []
247
- test_datasets = []
248
-
249
- for _ in range(num_clients):
250
- train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
251
- test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
252
- train_datasets.append(train_dataset)
253
- test_datasets.append(test_dataset)
254
-
255
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
256
-
257
- return train_datasets, test_datasets, data_collator
258
-
259
- def train(net, trainloader, epochs):
260
- optimizer = AdamW(net.parameters(), lr=5e-5)
261
- net.train()
262
- for _ in range(epochs):
263
- for batch in trainloader:
264
- batch = {k: v.to(DEVICE) for k, v in batch.items()}
265
- outputs = net(**batch)
266
- loss = outputs.loss
267
- loss.backward()
268
- optimizer.step()
269
- optimizer.zero_grad()
270
-
271
- def test(net, testloader):
272
- metric = load_metric("accuracy")
273
- net.eval()
274
- loss = 0
275
- for batch in testloader:
276
- batch = {k: v.to(DEVICE) for k, v in batch.items()}
277
- with torch.no_grad():
278
- outputs = net(**batch)
279
- logits = outputs.logits
280
- loss += outputs.loss.item()
281
- predictions = torch.argmax(logits, dim=-1)
282
- metric.add_batch(predictions=predictions, references=batch["labels"])
283
- loss /= len(testloader)
284
- accuracy = metric.compute()["accuracy"]
285
- return loss, accuracy
286
-
287
- class CustomClient(fl.client.NumPyClient):
288
- def __init__(self, net, trainloader, testloader, client_id):
289
- self.net = net
290
- self.trainloader = trainloader
291
- self.testloader = testloader
292
- self.client_id = client_id
293
- self.losses = []
294
- self.accuracies = []
295
-
296
- def get_parameters(self, config):
297
- return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
298
-
299
- def set_parameters(self, parameters):
300
- params_dict = zip(self.net.state_dict().keys(), parameters)
301
- state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
302
- self.net.load_state_dict(state_dict, strict=True)
303
-
304
- def fit(self, parameters, config):
305
- log(INFO, f"Client {self.client_id} is starting fit()")
306
- self.set_parameters(parameters)
307
- train(self.net, self.trainloader, epochs=1)
308
- loss, accuracy = test(self.net, self.testloader)
309
- self.losses.append(loss)
310
- self.accuracies.append(accuracy)
311
- log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
312
- return self.get_parameters(config={}), len(self.trainloader.dataset), {}
313
-
314
- def evaluate(self, parameters, config):
315
- log(INFO, f"Client {self.client_id} is starting evaluate()")
316
- self.set_parameters(parameters)
317
- loss, accuracy = test(self.net, self.testloader)
318
- log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
319
- return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
320
-
321
- def plot_metrics(self, round_num, plot_placeholder):
322
- if self.losses and self.accuracies:
323
- plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
324
- plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}")
325
- plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}")
326
-
327
- fig, ax1 = plt.subplots()
328
-
329
- color = 'tab:red'
330
- ax1.set_xlabel('Round')
331
- ax1.set_ylabel('Loss', color=color)
332
- ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color)
333
- ax1.tick_params(axis='y', labelcolor=color)
334
-
335
- ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
336
- color = 'tab:blue'
337
- ax2.set_ylabel('Accuracy', color=color)
338
- ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color)
339
- ax2.tick_params(axis='y', labelcolor=color)
340
-
341
- fig.tight_layout()
342
- plot_placeholder.pyplot(fig)
343
-
344
- def read_log_file():
345
- with open("log.txt", "r") as file:
346
- return file.read()
347
-
348
- def main():
349
- st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
350
- dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
351
- model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
352
-
353
- NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
354
- NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
355
-
356
- train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS)
357
-
358
- trainloaders = []
359
- testloaders = []
360
- clients = []
361
-
362
- for i in range(NUM_CLIENTS):
363
- st.write(f"### Client {i+1} Datasets")
364
-
365
- train_df = pd.DataFrame(train_datasets[i])
366
- test_df = pd.DataFrame(test_datasets[i])
367
-
368
- st.write("#### Train Dataset")
369
- edited_train_df = st.data_editor(train_df, key=f"train_{i}")
370
- st.write("#### Test Dataset")
371
- edited_test_df = st.data_editor(test_df, key=f"test_{i}")
372
-
373
- edited_train_dataset = Dataset.from_pandas(edited_train_df)
374
- edited_test_dataset = Dataset.from_pandas(edited_test_df)
375
-
376
- trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
377
- testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
378
-
379
- trainloaders.append(trainloader)
380
- testloaders.append(testloader)
381
-
382
- net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
383
- client = CustomClient(net, trainloader, testloader, client_id=i+1)
384
- clients.append(client)
385
-
386
- if st.button("Start Training"):
387
- def client_fn(cid):
388
- return clients[int(cid)]
389
-
390
- def weighted_average(metrics):
391
- accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
392
- losses = [num_examples * m["loss"] for num_examples, m in metrics]
393
- examples = [num_examples for num_examples, _ in metrics]
394
- return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
395
-
396
- strategy = fl.server.strategy.FedAvg(
397
- fraction_fit=1.0,
398
- fraction_evaluate=1.0,
399
- evaluate_metrics_aggregation_fn=weighted_average,
400
- )
401
-
402
- for round_num in range(NUM_ROUNDS):
403
- st.write(f"### Round {round_num + 1}")
404
- plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
405
-
406
- fl.simulation.start_simulation(
407
- client_fn=client_fn,
408
- num_clients=NUM_CLIENTS,
409
- config=fl.server.ServerConfig(num_rounds=1),
410
- strategy=strategy,
411
- client_resources={"num_cpus": 1, "num_gpus": 0},
412
- ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
413
- )
414
-
415
- for i, client in enumerate(clients):
416
- client.plot_metrics(round_num + 1, plot_placeholders[i])
417
- st.write(" ")
418
-
419
- st.success("Training completed successfully!")
420
-
421
- # Display final metrics
422
- st.write("## Final Client Metrics")
423
- for client in clients:
424
- st.write(f"### Client {client.client_id}")
425
- st.write(f"Final Loss: {client.losses[-1]:.4f}")
426
- st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
427
- client.plot_metrics(NUM_ROUNDS, st.empty())
428
- st.write(" ")
429
-
430
- # Display log.txt content
431
- st.write("## Training Log")
432
- st.text(read_log_file())
433
-
434
- else:
435
- st.write("Click the 'Start Training' button to start the training process.")
436
-
437
- if __name__ == "__main__":
438
- main()
439
-
440
  #############
441
 
442
  # # %%writefile app.py
 
1
+ # %%writefile app.py
2
+
3
+ import streamlit as st
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
7
+ from datasets import load_dataset, Dataset
8
+ from evaluate import load as load_metric
9
+ from torch.utils.data import DataLoader
10
+ import pandas as pd
11
+ import random
12
+ from collections import OrderedDict
13
+ import flwr as fl
14
+
15
+ DEVICE = torch.device("cpu")
16
+
17
+ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
18
+ raw_datasets = load_dataset(dataset_name)
19
+ raw_datasets = raw_datasets.shuffle(seed=42)
20
+ del raw_datasets["unsupervised"]
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
23
+
24
+ def tokenize_function(examples):
25
+ return tokenizer(examples["text"], truncation=True)
26
+
27
+ tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
28
+ tokenized_datasets = tokenized_datasets.remove_columns("text")
29
+ tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
30
+
31
+ train_datasets = []
32
+ test_datasets = []
33
+
34
+ for _ in range(num_clients):
35
+ train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
36
+ test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
37
+ train_datasets.append(train_dataset)
38
+ test_datasets.append(test_dataset)
39
+
40
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
41
+
42
+ return train_datasets, test_datasets, data_collator
43
+ def read_log_file():
44
+ with open("./log.txt", "r") as file:
45
+ return file.read()
46
+ def train(net, trainloader, epochs):
47
+ optimizer = AdamW(net.parameters(), lr=5e-5)
48
+ net.train()
49
+ for _ in range(epochs):
50
+ for batch in trainloader:
51
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
52
+ outputs = net(**batch)
53
+ loss = outputs.loss
54
+ loss.backward()
55
+ optimizer.step()
56
+ optimizer.zero_grad()
57
+
58
+ def test(net, testloader):
59
+ metric = load_metric("accuracy")
60
+ net.eval()
61
+ loss = 0
62
+ for batch in testloader:
63
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
64
+ with torch.no_grad():
65
+ outputs = net(**batch)
66
+ logits = outputs.logits
67
+ loss += outputs.loss.item()
68
+ predictions = torch.argmax(logits, dim=-1)
69
+ metric.add_batch(predictions=predictions, references=batch["labels"])
70
+ loss /= len(testloader)
71
+ accuracy = metric.compute()["accuracy"]
72
+ return loss, accuracy
73
+
74
+ class CustomClient(fl.client.NumPyClient):
75
+ def __init__(self, net, trainloader, testloader, client_id):
76
+ self.net = net
77
+ self.trainloader = trainloader
78
+ self.testloader = testloader
79
+ self.client_id = client_id
80
+ self.losses = []
81
+ self.accuracies = []
82
+
83
+ def get_parameters(self, config):
84
+ return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
85
+
86
+ def set_parameters(self, parameters):
87
+ params_dict = zip(self.net.state_dict().keys(), parameters)
88
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
89
+ self.net.load_state_dict(state_dict, strict=True)
90
+
91
+ def fit(self, parameters, config):
92
+ self.set_parameters(parameters)
93
+ train(self.net, self.trainloader, epochs=1)
94
+ loss, accuracy = test(self.net, self.testloader)
95
+ self.losses.append(loss)
96
+ self.accuracies.append(accuracy)
97
+ return self.get_parameters(config={}), len(self.trainloader.dataset), {}
98
+
99
+ def evaluate(self, parameters, config):
100
+ self.set_parameters(parameters)
101
+ loss, accuracy = test(self.net, self.testloader)
102
+ return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
103
+
104
+ def plot_metrics(self, round_num, plot_placeholder):
105
+ if self.losses and self.accuracies:
106
+ plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
107
+ plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}")
108
+ plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}")
109
+
110
+ fig, ax1 = plt.subplots()
111
+
112
+ color = 'tab:red'
113
+ ax1.set_xlabel('Round')
114
+ ax1.set_ylabel('Loss', color=color)
115
+ ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color)
116
+ ax1.tick_params(axis='y', labelcolor=color)
117
+
118
+ ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
119
+ color = 'tab:blue'
120
+ ax2.set_ylabel('Accuracy', color=color)
121
+ ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color)
122
+ ax2.tick_params(axis='y', labelcolor=color)
123
+
124
+ fig.tight_layout()
125
+ plot_placeholder.pyplot(fig)
126
+
127
+ def main():
128
+ st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
129
+ dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
130
+ model_name = st.selectbox("Model", ["bert-base-uncased","facebook/hubert-base-ls960", "distilbert-base-uncased"])
131
+
132
+ NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
133
+ NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
134
+
135
+ train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS)
136
+
137
+ trainloaders = []
138
+ testloaders = []
139
+ clients = []
140
+
141
+ for i in range(NUM_CLIENTS):
142
+ st.write(f"### Client {i+1} Datasets")
143
+
144
+ train_df = pd.DataFrame(train_datasets[i])
145
+ test_df = pd.DataFrame(test_datasets[i])
146
+
147
+ st.write("#### Train Dataset")
148
+ edited_train_df = st.data_editor(train_df, key=f"train_{i}")
149
+ st.write("#### Test Dataset")
150
+ edited_test_df = st.data_editor(test_df, key=f"test_{i}")
151
+
152
+ edited_train_dataset = Dataset.from_pandas(edited_train_df)
153
+ edited_test_dataset = Dataset.from_pandas(edited_test_df)
154
+
155
+ trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
156
+ testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
157
+
158
+ trainloaders.append(trainloader)
159
+ testloaders.append(testloader)
160
+
161
+ net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
162
+ client = CustomClient(net, trainloader, testloader, client_id=i+1)
163
+ clients.append(client)
164
+
165
+ if st.button("Start Training"):
166
+ def client_fn(cid):
167
+ return clients[int(cid)]
168
+
169
+ def weighted_average(metrics):
170
+ accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
171
+ losses = [num_examples * m["loss"] for num_examples, m in metrics]
172
+ examples = [num_examples for num_examples, _ in metrics]
173
+ return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
174
+
175
+ strategy = fl.server.strategy.FedAvg(
176
+ fraction_fit=1.0,
177
+ fraction_evaluate=1.0,
178
+ evaluate_metrics_aggregation_fn=weighted_average,
179
+ )
180
+
181
+ for round_num in range(NUM_ROUNDS):
182
+ st.write(f"### Round {round_num + 1}")
183
+ plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
184
+ fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
185
+
186
+ fl.simulation.start_simulation(
187
+ client_fn=client_fn,
188
+ num_clients=NUM_CLIENTS,
189
+ config=fl.server.ServerConfig(num_rounds=1),
190
+ strategy=strategy,
191
+ client_resources={"num_cpus": 1, "num_gpus": 0},
192
+ ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
193
+ )
194
+
195
+ for i, client in enumerate(clients):
196
+ st.markdown("LOGS : "read_log_file())
197
+ client.plot_metrics(round_num + 1, plot_placeholders[i])
198
+ st.write(" ")
199
+
200
+ st.success("Training completed successfully!")
201
+
202
+ # Display final metrics
203
+ st.write("## Final Client Metrics")
204
+ for client in clients:
205
+ st.write(f"### Client {client.client_id}")
206
+ st.write(f"Final Loss: {client.losses[-1]:.4f}")
207
+ st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
208
+ client.plot_metrics(NUM_ROUNDS, st.empty())
209
+ st.write(" ")
210
+
211
+ else:
212
+ st.write("Click the 'Start Training' button to start the training process.")
213
+
214
+ if __name__ == "__main__":
215
+ main()
216
+
217
+ # #############
218
  # # %%writefile app.py
219
 
220
  # import streamlit as st
 
228
  # import random
229
  # from collections import OrderedDict
230
  # import flwr as fl
231
+ # from logging import INFO, DEBUG
232
+ # from flwr.common.logger import log
233
 
234
  # DEVICE = torch.device("cpu")
235
 
 
306
  # self.net.load_state_dict(state_dict, strict=True)
307
 
308
  # def fit(self, parameters, config):
309
+ # log(INFO, f"Client {self.client_id} is starting fit()")
310
  # self.set_parameters(parameters)
311
  # train(self.net, self.trainloader, epochs=1)
312
  # loss, accuracy = test(self.net, self.testloader)
313
  # self.losses.append(loss)
314
  # self.accuracies.append(accuracy)
315
+ # log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
316
  # return self.get_parameters(config={}), len(self.trainloader.dataset), {}
317
 
318
  # def evaluate(self, parameters, config):
319
+ # log(INFO, f"Client {self.client_id} is starting evaluate()")
320
  # self.set_parameters(parameters)
321
  # loss, accuracy = test(self.net, self.testloader)
322
+ # log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
323
  # return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
324
 
325
  # def plot_metrics(self, round_num, plot_placeholder):
 
345
  # fig.tight_layout()
346
  # plot_placeholder.pyplot(fig)
347
 
348
+ # def read_log_file():
349
+ # with open("log.txt", "r") as file:
350
+ # return file.read()
351
+
352
  # def main():
353
  # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
354
  # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
355
+ # model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
356
 
357
  # NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
358
  # NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
 
431
  # client.plot_metrics(NUM_ROUNDS, st.empty())
432
  # st.write(" ")
433
 
434
+ # # Display log.txt content
435
+ # st.write("## Training Log")
436
+ # st.text(read_log_file())
437
+
438
  # else:
439
  # st.write("Click the 'Start Training' button to start the training process.")
440
 
441
  # if __name__ == "__main__":
442
  # main()
443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  #############
445
 
446
  # # %%writefile app.py