joy1515 commited on
Commit
8826e17
·
verified ·
1 Parent(s): 370ab3d

Implemented pruning

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -12,6 +12,7 @@ from tqdm import tqdm
12
  import speech_recognition as sr
13
  from gtts import gTTS
14
  import tempfile
 
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -27,12 +28,10 @@ class ImageSearchSystem:
27
  self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(self.device)
28
 
29
  # Prune the model
30
- pruned_model = self.model.prune(
31
- pruning_method="l1_unstructured",
32
- num_heads_to_prune=10,
33
- num_layers_to_prune=2,
34
  )
35
- self.model = pruned_model
36
 
37
  # Initialize dataset
38
  self.image_paths = []
 
12
  import speech_recognition as sr
13
  from gtts import gTTS
14
  import tempfile
15
+ import torch.nn.utils.prune as prune
16
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
28
  self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(self.device)
29
 
30
  # Prune the model
31
+ parameters_to_prune = (
32
+ (self.model.visual.encoder.layers, 'weight'),
 
 
33
  )
34
+ prune.l1_unstructured(parameters_to_prune, amount=0.2)
35
 
36
  # Initialize dataset
37
  self.image_paths = []