Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from PIL import Image | |
import clip | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
import logging | |
from sentence_transformers import SentenceTransformer, util | |
logger = logging.getLogger(__name__) | |
class PromptEvaluator: | |
"""Prompt following assessment using CLIP and other vision-language models""" | |
def __init__(self): | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.models = {} | |
self.processors = {} | |
self.load_models() | |
def load_models(self): | |
"""Load prompt evaluation models""" | |
try: | |
# Load CLIP model (primary) | |
logger.info("Loading CLIP model...") | |
self.load_clip() | |
# Load BLIP-2 model (secondary) | |
logger.info("Loading BLIP-2 model...") | |
self.load_blip2() | |
# Load sentence transformer for text similarity | |
logger.info("Loading sentence transformer...") | |
self.load_sentence_transformer() | |
except Exception as e: | |
logger.error(f"Error loading prompt evaluation models: {str(e)}") | |
self.use_fallback_implementation() | |
def load_clip(self): | |
"""Load CLIP model""" | |
try: | |
model, preprocess = clip.load("ViT-B/32", device=self.device) | |
self.models['clip'] = model | |
self.processors['clip'] = preprocess | |
logger.info("CLIP model loaded successfully") | |
except Exception as e: | |
logger.warning(f"Could not load CLIP: {str(e)}") | |
def load_blip2(self): | |
"""Load BLIP-2 model""" | |
try: | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
model = model.to(self.device) | |
self.models['blip2'] = model | |
self.processors['blip2'] = processor | |
logger.info("BLIP-2 model loaded successfully") | |
except Exception as e: | |
logger.warning(f"Could not load BLIP-2: {str(e)}") | |
def load_sentence_transformer(self): | |
"""Load sentence transformer for text similarity""" | |
try: | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
self.models['sentence_transformer'] = model | |
logger.info("Sentence transformer loaded successfully") | |
except Exception as e: | |
logger.warning(f"Could not load sentence transformer: {str(e)}") | |
def use_fallback_implementation(self): | |
"""Use simple fallback prompt evaluation""" | |
logger.info("Using fallback prompt evaluation implementation") | |
self.fallback_mode = True | |
def evaluate_with_clip(self, image: Image.Image, prompt: str) -> float: | |
"""Evaluate prompt following using CLIP""" | |
try: | |
if 'clip' not in self.models: | |
return self.fallback_prompt_score(image, prompt) | |
model = self.models['clip'] | |
preprocess = self.processors['clip'] | |
# Preprocess image | |
image_tensor = preprocess(image).unsqueeze(0).to(self.device) | |
# Tokenize text | |
text_tokens = clip.tokenize([prompt]).to(self.device) | |
# Get features | |
with torch.no_grad(): | |
image_features = model.encode_image(image_tensor) | |
text_features = model.encode_text(text_tokens) | |
# Normalize features | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
# Calculate similarity | |
similarity = (image_features @ text_features.T).item() | |
# Convert similarity to 0-10 scale | |
# CLIP similarity is typically between -1 and 1, but usually 0-1 for related content | |
score = max(0.0, min(10.0, (similarity + 1) * 5)) | |
return score | |
except Exception as e: | |
logger.error(f"Error in CLIP evaluation: {str(e)}") | |
return self.fallback_prompt_score(image, prompt) | |
def evaluate_with_blip2(self, image: Image.Image, prompt: str) -> float: | |
"""Evaluate prompt following using BLIP-2""" | |
try: | |
if 'blip2' not in self.models: | |
return self.fallback_prompt_score(image, prompt) | |
model = self.models['blip2'] | |
processor = self.processors['blip2'] | |
# Generate caption for the image | |
inputs = processor(image, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
out = model.generate(**inputs, max_length=50) | |
generated_caption = processor.decode(out[0], skip_special_tokens=True) | |
# Compare generated caption with original prompt using text similarity | |
if 'sentence_transformer' in self.models: | |
similarity_score = self.calculate_text_similarity(prompt, generated_caption) | |
else: | |
# Simple word overlap fallback | |
similarity_score = self.simple_text_similarity(prompt, generated_caption) | |
return similarity_score | |
except Exception as e: | |
logger.error(f"Error in BLIP-2 evaluation: {str(e)}") | |
return self.fallback_prompt_score(image, prompt) | |
def calculate_text_similarity(self, text1: str, text2: str) -> float: | |
"""Calculate semantic similarity between two texts""" | |
try: | |
model = self.models['sentence_transformer'] | |
# Encode texts | |
embeddings = model.encode([text1, text2]) | |
# Calculate cosine similarity | |
similarity = util.cos_sim(embeddings[0], embeddings[1]).item() | |
# Convert to 0-10 scale | |
score = max(0.0, min(10.0, (similarity + 1) * 5)) | |
return score | |
except Exception as e: | |
logger.error(f"Error calculating text similarity: {str(e)}") | |
return self.simple_text_similarity(text1, text2) | |
def simple_text_similarity(self, text1: str, text2: str) -> float: | |
"""Simple word overlap similarity""" | |
try: | |
# Convert to lowercase and split into words | |
words1 = set(text1.lower().split()) | |
words2 = set(text2.lower().split()) | |
# Calculate Jaccard similarity | |
intersection = len(words1.intersection(words2)) | |
union = len(words1.union(words2)) | |
if union == 0: | |
return 0.0 | |
jaccard_similarity = intersection / union | |
# Convert to 0-10 scale | |
score = jaccard_similarity * 10 | |
return max(0.0, min(10.0, score)) | |
except Exception: | |
return 5.0 # Default neutral score | |
def extract_key_concepts(self, prompt: str) -> list: | |
"""Extract key concepts from prompt for detailed analysis""" | |
try: | |
# Simple keyword extraction | |
# In production, this could use more sophisticated NLP | |
# Remove common words | |
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should'} | |
words = prompt.lower().split() | |
key_concepts = [word for word in words if word not in stop_words and len(word) > 2] | |
return key_concepts | |
except Exception: | |
return [] | |
def evaluate_concept_presence(self, image: Image.Image, concepts: list) -> float: | |
"""Evaluate presence of specific concepts in image""" | |
try: | |
if 'clip' not in self.models or not concepts: | |
return 5.0 | |
model = self.models['clip'] | |
preprocess = self.processors['clip'] | |
# Preprocess image | |
image_tensor = preprocess(image).unsqueeze(0).to(self.device) | |
# Create concept queries | |
concept_queries = [f"a photo of {concept}" for concept in concepts] | |
# Tokenize concepts | |
text_tokens = clip.tokenize(concept_queries).to(self.device) | |
# Get features | |
with torch.no_grad(): | |
image_features = model.encode_image(image_tensor) | |
text_features = model.encode_text(text_tokens) | |
# Normalize features | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
# Calculate similarities | |
similarities = (image_features @ text_features.T).squeeze(0) | |
# Average similarity across concepts | |
avg_similarity = similarities.mean().item() | |
# Convert to 0-10 scale | |
score = max(0.0, min(10.0, (avg_similarity + 1) * 5)) | |
return score | |
except Exception as e: | |
logger.error(f"Error in concept presence evaluation: {str(e)}") | |
return 5.0 | |
def fallback_prompt_score(self, image: Image.Image, prompt: str) -> float: | |
"""Simple fallback prompt evaluation""" | |
try: | |
# Very basic evaluation based on prompt length and image properties | |
prompt_length = len(prompt.split()) | |
# Longer, more detailed prompts might be harder to follow perfectly | |
if prompt_length < 5: | |
length_penalty = 0.0 | |
elif prompt_length < 15: | |
length_penalty = 0.5 | |
else: | |
length_penalty = 1.0 | |
# Base score | |
base_score = 7.0 - length_penalty | |
return max(0.0, min(10.0, base_score)) | |
except Exception: | |
return 5.0 # Default neutral score | |
def evaluate(self, image: Image.Image, prompt: str) -> float: | |
""" | |
Evaluate how well the image follows the given prompt | |
Args: | |
image: PIL Image to evaluate | |
prompt: Text prompt to compare against | |
Returns: | |
Prompt following score from 0-10 | |
""" | |
try: | |
if not prompt or not prompt.strip(): | |
return 0.0 # No prompt to evaluate against | |
scores = [] | |
# CLIP evaluation (primary) | |
clip_score = self.evaluate_with_clip(image, prompt) | |
scores.append(clip_score) | |
# BLIP-2 evaluation (secondary) | |
blip2_score = self.evaluate_with_blip2(image, prompt) | |
scores.append(blip2_score) | |
# Concept presence evaluation | |
key_concepts = self.extract_key_concepts(prompt) | |
concept_score = self.evaluate_concept_presence(image, key_concepts) | |
scores.append(concept_score) | |
# Ensemble scoring | |
weights = [0.5, 0.3, 0.2] # CLIP gets highest weight | |
final_score = sum(score * weight for score, weight in zip(scores, weights)) | |
logger.info(f"Prompt scores - CLIP: {clip_score:.2f}, BLIP-2: {blip2_score:.2f}, " | |
f"Concepts: {concept_score:.2f}, Final: {final_score:.2f}") | |
return max(0.0, min(10.0, final_score)) | |
except Exception as e: | |
logger.error(f"Error in prompt evaluation: {str(e)}") | |
return self.fallback_prompt_score(image, prompt) | |