alisrbdni commited on
Commit
41d8e7a
·
verified ·
1 Parent(s): ec3d192

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -43
app.py CHANGED
@@ -1,4 +1,5 @@
1
  # %%writefile app.py
 
2
 
3
  import streamlit as st
4
  import matplotlib.pyplot as plt
@@ -6,12 +7,15 @@ import torch
6
  from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
7
  from datasets import load_dataset
8
  from evaluate import load as load_metric
9
- from torch.utils.data import DataLoader, random_split
10
  import random
 
 
 
11
 
12
  DEVICE = torch.device("cpu")
13
 
14
- def load_data(dataset_name, train_size=20, test_size=20):
15
  raw_datasets = load_dataset(dataset_name)
16
  raw_datasets = raw_datasets.shuffle(seed=42)
17
  del raw_datasets["unsupervised"]
@@ -25,14 +29,21 @@ def load_data(dataset_name, train_size=20, test_size=20):
25
  tokenized_datasets = tokenized_datasets.remove_columns("text")
26
  tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
27
 
28
- train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
29
- test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
 
 
 
 
 
 
30
 
31
  data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
32
- trainloader = DataLoader(train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
33
- testloader = DataLoader(test_dataset, batch_size=32, collate_fn=data_collator)
34
 
35
- return trainloader, testloader
 
 
 
36
 
37
  def train(net, trainloader, epochs):
38
  optimizer = AdamW(net.parameters(), lr=5e-5)
@@ -62,6 +73,30 @@ def test(net, testloader):
62
  accuracy = metric.compute()["accuracy"]
63
  return loss, accuracy
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def main():
66
  st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
67
  dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
@@ -72,52 +107,39 @@ def main():
72
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
73
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
74
 
75
- trainloader, testloader = load_data(dataset_name)
76
 
77
  if st.button("Start Training"):
78
  round_losses = []
79
  round_accuracies = []
80
 
81
- for round_num in range(1, NUM_ROUNDS + 1):
82
- st.write(f"## Round {round_num}")
83
-
84
- st.write("### Training Metrics for Each Client")
85
- client_losses = []
86
- client_accuracies = []
87
 
88
- for client in range(1, NUM_CLIENTS + 1):
89
- train_subset, _ = random_split(trainloader.dataset, [len(trainloader.dataset) // NUM_CLIENTS] * NUM_CLIENTS)
90
- trainloader_client = DataLoader(train_subset, shuffle=True, batch_size=32, collate_fn=trainloader.collate_fn)
91
- train(net, trainloader_client, epochs=1)
92
- client_loss, client_accuracy = test(net, testloader)
93
- st.write(f"Client {client}: Loss: {client_loss:.4f}, Accuracy: {client_accuracy:.4f}")
94
- client_losses.append(client_loss)
95
- client_accuracies.append(client_accuracy)
96
 
97
- avg_client_loss = sum(client_losses) / NUM_CLIENTS
98
- avg_client_accuracy = sum(client_accuracies) / NUM_CLIENTS
 
 
 
99
 
100
- st.write("### Average Metrics Across All Clients")
101
- st.write(f"Average Loss: {avg_client_loss:.4f}, Average Accuracy: {avg_client_accuracy:.4f}")
 
 
 
102
 
103
- round_losses.append(avg_client_loss)
104
- round_accuracies.append(avg_client_accuracy)
 
 
 
 
 
 
105
 
106
- st.write("### Accuracy Over Rounds")
107
- plt.plot(range(1, round_num + 1), round_accuracies, marker='o', label="Accuracy")
108
- plt.xlabel("Round")
109
- plt.ylabel("Accuracy")
110
- plt.title("Accuracy Over Rounds")
111
- st.pyplot()
112
-
113
- st.write("### Loss Over Rounds")
114
- plt.plot(range(1, round_num + 1), round_losses, marker='o', color='red', label="Loss")
115
- plt.xlabel("Round")
116
- plt.ylabel("Loss")
117
- plt.title("Loss Over Rounds")
118
- st.pyplot()
119
-
120
- st.success(f"Round {round_num} completed successfully!")
121
 
122
  else:
123
  st.write("Click the 'Start Training' button to start the training process.")
@@ -125,6 +147,7 @@ def main():
125
  if __name__ == "__main__":
126
  main()
127
 
 
128
  ##ORIGINAL###
129
 
130
 
 
1
  # %%writefile app.py
2
+ # %%writefile app.py
3
 
4
  import streamlit as st
5
  import matplotlib.pyplot as plt
 
7
  from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
8
  from datasets import load_dataset
9
  from evaluate import load as load_metric
10
+ from torch.utils.data import DataLoader
11
  import random
12
+ import warnings
13
+ from collections import OrderedDict
14
+ import flwr as fl
15
 
16
  DEVICE = torch.device("cpu")
17
 
18
+ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
19
  raw_datasets = load_dataset(dataset_name)
20
  raw_datasets = raw_datasets.shuffle(seed=42)
21
  del raw_datasets["unsupervised"]
 
29
  tokenized_datasets = tokenized_datasets.remove_columns("text")
30
  tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
31
 
32
+ train_datasets = []
33
+ test_datasets = []
34
+
35
+ for _ in range(num_clients):
36
+ train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
37
+ test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
38
+ train_datasets.append(train_dataset)
39
+ test_datasets.append(test_dataset)
40
 
41
  data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
 
 
42
 
43
+ trainloaders = [DataLoader(ds, shuffle=True, batch_size=32, collate_fn=data_collator) for ds in train_datasets]
44
+ testloaders = [DataLoader(ds, batch_size=32, collate_fn=data_collator) for ds in test_datasets]
45
+
46
+ return trainloaders, testloaders
47
 
48
  def train(net, trainloader, epochs):
49
  optimizer = AdamW(net.parameters(), lr=5e-5)
 
73
  accuracy = metric.compute()["accuracy"]
74
  return loss, accuracy
75
 
76
+ class CustomClient(fl.client.NumPyClient):
77
+ def __init__(self, net, trainloader, testloader):
78
+ self.net = net
79
+ self.trainloader = trainloader
80
+ self.testloader = testloader
81
+
82
+ def get_parameters(self, config):
83
+ return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
84
+
85
+ def set_parameters(self, parameters):
86
+ params_dict = zip(self.net.state_dict().keys(), parameters)
87
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
88
+ self.net.load_state_dict(state_dict, strict=True)
89
+
90
+ def fit(self, parameters, config):
91
+ self.set_parameters(parameters)
92
+ train(self.net, self.trainloader, epochs=1)
93
+ return self.get_parameters(config={}), len(self.trainloader.dataset), {}
94
+
95
+ def evaluate(self, parameters, config):
96
+ self.set_parameters(parameters)
97
+ loss, accuracy = test(self.net, self.testloader)
98
+ return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
99
+
100
  def main():
101
  st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
102
  dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
 
107
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
108
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
109
 
110
+ trainloaders, testloaders = load_data(dataset_name, num_clients=NUM_CLIENTS)
111
 
112
  if st.button("Start Training"):
113
  round_losses = []
114
  round_accuracies = []
115
 
116
+ clients = [CustomClient(net, trainloaders[i], testloaders[i]) for i in range(NUM_CLIENTS)]
 
 
 
 
 
117
 
118
+ def client_fn(cid):
119
+ return clients[int(cid)]
 
 
 
 
 
 
120
 
121
+ def weighted_average(metrics):
122
+ accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
123
+ losses = [num_examples * m["loss"] for num_examples, m in metrics]
124
+ examples = [num_examples for num_examples, _ in metrics]
125
+ return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
126
 
127
+ strategy = fl.server.strategy.FedAvg(
128
+ fraction_fit=1.0,
129
+ fraction_evaluate=1.0,
130
+ evaluate_metrics_aggregation_fn=weighted_average,
131
+ )
132
 
133
+ fl.simulation.start_simulation(
134
+ client_fn=client_fn,
135
+ num_clients=NUM_CLIENTS,
136
+ config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
137
+ strategy=strategy,
138
+ client_resources={"num_cpus": 1, "num_gpus": 0},
139
+ ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 0}
140
+ )
141
 
142
+ st.success(f"Training completed successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  else:
145
  st.write("Click the 'Start Training' button to start the training process.")
 
147
  if __name__ == "__main__":
148
  main()
149
 
150
+
151
  ##ORIGINAL###
152
 
153