Spaces:
Build error
Build error
update hubert classificaiton
Browse files
app.py
CHANGED
|
@@ -330,15 +330,40 @@ import torch
|
|
| 330 |
|
| 331 |
# if __name__ == "__main__":
|
| 332 |
# main()
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
def main():
|
| 335 |
st.write("## Federated Learning with dynamic models and datasets for mobile devices")
|
| 336 |
dataset_name = st.selectbox("Dataset", ["audio_instruction_task","imdb", "amazon_polarity", "ag_news"])
|
| 337 |
model_name = st.selectbox("Model", ["facebook/hubert-base-ls960","bert-base-uncased", "distilbert-base-uncased"])
|
| 338 |
|
| 339 |
# net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
| 340 |
-
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
|
| 341 |
net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
|
|
|
| 342 |
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
| 343 |
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
| 344 |
|
|
|
|
| 330 |
|
| 331 |
# if __name__ == "__main__":
|
| 332 |
# main()
|
| 333 |
+
from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
|
| 334 |
+
import torch
|
| 335 |
+
import soundfile as sf
|
| 336 |
+
|
| 337 |
+
def load_audio(file_path):
|
| 338 |
+
# Load an audio file, return waveform and sampling rate
|
| 339 |
+
waveform, sample_rate = sf.read(file_path)
|
| 340 |
+
return waveform, sample_rate
|
| 341 |
+
|
| 342 |
+
def prepare_dataset(data_paths):
|
| 343 |
+
# Dummy function to simulate loading and processing a dataset
|
| 344 |
+
# Replace this with actual data loading and processing logic
|
| 345 |
+
features = []
|
| 346 |
+
labels = []
|
| 347 |
+
for path, label in data_paths:
|
| 348 |
+
waveform, sr = load_audio(path)
|
| 349 |
+
input_values = feature_extractor(waveform, sampling_rate=sr, return_tensors="pt").input_values
|
| 350 |
+
features.append(input_values)
|
| 351 |
+
labels.append(label)
|
| 352 |
+
return torch.cat(features, dim=0), torch.tensor(labels)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
def main():
|
| 356 |
st.write("## Federated Learning with dynamic models and datasets for mobile devices")
|
| 357 |
dataset_name = st.selectbox("Dataset", ["audio_instruction_task","imdb", "amazon_polarity", "ag_news"])
|
| 358 |
model_name = st.selectbox("Model", ["facebook/hubert-base-ls960","bert-base-uncased", "distilbert-base-uncased"])
|
| 359 |
|
| 360 |
# net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
| 361 |
+
# processor = Wav2Vec2Processor.from_pretrained(model_name)
|
| 362 |
+
# net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
| 363 |
+
|
| 364 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
|
| 365 |
net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
| 366 |
+
|
| 367 |
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
|
| 368 |
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
|
| 369 |
|