joy1515 commited on
Commit
370ab3d
·
verified ·
1 Parent(s): dc0ad5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -167
app.py CHANGED
@@ -1,167 +1,175 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import CLIPProcessor, CLIPModel
4
- import numpy as np
5
- import kagglehub
6
- from PIL import Image
7
- import os
8
- from pathlib import Path
9
- import logging
10
- import faiss
11
- from tqdm import tqdm
12
- import speech_recognition as sr
13
- from gtts import gTTS
14
- import tempfile
15
-
16
- # Configure logging
17
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
18
- logger = logging.getLogger(__name__)
19
-
20
- class ImageSearchSystem:
21
- def __init__(self):
22
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
- logger.info(f"Using device: {self.device}")
24
-
25
- # Load CLIP model
26
- self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
27
- self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(self.device)
28
-
29
- # Initialize dataset
30
- self.image_paths = []
31
- self.index = None
32
- self.initialized = False
33
-
34
- def initialize_dataset(self) -> None:
35
- """Download and process dataset"""
36
- try:
37
- path = kagglehub.dataset_download("alessandrasala79/ai-vs-human-generated-dataset")
38
- image_folder = os.path.join(path, 'test_data_v2')
39
-
40
- self.image_paths = [
41
- f for f in Path(image_folder).glob("**/*")
42
- if f.suffix.lower() in ['.jpg', '.jpeg', '.png']
43
- ]
44
-
45
- if not self.image_paths:
46
- raise ValueError(f"No images found in {image_folder}")
47
-
48
- logger.info(f"Found {len(self.image_paths)} images")
49
-
50
- self._create_image_index()
51
- self.initialized = True
52
-
53
- except Exception as e:
54
- logger.error(f"Dataset initialization failed: {str(e)}")
55
- raise
56
-
57
- def _create_image_index(self, batch_size: int = 32) -> None:
58
- """Create FAISS index"""
59
- try:
60
- all_features = []
61
-
62
- for i in tqdm(range(0, len(self.image_paths), batch_size), desc="Indexing images"):
63
- batch_paths = self.image_paths[i:i + batch_size]
64
- batch_images = [Image.open(img).convert("RGB") for img in batch_paths]
65
-
66
- if batch_images:
67
- inputs = self.processor(images=batch_images, return_tensors="pt", padding=True)
68
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
69
-
70
- with torch.no_grad():
71
- features = self.model.get_image_features(**inputs)
72
- features = features / features.norm(dim=-1, keepdim=True)
73
- all_features.append(features.cpu().numpy())
74
-
75
- all_features = np.concatenate(all_features, axis=0)
76
- self.index = faiss.IndexFlatIP(all_features.shape[1])
77
- self.index.add(all_features)
78
-
79
- logger.info("Image index created successfully")
80
-
81
- except Exception as e:
82
- logger.error(f"Failed to create image index: {str(e)}")
83
- raise
84
-
85
- def search(self, query: str, audio_path: str = None, k: int = 5):
86
- """Search for images using text or speech"""
87
- try:
88
- if not self.initialized:
89
- raise RuntimeError("System not initialized. Call initialize_dataset() first.")
90
-
91
- # Convert speech to text if audio input is provided
92
- if audio_path:
93
- recognizer = sr.Recognizer()
94
- with sr.AudioFile(audio_path) as source:
95
- audio_data = recognizer.record(source)
96
- try:
97
- query = recognizer.recognize_google(audio_data)
98
- except sr.UnknownValueError:
99
- return [], "Could not understand the spoken query.", None
100
-
101
- # Process text query
102
- inputs = self.processor(text=[query], return_tensors="pt", padding=True)
103
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
104
-
105
- with torch.no_grad():
106
- text_features = self.model.get_text_features(**inputs)
107
- text_features = text_features / text_features.norm(dim=-1, keepdim=True)
108
-
109
- # Search FAISS index
110
- scores, indices = self.index.search(text_features.cpu().numpy(), k)
111
- results = [Image.open(self.image_paths[idx]) for idx in indices[0]]
112
-
113
- # Generate Text-to-Speech
114
- tts = gTTS(f"Showing results for {query}")
115
- temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
116
- tts.save(temp_audio.name)
117
-
118
- return results, query, temp_audio.name
119
-
120
- except Exception as e:
121
- logger.error(f"Search failed: {str(e)}")
122
- return [], "Error during search.", None
123
-
124
- def create_demo_interface() -> gr.Interface:
125
- """Create Gradio interface with dark mode & speech support"""
126
- system = ImageSearchSystem()
127
-
128
- try:
129
- system.initialize_dataset()
130
- except Exception as e:
131
- logger.error(f"Failed to initialize system: {str(e)}")
132
- raise
133
-
134
- examples = [
135
- ["a beautiful landscape with mountains"],
136
- ["people working in an office"],
137
- ["a cute dog playing"],
138
- ["a modern city skyline at night"],
139
- ["a delicious-looking meal"]
140
- ]
141
-
142
- return gr.Interface(
143
- fn=system.search,
144
- inputs=[
145
- gr.Textbox(label="Enter your search query:", placeholder="Describe the image...", lines=2),
146
- gr.Audio(source="microphone", type="filepath", label="Speak Your Query (Optional)")
147
- ],
148
- outputs=[
149
- gr.Gallery(label="Search Results", show_label=True, columns=5, height="auto"),
150
- gr.Textbox(label="Spoken Query", interactive=False),
151
- gr.Audio(label="Results Spoken Out Loud")
152
- ],
153
- title="Multi-Modal Image Search",
154
- description="Use text or voice to search for images.",
155
- theme="dark",
156
- examples=examples,
157
- cache_examples=True,
158
- css=".gradio-container {background-color: #121212; color: #ffffff;}"
159
- )
160
-
161
- if __name__ == "__main__":
162
- try:
163
- demo = create_demo_interface()
164
- demo.launch(share=True, enable_queue=True, max_threads=40)
165
- except Exception as e:
166
- logger.error(f"Failed to launch app: {str(e)}")
167
- raise
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ import numpy as np
5
+ import kagglehub
6
+ from PIL import Image
7
+ import os
8
+ from pathlib import Path
9
+ import logging
10
+ import faiss
11
+ from tqdm import tqdm
12
+ import speech_recognition as sr
13
+ from gtts import gTTS
14
+ import tempfile
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class ImageSearchSystem:
21
+ def __init__(self):
22
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ logger.info(f"Using device: {self.device}")
24
+
25
+ # Load CLIP model
26
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
27
+ self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(self.device)
28
+
29
+ # Prune the model
30
+ pruned_model = self.model.prune(
31
+ pruning_method="l1_unstructured",
32
+ num_heads_to_prune=10,
33
+ num_layers_to_prune=2,
34
+ )
35
+ self.model = pruned_model
36
+
37
+ # Initialize dataset
38
+ self.image_paths = []
39
+ self.index = None
40
+ self.initialized = False
41
+
42
+ def initialize_dataset(self) -> None:
43
+ """Download and process dataset"""
44
+ try:
45
+ path = kagglehub.dataset_download("alessandrasala79/ai-vs-human-generated-dataset")
46
+ image_folder = os.path.join(path, 'test_data_v2')
47
+
48
+ self.image_paths = [
49
+ f for f in Path(image_folder).glob("**/*")
50
+ if f.suffix.lower() in ['.jpg', '.jpeg', '.png']
51
+ ]
52
+
53
+ if not self.image_paths:
54
+ raise ValueError(f"No images found in {image_folder}")
55
+
56
+ logger.info(f"Found {len(self.image_paths)} images")
57
+
58
+ self._create_image_index()
59
+ self.initialized = True
60
+
61
+ except Exception as e:
62
+ logger.error(f"Dataset initialization failed: {str(e)}")
63
+ raise
64
+
65
+ def _create_image_index(self, batch_size: int = 32) -> None:
66
+ """Create FAISS index"""
67
+ try:
68
+ all_features = []
69
+
70
+ for i in tqdm(range(0, len(self.image_paths), batch_size), desc="Indexing images"):
71
+ batch_paths = self.image_paths[i:i + batch_size]
72
+ batch_images = [Image.open(img).convert("RGB") for img in batch_paths]
73
+
74
+ if batch_images:
75
+ inputs = self.processor(images=batch_images, return_tensors="pt", padding=True)
76
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
77
+
78
+ with torch.no_grad():
79
+ features = self.model.get_image_features(**inputs)
80
+ features = features / features.norm(dim=-1, keepdim=True)
81
+ all_features.append(features.cpu().numpy())
82
+
83
+ all_features = np.concatenate(all_features, axis=0)
84
+ self.index = faiss.IndexFlatIP(all_features.shape[1])
85
+ self.index.add(all_features)
86
+
87
+ logger.info("Image index created successfully")
88
+
89
+ except Exception as e:
90
+ logger.error(f"Failed to create image index: {str(e)}")
91
+ raise
92
+
93
+ def search(self, query: str, audio_path: str = None, k: int = 5):
94
+ """Search for images using text or speech"""
95
+ try:
96
+ if not self.initialized:
97
+ raise RuntimeError("System not initialized. Call initialize_dataset() first.")
98
+
99
+ # Convert speech to text if audio input is provided
100
+ if audio_path:
101
+ recognizer = sr.Recognizer()
102
+ with sr.AudioFile(audio_path) as source:
103
+ audio_data = recognizer.record(source)
104
+ try:
105
+ query = recognizer.recognize_google(audio_data)
106
+ except sr.UnknownValueError:
107
+ return [], "Could not understand the spoken query.", None
108
+
109
+ # Process text query
110
+ inputs = self.processor(text=[query], return_tensors="pt", padding=True)
111
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
112
+
113
+ with torch.no_grad():
114
+ text_features = self.model.get_text_features(**inputs)
115
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
116
+
117
+ # Search FAISS index
118
+ scores, indices = self.index.search(text_features.cpu().numpy(), k)
119
+ results = [Image.open(self.image_paths[idx]) for idx in indices[0]]
120
+
121
+ # Generate Text-to-Speech
122
+ tts = gTTS(f"Showing results for {query}")
123
+ temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
124
+ tts.save(temp_audio.name)
125
+
126
+ return results, query, temp_audio.name
127
+
128
+ except Exception as e:
129
+ logger.error(f"Search failed: {str(e)}")
130
+ return [], "Error during search.", None
131
+
132
+ def create_demo_interface() -> gr.Interface:
133
+ """Create Gradio interface with dark mode & speech support"""
134
+ system = ImageSearchSystem()
135
+
136
+ try:
137
+ system.initialize_dataset()
138
+ except Exception as e:
139
+ logger.error(f"Failed to initialize system: {str(e)}")
140
+ raise
141
+
142
+ examples = [
143
+ ["a beautiful landscape with mountains"],
144
+ ["people working in an office"],
145
+ ["a cute dog playing"],
146
+ ["a modern city skyline at night"],
147
+ ["a delicious-looking meal"]
148
+ ]
149
+
150
+ return gr.Interface(
151
+ fn=system.search,
152
+ inputs=[
153
+ gr.Textbox(label="Enter your search query:", placeholder="Describe the image...", lines=2),
154
+ gr.Audio(source="microphone", type="filepath", label="Speak Your Query (Optional)")
155
+ ],
156
+ outputs=[
157
+ gr.Gallery(label="Search Results", show_label=True, columns=5, height="auto"),
158
+ gr.Textbox(label="Spoken Query", interactive=False),
159
+ gr.Audio(label="Results Spoken Out Loud")
160
+ ],
161
+ title="Multi-Modal Image Search",
162
+ description="Use text or voice to search for images.",
163
+ theme="dark",
164
+ examples=examples,
165
+ cache_examples=True,
166
+ css=".gradio-container {background-color: #121212; color: #ffffff;}"
167
+ )
168
+
169
+ if __name__ == "__main__":
170
+ try:
171
+ demo = create_demo_interface()
172
+ demo.launch(share=True, enable_queue=True, max_threads=40)
173
+ except Exception as e:
174
+ logger.error(f"Failed to launch app: {str(e)}")
175
+ raise