update hubert classificaiton
Browse files
app.py
CHANGED
@@ -330,15 +330,40 @@ import torch
|
|
330 |
|
331 |
# if __name__ == "__main__":
|
332 |
# main()
|
333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
def main():
|
335 |
st.write("## Federated Learning with dynamic models and datasets for mobile devices")
|
336 |
dataset_name = st.selectbox("Dataset", ["audio_instruction_task","imdb", "amazon_polarity", "ag_news"])
|
337 |
model_name = st.selectbox("Model", ["facebook/hubert-base-ls960","bert-base-uncased", "distilbert-base-uncased"])
|
338 |
|
339 |
# net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
340 |
-
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
|
|
|
|
|
|
341 |
net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
|
|
342 |
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
343 |
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
344 |
|
|
|
330 |
|
331 |
# if __name__ == "__main__":
|
332 |
# main()
|
333 |
+
from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
|
334 |
+
import torch
|
335 |
+
import soundfile as sf
|
336 |
+
|
337 |
+
def load_audio(file_path):
|
338 |
+
# Load an audio file, return waveform and sampling rate
|
339 |
+
waveform, sample_rate = sf.read(file_path)
|
340 |
+
return waveform, sample_rate
|
341 |
+
|
342 |
+
def prepare_dataset(data_paths):
|
343 |
+
# Dummy function to simulate loading and processing a dataset
|
344 |
+
# Replace this with actual data loading and processing logic
|
345 |
+
features = []
|
346 |
+
labels = []
|
347 |
+
for path, label in data_paths:
|
348 |
+
waveform, sr = load_audio(path)
|
349 |
+
input_values = feature_extractor(waveform, sampling_rate=sr, return_tensors="pt").input_values
|
350 |
+
features.append(input_values)
|
351 |
+
labels.append(label)
|
352 |
+
return torch.cat(features, dim=0), torch.tensor(labels)
|
353 |
+
|
354 |
+
|
355 |
def main():
|
356 |
st.write("## Federated Learning with dynamic models and datasets for mobile devices")
|
357 |
dataset_name = st.selectbox("Dataset", ["audio_instruction_task","imdb", "amazon_polarity", "ag_news"])
|
358 |
model_name = st.selectbox("Model", ["facebook/hubert-base-ls960","bert-base-uncased", "distilbert-base-uncased"])
|
359 |
|
360 |
# net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
361 |
+
# processor = Wav2Vec2Processor.from_pretrained(model_name)
|
362 |
+
# net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
363 |
+
|
364 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
|
365 |
net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
366 |
+
|
367 |
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
368 |
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
369 |
|