joy1515 commited on
Commit
38d364b
·
verified ·
1 Parent(s): 3c3d449

adjusted code

Browse files
Files changed (1) hide show
  1. app.py +27 -24
app.py CHANGED
@@ -27,16 +27,10 @@ class ImageSearchSystem:
27
  self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
28
  self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(self.device)
29
 
30
- # Prune the model (access vision module correctly)
31
- parameters_to_prune = (
32
- (self.model.vision_model.encoder.layers[0].attention.self.query.weight, 'attention.self.query.weight'),
33
- (self.model.vision_model.encoder.layers[0].attention.self.key.weight, 'attention.self.key.weight'),
34
- (self.model.vision_model.encoder.layers[0].attention.self.value.weight, 'attention.self.value.weight')
35
- )
36
-
37
- # Prune the weights
38
- for param, name in parameters_to_prune:
39
- prune.l1_unstructured(param, amount=0.2)
40
 
41
  # Initialize dataset
42
  self.image_paths = []
@@ -44,30 +38,38 @@ class ImageSearchSystem:
44
  self.initialized = False
45
 
46
  def initialize_dataset(self) -> None:
47
- """Download and process dataset"""
48
  try:
49
- path = kagglehub.dataset_download("alessandrasala79/ai-vs-human-generated-dataset")
50
- image_folder = os.path.join(path, 'test_data_v2')
 
 
 
 
 
 
 
 
 
51
 
52
- self.image_paths = [
53
- f for f in Path(image_folder).glob("**/*")
54
- if f.suffix.lower() in ['.jpg', '.jpeg', '.png']
55
- ]
56
 
57
  if not self.image_paths:
58
- raise ValueError(f"No images found in {image_folder}")
59
-
60
- logger.info(f"Found {len(self.image_paths)} images")
61
 
 
62
  self._create_image_index()
63
  self.initialized = True
64
-
65
  except Exception as e:
66
  logger.error(f"Dataset initialization failed: {str(e)}")
67
  raise
68
 
69
  def _create_image_index(self, batch_size: int = 32) -> None:
70
- """Create FAISS index"""
71
  try:
72
  all_features = []
73
 
@@ -95,7 +97,7 @@ class ImageSearchSystem:
95
  raise
96
 
97
  def search(self, query: str, audio_path: str = None, k: int = 5):
98
- """Search for images using text or speech"""
99
  try:
100
  if not self.initialized:
101
  raise RuntimeError("System not initialized. Call initialize_dataset() first.")
@@ -134,7 +136,7 @@ class ImageSearchSystem:
134
  return [], "Error during search.", None
135
 
136
  def create_demo_interface() -> gr.Interface:
137
- """Create Gradio interface with dark mode & speech support"""
138
  system = ImageSearchSystem()
139
 
140
  try:
@@ -177,3 +179,4 @@ if __name__ == "__main__":
177
  except Exception as e:
178
  logger.error(f"Failed to launch app: {str(e)}")
179
  raise
 
 
27
  self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
28
  self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(self.device)
29
 
30
+ # Prune the model (optimize memory usage)
31
+ for name, module in self.model.named_modules():
32
+ if isinstance(module, torch.nn.Linear):
33
+ prune.l1_unstructured(module, name='weight', amount=0.2)
 
 
 
 
 
 
34
 
35
  # Initialize dataset
36
  self.image_paths = []
 
38
  self.initialized = False
39
 
40
  def initialize_dataset(self) -> None:
41
+ """Automatically download and process the dataset."""
42
  try:
43
+ dataset_path = os.path.expanduser("~/.kagglehub/datasets/alessandrasala79/ai-vs-human-generated-dataset")
44
+ image_folder = os.path.join(dataset_path, 'test_data_v2')
45
+
46
+ # Download dataset if not already present
47
+ if not os.path.exists(dataset_path):
48
+ logger.info("Downloading dataset from Kaggle...")
49
+ dataset_path = kagglehub.dataset_download("alessandrasala79/ai-vs-human-generated-dataset")
50
+
51
+ # Validate dataset
52
+ if not os.path.exists(image_folder):
53
+ raise FileNotFoundError(f"Expected dataset folder not found: {image_folder}")
54
 
55
+ # Load images dynamically
56
+ self.image_paths = [f for f in Path(image_folder).glob("**/*") if f.suffix.lower() in ['.jpg', '.jpeg', '.png']]
 
 
57
 
58
  if not self.image_paths:
59
+ raise ValueError("No images found in the dataset!")
60
+
61
+ logger.info(f"Successfully loaded {len(self.image_paths)} images.")
62
 
63
+ # Create image index
64
  self._create_image_index()
65
  self.initialized = True
66
+
67
  except Exception as e:
68
  logger.error(f"Dataset initialization failed: {str(e)}")
69
  raise
70
 
71
  def _create_image_index(self, batch_size: int = 32) -> None:
72
+ """Create FAISS index for fast image retrieval."""
73
  try:
74
  all_features = []
75
 
 
97
  raise
98
 
99
  def search(self, query: str, audio_path: str = None, k: int = 5):
100
+ """Search for images using text or speech."""
101
  try:
102
  if not self.initialized:
103
  raise RuntimeError("System not initialized. Call initialize_dataset() first.")
 
136
  return [], "Error during search.", None
137
 
138
  def create_demo_interface() -> gr.Interface:
139
+ """Create Gradio interface with dark mode & speech support."""
140
  system = ImageSearchSystem()
141
 
142
  try:
 
179
  except Exception as e:
180
  logger.error(f"Failed to launch app: {str(e)}")
181
  raise
182
+