Curinha commited on
Commit
cd0c397
·
1 Parent(s): a008552

Fix model loading in sound_generator.py to access the model attribute directly

Browse files
Files changed (1) hide show
  1. sound_generator.py +2 -2
sound_generator.py CHANGED
@@ -6,8 +6,8 @@ import torch
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
  print("Using device:", device)
8
 
9
- sound_model = AudioGen.get_pretrained('facebook/audiogen-medium').to(device)
10
- music_model = MusicGen.get_pretrained('facebook/musicgen-small').to(device)
11
 
12
  # Set generation parameters (for example, audio duration of 5 seconds)
13
  sound_model.set_generation_params(duration=5)
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
  print("Using device:", device)
8
 
9
+ sound_model = AudioGen.get_pretrained('facebook/audiogen-medium').model.to(device)
10
+ music_model = MusicGen.get_pretrained('facebook/musicgen-small').model.to(device)
11
 
12
  # Set generation parameters (for example, audio duration of 5 seconds)
13
  sound_model.set_generation_params(duration=5)