import gradio as gr import pandas as pd import torch import torchvision.transforms as T from PIL import Image import numpy as np import io import base64 import os import shutil import tempfile # PIQ imports try: import piq except ImportError: print("Warning: PIQ library not found. Some metrics (BRISQUE, FID) will be unavailable.") piq = None # IQA-PyTorch imports try: # This import needs to succeed for NIQE and MUSIQ from iqa_pytorch import IQA except ImportError as e: print(f"ERROR: IQA-PyTorch library import failed: {e}. Some metrics (NIQE, MUSIQ-NR) will be unavailable. Check installation and dependencies (like kornia).") IQA = None except Exception as e: print(f"ERROR: An unexpected error occurred during IQA-PyTorch import: {e}") IQA = None # --- Configuration --- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MAX_IMAGES_PER_BATCH = 100 THUMBNAIL_SIZE = (64, 64) # (width, height) for preview # --- Metric Normalization Parameters (Approximate typical ranges) --- # For "lower is better" metrics, score is (max_val - current_val) / (max_val - min_val) # For "higher is better" metrics, score is (current_val - min_val) / (max_val - min_val) # These are heuristics and can be adjusted. METRIC_RANGES = { "brisque": {"min": 0, "max": 120, "lower_is_better": True}, # Typical BRISQUE range "niqe": {"min": 0, "max": 12, "lower_is_better": True}, # Typical NIQE range "musiq_nr": {"min": 10, "max": 90, "lower_is_better": False} # Example MUSIQ range } # --- Metric Functions --- def get_brisque_score(img_tensor_chw_01): if piq is None: return "N/A (PIQ missing)" try: if img_tensor_chw_01.ndim == 3: img_tensor_bchw_01 = img_tensor_chw_01.unsqueeze(0) else: img_tensor_bchw_01 = img_tensor_chw_01 if img_tensor_bchw_01.shape[1] == 1: img_tensor_bchw_01 = img_tensor_bchw_01.repeat(1, 3, 1, 1) brisque_loss = piq.brisque(img_tensor_bchw_01.to(DEVICE), data_range=1.) return round(brisque_loss.item(), 3) except Exception: return "Error" def get_niqe_score(img_pil_rgb): if IQA is None: return "N/A (IQA missing)" try: niqe_metric = IQA(libs='NIQE-PyTorch', device=DEVICE) score = niqe_metric(img_pil_rgb) return round(score.item(), 3) except Exception: return "Error" def get_musiq_nr_score(img_pil_rgb): if IQA is None: return "N/A (IQA missing)" try: musiq_metric = IQA(libs='MUSIQ-L2N-lessons', device=DEVICE) # Example, could be other MUSIQ variants score = musiq_metric(img_pil_rgb) return round(score.item(), 3) except Exception: return "Error" def get_fid_score_piq_folders(path_to_set1_folder, path_to_set2_folder): if piq is None: return "N/A (PIQ missing)" try: set1_files = [os.path.join(path_to_set1_folder, f) for f in os.listdir(path_to_set1_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))] set2_files = [os.path.join(path_to_set2_folder, f) for f in os.listdir(path_to_set2_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))] if not set1_files or not set2_files: return "One or both sets have no valid image files." if len(set1_files) < 2 or len(set2_files) < 2: return f"FID needs at least 2 images per set. Found: Set1={len(set1_files)}, Set2={len(set2_files)}." fid_metric = piq.FID() set1_features = fid_metric.compute_feats(set1_files, device=DEVICE) set2_features = fid_metric.compute_feats(set2_files, device=DEVICE) if set1_features is None or set2_features is None: return "Could not extract features for one or both sets." if set1_features.dim() == 0 or set2_features.dim() == 0 or set1_features.numel() == 0 or set2_features.numel() == 0: return "Feature extraction resulted in empty tensors." fid_value = fid_metric(set1_features, set2_features) return round(fid_value.item(), 3) except Exception as e: print(f"FID calculation error: {e}") return f"FID Error: {str(e)[:100]}" # --- Helper & Final Score Calculation --- def pil_to_tensor_chw_01(img_pil_rgb): transform = T.Compose([T.ToTensor()]) return transform(img_pil_rgb) def create_thumbnail_base64(img_pil_rgb, size=THUMBNAIL_SIZE): img_copy = img_pil_rgb.copy() img_copy.thumbnail(size) buffered = io.BytesIO() img_copy.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return f"data:image/png;base64,{img_str}" def calculate_final_score(brisque_val, niqe_val, musiq_nr_val): normalized_scores = [] # BRISQUE if isinstance(brisque_val, (float, int)): cfg = METRIC_RANGES["brisque"] val = max(cfg["min"], min(cfg["max"], brisque_val)) # Clip norm_score = (cfg["max"] - val) / (cfg["max"] - cfg["min"]) if cfg["lower_is_better"] else (val - cfg["min"]) / (cfg["max"] - cfg["min"]) normalized_scores.append(norm_score) # NIQE if isinstance(niqe_val, (float, int)): cfg = METRIC_RANGES["niqe"] val = max(cfg["min"], min(cfg["max"], niqe_val)) # Clip norm_score = (cfg["max"] - val) / (cfg["max"] - cfg["min"]) if cfg["lower_is_better"] else (val - cfg["min"]) / (cfg["max"] - cfg["min"]) normalized_scores.append(norm_score) # MUSIQ-NR if isinstance(musiq_nr_val, (float, int)): cfg = METRIC_RANGES["musiq_nr"] val = max(cfg["min"], min(cfg["max"], musiq_nr_val)) # Clip norm_score = (cfg["max"] - val) / (cfg["max"] - cfg["min"]) if cfg["lower_is_better"] else (val - cfg["min"]) / (cfg["max"] - cfg["min"]) normalized_scores.append(norm_score) if not normalized_scores: return "N/A" # Average of normalized scores, then scale to 0-10 final_score_0_10 = (sum(normalized_scores) / len(normalized_scores)) * 10.0 return round(final_score_0_10, 4) # --- Main Processing Functions for Gradio --- def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progress(track_tqdm=True)): if not uploaded_file_list: return pd.DataFrame(), "Please upload images first." if len(uploaded_file_list) > MAX_IMAGES_PER_BATCH: status_message = f"Too many images ({len(uploaded_file_list)}). Processing first {MAX_IMAGES_PER_BATCH} images." uploaded_file_list = uploaded_file_list[:MAX_IMAGES_PER_BATCH] else: status_message = f"Processing {len(uploaded_file_list)} images..." progress(0, desc=status_message) results_data = [] for i, file_obj in enumerate(uploaded_file_list): base_filename = "Unknown File" try: file_path = file_obj.name base_filename = os.path.basename(file_path) img_pil_rgb = Image.open(file_path).convert("RGB") img_tensor_chw_01 = pil_to_tensor_chw_01(img_pil_rgb) brisque_val = get_brisque_score(img_tensor_chw_01) niqe_val = get_niqe_score(img_pil_rgb) musiq_nr_val = get_musiq_nr_score(img_pil_rgb) final_score = calculate_final_score(brisque_val, niqe_val, musiq_nr_val) thumbnail_b64 = create_thumbnail_base64(img_pil_rgb) preview_html = f'{base_filename}' results_data.append({ "Preview": preview_html, "Filename": base_filename, "BRISQUE (PIQ) (↓)": brisque_val, "NIQE (IQA-PyTorch) (↓)": niqe_val, "MUSIQ-NR (IQA-PyTorch) (↑)": musiq_nr_val, "Final Score (0-10) (↑)": final_score, }) except Exception as e: results_data.append({ "Preview": "Error processing", "Filename": base_filename, "BRISQUE (PIQ) (↓)": f"Load Err: {str(e)[:30]}", "NIQE (IQA-PyTorch) (↓)": "N/A", "MUSIQ-NR (IQA-PyTorch) (↑)": "N/A", "Final Score (0-10) (↑)": "N/A", }) progress((i + 1) / len(uploaded_file_list), desc=f"Processing {base_filename}") df_results = pd.DataFrame(results_data) status_message += f"\nPer-image metrics calculated for {len(results_data)} images." return df_results, status_message def process_fid_for_two_sets(set1_file_list, set2_file_list, progress=gr.Progress(track_tqdm=True)): if not set1_file_list or not set2_file_list: return "Please upload files for both Set 1 and Set 2." set1_dir = tempfile.mkdtemp(prefix="fid_set1_") set2_dir = tempfile.mkdtemp(prefix="fid_set2_") fid_result_text = "Starting FID calculation..." progress(0.1, desc="Preparing image sets for FID...") try: for i, file_obj in enumerate(set1_file_list): shutil.copy(file_obj.name, os.path.join(set1_dir, os.path.basename(file_obj.name))) progress(0.1 + 0.2 * (i / len(set1_file_list)), desc=f"Copying Set 1: {os.path.basename(file_obj.name)}") for i, file_obj in enumerate(set2_file_list): shutil.copy(file_obj.name, os.path.join(set2_dir, os.path.basename(file_obj.name))) progress(0.3 + 0.2 * (i / len(set2_file_list)), desc=f"Copying Set 2: {os.path.basename(file_obj.name)}") num_set1 = len(os.listdir(set1_dir)); num_set2 = len(os.listdir(set2_dir)) if num_set1 == 0 or num_set2 == 0: return f"FID Error: One or both sets are empty after copying. Set 1: {num_set1}, Set 2: {num_set2}." progress(0.5, desc=f"Calculating FID between Set 1 ({num_set1} images) and Set 2 ({num_set2} images)...") fid_score = get_fid_score_piq_folders(set1_dir, set2_dir) progress(1, desc="FID calculation complete.") fid_result_text = f"FID (PIQ) between Set 1 ({num_set1} images) and Set 2 ({num_set2} images): {fid_score}" except Exception as e: fid_result_text = f"Error during FID preparation or calculation: {str(e)}" finally: if os.path.exists(set1_dir): shutil.rmtree(set1_dir) if os.path.exists(set2_dir): shutil.rmtree(set2_dir) return fid_result_text # --- Gradio UI Definition --- css_custom = """ table {font-size: 0.8em !important; width: 100% !important;} th, td {padding: 4px !important; text-align: left !important;} img {max-width: 64px !important; max-height: 64px !important; object-fit: contain;} """ with gr.Blocks(theme=gr.themes.Soft(), css=css_custom) as demo: gr.Markdown(f""" # Image Generation Model Evaluation Tool **Objective:** Automated evaluation and comparison of image quality from different model versions. Utilizes `PIQ` and `IQA-PyTorch` libraries. Runs on **{DEVICE}**. (↓) means lower is better, (↑) means higher is better. Final Score is a heuristic combination of available metrics (0-10, higher is better). """) with gr.Tabs(): with gr.TabItem("Per-Image Quality Evaluation"): gr.Markdown(f"Upload a batch of images (up to **{MAX_IMAGES_PER_BATCH}**) to get individual quality scores.") image_upload_input = gr.Files(label=f"Upload Images (max {MAX_IMAGES_PER_BATCH}, .png, .jpg, .jpeg, .bmp, .webp)", file_count="multiple", type="filepath") evaluate_button_main = gr.Button("🖼️ Evaluate Uploaded Images", variant="primary") gr.Markdown("---") status_output_main = gr.Textbox(label="📊 Evaluation Status", interactive=False, lines=2) gr.Markdown("### 🖼️ Per-Image Evaluation Results") gr.Markdown("Click column headers to sort. Previews are thumbnails.") results_table_output = gr.DataFrame( headers=["Preview", "Filename", "BRISQUE (PIQ) (↓)", "NIQE (IQA-PyTorch) (↓)", "MUSIQ-NR (IQA-PyTorch) (↑)", "Final Score (0-10) (↑)"], datatype=["html", "str", "number", "number", "number", "number"], # Added "number" for Final Score interactive=False, wrap=True, row_count=(15, "paginate") ) with gr.TabItem("↔️ Calculate FID (Set vs. Set)"): gr.Markdown(""" Calculate Fréchet Inception Distance (FID) between two sets of images. FID measures the similarity of two image distributions. **Lower FID scores are better**. """) with gr.Row(): fid_set1_upload = gr.Files(label="Upload Images for Set 1", file_count="multiple", type="filepath") fid_set2_upload = gr.Files(label="Upload Images for Set 2", file_count="multiple", type="filepath") fid_calculate_button = gr.Button("🔗 Calculate FID between Set 1 and Set 2", variant="primary") fid_result_output = gr.Textbox(label="📈 FID Result", interactive=False, lines=2) evaluate_button_main.click(fn=process_images_for_individual_scores, inputs=[image_upload_input], outputs=[results_table_output, status_output_main]) fid_calculate_button.click(fn=process_fid_for_two_sets, inputs=[fid_set1_upload, fid_set2_upload], outputs=[fid_result_output]) # --- For Hugging Face Spaces --- # Ensure 'requirements.txt' includes: """ gradio torch torchvision Pillow numpy piq>=0.8.0 iqa-pytorch==0.1 timm scikit-image pandas kornia """ if __name__ == "__main__": if piq is None: print("\nWARNING: PIQ library is missing. pip install piq\n") if IQA is None: print("\nERROR: IQA-PyTorch library import failed. pip install iqa-pytorch==0.1 kornia\n") demo.launch(debug=True)