alisrbdni commited on
Commit
b722033
·
verified ·
1 Parent(s): dd0abb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -22
app.py CHANGED
@@ -330,26 +330,27 @@ import torch
330
 
331
  # if __name__ == "__main__":
332
  # main()
333
- from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
334
- import torch
335
- import soundfile as sf
336
-
337
- def load_audio(file_path):
338
- # Load an audio file, return waveform and sampling rate
339
- waveform, sample_rate = sf.read(file_path)
340
- return waveform, sample_rate
341
-
342
- def prepare_dataset(data_paths):
343
- # Dummy function to simulate loading and processing a dataset
344
- # Replace this with actual data loading and processing logic
345
- features = []
346
- labels = []
347
- for path, label in data_paths:
348
- waveform, sr = load_audio(path)
349
- input_values = feature_extractor(waveform, sampling_rate=sr, return_tensors="pt").input_values
350
- features.append(input_values)
351
- labels.append(label)
352
- return torch.cat(features, dim=0), torch.tensor(labels)
 
353
 
354
 
355
  def main():
@@ -361,8 +362,8 @@ def main():
361
  # processor = Wav2Vec2Processor.from_pretrained(model_name)
362
  # net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
363
 
364
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
365
- net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
366
 
367
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
368
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
 
330
 
331
  # if __name__ == "__main__":
332
  # main()
333
+
334
+ # from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
335
+ # import torch
336
+ # import soundfile as sf
337
+
338
+ # def load_audio(file_path):
339
+ # # Load an audio file, return waveform and sampling rate
340
+ # waveform, sample_rate = sf.read(file_path)
341
+ # return waveform, sample_rate
342
+
343
+ # def prepare_dataset(data_paths):
344
+ # # Dummy function to simulate loading and processing a dataset
345
+ # # Replace this with actual data loading and processing logic
346
+ # features = []
347
+ # labels = []
348
+ # for path, label in data_paths:
349
+ # waveform, sr = load_audio(path)
350
+ # input_values = feature_extractor(waveform, sampling_rate=sr, return_tensors="pt").input_values
351
+ # features.append(input_values)
352
+ # labels.append(label)
353
+ # return torch.cat(features, dim=0), torch.tensor(labels)
354
 
355
 
356
  def main():
 
362
  # processor = Wav2Vec2Processor.from_pretrained(model_name)
363
  # net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
364
 
365
+ # feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
366
+ # net = HubertForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
367
 
368
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
369
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)