alisrbdni commited on
Commit
09b13a5
·
verified ·
1 Parent(s): f6b1769

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -445
app.py CHANGED
@@ -1,429 +1,5 @@
1
 
2
- # # %%writefile app.py
3
-
4
- # import streamlit as st
5
- # import matplotlib.pyplot as plt
6
- # import torch
7
- # from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
8
- # from datasets import load_dataset, Dataset
9
- # from evaluate import load as load_metric
10
- # from torch.utils.data import DataLoader
11
- # import pandas as pd
12
- # import random
13
- # from collections import OrderedDict
14
- # import flwr as fl
15
- # from logging import INFO, DEBUG
16
- # from flwr.common.logger import log
17
- # import logging
18
- # import streamlit
19
-
20
- # # If you're curious of all the loggers
21
-
22
- # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
- # fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
24
-
25
- # def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
26
- # raw_datasets = load_dataset(dataset_name)
27
- # raw_datasets = raw_datasets.shuffle(seed=42)
28
- # del raw_datasets["unsupervised"]
29
-
30
- # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
31
-
32
- # def tokenize_function(examples):
33
- # return tokenizer(examples["text"], truncation=True)
34
-
35
- # tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
36
- # tokenized_datasets = tokenized_datasets.remove_columns("text")
37
- # tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
38
-
39
- # train_datasets = []
40
- # test_datasets = []
41
-
42
- # for _ in range(num_clients):
43
- # train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
44
- # test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
45
- # train_datasets.append(train_dataset)
46
- # test_datasets.append(test_dataset)
47
-
48
- # data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
49
-
50
- # return train_datasets, test_datasets, data_collator, raw_datasets
51
-
52
- # def train(net, trainloader, epochs):
53
- # optimizer = AdamW(net.parameters(), lr=5e-5)
54
- # net.train()
55
- # for _ in range(epochs):
56
- # for batch in trainloader:
57
- # batch = {k: v.to(DEVICE) for k, v in batch.items()}
58
- # outputs = net(**batch)
59
- # loss = outputs.loss
60
- # loss.backward()
61
- # optimizer.step()
62
- # optimizer.zero_grad()
63
-
64
-
65
-
66
-
67
- # # class SaveModelStrategy(fl.server.strategy.FedAvg):
68
- # # def aggregate_fit(
69
- # # self,
70
- # # server_round: int,
71
- # # results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
72
- # # failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
73
- # # ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
74
- # # """Aggregate model weights using weighted average and store checkpoint"""
75
-
76
- # # # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
77
- # # aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)
78
-
79
- # # if aggregated_parameters is not None:
80
- # # print(f"Saving round {server_round} aggregated_parameters...")
81
-
82
- # # # Convert `Parameters` to `List[np.ndarray]`
83
- # # aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)
84
-
85
- # # # Convert `List[np.ndarray]` to PyTorch`state_dict`
86
- # # params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
87
- # # state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
88
- # # net.load_state_dict(state_dict, strict=True)
89
-
90
- # # # Save the model
91
- # # torch.save(net.state_dict(), f"model_round_{server_round}.pth")
92
-
93
- # # return aggregated_parameters, aggregated_metrics
94
-
95
-
96
- # def test(net, testloader):
97
- # metric = load_metric("accuracy")
98
- # net.eval()
99
- # loss = 0
100
- # for batch in testloader:
101
- # batch = {k: v.to(DEVICE) for k, v in batch.items()}
102
- # with torch.no_grad():
103
- # outputs = net(**batch)
104
- # logits = outputs.logits
105
- # loss += outputs.loss.item()
106
- # predictions = torch.argmax(logits, dim=-1)
107
- # metric.add_batch(predictions=predictions, references=batch["labels"])
108
- # loss /= len(testloader)
109
- # accuracy = metric.compute()["accuracy"]
110
- # return loss, accuracy
111
-
112
- # class CustomClient(fl.client.NumPyClient):
113
- # def __init__(self, net, trainloader, testloader, client_id):
114
- # self.net = net
115
- # self.trainloader = trainloader
116
- # self.testloader = testloader
117
- # self.client_id = client_id
118
- # self.losses = []
119
- # self.accuracies = []
120
-
121
- # def get_parameters(self, config):
122
- # return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
123
-
124
- # def set_parameters(self, parameters):
125
- # params_dict = zip(self.net.state_dict().keys(), parameters)
126
- # state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
127
- # self.net.load_state_dict(state_dict, strict=True)
128
-
129
- # def fit(self, parameters, config):
130
- # log(INFO, f"Client {self.client_id} is starting fit()")
131
- # self.set_parameters(parameters)
132
- # train(self.net, self.trainloader, epochs=1)
133
- # loss, accuracy = test(self.net, self.testloader)
134
- # self.losses.append(loss)
135
- # self.accuracies.append(accuracy)
136
- # log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
137
- # return self.get_parameters(config={}), len(self.trainloader.dataset), {"loss": loss, "accuracy": accuracy}
138
-
139
- # def evaluate(self, parameters, config):
140
- # log(INFO, f"Client {self.client_id} is starting evaluate()")
141
- # self.set_parameters(parameters)
142
- # loss, accuracy = test(self.net, self.testloader)
143
- # log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
144
- # return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)}
145
-
146
- # def plot_metrics(self, round_num, plot_placeholder):
147
- # if self.losses and self.accuracies:
148
- # plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
149
- # plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}")
150
- # plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}")
151
-
152
- # fig, ax1 = plt.subplots()
153
-
154
- # color = 'tab:red'
155
- # ax1.set_xlabel('Round')
156
- # ax1.set_ylabel('Loss', color=color)
157
- # ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color)
158
- # ax1.tick_params(axis='y', labelcolor=color)
159
-
160
- # ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
161
- # color = 'tab:blue'
162
- # ax2.set_ylabel('Accuracy', color=color)
163
- # ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color)
164
- # ax2.tick_params(axis='y', labelcolor=color)
165
-
166
- # fig.tight_layout()
167
- # plot_placeholder.pyplot(fig)
168
- # import matplotlib.pyplot as plt
169
- # import re
170
-
171
- # def read_log_file(log_path='./log.txt'):
172
- # with open(log_path, 'r') as file:
173
- # log_lines = file.readlines()
174
- # return log_lines
175
-
176
- # def parse_log(log_lines):
177
- # rounds = []
178
- # clients = {}
179
- # memory_usage = []
180
-
181
- # round_pattern = re.compile(r'ROUND(\d+)ROUND (\d+)')
182
- # client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
183
- # memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
184
-
185
- # current_round = None
186
-
187
- # for line in log_lines:
188
- # round_match = round_pattern.search(line)
189
- # client_match = client_pattern.search(line)
190
- # memory_match = memory_pattern.search(line)
191
-
192
- # if round_match:
193
- # current_round = int(round_match.group(1))
194
- # rounds.append(current_round)
195
- # elif client_match:
196
- # client_id = int(client_match.group(1))
197
- # log_level = client_match.group(2)
198
- # message = client_match.group(3)
199
-
200
- # if client_id not in clients:
201
- # clients[client_id] = {'rounds': [], 'messages': []}
202
-
203
- # clients[client_id]['rounds'].append(current_round)
204
- # clients[client_id]['messages'].append((log_level, message))
205
- # elif memory_match:
206
- # memory_usage.append(float(memory_match.group(1)))
207
-
208
- # return rounds, clients, memory_usage
209
-
210
- # def plot_metrics(rounds, clients, memory_usage):
211
- # st.write("## Metrics Overview")
212
-
213
- # st.write("### Memory Usage")
214
- # plt.figure()
215
- # plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)')
216
- # plt.xlabel('Step')
217
- # plt.ylabel('Memory Usage (GB)')
218
- # plt.legend()
219
- # st.pyplot(plt)
220
-
221
- # for client_id, data in clients.items():
222
- # st.write(f"### Client {client_id} Metrics")
223
-
224
- # info_messages = [msg for level, msg in data['messages'] if level == 'INFO']
225
- # debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG']
226
-
227
- # st.write("#### INFO Messages")
228
- # for msg in info_messages:
229
- # st.write(msg)
230
-
231
- # st.write("#### DEBUG Messages")
232
- # for msg in debug_messages:
233
- # st.write(msg)
234
-
235
- # # Placeholder for actual loss and accuracy values, assuming they're included in the messages
236
- # losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg]
237
- # accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg]
238
-
239
- # if losses:
240
- # plt.figure()
241
- # plt.plot(data['rounds'], losses, label='Loss')
242
- # plt.xlabel('Round')
243
- # plt.ylabel('Loss')
244
- # plt.legend()
245
- # st.pyplot(plt)
246
-
247
- # if accuracies:
248
- # plt.figure()
249
- # plt.plot(data['rounds'], accuracies, label='Accuracy')
250
- # plt.xlabel('Round')
251
- # plt.ylabel('Accuracy')
252
- # plt.legend()
253
- # st.pyplot(plt)
254
-
255
-
256
- # def read_log_file2():
257
- # with open("./log.txt", "r") as file:
258
- # return file.read()
259
- # def main():
260
-
261
- # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
262
- # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
263
- # model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
264
-
265
- # NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
266
- # NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
267
-
268
- # train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS)
269
-
270
- # trainloaders = []
271
- # testloaders = []
272
- # clients = []
273
-
274
- # for i in range(NUM_CLIENTS):
275
- # st.write(f"### Client {i+1} Datasets")
276
-
277
- # train_df = pd.DataFrame(train_datasets[i])
278
- # test_df = pd.DataFrame(test_datasets[i])
279
-
280
- # st.write("#### Train Dataset (Words)")
281
- # st.dataframe(raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 20)))
282
- # st.write("#### Train Dataset (Tokens)")
283
- # edited_train_df = st.data_editor(train_df, key=f"train_{i}")
284
-
285
- # st.write("#### Test Dataset (Words)")
286
- # st.dataframe(raw_datasets["test"].select(random.sample(range(len(raw_datasets["test"])), 20)))
287
- # st.write("#### Test Dataset (Tokens)")
288
- # edited_test_df = st.data_editor(test_df, key=f"test_{i}")
289
-
290
- # edited_train_dataset = Dataset.from_pandas(edited_train_df)
291
- # edited_test_dataset = Dataset.from_pandas(edited_test_df)
292
-
293
- # trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
294
- # testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
295
-
296
- # trainloaders.append(trainloader)
297
- # testloaders.append(testloader)
298
-
299
- # net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
300
- # client = CustomClient(net, trainloader, testloader, client_id=i+1)
301
- # clients.append(client)
302
-
303
- # if st.button("Start Training"):
304
- # def client_fn(cid):
305
- # return clients[int(cid)]
306
-
307
- # def weighted_average(metrics):
308
- # accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
309
- # losses = [num_examples * m["loss"] for num_examples, m in metrics]
310
- # examples = [num_examples for num_examples, _ in metrics]
311
- # return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
312
-
313
- # strategy = fl.server.strategy.FedAvg(
314
- # fraction_fit=1.0,
315
- # fraction_evaluate=1.0,
316
- # evaluate_metrics_aggregation_fn=weighted_average,
317
- # )
318
-
319
- # for round_num in range(NUM_ROUNDS):
320
- # st.write(f"### Round {round_num + 1} ✅")
321
-
322
- # st.markdown(print(st.logger._loggers))
323
- # st.markdown(read_log_file2())
324
- # logs = read_log_file2()
325
- # import re
326
- # import plotly.graph_objects as go
327
- # import streamlit as st
328
- # import pandas as pd
329
-
330
- # # Log data
331
- # log_data = logs
332
-
333
- # # Extract relevant data
334
- # accuracy_pattern = re.compile(r"'accuracy': \((\d+),([\d.]+)\)\((\d+), ([\d.]+)\)")
335
- # loss_pattern = re.compile(r"'loss': \((\d+),([\d.]+)\)\((\d+), ([\d.]+)\)")
336
-
337
- # accuracy_matches = accuracy_pattern.findall(log_data)
338
- # loss_matches = loss_pattern.findall(log_data)
339
-
340
- # rounds = [int(match[0]) for match in accuracy_matches]
341
- # accuracies = [float(match[1]) for match in accuracy_matches]
342
- # losses = [float(match[1]) for match in loss_matches]
343
-
344
- # # Create accuracy plot
345
- # accuracy_fig = go.Figure()
346
- # accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
347
- # accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
348
-
349
- # # Create loss plot
350
- # loss_fig = go.Figure()
351
- # loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
352
- # loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
353
-
354
- # # Display plots in Streamlit
355
- # st.plotly_chart(accuracy_fig)
356
- # st.plotly_chart(loss_fig)
357
-
358
- # # Display data table
359
- # data = {
360
- # 'Round': rounds,
361
- # 'Accuracy': accuracies,
362
- # 'Loss': losses
363
- # }
364
-
365
- # df = pd.DataFrame(data)
366
- # st.write("## Training Metrics")
367
- # st.table(df)
368
-
369
-
370
-
371
-
372
-
373
-
374
-
375
- # plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
376
-
377
- # fl.simulation.start_simulation(
378
- # client_fn=client_fn,
379
- # num_clients=NUM_CLIENTS,
380
- # config=fl.server.ServerConfig(num_rounds=1),
381
- # strategy=strategy,
382
- # client_resources={"num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)},
383
- # ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)}
384
- # )
385
-
386
- # for i, client in enumerate(clients):
387
- # client.plot_metrics(round_num + 1, plot_placeholders[i])
388
- # st.write(" ")
389
-
390
- # st.success("Training completed successfully!")
391
-
392
- # # Display final metrics
393
- # st.write("## Final Client Metrics")
394
- # for client in clients:
395
- # st.write(f"### Client {client.client_id}")
396
- # if client.losses and client.accuracies:
397
- # st.write(f"Final Loss: {client.losses[-1]:.4f}")
398
- # st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
399
- # client.plot_metrics(NUM_ROUNDS, st.empty())
400
- # else:
401
- # st.write("No metrics available.")
402
-
403
- # st.write(" ")
404
-
405
- # # Display log.txt content
406
- # st.write("## Training Log")
407
- # # st.text(read_log_file())
408
- # st.write("## Training Log Analysis")
409
-
410
- # log_lines = read_log_file()
411
- # rounds, clients, memory_usage = parse_log(log_lines)
412
-
413
- # plot_metrics(rounds, clients, memory_usage)
414
-
415
- # else:
416
- # st.write("Click the 'Start Training' button to start the training process.")
417
-
418
- # if __name__ == "__main__":
419
- # main()
420
-
421
-
422
-
423
-
424
-
425
- # ##############NEW
426
-
427
  # import streamlit as st
428
  # import matplotlib.pyplot as plt
429
  # import torch
@@ -441,8 +17,6 @@
441
  # import re
442
  # import plotly.graph_objects as go
443
 
444
- # # If you're curious of all the loggers
445
-
446
  # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
447
  # fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
448
 
@@ -568,7 +142,7 @@
568
  # clients = {}
569
  # memory_usage = []
570
 
571
- # round_pattern = re.compile(r'ROUND(\d+)ROUND (\d+)')
572
  # client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
573
  # memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
574
 
@@ -648,6 +222,22 @@
648
 
649
  # def main():
650
  # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
  # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
652
  # model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
653
 
@@ -691,7 +281,7 @@
691
 
692
  # if st.button("Start Training"):
693
  # def client_fn(cid):
694
- # return clients[int(cid)]
695
 
696
  # def weighted_average(metrics):
697
  # accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
@@ -709,13 +299,23 @@
709
  # st.write(f"### Round {round_num + 1} ✅")
710
 
711
  # logs = read_log_file2()
712
- # st.markdown(logs)
713
- # # Extract relevant data
 
 
 
 
 
 
 
 
 
 
714
  # accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}")
715
  # loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}")
716
 
717
- # accuracy_matches = accuracy_pattern.findall(logs)
718
- # loss_matches = loss_pattern.findall(logs)
719
 
720
  # rounds = [int(match[0]) for match in accuracy_matches]
721
  # accuracies = [float(match[1]) for match in accuracy_matches]
@@ -792,9 +392,6 @@
792
  # if __name__ == "__main__":
793
  # main()
794
 
795
-
796
-
797
- # #################
798
  import streamlit as st
799
  import matplotlib.pyplot as plt
800
  import torch
@@ -815,19 +412,28 @@ import plotly.graph_objects as go
815
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
816
  fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
817
 
818
- def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
819
  raw_datasets = load_dataset(dataset_name)
820
  raw_datasets = raw_datasets.shuffle(seed=42)
821
  del raw_datasets["unsupervised"]
822
 
823
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
824
 
825
- def tokenize_function(examples):
826
- return tokenizer(examples["text"], truncation=True)
827
 
828
- tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
829
- tokenized_datasets = tokenized_datasets.remove_columns("text")
830
- tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
 
 
 
 
 
 
 
 
831
 
832
  train_datasets = []
833
  test_datasets = []
@@ -838,7 +444,7 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
838
  train_datasets.append(train_dataset)
839
  test_datasets.append(test_dataset)
840
 
841
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
842
 
843
  return train_datasets, test_datasets, data_collator, raw_datasets
844
 
@@ -1038,8 +644,9 @@ def main():
1038
 
1039
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
1040
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
 
1041
 
1042
- train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS)
1043
 
1044
  trainloaders = []
1045
  testloaders = []
@@ -1187,3 +794,4 @@ def main():
1187
  if __name__ == "__main__":
1188
  main()
1189
 
 
 
1
 
2
+ # # #################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  # import streamlit as st
4
  # import matplotlib.pyplot as plt
5
  # import torch
 
17
  # import re
18
  # import plotly.graph_objects as go
19
 
 
 
20
  # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  # fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
22
 
 
142
  # clients = {}
143
  # memory_usage = []
144
 
145
+ # round_pattern = re.compile(r'ROUND (\d+)')
146
  # client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
147
  # memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
148
 
 
222
 
223
  # def main():
224
  # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
225
+ # logs = read_log_file2()
226
+ # # cleanLogs = # Define a pattern to match relevant log entries
227
+ # pattern = re.compile(r"memory|loss|accuracy|round|client", re.IGNORECASE)
228
+
229
+
230
+ # # Filter the log data
231
+ # filtered_logs = [line for line in logs.splitlines() if pattern.search(line)]
232
+ # st.markdown(filtered_logs)
233
+
234
+ # # Provide a download button for the logs
235
+ # st.download_button(
236
+ # label="Download Logs",
237
+ # data="\n".join(filtered_logs),
238
+ # file_name="./log.txt",
239
+ # mime="text/plain"
240
+ # )
241
  # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
242
  # model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
243
 
 
281
 
282
  # if st.button("Start Training"):
283
  # def client_fn(cid):
284
+ # return clients[int(cid)].to_client()
285
 
286
  # def weighted_average(metrics):
287
  # accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
 
299
  # st.write(f"### Round {round_num + 1} ✅")
300
 
301
  # logs = read_log_file2()
302
+ # filtered_log_list = [line for line in logs.splitlines() if pattern.search(line)]
303
+ # filtered_logs = "\n".join(filtered_log_list)
304
+
305
+ # st.markdown(filtered_logs)
306
+ # # Provide a download button for the logs
307
+ # # st.download_button(
308
+ # # label="Download Logs",
309
+ # # data=logs,
310
+ # # file_name="./log.txt",
311
+ # # mime="text/plain"
312
+ # # )
313
+ # # # Extract relevant data
314
  # accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}")
315
  # loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}")
316
 
317
+ # accuracy_matches = accuracy_pattern.findall(filtered_logs)
318
+ # loss_matches = loss_pattern.findall(filtered_logs)
319
 
320
  # rounds = [int(match[0]) for match in accuracy_matches]
321
  # accuracies = [float(match[1]) for match in accuracy_matches]
 
392
  # if __name__ == "__main__":
393
  # main()
394
 
 
 
 
395
  import streamlit as st
396
  import matplotlib.pyplot as plt
397
  import torch
 
412
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
413
  fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
414
 
415
+ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8=False):
416
  raw_datasets = load_dataset(dataset_name)
417
  raw_datasets = raw_datasets.shuffle(seed=42)
418
  del raw_datasets["unsupervised"]
419
 
420
+ if not use_utf8:
421
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
422
 
423
+ def tokenize_function(examples):
424
+ return tokenizer(examples["text"], truncation=True)
425
 
426
+ tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
427
+ tokenized_datasets = tokenized_datasets.remove_columns("text")
428
+ tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
429
+ else:
430
+ def utf8_encode_function(examples):
431
+ examples["text"] = [text.encode('utf-8') for text in examples["text"]]
432
+ return examples
433
+
434
+ tokenized_datasets = raw_datasets.map(utf8_encode_function, batched=True)
435
+ tokenized_datasets = tokenized_datasets.remove_columns("text")
436
+ tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
437
 
438
  train_datasets = []
439
  test_datasets = []
 
444
  train_datasets.append(train_dataset)
445
  test_datasets.append(test_dataset)
446
 
447
+ data_collator = DataCollatorWithPadding(tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased"))
448
 
449
  return train_datasets, test_datasets, data_collator, raw_datasets
450
 
 
644
 
645
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
646
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
647
+ use_utf8 = st.checkbox("Train on Byte UTF-8 Dataset", value=False)
648
 
649
+ train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS, use_utf8=use_utf8)
650
 
651
  trainloaders = []
652
  testloaders = []
 
794
  if __name__ == "__main__":
795
  main()
796
 
797
+