joy1515 commited on
Commit
a2fb18c
·
verified ·
1 Parent(s): 8826e17

adjusted pruning

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -22,17 +22,22 @@ class ImageSearchSystem:
22
  def __init__(self):
23
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
24
  logger.info(f"Using device: {self.device}")
25
-
26
  # Load CLIP model
27
  self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
28
  self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(self.device)
29
 
30
- # Prune the model
31
  parameters_to_prune = (
32
- (self.model.visual.encoder.layers, 'weight'),
 
 
33
  )
 
 
34
  prune.l1_unstructured(parameters_to_prune, amount=0.2)
35
-
 
36
  # Initialize dataset
37
  self.image_paths = []
38
  self.index = None
 
22
  def __init__(self):
23
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
24
  logger.info(f"Using device: {self.device}")
25
+
26
  # Load CLIP model
27
  self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
28
  self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(self.device)
29
 
30
+ # Prune the model (access vision module directly)
31
  parameters_to_prune = (
32
+ (self.model.vision.transformer.encoder.layers, 'attention.self.query.weight'),
33
+ (self.model.vision.transformer.encoder.layers, 'attention.self.key.weight'),
34
+ (self.model.vision.transformer.encoder.layers, 'attention.self.value.weight')
35
  )
36
+
37
+ # Prune the weights
38
  prune.l1_unstructured(parameters_to_prune, amount=0.2)
39
+
40
+
41
  # Initialize dataset
42
  self.image_paths = []
43
  self.index = None