alisrbdni commited on
Commit
8a9be66
·
verified ·
1 Parent(s): 764d201

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -2
app.py CHANGED
@@ -1,3 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # %%writefile app.py
2
 
3
  import streamlit as st
@@ -11,6 +224,8 @@ import pandas as pd
11
  import random
12
  from collections import OrderedDict
13
  import flwr as fl
 
 
14
 
15
  DEVICE = torch.device("cpu")
16
 
@@ -87,16 +302,20 @@ class CustomClient(fl.client.NumPyClient):
87
  self.net.load_state_dict(state_dict, strict=True)
88
 
89
  def fit(self, parameters, config):
 
90
  self.set_parameters(parameters)
91
  train(self.net, self.trainloader, epochs=1)
92
  loss, accuracy = test(self.net, self.testloader)
93
  self.losses.append(loss)
94
  self.accuracies.append(accuracy)
 
95
  return self.get_parameters(config={}), len(self.trainloader.dataset), {}
96
 
97
  def evaluate(self, parameters, config):
 
98
  self.set_parameters(parameters)
99
  loss, accuracy = test(self.net, self.testloader)
 
100
  return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
101
 
102
  def plot_metrics(self, round_num, plot_placeholder):
@@ -122,10 +341,14 @@ class CustomClient(fl.client.NumPyClient):
122
  fig.tight_layout()
123
  plot_placeholder.pyplot(fig)
124
 
 
 
 
 
125
  def main():
126
  st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
127
  dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
128
- model_name = st.selectbox("Model", ["bert-base-uncased","facebook/hubert-base-ls960", "distilbert-base-uncased"])
129
 
130
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
131
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
@@ -204,13 +427,17 @@ def main():
204
  client.plot_metrics(NUM_ROUNDS, st.empty())
205
  st.write(" ")
206
 
 
 
 
 
207
  else:
208
  st.write("Click the 'Start Training' button to start the training process.")
209
 
210
  if __name__ == "__main__":
211
  main()
212
 
213
-
214
 
215
  # # %%writefile app.py
216
 
 
1
+ # # %%writefile app.py
2
+
3
+ # import streamlit as st
4
+ # import matplotlib.pyplot as plt
5
+ # import torch
6
+ # from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
7
+ # from datasets import load_dataset, Dataset
8
+ # from evaluate import load as load_metric
9
+ # from torch.utils.data import DataLoader
10
+ # import pandas as pd
11
+ # import random
12
+ # from collections import OrderedDict
13
+ # import flwr as fl
14
+
15
+ # DEVICE = torch.device("cpu")
16
+
17
+ # def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
18
+ # raw_datasets = load_dataset(dataset_name)
19
+ # raw_datasets = raw_datasets.shuffle(seed=42)
20
+ # del raw_datasets["unsupervised"]
21
+
22
+ # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
23
+
24
+ # def tokenize_function(examples):
25
+ # return tokenizer(examples["text"], truncation=True)
26
+
27
+ # tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
28
+ # tokenized_datasets = tokenized_datasets.remove_columns("text")
29
+ # tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
30
+
31
+ # train_datasets = []
32
+ # test_datasets = []
33
+
34
+ # for _ in range(num_clients):
35
+ # train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
36
+ # test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
37
+ # train_datasets.append(train_dataset)
38
+ # test_datasets.append(test_dataset)
39
+
40
+ # data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
41
+
42
+ # return train_datasets, test_datasets, data_collator
43
+
44
+ # def train(net, trainloader, epochs):
45
+ # optimizer = AdamW(net.parameters(), lr=5e-5)
46
+ # net.train()
47
+ # for _ in range(epochs):
48
+ # for batch in trainloader:
49
+ # batch = {k: v.to(DEVICE) for k, v in batch.items()}
50
+ # outputs = net(**batch)
51
+ # loss = outputs.loss
52
+ # loss.backward()
53
+ # optimizer.step()
54
+ # optimizer.zero_grad()
55
+
56
+ # def test(net, testloader):
57
+ # metric = load_metric("accuracy")
58
+ # net.eval()
59
+ # loss = 0
60
+ # for batch in testloader:
61
+ # batch = {k: v.to(DEVICE) for k, v in batch.items()}
62
+ # with torch.no_grad():
63
+ # outputs = net(**batch)
64
+ # logits = outputs.logits
65
+ # loss += outputs.loss.item()
66
+ # predictions = torch.argmax(logits, dim=-1)
67
+ # metric.add_batch(predictions=predictions, references=batch["labels"])
68
+ # loss /= len(testloader)
69
+ # accuracy = metric.compute()["accuracy"]
70
+ # return loss, accuracy
71
+
72
+ # class CustomClient(fl.client.NumPyClient):
73
+ # def __init__(self, net, trainloader, testloader, client_id):
74
+ # self.net = net
75
+ # self.trainloader = trainloader
76
+ # self.testloader = testloader
77
+ # self.client_id = client_id
78
+ # self.losses = []
79
+ # self.accuracies = []
80
+
81
+ # def get_parameters(self, config):
82
+ # return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
83
+
84
+ # def set_parameters(self, parameters):
85
+ # params_dict = zip(self.net.state_dict().keys(), parameters)
86
+ # state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
87
+ # self.net.load_state_dict(state_dict, strict=True)
88
+
89
+ # def fit(self, parameters, config):
90
+ # self.set_parameters(parameters)
91
+ # train(self.net, self.trainloader, epochs=1)
92
+ # loss, accuracy = test(self.net, self.testloader)
93
+ # self.losses.append(loss)
94
+ # self.accuracies.append(accuracy)
95
+ # return self.get_parameters(config={}), len(self.trainloader.dataset), {}
96
+
97
+ # def evaluate(self, parameters, config):
98
+ # self.set_parameters(parameters)
99
+ # loss, accuracy = test(self.net, self.testloader)
100
+ # return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
101
+
102
+ # def plot_metrics(self, round_num, plot_placeholder):
103
+ # if self.losses and self.accuracies:
104
+ # plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
105
+ # plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}")
106
+ # plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}")
107
+
108
+ # fig, ax1 = plt.subplots()
109
+
110
+ # color = 'tab:red'
111
+ # ax1.set_xlabel('Round')
112
+ # ax1.set_ylabel('Loss', color=color)
113
+ # ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color)
114
+ # ax1.tick_params(axis='y', labelcolor=color)
115
+
116
+ # ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
117
+ # color = 'tab:blue'
118
+ # ax2.set_ylabel('Accuracy', color=color)
119
+ # ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color)
120
+ # ax2.tick_params(axis='y', labelcolor=color)
121
+
122
+ # fig.tight_layout()
123
+ # plot_placeholder.pyplot(fig)
124
+
125
+ # def main():
126
+ # st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
127
+ # dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
128
+ # model_name = st.selectbox("Model", ["bert-base-uncased","facebook/hubert-base-ls960", "distilbert-base-uncased"])
129
+
130
+ # NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
131
+ # NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
132
+
133
+ # train_datasets, test_datasets, data_collator = load_data(dataset_name, num_clients=NUM_CLIENTS)
134
+
135
+ # trainloaders = []
136
+ # testloaders = []
137
+ # clients = []
138
+
139
+ # for i in range(NUM_CLIENTS):
140
+ # st.write(f"### Client {i+1} Datasets")
141
+
142
+ # train_df = pd.DataFrame(train_datasets[i])
143
+ # test_df = pd.DataFrame(test_datasets[i])
144
+
145
+ # st.write("#### Train Dataset")
146
+ # edited_train_df = st.data_editor(train_df, key=f"train_{i}")
147
+ # st.write("#### Test Dataset")
148
+ # edited_test_df = st.data_editor(test_df, key=f"test_{i}")
149
+
150
+ # edited_train_dataset = Dataset.from_pandas(edited_train_df)
151
+ # edited_test_dataset = Dataset.from_pandas(edited_test_df)
152
+
153
+ # trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
154
+ # testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
155
+
156
+ # trainloaders.append(trainloader)
157
+ # testloaders.append(testloader)
158
+
159
+ # net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
160
+ # client = CustomClient(net, trainloader, testloader, client_id=i+1)
161
+ # clients.append(client)
162
+
163
+ # if st.button("Start Training"):
164
+ # def client_fn(cid):
165
+ # return clients[int(cid)]
166
+
167
+ # def weighted_average(metrics):
168
+ # accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
169
+ # losses = [num_examples * m["loss"] for num_examples, m in metrics]
170
+ # examples = [num_examples for num_examples, _ in metrics]
171
+ # return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
172
+
173
+ # strategy = fl.server.strategy.FedAvg(
174
+ # fraction_fit=1.0,
175
+ # fraction_evaluate=1.0,
176
+ # evaluate_metrics_aggregation_fn=weighted_average,
177
+ # )
178
+
179
+ # for round_num in range(NUM_ROUNDS):
180
+ # st.write(f"### Round {round_num + 1}")
181
+ # plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
182
+
183
+ # fl.simulation.start_simulation(
184
+ # client_fn=client_fn,
185
+ # num_clients=NUM_CLIENTS,
186
+ # config=fl.server.ServerConfig(num_rounds=1),
187
+ # strategy=strategy,
188
+ # client_resources={"num_cpus": 1, "num_gpus": 0},
189
+ # ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
190
+ # )
191
+
192
+ # for i, client in enumerate(clients):
193
+ # client.plot_metrics(round_num + 1, plot_placeholders[i])
194
+ # st.write(" ")
195
+
196
+ # st.success("Training completed successfully!")
197
+
198
+ # # Display final metrics
199
+ # st.write("## Final Client Metrics")
200
+ # for client in clients:
201
+ # st.write(f"### Client {client.client_id}")
202
+ # st.write(f"Final Loss: {client.losses[-1]:.4f}")
203
+ # st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
204
+ # client.plot_metrics(NUM_ROUNDS, st.empty())
205
+ # st.write(" ")
206
+
207
+ # else:
208
+ # st.write("Click the 'Start Training' button to start the training process.")
209
+
210
+ # if __name__ == "__main__":
211
+ # main()
212
+
213
+ #############
214
  # %%writefile app.py
215
 
216
  import streamlit as st
 
224
  import random
225
  from collections import OrderedDict
226
  import flwr as fl
227
+ from logging import INFO, DEBUG
228
+ from flwr.common.logger import log
229
 
230
  DEVICE = torch.device("cpu")
231
 
 
302
  self.net.load_state_dict(state_dict, strict=True)
303
 
304
  def fit(self, parameters, config):
305
+ log(INFO, f"Client {self.client_id} is starting fit()")
306
  self.set_parameters(parameters)
307
  train(self.net, self.trainloader, epochs=1)
308
  loss, accuracy = test(self.net, self.testloader)
309
  self.losses.append(loss)
310
  self.accuracies.append(accuracy)
311
+ log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
312
  return self.get_parameters(config={}), len(self.trainloader.dataset), {}
313
 
314
  def evaluate(self, parameters, config):
315
+ log(INFO, f"Client {self.client_id} is starting evaluate()")
316
  self.set_parameters(parameters)
317
  loss, accuracy = test(self.net, self.testloader)
318
+ log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
319
  return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
320
 
321
  def plot_metrics(self, round_num, plot_placeholder):
 
341
  fig.tight_layout()
342
  plot_placeholder.pyplot(fig)
343
 
344
+ def read_log_file():
345
+ with open("log.txt", "r") as file:
346
+ return file.read()
347
+
348
  def main():
349
  st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
350
  dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
351
+ model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
352
 
353
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
354
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
 
427
  client.plot_metrics(NUM_ROUNDS, st.empty())
428
  st.write(" ")
429
 
430
+ # Display log.txt content
431
+ st.write("## Training Log")
432
+ st.text(read_log_file())
433
+
434
  else:
435
  st.write("Click the 'Start Training' button to start the training process.")
436
 
437
  if __name__ == "__main__":
438
  main()
439
 
440
+ #############
441
 
442
  # # %%writefile app.py
443