alisrbdni commited on
Commit
bc648b4
·
verified ·
1 Parent(s): b18beee

update with hubert classification

Browse files
Files changed (1) hide show
  1. app.py +25 -2
app.py CHANGED
@@ -309,13 +309,36 @@ def test(net, testloader):
309
  return loss, accuracy
310
 
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  def main():
313
  st.write("## Federated Learning with dynamic models and datasets for mobile devices")
314
  dataset_name = st.selectbox("Dataset", ["audio_instruction_task","imdb", "amazon_polarity", "ag_news"])
315
  model_name = st.selectbox("Model", ["facebook/hubert-base-ls960","bert-base-uncased", "distilbert-base-uncased"])
316
 
317
- net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
318
-
 
319
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
320
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
321
 
 
309
  return loss, accuracy
310
 
311
 
312
+
313
+
314
+
315
+ from transformers import Wav2Vec2Processor, HubertForSequenceClassification
316
+ import torch
317
+
318
+ # def main():
319
+ # st.write("## Audio Classification with HuBERT")
320
+ # dataset_name = st.selectbox("Dataset", ["librispeech", "your_audio_dataset"])
321
+ # model_name = "facebook/hubert-base-ls960"
322
+
323
+ # processor = Wav2Vec2Processor.from_pretrained(model_name)
324
+ # net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
325
+
326
+ # train_dataset, test_dataset = load_data(dataset_name)
327
+ # # Further implementation needed for actual data preparation and training loops
328
+
329
+ # st.write("Details of further steps would be filled in based on specific requirements and dataset structure.")
330
+
331
+ # if __name__ == "__main__":
332
+ # main()
333
+
334
  def main():
335
  st.write("## Federated Learning with dynamic models and datasets for mobile devices")
336
  dataset_name = st.selectbox("Dataset", ["audio_instruction_task","imdb", "amazon_polarity", "ag_news"])
337
  model_name = st.selectbox("Model", ["facebook/hubert-base-ls960","bert-base-uncased", "distilbert-base-uncased"])
338
 
339
+ # net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
340
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
341
+ net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
342
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
343
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
344