Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import CLIPProcessor, CLIPModel | |
| import numpy as np | |
| import kagglehub | |
| from PIL import Image | |
| import os | |
| from pathlib import Path | |
| import logging | |
| import faiss | |
| from tqdm import tqdm | |
| import speech_recognition as sr | |
| from gtts import gTTS | |
| import tempfile | |
| import torch.nn.utils.prune as prune | |
| import random | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| class ImageSearchSystem: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {self.device}") | |
| # Load CLIP model | |
| self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") | |
| self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(self.device) | |
| # Prune the model (optimize memory usage) | |
| for name, module in self.model.named_modules(): | |
| if isinstance(module, torch.nn.Linear): | |
| prune.l1_unstructured(module, name='weight', amount=0.2) | |
| # Initialize dataset | |
| self.image_paths = [] | |
| self.index = None | |
| self.initialized = False | |
| def initialize_dataset(self) -> None: | |
| """Automatically download and process the dataset with a 500-sample limit.""" | |
| try: | |
| logger.info("Downloading dataset from KaggleHub...") | |
| dataset_path = kagglehub.dataset_download("alessandrasala79/ai-vs-human-generated-dataset") | |
| image_folder = os.path.join(dataset_path, 'test_data_v2') # Adjust if needed | |
| # Validate dataset | |
| if not os.path.exists(image_folder): | |
| raise FileNotFoundError(f"Expected dataset folder not found: {image_folder}") | |
| # Load images dynamically | |
| all_images = [f for f in Path(image_folder).glob("**/*") if f.suffix.lower() in ['.jpg', '.jpeg', '.png']] | |
| if not all_images: | |
| raise ValueError("No images found in the dataset!") | |
| # Limit dataset to 500 randomly selected samples | |
| self.image_paths = random.sample(all_images, min(500, len(all_images))) | |
| logger.info(f"Loaded {len(self.image_paths)} images (limited to 500 samples).") | |
| # Create image index | |
| self._create_image_index() | |
| self.initialized = True | |
| except Exception as e: | |
| logger.error(f"Dataset initialization failed: {str(e)}") | |
| raise | |
| def _create_image_index(self, batch_size: int = 32) -> None: | |
| """Create FAISS index for fast image retrieval.""" | |
| try: | |
| all_features = [] | |
| for i in tqdm(range(0, len(self.image_paths), batch_size), desc="Indexing images"): | |
| batch_paths = self.image_paths[i:i + batch_size] | |
| batch_images = [Image.open(img).convert("RGB") for img in batch_paths] | |
| if batch_images: | |
| inputs = self.processor(images=batch_images, return_tensors="pt", padding=True) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| features = self.model.get_image_features(**inputs) | |
| features = features / features.norm(dim=-1, keepdim=True) | |
| all_features.append(features.cpu().numpy()) | |
| all_features = np.concatenate(all_features, axis=0) | |
| self.index = faiss.IndexFlatIP(all_features.shape[1]) | |
| self.index.add(all_features) | |
| logger.info("Image index created successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to create image index: {str(e)}") | |
| raise | |
| def search(self, query: str, audio_path: str = None, k: int = 5): | |
| """Search for images using text or speech.""" | |
| try: | |
| if not self.initialized: | |
| raise RuntimeError("System not initialized. Call initialize_dataset() first.") | |
| # Convert speech to text if audio input is provided | |
| if audio_path: | |
| recognizer = sr.Recognizer() | |
| with sr.AudioFile(audio_path) as source: | |
| audio_data = recognizer.record(source) | |
| try: | |
| query = recognizer.recognize_google(audio_data) | |
| except sr.UnknownValueError: | |
| return [], "Could not understand the spoken query.", None | |
| # Process text query | |
| inputs = self.processor(text=[query], return_tensors="pt", padding=True) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| text_features = self.model.get_text_features(**inputs) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| # Search FAISS index | |
| scores, indices = self.index.search(text_features.cpu().numpy(), k) | |
| results = [Image.open(self.image_paths[idx]) for idx in indices[0]] | |
| # Generate Text-to-Speech | |
| tts = gTTS(f"Showing results for {query}") | |
| temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") | |
| tts.save(temp_audio.name) | |
| return results, query, temp_audio.name | |
| except Exception as e: | |
| logger.error(f"Search failed: {str(e)}") | |
| return [], "Error during search.", None | |
| def create_demo_interface() -> gr.Interface: | |
| """Create Gradio interface with dark mode & speech support.""" | |
| system = ImageSearchSystem() | |
| try: | |
| system.initialize_dataset() | |
| except Exception as e: | |
| logger.error(f"Failed to initialize system: {str(e)}") | |
| raise | |
| examples = [ | |
| ["a beautiful landscape with mountains"], | |
| ["people working in an office"], | |
| ["a cute dog playing"], | |
| ["a modern city skyline at night"], | |
| ["a delicious-looking meal"] | |
| ] | |
| return gr.Interface( | |
| fn=system.search, | |
| inputs=[ | |
| gr.Textbox(label="Enter your search query:", placeholder="Describe the image...", lines=2), | |
| gr.Audio(sources=["microphone"], type="filepath", label="Speak Your Query (Optional)") | |
| ], | |
| outputs=[ | |
| gr.Gallery(label="Search Results", show_label=True, columns=5, height="auto"), | |
| gr.Textbox(label="Spoken Query", interactive=False), | |
| gr.Audio(label="Results Spoken Out Loud") | |
| ], | |
| title="Multi-Modal Image Search", | |
| description="Use text or voice to search for images.", | |
| theme="dark", | |
| examples=examples, | |
| cache_examples=True, | |
| css=".gradio-container {background-color: #121212; color: #ffffff;}" | |
| ) | |
| if __name__ == "__main__": | |
| try: | |
| demo = create_demo_interface() | |
| demo.launch(share=True, max_threads=40) | |
| except Exception as e: | |
| logger.error(f"Failed to launch app: {str(e)}") | |
| raise | |