import torch import numpy as np from PIL import Image import clip from transformers import BlipProcessor, BlipForConditionalGeneration import logging from sentence_transformers import SentenceTransformer, util logger = logging.getLogger(__name__) class PromptEvaluator: """Prompt following assessment using CLIP and other vision-language models""" def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.models = {} self.processors = {} self.load_models() def load_models(self): """Load prompt evaluation models""" try: # Load CLIP model (primary) logger.info("Loading CLIP model...") self.load_clip() # Load BLIP-2 model (secondary) logger.info("Loading BLIP-2 model...") self.load_blip2() # Load sentence transformer for text similarity logger.info("Loading sentence transformer...") self.load_sentence_transformer() except Exception as e: logger.error(f"Error loading prompt evaluation models: {str(e)}") self.use_fallback_implementation() def load_clip(self): """Load CLIP model""" try: model, preprocess = clip.load("ViT-B/32", device=self.device) self.models['clip'] = model self.processors['clip'] = preprocess logger.info("CLIP model loaded successfully") except Exception as e: logger.warning(f"Could not load CLIP: {str(e)}") def load_blip2(self): """Load BLIP-2 model""" try: processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") model = model.to(self.device) self.models['blip2'] = model self.processors['blip2'] = processor logger.info("BLIP-2 model loaded successfully") except Exception as e: logger.warning(f"Could not load BLIP-2: {str(e)}") def load_sentence_transformer(self): """Load sentence transformer for text similarity""" try: model = SentenceTransformer('all-MiniLM-L6-v2') self.models['sentence_transformer'] = model logger.info("Sentence transformer loaded successfully") except Exception as e: logger.warning(f"Could not load sentence transformer: {str(e)}") def use_fallback_implementation(self): """Use simple fallback prompt evaluation""" logger.info("Using fallback prompt evaluation implementation") self.fallback_mode = True def evaluate_with_clip(self, image: Image.Image, prompt: str) -> float: """Evaluate prompt following using CLIP""" try: if 'clip' not in self.models: return self.fallback_prompt_score(image, prompt) model = self.models['clip'] preprocess = self.processors['clip'] # Preprocess image image_tensor = preprocess(image).unsqueeze(0).to(self.device) # Tokenize text text_tokens = clip.tokenize([prompt]).to(self.device) # Get features with torch.no_grad(): image_features = model.encode_image(image_tensor) text_features = model.encode_text(text_tokens) # Normalize features image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) # Calculate similarity similarity = (image_features @ text_features.T).item() # Convert similarity to 0-10 scale # CLIP similarity is typically between -1 and 1, but usually 0-1 for related content score = max(0.0, min(10.0, (similarity + 1) * 5)) return score except Exception as e: logger.error(f"Error in CLIP evaluation: {str(e)}") return self.fallback_prompt_score(image, prompt) def evaluate_with_blip2(self, image: Image.Image, prompt: str) -> float: """Evaluate prompt following using BLIP-2""" try: if 'blip2' not in self.models: return self.fallback_prompt_score(image, prompt) model = self.models['blip2'] processor = self.processors['blip2'] # Generate caption for the image inputs = processor(image, return_tensors="pt").to(self.device) with torch.no_grad(): out = model.generate(**inputs, max_length=50) generated_caption = processor.decode(out[0], skip_special_tokens=True) # Compare generated caption with original prompt using text similarity if 'sentence_transformer' in self.models: similarity_score = self.calculate_text_similarity(prompt, generated_caption) else: # Simple word overlap fallback similarity_score = self.simple_text_similarity(prompt, generated_caption) return similarity_score except Exception as e: logger.error(f"Error in BLIP-2 evaluation: {str(e)}") return self.fallback_prompt_score(image, prompt) def calculate_text_similarity(self, text1: str, text2: str) -> float: """Calculate semantic similarity between two texts""" try: model = self.models['sentence_transformer'] # Encode texts embeddings = model.encode([text1, text2]) # Calculate cosine similarity similarity = util.cos_sim(embeddings[0], embeddings[1]).item() # Convert to 0-10 scale score = max(0.0, min(10.0, (similarity + 1) * 5)) return score except Exception as e: logger.error(f"Error calculating text similarity: {str(e)}") return self.simple_text_similarity(text1, text2) def simple_text_similarity(self, text1: str, text2: str) -> float: """Simple word overlap similarity""" try: # Convert to lowercase and split into words words1 = set(text1.lower().split()) words2 = set(text2.lower().split()) # Calculate Jaccard similarity intersection = len(words1.intersection(words2)) union = len(words1.union(words2)) if union == 0: return 0.0 jaccard_similarity = intersection / union # Convert to 0-10 scale score = jaccard_similarity * 10 return max(0.0, min(10.0, score)) except Exception: return 5.0 # Default neutral score def extract_key_concepts(self, prompt: str) -> list: """Extract key concepts from prompt for detailed analysis""" try: # Simple keyword extraction # In production, this could use more sophisticated NLP # Remove common words stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should'} words = prompt.lower().split() key_concepts = [word for word in words if word not in stop_words and len(word) > 2] return key_concepts except Exception: return [] def evaluate_concept_presence(self, image: Image.Image, concepts: list) -> float: """Evaluate presence of specific concepts in image""" try: if 'clip' not in self.models or not concepts: return 5.0 model = self.models['clip'] preprocess = self.processors['clip'] # Preprocess image image_tensor = preprocess(image).unsqueeze(0).to(self.device) # Create concept queries concept_queries = [f"a photo of {concept}" for concept in concepts] # Tokenize concepts text_tokens = clip.tokenize(concept_queries).to(self.device) # Get features with torch.no_grad(): image_features = model.encode_image(image_tensor) text_features = model.encode_text(text_tokens) # Normalize features image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) # Calculate similarities similarities = (image_features @ text_features.T).squeeze(0) # Average similarity across concepts avg_similarity = similarities.mean().item() # Convert to 0-10 scale score = max(0.0, min(10.0, (avg_similarity + 1) * 5)) return score except Exception as e: logger.error(f"Error in concept presence evaluation: {str(e)}") return 5.0 def fallback_prompt_score(self, image: Image.Image, prompt: str) -> float: """Simple fallback prompt evaluation""" try: # Very basic evaluation based on prompt length and image properties prompt_length = len(prompt.split()) # Longer, more detailed prompts might be harder to follow perfectly if prompt_length < 5: length_penalty = 0.0 elif prompt_length < 15: length_penalty = 0.5 else: length_penalty = 1.0 # Base score base_score = 7.0 - length_penalty return max(0.0, min(10.0, base_score)) except Exception: return 5.0 # Default neutral score def evaluate(self, image: Image.Image, prompt: str) -> float: """ Evaluate how well the image follows the given prompt Args: image: PIL Image to evaluate prompt: Text prompt to compare against Returns: Prompt following score from 0-10 """ try: if not prompt or not prompt.strip(): return 0.0 # No prompt to evaluate against scores = [] # CLIP evaluation (primary) clip_score = self.evaluate_with_clip(image, prompt) scores.append(clip_score) # BLIP-2 evaluation (secondary) blip2_score = self.evaluate_with_blip2(image, prompt) scores.append(blip2_score) # Concept presence evaluation key_concepts = self.extract_key_concepts(prompt) concept_score = self.evaluate_concept_presence(image, key_concepts) scores.append(concept_score) # Ensemble scoring weights = [0.5, 0.3, 0.2] # CLIP gets highest weight final_score = sum(score * weight for score, weight in zip(scores, weights)) logger.info(f"Prompt scores - CLIP: {clip_score:.2f}, BLIP-2: {blip2_score:.2f}, " f"Concepts: {concept_score:.2f}, Final: {final_score:.2f}") return max(0.0, min(10.0, final_score)) except Exception as e: logger.error(f"Error in prompt evaluation: {str(e)}") return self.fallback_prompt_score(image, prompt)