update with hubert classification
Browse files
app.py
CHANGED
@@ -309,13 +309,36 @@ def test(net, testloader):
|
|
309 |
return loss, accuracy
|
310 |
|
311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
def main():
|
313 |
st.write("## Federated Learning with dynamic models and datasets for mobile devices")
|
314 |
dataset_name = st.selectbox("Dataset", ["audio_instruction_task","imdb", "amazon_polarity", "ag_news"])
|
315 |
model_name = st.selectbox("Model", ["facebook/hubert-base-ls960","bert-base-uncased", "distilbert-base-uncased"])
|
316 |
|
317 |
-
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
318 |
-
|
|
|
319 |
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
320 |
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
321 |
|
|
|
309 |
return loss, accuracy
|
310 |
|
311 |
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
from transformers import Wav2Vec2Processor, HubertForSequenceClassification
|
316 |
+
import torch
|
317 |
+
|
318 |
+
# def main():
|
319 |
+
# st.write("## Audio Classification with HuBERT")
|
320 |
+
# dataset_name = st.selectbox("Dataset", ["librispeech", "your_audio_dataset"])
|
321 |
+
# model_name = "facebook/hubert-base-ls960"
|
322 |
+
|
323 |
+
# processor = Wav2Vec2Processor.from_pretrained(model_name)
|
324 |
+
# net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
325 |
+
|
326 |
+
# train_dataset, test_dataset = load_data(dataset_name)
|
327 |
+
# # Further implementation needed for actual data preparation and training loops
|
328 |
+
|
329 |
+
# st.write("Details of further steps would be filled in based on specific requirements and dataset structure.")
|
330 |
+
|
331 |
+
# if __name__ == "__main__":
|
332 |
+
# main()
|
333 |
+
|
334 |
def main():
|
335 |
st.write("## Federated Learning with dynamic models and datasets for mobile devices")
|
336 |
dataset_name = st.selectbox("Dataset", ["audio_instruction_task","imdb", "amazon_polarity", "ag_news"])
|
337 |
model_name = st.selectbox("Model", ["facebook/hubert-base-ls960","bert-base-uncased", "distilbert-base-uncased"])
|
338 |
|
339 |
+
# net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
340 |
+
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
341 |
+
net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
342 |
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
343 |
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
344 |
|