Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from PIL import Image | |
import json | |
import io | |
import base64 | |
from typing import List, Dict, Tuple, Optional | |
import logging | |
from pathlib import Path | |
import random | |
# Simplified imports for testing | |
try: | |
import torch | |
TORCH_AVAILABLE = True | |
except ImportError: | |
TORCH_AVAILABLE = False | |
print("Warning: PyTorch not available, using mock implementations") | |
# Import evaluation modules with fallbacks | |
try: | |
from models.quality_evaluator import QualityEvaluator | |
from models.aesthetics_evaluator import AestheticsEvaluator | |
from models.prompt_evaluator import PromptEvaluator | |
from models.ai_detection_evaluator import AIDetectionEvaluator | |
from utils.metadata_extractor import extract_png_metadata | |
from utils.scoring import calculate_final_score | |
except ImportError as e: | |
print(f"Warning: Could not import evaluation modules: {e}") | |
# Use mock implementations | |
class MockEvaluator: | |
def __init__(self): | |
pass | |
# FIX: Make mock evaluation deterministic based on image content | |
def evaluate(self, image: Image.Image, *args, **kwargs): | |
try: | |
img_bytes = image.tobytes() | |
img_hash = hash(img_bytes) | |
random.seed(img_hash) | |
# Return a consistent score for the same image | |
return random.uniform(5.0, 9.5) | |
except Exception: | |
return random.uniform(5.0, 9.5) # Fallback for any error | |
QualityEvaluator = MockEvaluator | |
AestheticsEvaluator = MockEvaluator | |
PromptEvaluator = MockEvaluator | |
AIDetectionEvaluator = MockEvaluator | |
def extract_png_metadata(path): | |
return None | |
# Use the corrected scoring logic from scoring.py | |
from scoring import calculate_final_score | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class ImageEvaluationApp: | |
def __init__(self): | |
self.quality_evaluator = None | |
self.aesthetics_evaluator = None | |
self.prompt_evaluator = None | |
self.ai_detection_evaluator = None | |
self.models_loaded = False | |
def load_models(self, selected_models: Dict[str, bool]): | |
"""Load selected evaluation models""" | |
try: | |
if selected_models.get('quality', True) and self.quality_evaluator is None: | |
logger.info("Loading quality evaluation models...") | |
self.quality_evaluator = QualityEvaluator() | |
if selected_models.get('aesthetics', True) and self.aesthetics_evaluator is None: | |
logger.info("Loading aesthetics evaluation models...") | |
self.aesthetics_evaluator = AestheticsEvaluator() | |
if selected_models.get('prompt', True) and self.prompt_evaluator is None: | |
logger.info("Loading prompt evaluation models...") | |
self.prompt_evaluator = PromptEvaluator() | |
if selected_models.get('ai_detection', True) and self.ai_detection_evaluator is None: | |
logger.info("Loading AI detection models...") | |
self.ai_detection_evaluator = AIDetectionEvaluator() | |
self.models_loaded = True | |
logger.info("All selected models loaded successfully!") | |
except Exception as e: | |
logger.error(f"Error loading models: {str(e)}") | |
raise e | |
def evaluate_images( | |
self, | |
images: List[str], | |
enable_quality: bool = True, | |
enable_aesthetics: bool = True, | |
enable_prompt: bool = True, | |
enable_ai_detection: bool = True, | |
anime_mode: bool = False, | |
progress=gr.Progress() | |
) -> Tuple[pd.DataFrame, str]: | |
""" | |
Evaluate uploaded images and return results | |
""" | |
if not images: | |
return pd.DataFrame(), "No images uploaded." | |
try: | |
selected_models = { | |
'quality': enable_quality, | |
'aesthetics': enable_aesthetics, | |
'prompt': enable_prompt, | |
'ai_detection': enable_ai_detection | |
} | |
progress(0.1, desc="Loading models...") | |
self.load_models(selected_models) | |
results = [] | |
total_images = len(images) | |
for i, image_path in enumerate(images): | |
progress((i + 1) / total_images * 0.9 + 0.1, | |
desc=f"Evaluating image {i+1}/{total_images}") | |
try: | |
image = Image.open(image_path).convert('RGB') | |
filename = Path(image_path).name | |
metadata = extract_png_metadata(image_path) | |
prompt = metadata.get('prompt', '') if metadata else '' | |
scores = { | |
'filename': filename, | |
'quality_score': 0.0, | |
'aesthetics_score': 0.0, | |
'prompt_score': 0.0, | |
'ai_detection_score': 0.0, | |
'has_prompt': bool(prompt) | |
} | |
if enable_quality and self.quality_evaluator: | |
scores['quality_score'] = self.quality_evaluator.evaluate(image, anime_mode=anime_mode) | |
if enable_aesthetics and self.aesthetics_evaluator: | |
scores['aesthetics_score'] = self.aesthetics_evaluator.evaluate(image, anime_mode=anime_mode) | |
if enable_prompt and self.prompt_evaluator and prompt: | |
scores['prompt_score'] = self.prompt_evaluator.evaluate(image, prompt) | |
if enable_ai_detection and self.ai_detection_evaluator: | |
scores['ai_detection_score'] = self.ai_detection_evaluator.evaluate(image) | |
scores['final_score'] = calculate_final_score( | |
scores['quality_score'], | |
scores['aesthetics_score'], | |
scores['prompt_score'], | |
scores['ai_detection_score'], | |
scores['has_prompt'] | |
) | |
thumbnail = image.copy() | |
thumbnail.thumbnail((100, 100), Image.Resampling.LANCZOS) | |
buffer = io.BytesIO() | |
thumbnail.save(buffer, format='PNG') | |
thumbnail_b64 = base64.b64encode(buffer.getvalue()).decode() | |
# FIX: Use markdown format for Gradio dataframe image display | |
scores['thumbnail'] = f"" | |
results.append(scores) | |
except Exception as e: | |
logger.error(f"Error evaluating {image_path}: {str(e)}") | |
results.append({ | |
'filename': Path(image_path).name, | |
'error': str(e), | |
'thumbnail': '' | |
}) | |
if not results: | |
return pd.DataFrame(), "Evaluation failed for all images." | |
df = pd.DataFrame(results) | |
# FIX: Create a display-ready dataframe with proper formatting and column names | |
if not df.empty: | |
# Separate error rows | |
error_df = df[df['final_score'].isna()] | |
valid_df = df.dropna(subset=['final_score']) | |
if not valid_df.empty: | |
valid_df = valid_df.sort_values('final_score', ascending=False).reset_index(drop=True) | |
valid_df.index = valid_df.index + 1 | |
valid_df = valid_df.reset_index().rename(columns={'index': 'Rank'}) | |
# Format columns for display | |
display_cols = { | |
'Rank': 'Rank', | |
'thumbnail': 'Thumbnail', | |
'filename': 'Filename', | |
'final_score': 'Final Score', | |
'quality_score': 'Quality', | |
'aesthetics_score': 'Aesthetics', | |
'prompt_score': 'Prompt', | |
'ai_detection_score': 'AI Detection' | |
} | |
display_df = valid_df[list(display_cols.keys())] | |
display_df = display_df.rename(columns=display_cols) | |
# Apply formatting | |
for col in ['Final Score', 'Quality', 'Aesthetics', 'Prompt']: | |
display_df[col] = display_df[col].map('{:.2f}'.format) | |
display_df['AI Detection'] = display_df['AI Detection'].map('{:.1%}'.format) | |
else: | |
display_df = pd.DataFrame() | |
status_msg = f"Successfully evaluated {len(df[df['final_score'].notna()])} images." | |
error_count = len(df[df['final_score'].isna()]) | |
if error_count > 0: | |
status_msg += f" {error_count} images had evaluation errors." | |
return display_df, status_msg | |
except Exception as e: | |
logger.error(f"Error in evaluate_images: {str(e)}") | |
return pd.DataFrame(), f"Error during evaluation: {str(e)}" | |
def create_interface(): | |
app = ImageEvaluationApp() | |
css = """ | |
.gradio-container { max-width: 1400px !important; } | |
.results-table { font-size: 14px; } | |
.results-table .thumbnail-cell img { max-width: 100px; max-height: 100px; object-fit: cover; } | |
""" | |
with gr.Blocks(css=css, title="AI Image Evaluation Tool") as interface: | |
gr.Markdown("# π¨ AI Image Evaluation Tool") | |
gr.Markdown("Upload your AI-generated images to evaluate their quality, aesthetics, prompt following, and detect AI generation.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
images_input = gr.File(label="Upload Images", file_count="multiple", file_types=["image"], height=200) | |
gr.Markdown("### Model Selection") | |
with gr.Row(): | |
enable_quality = gr.Checkbox(label="Image Quality", value=True) | |
enable_aesthetics = gr.Checkbox(label="Aesthetics", value=True) | |
with gr.Row(): | |
enable_prompt = gr.Checkbox(label="Prompt Following", value=True) | |
enable_ai_detection = gr.Checkbox(label="AI Detection", value=True) | |
gr.Markdown("### Options") | |
anime_mode = gr.Checkbox(label="Anime/Art Mode", value=False) | |
evaluate_btn = gr.Button("π Evaluate Images", variant="primary", size="lg") | |
status_output = gr.Textbox(label="Status", interactive=False) | |
with gr.Column(scale=3): | |
gr.Markdown("### π Evaluation Results") | |
# FIX: Update headers and datatypes to match the new formatted DataFrame | |
results_output = gr.Dataframe( | |
headers=["Rank", "Thumbnail", "Filename", "Final Score", "Quality", "Aesthetics", "Prompt", "AI Detection"], | |
datatype=["number", "markdown", "str", "str", "str", "str", "str", "str"], | |
label="Results", | |
interactive=False, | |
wrap=True, | |
elem_classes=["results-table"] | |
) | |
evaluate_btn.click( | |
fn=app.evaluate_images, | |
inputs=[images_input, enable_quality, enable_aesthetics, enable_prompt, enable_ai_detection, anime_mode], | |
outputs=[results_output, status_output] | |
) | |
with gr.Accordion("βΉοΈ Help & Information", open=False): | |
# Help text remains the same as it describes the intended functionality | |
gr.Markdown(""" | |
### How to Use | |
1. **Upload Images**: Select multiple PNG/JPG images. | |
2. **Select Models**: Choose which evaluation metrics to use. | |
3. **Anime Mode**: Enable for better evaluation of anime/art style images. | |
4. **Evaluate**: Click the button to start evaluation. | |
### Scoring System | |
- **Quality Score**: Technical image quality (0-10). | |
- **Aesthetics Score**: Visual appeal and composition (0-10). | |
- **Prompt Score**: How well the image follows the text prompt (0-10, requires metadata). | |
- **AI Detection**: Probability of being AI-generated (0-1, lower is better for the final score). | |
- **Final Score**: Weighted combination of all metrics (0-10). | |
""") | |
return interface | |
if __name__ == "__main__": | |
interface = create_interface() | |
interface.launch(server_name="0.0.0.0", server_port=7860, show_error=True) |