Spaces:
Sleeping
Sleeping
import asyncio | |
import os | |
import logging | |
from PIL import Image | |
import torch | |
from transformers import ( | |
CLIPProcessor, | |
CLIPModel, | |
BlipProcessor, | |
BlipForConditionalGeneration, | |
) | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
import aiofiles | |
import json | |
from abc import ABC, abstractmethod | |
from typing import Set, Tuple | |
from concurrent.futures import ProcessPoolExecutor | |
from dataclasses import dataclass, field | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
device = "cpu" | |
class State: | |
processed_files: Set[str] = field(default_factory=set) | |
def to_dict(self) -> dict: | |
return {"processed_files": list(self.processed_files)} | |
def from_dict(state_dict: dict) -> "State": | |
return State(processed_files=set(state_dict.get("processed_files", []))) | |
class ImageProcessor(ABC): | |
def process(self, image: Image.Image) -> np.ndarray: | |
pass | |
class CLIPImageProcessor(ImageProcessor): | |
def __init__(self): | |
self.model = CLIPModel.from_pretrained( | |
"wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M" | |
).to(device) | |
self.processor = CLIPProcessor.from_pretrained( | |
"wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M" | |
) | |
print("Initialized CLIP model and processor") | |
def process(self, image: Image.Image) -> np.ndarray: | |
inputs = self.processor(images=image, return_tensors="pt").to(device) | |
outputs = self.model.get_image_features(**inputs) | |
return outputs.detach().cpu().numpy() | |
class ImageCaptioningProcessor(ImageProcessor): | |
def __init__(self): | |
self.image_caption_model = BlipForConditionalGeneration.from_pretrained( | |
"Salesforce/blip-image-captioning-base" | |
).to(device) | |
self.image_caption_processor = BlipProcessor.from_pretrained( | |
"Salesforce/blip-image-captioning-base" | |
) | |
self.text_embedding_model = SentenceTransformer( | |
"all-MiniLM-L6-v2", device=device | |
) | |
print("Initialized BLIP model and processor") | |
def process(self, image: Image.Image) -> np.ndarray: | |
inputs = self.image_caption_processor(images=image, return_tensors="pt").to( | |
device | |
) | |
output = self.image_caption_model.generate(**inputs) | |
caption = self.image_caption_processor.decode( | |
output[0], skip_special_tokens=True | |
) | |
# embedding dim 384 | |
return self.text_embedding_model.encode(caption).flatten() | |
class ImageFeatureExtractor: | |
def __init__( | |
self, | |
clip_processor: CLIPImageProcessor, | |
caption_processor: ImageCaptioningProcessor, | |
max_queue_size: int = 100, | |
checkpoint_file: str = "checkpoint.json", | |
): | |
self.clip_processor = clip_processor | |
self.caption_processor = caption_processor | |
self.image_queue = asyncio.Queue(maxsize=max_queue_size) | |
self.processed_images_queue = asyncio.Queue() | |
self.checkpoint_file = checkpoint_file | |
self.state = self.load_state() | |
self.executor = ProcessPoolExecutor() | |
self.total_images = 0 | |
self.processed_count = 0 | |
print( | |
"Initialized ImageFeatureExtractor with checkpoint file:", checkpoint_file | |
) | |
async def image_loader(self, input_folder: str): | |
print(f"Loading images from {input_folder}") | |
for filename in os.listdir(input_folder): | |
if "resized_" in filename and filename not in self.state.processed_files: | |
try: | |
file_path = os.path.join(input_folder, filename) | |
await self.image_queue.put((filename, file_path)) | |
self.total_images += 1 | |
print(f"Loaded image {filename} into queue") | |
except Exception as e: | |
logger.error(f"Error loading image {filename}: {e}") | |
await self.image_queue.put(None) # Sentinel to signal end of images | |
print(f"Total images to process: {self.total_images}") | |
async def image_processor_worker(self, loop: asyncio.AbstractEventLoop): | |
while True: | |
item = await self.image_queue.get() | |
if item is None: | |
await self.image_queue.put(None) # Propagate sentinel | |
break | |
filename, file_path = item | |
try: | |
print(f"Processing image {filename}") | |
image = Image.open(file_path) | |
clip_embedding, caption_embedding = await asyncio.gather( | |
loop.run_in_executor( | |
self.executor, self.clip_processor.process, image | |
), | |
loop.run_in_executor( | |
self.executor, self.caption_processor.process, image | |
), | |
) | |
await self.processed_images_queue.put( | |
(filename, clip_embedding, caption_embedding) | |
) | |
print(f"Processed image {filename}") | |
except Exception as e: | |
logger.error(f"Error processing image {filename}: {e}") | |
finally: | |
self.image_queue.task_done() | |
async def save_processed_images(self, output_folder: str): | |
while self.processed_count < self.total_images: | |
filename, clip_embedding, caption_embedding = ( | |
await self.processed_images_queue.get() | |
) | |
try: | |
clip_output_path = os.path.join( | |
output_folder, f"{os.path.splitext(filename)[0]}_clip.npy" | |
) | |
caption_output_path = os.path.join( | |
output_folder, f"{os.path.splitext(filename)[0]}_caption.npy" | |
) | |
await asyncio.gather( | |
self.save_embedding(clip_output_path, clip_embedding), | |
self.save_embedding(caption_output_path, caption_embedding), | |
) | |
self.state.processed_files.add(filename) | |
self.save_state() | |
self.processed_count += 1 | |
print(f"Saved processed embeddings for {filename}") | |
except Exception as e: | |
logger.error(f"Error saving processed image {filename}: {e}") | |
finally: | |
self.processed_images_queue.task_done() | |
async def save_embedding(self, output_path: str, embedding: np.ndarray): | |
async with aiofiles.open(output_path, "wb") as f: | |
await f.write(embedding.tobytes()) | |
def load_state(self) -> State: | |
try: | |
with open(self.checkpoint_file, "r") as f: | |
state_dict = json.load(f) | |
print("Loaded state from checkpoint") | |
return State.from_dict(state_dict) | |
except (FileNotFoundError, json.JSONDecodeError): | |
print("No checkpoint found, starting with empty state") | |
return State() | |
def save_state(self): | |
with open(self.checkpoint_file, "w") as f: | |
json.dump(self.state.to_dict(), f) | |
print("Saved state to checkpoint") | |
async def run( | |
self, | |
input_folder: str, | |
output_folder: str, | |
loop: asyncio.AbstractEventLoop, | |
num_workers: int = 2, | |
): | |
os.makedirs(output_folder, exist_ok=True) | |
print(f"Output folder {output_folder} created") | |
tasks = [ | |
loop.create_task(self.image_loader(input_folder)), | |
loop.create_task(self.save_processed_images(output_folder)), | |
] | |
tasks.extend( | |
[ | |
loop.create_task(self.image_processor_worker(loop)) | |
for _ in range(num_workers) | |
] | |
) | |
await asyncio.gather(*tasks) | |
class ImageFeatureExtractorFactory: | |
def create() -> ImageFeatureExtractor: | |
print( | |
"Creating ImageFeatureExtractor with CLIPImageProcessor and ImageCaptioningProcessor" | |
) | |
return ImageFeatureExtractor(CLIPImageProcessor(), ImageCaptioningProcessor()) | |
async def main(loop: asyncio.AbstractEventLoop, input_folder: str, output_folder: str): | |
print("Starting main function") | |
extractor = ImageFeatureExtractorFactory.create() | |
try: | |
await extractor.run(input_folder, output_folder, loop) | |
except Exception as e: | |
logger.error(f"An error occurred during execution: {e}") | |
finally: | |
logger.info("Image processing completed.") | |
if __name__ == "__main__": | |
from pathlib import Path | |
PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
print("Event loop created and set") | |
input_folder = str(PROJECT_ROOT / "data/images") | |
output_folder = str(PROJECT_ROOT / "data/features") | |
loop.run_until_complete(main(loop, input_folder, output_folder)) | |
loop.close() | |
print("Event loop closed") | |