alisrbdni commited on
Commit
2859e8b
·
verified ·
1 Parent(s): 68ebf06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -114
app.py CHANGED
@@ -14,6 +14,126 @@ NUM_ROUNDS = 3
14
 
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # ########################TinyLLM####################################
18
 
19
  # import torch
@@ -254,67 +374,8 @@ NUM_ROUNDS = 3
254
 
255
  # ########################TinyLLM##################################
256
 
257
- def load_data(dataset_name):
258
- raw_datasets = load_dataset(dataset_name)
259
- raw_datasets = raw_datasets.shuffle(seed=42)
260
- del raw_datasets["unsupervised"]
261
-
262
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
263
-
264
- def tokenize_function(examples):
265
- return tokenizer(examples["text"], truncation=True)
266
-
267
- train_population = random.sample(range(len(raw_datasets["train"])), 20)
268
- test_population = random.sample(range(len(raw_datasets["test"])), 20)
269
-
270
- tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
271
- tokenized_datasets["train"] = tokenized_datasets["train"].select(train_population)
272
- tokenized_datasets["test"] = tokenized_datasets["test"].select(test_population)
273
-
274
- tokenized_datasets = tokenized_datasets.remove_columns("text")
275
- tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
276
-
277
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
278
- trainloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=32, collate_fn=data_collator)
279
- testloader = DataLoader(tokenized_datasets["test"], batch_size=32, collate_fn=data_collator)
280
-
281
- return trainloader, testloader
282
-
283
- def train(net, trainloader, epochs):
284
- optimizer = AdamW(net.parameters(), lr=5e-5)
285
- net.train()
286
- for _ in range(epochs):
287
- for batch in trainloader:
288
- batch = {k: v.to(DEVICE) for k, v in batch.items()}
289
- outputs = net(**batch)
290
- loss = outputs.loss
291
- loss.backward()
292
- optimizer.step()
293
- optimizer.zero_grad()
294
-
295
- def test(net, testloader):
296
- metric = load_metric("accuracy")
297
- loss = 0
298
- net.eval()
299
- for batch in testloader:
300
- batch = {k: v.to(DEVICE) for k, v in batch.items()}
301
- with torch.no_grad():
302
- outputs = net(**batch)
303
- logits = outputs.logits
304
- loss += outputs.loss.item()
305
- predictions = torch.argmax(logits, dim=-1)
306
- metric.add_batch(predictions=predictions, references=batch["labels"])
307
- loss /= len(testloader.dataset)
308
- accuracy = metric.compute()["accuracy"]
309
- return loss, accuracy
310
-
311
-
312
-
313
 
314
 
315
- from transformers import Wav2Vec2Processor, HubertForSequenceClassification
316
- import torch
317
-
318
  # def main():
319
  # st.write("## Audio Classification with HuBERT")
320
  # dataset_name = st.selectbox("Dataset", ["librispeech", "your_audio_dataset"])
@@ -351,58 +412,3 @@ import torch
351
  # features.append(input_values)
352
  # labels.append(label)
353
  # return torch.cat(features, dim=0), torch.tensor(labels)
354
-
355
-
356
- def main():
357
- st.write("## Federated Learning with dynamic models and datasets for mobile devices")
358
- dataset_name = st.selectbox("Dataset", ["imdb","audio_instruction_task", "amazon_polarity", "ag_news"])
359
- model_name = st.selectbox("Model", ["bert-base-uncased","facebook/hubert-base-ls960", "distilbert-base-uncased"])
360
-
361
- net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
362
- # processor = Wav2Vec2Processor.from_pretrained(model_name)
363
- # net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
364
-
365
- # feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
366
- # net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
367
-
368
- NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
369
- NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
370
-
371
- trainloader, testloader = load_data(dataset_name)
372
-
373
- if st.button("Start Training"):
374
- round_losses = []
375
- round_accuracies = [] # Store accuracy values for each round
376
- for round_num in range(1, NUM_ROUNDS + 1):
377
- st.write(f"## Round {round_num}")
378
-
379
- st.write("### Training Metrics for Each Client")
380
- for client in range(1, NUM_CLIENTS + 1):
381
- client_loss, client_accuracy = test(net, testloader) # Placeholder for actual client metrics
382
- st.write(f"Client {client}: Loss: {client_loss}, Accuracy: {client_accuracy}")
383
-
384
- st.write("### Accuracy Over Rounds")
385
- round_accuracies.append(client_accuracy) # Append the accuracy for this round
386
- plt.plot(range(1, round_num + 1), round_accuracies, marker='o') # Plot accuracy over rounds
387
- plt.xlabel("Round")
388
- plt.ylabel("Accuracy")
389
- plt.title("Accuracy Over Rounds")
390
- st.pyplot()
391
-
392
- st.write("### Loss Over Rounds")
393
- loss_value = random.random() # Placeholder for loss values
394
- round_losses.append(loss_value)
395
- rounds = list(range(1, round_num + 1))
396
- plt.plot(rounds, round_losses)
397
- plt.xlabel("Round")
398
- plt.ylabel("Loss")
399
- plt.title("Loss Over Rounds")
400
- st.pyplot()
401
-
402
- st.success(f"Round {round_num} completed successfully!")
403
-
404
- else:
405
- st.write("Click the 'Start Training' button to start the training process.")
406
-
407
- if __name__ == "__main__":
408
- main()
 
14
 
15
 
16
 
17
+ def load_data(dataset_name):
18
+ raw_datasets = load_dataset(dataset_name)
19
+ raw_datasets = raw_datasets.shuffle(seed=42)
20
+ del raw_datasets["unsupervised"]
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
23
+
24
+ def tokenize_function(examples):
25
+ return tokenizer(examples["text"], truncation=True)
26
+
27
+ train_population = random.sample(range(len(raw_datasets["train"])), 20)
28
+ test_population = random.sample(range(len(raw_datasets["test"])), 20)
29
+
30
+ tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
31
+ tokenized_datasets["train"] = tokenized_datasets["train"].select(train_population)
32
+ tokenized_datasets["test"] = tokenized_datasets["test"].select(test_population)
33
+
34
+ tokenized_datasets = tokenized_datasets.remove_columns("text")
35
+ tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
36
+
37
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
38
+ trainloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=32, collate_fn=data_collator)
39
+ testloader = DataLoader(tokenized_datasets["test"], batch_size=32, collate_fn=data_collator)
40
+
41
+ return trainloader, testloader
42
+
43
+ def train(net, trainloader, epochs):
44
+ optimizer = AdamW(net.parameters(), lr=5e-5)
45
+ net.train()
46
+ for _ in range(epochs):
47
+ for batch in trainloader:
48
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
49
+ outputs = net(**batch)
50
+ loss = outputs.loss
51
+ loss.backward()
52
+ optimizer.step()
53
+ optimizer.zero_grad()
54
+
55
+ def test(net, testloader):
56
+ metric = load_metric("accuracy")
57
+ loss = 0
58
+ net.eval()
59
+ for batch in testloader:
60
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
61
+ with torch.no_grad():
62
+ outputs = net(**batch)
63
+ logits = outputs.logits
64
+ loss += outputs.loss.item()
65
+ predictions = torch.argmax(logits, dim=-1)
66
+ metric.add_batch(predictions=predictions, references=batch["labels"])
67
+ loss /= len(testloader.dataset)
68
+ accuracy = metric.compute()["accuracy"]
69
+ return loss, accuracy
70
+
71
+
72
+
73
+
74
+
75
+ from transformers import Wav2Vec2Processor, HubertForSequenceClassification
76
+ import torch
77
+
78
+
79
+ def main():
80
+ st.write("## Federated Learning with dynamic models and datasets for mobile devices")
81
+ dataset_name = st.selectbox("Dataset", ["imdb","audio_instruction_task", "amazon_polarity", "ag_news"])
82
+ model_name = st.selectbox("Model", ["bert-base-uncased","facebook/hubert-base-ls960", "distilbert-base-uncased"])
83
+
84
+ net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
85
+ # processor = Wav2Vec2Processor.from_pretrained(model_name)
86
+ # net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
87
+
88
+ # feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
89
+ # net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
90
+
91
+ NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
92
+ NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
93
+
94
+ trainloader, testloader = load_data(dataset_name)
95
+
96
+ if st.button("Start Training"):
97
+ round_losses = []
98
+ round_accuracies = [] # Store accuracy values for each round
99
+ for round_num in range(1, NUM_ROUNDS + 1):
100
+ st.write(f"## Round {round_num}")
101
+
102
+ st.write("### Training Metrics for Each Client")
103
+ for client in range(1, NUM_CLIENTS + 1):
104
+ client_loss, client_accuracy = test(net, testloader) # Placeholder for actual client metrics
105
+ st.write(f"Client {client}: Loss: {client_loss}, Accuracy: {client_accuracy}")
106
+
107
+ st.write("### Accuracy Over Rounds")
108
+ round_accuracies.append(client_accuracy) # Append the accuracy for this round
109
+ plt.plot(range(1, round_num + 1), round_accuracies, marker='o') # Plot accuracy over rounds
110
+ plt.xlabel("Round")
111
+ plt.ylabel("Accuracy")
112
+ plt.title("Accuracy Over Rounds")
113
+ st.pyplot()
114
+
115
+ st.write("### Loss Over Rounds")
116
+ loss_value = random.random() # Placeholder for loss values
117
+ round_losses.append(loss_value)
118
+ rounds = list(range(1, round_num + 1))
119
+ plt.plot(rounds, round_losses)
120
+ plt.xlabel("Round")
121
+ plt.ylabel("Loss")
122
+ plt.title("Loss Over Rounds")
123
+ st.pyplot()
124
+
125
+ st.success(f"Round {round_num} completed successfully!")
126
+
127
+ else:
128
+ st.write("Click the 'Start Training' button to start the training process.")
129
+
130
+ if __name__ == "__main__":
131
+ main()
132
+
133
+
134
+
135
+
136
+
137
  # ########################TinyLLM####################################
138
 
139
  # import torch
 
374
 
375
  # ########################TinyLLM##################################
376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
 
 
 
 
379
  # def main():
380
  # st.write("## Audio Classification with HuBERT")
381
  # dataset_name = st.selectbox("Dataset", ["librispeech", "your_audio_dataset"])
 
412
  # features.append(input_values)
413
  # labels.append(label)
414
  # return torch.cat(features, dim=0), torch.tensor(labels)