Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,126 @@ NUM_ROUNDS = 3
|
|
14 |
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
# ########################TinyLLM####################################
|
18 |
|
19 |
# import torch
|
@@ -254,67 +374,8 @@ NUM_ROUNDS = 3
|
|
254 |
|
255 |
# ########################TinyLLM##################################
|
256 |
|
257 |
-
def load_data(dataset_name):
|
258 |
-
raw_datasets = load_dataset(dataset_name)
|
259 |
-
raw_datasets = raw_datasets.shuffle(seed=42)
|
260 |
-
del raw_datasets["unsupervised"]
|
261 |
-
|
262 |
-
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
263 |
-
|
264 |
-
def tokenize_function(examples):
|
265 |
-
return tokenizer(examples["text"], truncation=True)
|
266 |
-
|
267 |
-
train_population = random.sample(range(len(raw_datasets["train"])), 20)
|
268 |
-
test_population = random.sample(range(len(raw_datasets["test"])), 20)
|
269 |
-
|
270 |
-
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
|
271 |
-
tokenized_datasets["train"] = tokenized_datasets["train"].select(train_population)
|
272 |
-
tokenized_datasets["test"] = tokenized_datasets["test"].select(test_population)
|
273 |
-
|
274 |
-
tokenized_datasets = tokenized_datasets.remove_columns("text")
|
275 |
-
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
276 |
-
|
277 |
-
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
278 |
-
trainloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=32, collate_fn=data_collator)
|
279 |
-
testloader = DataLoader(tokenized_datasets["test"], batch_size=32, collate_fn=data_collator)
|
280 |
-
|
281 |
-
return trainloader, testloader
|
282 |
-
|
283 |
-
def train(net, trainloader, epochs):
|
284 |
-
optimizer = AdamW(net.parameters(), lr=5e-5)
|
285 |
-
net.train()
|
286 |
-
for _ in range(epochs):
|
287 |
-
for batch in trainloader:
|
288 |
-
batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
289 |
-
outputs = net(**batch)
|
290 |
-
loss = outputs.loss
|
291 |
-
loss.backward()
|
292 |
-
optimizer.step()
|
293 |
-
optimizer.zero_grad()
|
294 |
-
|
295 |
-
def test(net, testloader):
|
296 |
-
metric = load_metric("accuracy")
|
297 |
-
loss = 0
|
298 |
-
net.eval()
|
299 |
-
for batch in testloader:
|
300 |
-
batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
301 |
-
with torch.no_grad():
|
302 |
-
outputs = net(**batch)
|
303 |
-
logits = outputs.logits
|
304 |
-
loss += outputs.loss.item()
|
305 |
-
predictions = torch.argmax(logits, dim=-1)
|
306 |
-
metric.add_batch(predictions=predictions, references=batch["labels"])
|
307 |
-
loss /= len(testloader.dataset)
|
308 |
-
accuracy = metric.compute()["accuracy"]
|
309 |
-
return loss, accuracy
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
|
314 |
|
315 |
-
from transformers import Wav2Vec2Processor, HubertForSequenceClassification
|
316 |
-
import torch
|
317 |
-
|
318 |
# def main():
|
319 |
# st.write("## Audio Classification with HuBERT")
|
320 |
# dataset_name = st.selectbox("Dataset", ["librispeech", "your_audio_dataset"])
|
@@ -351,58 +412,3 @@ import torch
|
|
351 |
# features.append(input_values)
|
352 |
# labels.append(label)
|
353 |
# return torch.cat(features, dim=0), torch.tensor(labels)
|
354 |
-
|
355 |
-
|
356 |
-
def main():
|
357 |
-
st.write("## Federated Learning with dynamic models and datasets for mobile devices")
|
358 |
-
dataset_name = st.selectbox("Dataset", ["imdb","audio_instruction_task", "amazon_polarity", "ag_news"])
|
359 |
-
model_name = st.selectbox("Model", ["bert-base-uncased","facebook/hubert-base-ls960", "distilbert-base-uncased"])
|
360 |
-
|
361 |
-
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
362 |
-
# processor = Wav2Vec2Processor.from_pretrained(model_name)
|
363 |
-
# net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
364 |
-
|
365 |
-
# feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
|
366 |
-
# net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
367 |
-
|
368 |
-
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
369 |
-
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
370 |
-
|
371 |
-
trainloader, testloader = load_data(dataset_name)
|
372 |
-
|
373 |
-
if st.button("Start Training"):
|
374 |
-
round_losses = []
|
375 |
-
round_accuracies = [] # Store accuracy values for each round
|
376 |
-
for round_num in range(1, NUM_ROUNDS + 1):
|
377 |
-
st.write(f"## Round {round_num}")
|
378 |
-
|
379 |
-
st.write("### Training Metrics for Each Client")
|
380 |
-
for client in range(1, NUM_CLIENTS + 1):
|
381 |
-
client_loss, client_accuracy = test(net, testloader) # Placeholder for actual client metrics
|
382 |
-
st.write(f"Client {client}: Loss: {client_loss}, Accuracy: {client_accuracy}")
|
383 |
-
|
384 |
-
st.write("### Accuracy Over Rounds")
|
385 |
-
round_accuracies.append(client_accuracy) # Append the accuracy for this round
|
386 |
-
plt.plot(range(1, round_num + 1), round_accuracies, marker='o') # Plot accuracy over rounds
|
387 |
-
plt.xlabel("Round")
|
388 |
-
plt.ylabel("Accuracy")
|
389 |
-
plt.title("Accuracy Over Rounds")
|
390 |
-
st.pyplot()
|
391 |
-
|
392 |
-
st.write("### Loss Over Rounds")
|
393 |
-
loss_value = random.random() # Placeholder for loss values
|
394 |
-
round_losses.append(loss_value)
|
395 |
-
rounds = list(range(1, round_num + 1))
|
396 |
-
plt.plot(rounds, round_losses)
|
397 |
-
plt.xlabel("Round")
|
398 |
-
plt.ylabel("Loss")
|
399 |
-
plt.title("Loss Over Rounds")
|
400 |
-
st.pyplot()
|
401 |
-
|
402 |
-
st.success(f"Round {round_num} completed successfully!")
|
403 |
-
|
404 |
-
else:
|
405 |
-
st.write("Click the 'Start Training' button to start the training process.")
|
406 |
-
|
407 |
-
if __name__ == "__main__":
|
408 |
-
main()
|
|
|
14 |
|
15 |
|
16 |
|
17 |
+
def load_data(dataset_name):
|
18 |
+
raw_datasets = load_dataset(dataset_name)
|
19 |
+
raw_datasets = raw_datasets.shuffle(seed=42)
|
20 |
+
del raw_datasets["unsupervised"]
|
21 |
+
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
23 |
+
|
24 |
+
def tokenize_function(examples):
|
25 |
+
return tokenizer(examples["text"], truncation=True)
|
26 |
+
|
27 |
+
train_population = random.sample(range(len(raw_datasets["train"])), 20)
|
28 |
+
test_population = random.sample(range(len(raw_datasets["test"])), 20)
|
29 |
+
|
30 |
+
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
|
31 |
+
tokenized_datasets["train"] = tokenized_datasets["train"].select(train_population)
|
32 |
+
tokenized_datasets["test"] = tokenized_datasets["test"].select(test_population)
|
33 |
+
|
34 |
+
tokenized_datasets = tokenized_datasets.remove_columns("text")
|
35 |
+
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
36 |
+
|
37 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
38 |
+
trainloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=32, collate_fn=data_collator)
|
39 |
+
testloader = DataLoader(tokenized_datasets["test"], batch_size=32, collate_fn=data_collator)
|
40 |
+
|
41 |
+
return trainloader, testloader
|
42 |
+
|
43 |
+
def train(net, trainloader, epochs):
|
44 |
+
optimizer = AdamW(net.parameters(), lr=5e-5)
|
45 |
+
net.train()
|
46 |
+
for _ in range(epochs):
|
47 |
+
for batch in trainloader:
|
48 |
+
batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
49 |
+
outputs = net(**batch)
|
50 |
+
loss = outputs.loss
|
51 |
+
loss.backward()
|
52 |
+
optimizer.step()
|
53 |
+
optimizer.zero_grad()
|
54 |
+
|
55 |
+
def test(net, testloader):
|
56 |
+
metric = load_metric("accuracy")
|
57 |
+
loss = 0
|
58 |
+
net.eval()
|
59 |
+
for batch in testloader:
|
60 |
+
batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
61 |
+
with torch.no_grad():
|
62 |
+
outputs = net(**batch)
|
63 |
+
logits = outputs.logits
|
64 |
+
loss += outputs.loss.item()
|
65 |
+
predictions = torch.argmax(logits, dim=-1)
|
66 |
+
metric.add_batch(predictions=predictions, references=batch["labels"])
|
67 |
+
loss /= len(testloader.dataset)
|
68 |
+
accuracy = metric.compute()["accuracy"]
|
69 |
+
return loss, accuracy
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
from transformers import Wav2Vec2Processor, HubertForSequenceClassification
|
76 |
+
import torch
|
77 |
+
|
78 |
+
|
79 |
+
def main():
|
80 |
+
st.write("## Federated Learning with dynamic models and datasets for mobile devices")
|
81 |
+
dataset_name = st.selectbox("Dataset", ["imdb","audio_instruction_task", "amazon_polarity", "ag_news"])
|
82 |
+
model_name = st.selectbox("Model", ["bert-base-uncased","facebook/hubert-base-ls960", "distilbert-base-uncased"])
|
83 |
+
|
84 |
+
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
85 |
+
# processor = Wav2Vec2Processor.from_pretrained(model_name)
|
86 |
+
# net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
87 |
+
|
88 |
+
# feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
|
89 |
+
# net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
90 |
+
|
91 |
+
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
92 |
+
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
93 |
+
|
94 |
+
trainloader, testloader = load_data(dataset_name)
|
95 |
+
|
96 |
+
if st.button("Start Training"):
|
97 |
+
round_losses = []
|
98 |
+
round_accuracies = [] # Store accuracy values for each round
|
99 |
+
for round_num in range(1, NUM_ROUNDS + 1):
|
100 |
+
st.write(f"## Round {round_num}")
|
101 |
+
|
102 |
+
st.write("### Training Metrics for Each Client")
|
103 |
+
for client in range(1, NUM_CLIENTS + 1):
|
104 |
+
client_loss, client_accuracy = test(net, testloader) # Placeholder for actual client metrics
|
105 |
+
st.write(f"Client {client}: Loss: {client_loss}, Accuracy: {client_accuracy}")
|
106 |
+
|
107 |
+
st.write("### Accuracy Over Rounds")
|
108 |
+
round_accuracies.append(client_accuracy) # Append the accuracy for this round
|
109 |
+
plt.plot(range(1, round_num + 1), round_accuracies, marker='o') # Plot accuracy over rounds
|
110 |
+
plt.xlabel("Round")
|
111 |
+
plt.ylabel("Accuracy")
|
112 |
+
plt.title("Accuracy Over Rounds")
|
113 |
+
st.pyplot()
|
114 |
+
|
115 |
+
st.write("### Loss Over Rounds")
|
116 |
+
loss_value = random.random() # Placeholder for loss values
|
117 |
+
round_losses.append(loss_value)
|
118 |
+
rounds = list(range(1, round_num + 1))
|
119 |
+
plt.plot(rounds, round_losses)
|
120 |
+
plt.xlabel("Round")
|
121 |
+
plt.ylabel("Loss")
|
122 |
+
plt.title("Loss Over Rounds")
|
123 |
+
st.pyplot()
|
124 |
+
|
125 |
+
st.success(f"Round {round_num} completed successfully!")
|
126 |
+
|
127 |
+
else:
|
128 |
+
st.write("Click the 'Start Training' button to start the training process.")
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
main()
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
# ########################TinyLLM####################################
|
138 |
|
139 |
# import torch
|
|
|
374 |
|
375 |
# ########################TinyLLM##################################
|
376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
|
378 |
|
|
|
|
|
|
|
379 |
# def main():
|
380 |
# st.write("## Audio Classification with HuBERT")
|
381 |
# dataset_name = st.selectbox("Dataset", ["librispeech", "your_audio_dataset"])
|
|
|
412 |
# features.append(input_values)
|
413 |
# labels.append(label)
|
414 |
# return torch.cat(features, dim=0), torch.tensor(labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|