alisrbdni commited on
Commit
dd0abb7
·
verified ·
1 Parent(s): bc648b4

update hubert classificaiton

Browse files
Files changed (1) hide show
  1. app.py +27 -2
app.py CHANGED
@@ -330,15 +330,40 @@ import torch
330
 
331
  # if __name__ == "__main__":
332
  # main()
333
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  def main():
335
  st.write("## Federated Learning with dynamic models and datasets for mobile devices")
336
  dataset_name = st.selectbox("Dataset", ["audio_instruction_task","imdb", "amazon_polarity", "ag_news"])
337
  model_name = st.selectbox("Model", ["facebook/hubert-base-ls960","bert-base-uncased", "distilbert-base-uncased"])
338
 
339
  # net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
340
- processor = Wav2Vec2Processor.from_pretrained(model_name)
 
 
 
341
  net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
 
342
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
343
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
344
 
 
330
 
331
  # if __name__ == "__main__":
332
  # main()
333
+ from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
334
+ import torch
335
+ import soundfile as sf
336
+
337
+ def load_audio(file_path):
338
+ # Load an audio file, return waveform and sampling rate
339
+ waveform, sample_rate = sf.read(file_path)
340
+ return waveform, sample_rate
341
+
342
+ def prepare_dataset(data_paths):
343
+ # Dummy function to simulate loading and processing a dataset
344
+ # Replace this with actual data loading and processing logic
345
+ features = []
346
+ labels = []
347
+ for path, label in data_paths:
348
+ waveform, sr = load_audio(path)
349
+ input_values = feature_extractor(waveform, sampling_rate=sr, return_tensors="pt").input_values
350
+ features.append(input_values)
351
+ labels.append(label)
352
+ return torch.cat(features, dim=0), torch.tensor(labels)
353
+
354
+
355
  def main():
356
  st.write("## Federated Learning with dynamic models and datasets for mobile devices")
357
  dataset_name = st.selectbox("Dataset", ["audio_instruction_task","imdb", "amazon_polarity", "ag_news"])
358
  model_name = st.selectbox("Model", ["facebook/hubert-base-ls960","bert-base-uncased", "distilbert-base-uncased"])
359
 
360
  # net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
361
+ # processor = Wav2Vec2Processor.from_pretrained(model_name)
362
+ # net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
363
+
364
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
365
  net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
366
+
367
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
368
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
369