Update app.py
Browse files
app.py
CHANGED
@@ -6,15 +6,12 @@ import torch
|
|
6 |
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
|
7 |
from datasets import load_dataset
|
8 |
from evaluate import load as load_metric
|
9 |
-
from torch.utils.data import DataLoader
|
10 |
import random
|
11 |
|
12 |
DEVICE = torch.device("cpu")
|
13 |
-
NUM_ROUNDS = 3
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
def load_data(dataset_name):
|
18 |
raw_datasets = load_dataset(dataset_name)
|
19 |
raw_datasets = raw_datasets.shuffle(seed=42)
|
20 |
del raw_datasets["unsupervised"]
|
@@ -24,19 +21,16 @@ def load_data(dataset_name):
|
|
24 |
def tokenize_function(examples):
|
25 |
return tokenizer(examples["text"], truncation=True)
|
26 |
|
27 |
-
train_population = random.sample(range(len(raw_datasets["train"])), 20)
|
28 |
-
test_population = random.sample(range(len(raw_datasets["test"])), 20)
|
29 |
-
|
30 |
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
|
31 |
-
tokenized_datasets["train"] = tokenized_datasets["train"].select(train_population)
|
32 |
-
tokenized_datasets["test"] = tokenized_datasets["test"].select(test_population)
|
33 |
-
|
34 |
tokenized_datasets = tokenized_datasets.remove_columns("text")
|
35 |
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
36 |
|
|
|
|
|
|
|
37 |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
38 |
-
trainloader = DataLoader(
|
39 |
-
testloader = DataLoader(
|
40 |
|
41 |
return trainloader, testloader
|
42 |
|
@@ -54,8 +48,8 @@ def train(net, trainloader, epochs):
|
|
54 |
|
55 |
def test(net, testloader):
|
56 |
metric = load_metric("accuracy")
|
57 |
-
loss = 0
|
58 |
net.eval()
|
|
|
59 |
for batch in testloader:
|
60 |
batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
61 |
with torch.no_grad():
|
@@ -64,29 +58,16 @@ def test(net, testloader):
|
|
64 |
loss += outputs.loss.item()
|
65 |
predictions = torch.argmax(logits, dim=-1)
|
66 |
metric.add_batch(predictions=predictions, references=batch["labels"])
|
67 |
-
loss /= len(testloader
|
68 |
accuracy = metric.compute()["accuracy"]
|
69 |
return loss, accuracy
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
from transformers import Wav2Vec2Processor, HubertForSequenceClassification
|
76 |
-
import torch
|
77 |
-
|
78 |
-
|
79 |
def main():
|
80 |
-
st.write("## Federated Learning with
|
81 |
-
dataset_name = st.selectbox("Dataset", ["imdb",
|
82 |
-
model_name = st.selectbox("Model", ["bert-base-uncased",
|
83 |
|
84 |
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
85 |
-
# processor = Wav2Vec2Processor.from_pretrained(model_name)
|
86 |
-
# net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
87 |
-
|
88 |
-
# feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
|
89 |
-
# net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
90 |
|
91 |
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
92 |
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
@@ -95,28 +76,42 @@ def main():
|
|
95 |
|
96 |
if st.button("Start Training"):
|
97 |
round_losses = []
|
98 |
-
round_accuracies = []
|
|
|
99 |
for round_num in range(1, NUM_ROUNDS + 1):
|
100 |
st.write(f"## Round {round_num}")
|
101 |
|
102 |
st.write("### Training Metrics for Each Client")
|
|
|
|
|
|
|
103 |
for client in range(1, NUM_CLIENTS + 1):
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
st.write("### Accuracy Over Rounds")
|
108 |
-
|
109 |
-
plt.plot(range(1, round_num + 1), round_accuracies, marker='o') # Plot accuracy over rounds
|
110 |
plt.xlabel("Round")
|
111 |
plt.ylabel("Accuracy")
|
112 |
plt.title("Accuracy Over Rounds")
|
113 |
st.pyplot()
|
114 |
|
115 |
st.write("### Loss Over Rounds")
|
116 |
-
|
117 |
-
round_losses.append(loss_value)
|
118 |
-
rounds = list(range(1, round_num + 1))
|
119 |
-
plt.plot(rounds, round_losses)
|
120 |
plt.xlabel("Round")
|
121 |
plt.ylabel("Loss")
|
122 |
plt.title("Loss Over Rounds")
|
@@ -130,6 +125,141 @@ def main():
|
|
130 |
if __name__ == "__main__":
|
131 |
main()
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
|
135 |
|
|
|
6 |
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
|
7 |
from datasets import load_dataset
|
8 |
from evaluate import load as load_metric
|
9 |
+
from torch.utils.data import DataLoader, random_split
|
10 |
import random
|
11 |
|
12 |
DEVICE = torch.device("cpu")
|
|
|
13 |
|
14 |
+
def load_data(dataset_name, train_size=20, test_size=20):
|
|
|
|
|
15 |
raw_datasets = load_dataset(dataset_name)
|
16 |
raw_datasets = raw_datasets.shuffle(seed=42)
|
17 |
del raw_datasets["unsupervised"]
|
|
|
21 |
def tokenize_function(examples):
|
22 |
return tokenizer(examples["text"], truncation=True)
|
23 |
|
|
|
|
|
|
|
24 |
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
|
|
|
|
|
|
|
25 |
tokenized_datasets = tokenized_datasets.remove_columns("text")
|
26 |
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
27 |
|
28 |
+
train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
|
29 |
+
test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
|
30 |
+
|
31 |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
32 |
+
trainloader = DataLoader(train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
|
33 |
+
testloader = DataLoader(test_dataset, batch_size=32, collate_fn=data_collator)
|
34 |
|
35 |
return trainloader, testloader
|
36 |
|
|
|
48 |
|
49 |
def test(net, testloader):
|
50 |
metric = load_metric("accuracy")
|
|
|
51 |
net.eval()
|
52 |
+
loss = 0
|
53 |
for batch in testloader:
|
54 |
batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
55 |
with torch.no_grad():
|
|
|
58 |
loss += outputs.loss.item()
|
59 |
predictions = torch.argmax(logits, dim=-1)
|
60 |
metric.add_batch(predictions=predictions, references=batch["labels"])
|
61 |
+
loss /= len(testloader)
|
62 |
accuracy = metric.compute()["accuracy"]
|
63 |
return loss, accuracy
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
def main():
|
66 |
+
st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
|
67 |
+
dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
|
68 |
+
model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"])
|
69 |
|
70 |
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
73 |
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
|
|
76 |
|
77 |
if st.button("Start Training"):
|
78 |
round_losses = []
|
79 |
+
round_accuracies = []
|
80 |
+
|
81 |
for round_num in range(1, NUM_ROUNDS + 1):
|
82 |
st.write(f"## Round {round_num}")
|
83 |
|
84 |
st.write("### Training Metrics for Each Client")
|
85 |
+
client_losses = []
|
86 |
+
client_accuracies = []
|
87 |
+
|
88 |
for client in range(1, NUM_CLIENTS + 1):
|
89 |
+
train_subset, _ = random_split(trainloader.dataset, [len(trainloader.dataset) // NUM_CLIENTS] * NUM_CLIENTS)
|
90 |
+
trainloader_client = DataLoader(train_subset, shuffle=True, batch_size=32, collate_fn=trainloader.collate_fn)
|
91 |
+
train(net, trainloader_client, epochs=1)
|
92 |
+
client_loss, client_accuracy = test(net, testloader)
|
93 |
+
st.write(f"Client {client}: Loss: {client_loss:.4f}, Accuracy: {client_accuracy:.4f}")
|
94 |
+
client_losses.append(client_loss)
|
95 |
+
client_accuracies.append(client_accuracy)
|
96 |
+
|
97 |
+
avg_client_loss = sum(client_losses) / NUM_CLIENTS
|
98 |
+
avg_client_accuracy = sum(client_accuracies) / NUM_CLIENTS
|
99 |
+
|
100 |
+
st.write("### Average Metrics Across All Clients")
|
101 |
+
st.write(f"Average Loss: {avg_client_loss:.4f}, Average Accuracy: {avg_client_accuracy:.4f}")
|
102 |
+
|
103 |
+
round_losses.append(avg_client_loss)
|
104 |
+
round_accuracies.append(avg_client_accuracy)
|
105 |
|
106 |
st.write("### Accuracy Over Rounds")
|
107 |
+
plt.plot(range(1, round_num + 1), round_accuracies, marker='o', label="Accuracy")
|
|
|
108 |
plt.xlabel("Round")
|
109 |
plt.ylabel("Accuracy")
|
110 |
plt.title("Accuracy Over Rounds")
|
111 |
st.pyplot()
|
112 |
|
113 |
st.write("### Loss Over Rounds")
|
114 |
+
plt.plot(range(1, round_num + 1), round_losses, marker='o', color='red', label="Loss")
|
|
|
|
|
|
|
115 |
plt.xlabel("Round")
|
116 |
plt.ylabel("Loss")
|
117 |
plt.title("Loss Over Rounds")
|
|
|
125 |
if __name__ == "__main__":
|
126 |
main()
|
127 |
|
128 |
+
##ORIGINAL###
|
129 |
+
|
130 |
+
|
131 |
+
# # %%writefile app.py
|
132 |
+
|
133 |
+
# import streamlit as st
|
134 |
+
# import matplotlib.pyplot as plt
|
135 |
+
# import torch
|
136 |
+
# from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
|
137 |
+
# from datasets import load_dataset
|
138 |
+
# from evaluate import load as load_metric
|
139 |
+
# from torch.utils.data import DataLoader
|
140 |
+
# import random
|
141 |
+
|
142 |
+
# DEVICE = torch.device("cpu")
|
143 |
+
# NUM_ROUNDS = 3
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
# def load_data(dataset_name):
|
148 |
+
# raw_datasets = load_dataset(dataset_name)
|
149 |
+
# raw_datasets = raw_datasets.shuffle(seed=42)
|
150 |
+
# del raw_datasets["unsupervised"]
|
151 |
+
|
152 |
+
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
153 |
+
|
154 |
+
# def tokenize_function(examples):
|
155 |
+
# return tokenizer(examples["text"], truncation=True)
|
156 |
+
|
157 |
+
# train_population = random.sample(range(len(raw_datasets["train"])), 20)
|
158 |
+
# test_population = random.sample(range(len(raw_datasets["test"])), 20)
|
159 |
+
|
160 |
+
# tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
|
161 |
+
# tokenized_datasets["train"] = tokenized_datasets["train"].select(train_population)
|
162 |
+
# tokenized_datasets["test"] = tokenized_datasets["test"].select(test_population)
|
163 |
+
|
164 |
+
# tokenized_datasets = tokenized_datasets.remove_columns("text")
|
165 |
+
# tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
166 |
+
|
167 |
+
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
168 |
+
# trainloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=32, collate_fn=data_collator)
|
169 |
+
# testloader = DataLoader(tokenized_datasets["test"], batch_size=32, collate_fn=data_collator)
|
170 |
+
|
171 |
+
# return trainloader, testloader
|
172 |
+
|
173 |
+
# def train(net, trainloader, epochs):
|
174 |
+
# optimizer = AdamW(net.parameters(), lr=5e-5)
|
175 |
+
# net.train()
|
176 |
+
# for _ in range(epochs):
|
177 |
+
# for batch in trainloader:
|
178 |
+
# batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
179 |
+
# outputs = net(**batch)
|
180 |
+
# loss = outputs.loss
|
181 |
+
# loss.backward()
|
182 |
+
# optimizer.step()
|
183 |
+
# optimizer.zero_grad()
|
184 |
+
|
185 |
+
# def test(net, testloader):
|
186 |
+
# metric = load_metric("accuracy")
|
187 |
+
# loss = 0
|
188 |
+
# net.eval()
|
189 |
+
# for batch in testloader:
|
190 |
+
# batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
191 |
+
# with torch.no_grad():
|
192 |
+
# outputs = net(**batch)
|
193 |
+
# logits = outputs.logits
|
194 |
+
# loss += outputs.loss.item()
|
195 |
+
# predictions = torch.argmax(logits, dim=-1)
|
196 |
+
# metric.add_batch(predictions=predictions, references=batch["labels"])
|
197 |
+
# loss /= len(testloader.dataset)
|
198 |
+
# accuracy = metric.compute()["accuracy"]
|
199 |
+
# return loss, accuracy
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
# from transformers import Wav2Vec2Processor, HubertForSequenceClassification
|
206 |
+
# import torch
|
207 |
+
|
208 |
+
|
209 |
+
# def main():
|
210 |
+
# st.write("## Federated Learning with dynamic models and datasets for mobile devices")
|
211 |
+
# dataset_name = st.selectbox("Dataset", ["imdb","audio_instruction_task", "amazon_polarity", "ag_news"])
|
212 |
+
# model_name = st.selectbox("Model", ["bert-base-uncased","facebook/hubert-base-ls960", "distilbert-base-uncased"])
|
213 |
+
|
214 |
+
# net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
215 |
+
# # processor = Wav2Vec2Processor.from_pretrained(model_name)
|
216 |
+
# # net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
217 |
+
|
218 |
+
# # feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
|
219 |
+
# # net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
220 |
+
|
221 |
+
# NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
222 |
+
# NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
223 |
+
|
224 |
+
# trainloader, testloader = load_data(dataset_name)
|
225 |
+
|
226 |
+
# if st.button("Start Training"):
|
227 |
+
# round_losses = []
|
228 |
+
# round_accuracies = [] # Store accuracy values for each round
|
229 |
+
# for round_num in range(1, NUM_ROUNDS + 1):
|
230 |
+
# st.write(f"## Round {round_num}")
|
231 |
+
|
232 |
+
# st.write("### Training Metrics for Each Client")
|
233 |
+
# for client in range(1, NUM_CLIENTS + 1):
|
234 |
+
# client_loss, client_accuracy = test(net, testloader) # Placeholder for actual client metrics
|
235 |
+
# st.write(f"Client {client}: Loss: {client_loss}, Accuracy: {client_accuracy}")
|
236 |
+
|
237 |
+
# st.write("### Accuracy Over Rounds")
|
238 |
+
# round_accuracies.append(client_accuracy) # Append the accuracy for this round
|
239 |
+
# plt.plot(range(1, round_num + 1), round_accuracies, marker='o') # Plot accuracy over rounds
|
240 |
+
# plt.xlabel("Round")
|
241 |
+
# plt.ylabel("Accuracy")
|
242 |
+
# plt.title("Accuracy Over Rounds")
|
243 |
+
# st.pyplot()
|
244 |
+
|
245 |
+
# st.write("### Loss Over Rounds")
|
246 |
+
# loss_value = random.random() # Placeholder for loss values
|
247 |
+
# round_losses.append(loss_value)
|
248 |
+
# rounds = list(range(1, round_num + 1))
|
249 |
+
# plt.plot(rounds, round_losses)
|
250 |
+
# plt.xlabel("Round")
|
251 |
+
# plt.ylabel("Loss")
|
252 |
+
# plt.title("Loss Over Rounds")
|
253 |
+
# st.pyplot()
|
254 |
+
|
255 |
+
# st.success(f"Round {round_num} completed successfully!")
|
256 |
+
|
257 |
+
# else:
|
258 |
+
# st.write("Click the 'Start Training' button to start the training process.")
|
259 |
+
|
260 |
+
# if __name__ == "__main__":
|
261 |
+
# main()
|
262 |
+
###ORIGINAL##
|
263 |
|
264 |
|
265 |
|