Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import torch | |
import torchvision.transforms as T | |
from PIL import Image | |
import numpy as np | |
import io | |
import base64 | |
import os | |
import shutil | |
import tempfile | |
# PIQ imports | |
try: | |
import piq | |
except ImportError: | |
print("Warning: PIQ library not found. Some metrics (BRISQUE, FID) will be unavailable.") | |
piq = None | |
# IQA-PyTorch imports | |
try: | |
from iqa_pytorch import IQA | |
# Available models in IQA-PyTorch (examples for NR): | |
# "MUSIQ-L2N-lessons", "MUSIQ-Koniq-NSR", "MUSIQ-SpAq-NSR" | |
# "BRISQUE-PyTorch", "NIQE-PyTorch" | |
# "NIMA-VGG16-estimate", "NIMA-MobileNet-estimate" (Aesthetic) | |
except ImportError: | |
print("Warning: IQA-PyTorch library not found. Some metrics (NIQE, MUSIQ-NR) will be unavailable.") | |
IQA = None | |
# --- Configuration --- | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
MAX_IMAGES_PER_BATCH = 100 | |
THUMBNAIL_SIZE = (64, 64) # (width, height) for preview | |
# --- Metric Functions --- | |
def get_brisque_score(img_tensor_chw_01): | |
"""Calculates BRISQUE score using PIQ. Expects a (C, H, W) tensor, range [0, 1].""" | |
if piq is None: return "N/A (PIQ missing)" | |
try: | |
# Ensure tensor is (B, C, H, W) for piq.brisque | |
if img_tensor_chw_01.ndim == 3: | |
img_tensor_bchw_01 = img_tensor_chw_01.unsqueeze(0) | |
else: # Already has batch dim or incorrect dims | |
img_tensor_bchw_01 = img_tensor_chw_01 | |
# Ensure 3 channels if it's grayscale by repeating | |
if img_tensor_bchw_01.shape[1] == 1: | |
img_tensor_bchw_01 = img_tensor_bchw_01.repeat(1, 3, 1, 1) | |
brisque_loss = piq.brisque(img_tensor_bchw_01.to(DEVICE), data_range=1.) | |
return round(brisque_loss.item(), 3) | |
except Exception as e: | |
# print(f"BRISQUE Error: {e} for tensor shape {img_tensor_chw_01.shape}") | |
return f"Error" | |
def get_niqe_score(img_pil_rgb): | |
"""Calculates NIQE score using IQA-PyTorch. Expects a PIL RGB image.""" | |
if IQA is None: return "N/A (IQA missing)" | |
try: | |
niqe_metric = IQA(libs='NIQE-PyTorch', device=DEVICE) # NIQE is No-Reference | |
score = niqe_metric(img_pil_rgb) | |
return round(score.item(), 3) | |
except Exception as e: | |
# print(f"NIQE Error: {e}") | |
return f"Error" | |
def get_musiq_nr_score(img_pil_rgb): | |
"""Calculates No-Reference MUSIQ score using IQA-PyTorch. Expects a PIL RGB image.""" | |
if IQA is None: return "N/A (IQA missing)" | |
try: | |
# Using MUSIQ-L2N-lessons as an example NR model from IQA-PyTorch | |
# Other options: "MUSIQ-Koniq-NSR", "MUSIQ-SpAq-NSR" | |
musiq_metric = IQA(libs='MUSIQ-L2N-lessons', device=DEVICE) | |
score = musiq_metric(img_pil_rgb) | |
return round(score.item(), 3) | |
except Exception as e: | |
# print(f"MUSIQ-NR Error: {e}") | |
return f"Error" | |
def get_fid_score_piq_folders(path_to_set1_folder, path_to_set2_folder): | |
"""Calculates FID between two folders of images using PIQ.""" | |
if piq is None: return "N/A (PIQ missing)" | |
try: | |
# List image files in folders | |
set1_files = [os.path.join(path_to_set1_folder, f) for f in os.listdir(path_to_set1_folder) | |
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))] | |
set2_files = [os.path.join(path_to_set2_folder, f) for f in os.listdir(path_to_set2_folder) | |
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))] | |
if not set1_files or not set2_files: | |
return "One or both sets have no valid image files." | |
if len(set1_files) < 2 or len(set2_files) < 2: # FID usually needs more, but PIQ might handle small N. Min 2 to compute stats. | |
return f"FID needs at least 2 images per set. Found: Set1={len(set1_files)}, Set2={len(set2_files)}." | |
fid_metric = piq.FID() | |
# compute_feats expects a list of image paths | |
set1_features = fid_metric.compute_feats(set1_files, device=DEVICE) | |
set2_features = fid_metric.compute_feats(set2_files, device=DEVICE) | |
if set1_features is None or set2_features is None: | |
return "Could not extract features for one or both sets (check image validity and count)." | |
if set1_features.dim() == 0 or set2_features.dim() == 0 or set1_features.numel() == 0 or set2_features.numel() == 0: # Handle empty tensors | |
return "Feature extraction resulted in empty tensors." | |
fid_value = fid_metric(set1_features, set2_features) # Pass tensors directly | |
return round(fid_value.item(), 3) | |
except Exception as e: | |
print(f"FID calculation error: {e}") | |
return f"FID Error: {str(e)[:100]}" | |
# --- Helper Functions --- | |
def pil_to_tensor_chw_01(img_pil_rgb): | |
"""Converts PIL RGB image to PyTorch CHW tensor [0,1].""" | |
transform = T.Compose([T.ToTensor()]) # Converts PIL [0,255] to Tensor [0,1] C,H,W | |
return transform(img_pil_rgb) | |
def create_thumbnail_base64(img_pil_rgb, size=THUMBNAIL_SIZE): | |
"""Creates a base64 encoded PNG thumbnail string from a PIL image.""" | |
img_copy = img_pil_rgb.copy() | |
img_copy.thumbnail(size) | |
buffered = io.BytesIO() | |
img_copy.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return f"data:image/png;base64,{img_str}" | |
# --- Main Processing Functions for Gradio --- | |
def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progress(track_tqdm=True)): | |
"""Processes uploaded images for individual quality scores and displays them.""" | |
if not uploaded_file_list: | |
return pd.DataFrame(), "Please upload images first.", "IS: N/A (Not Implemented)", "FID: N/A (Use FID Tab)" | |
# Limit number of images | |
if len(uploaded_file_list) > MAX_IMAGES_PER_BATCH: | |
status_message = f"Too many images ({len(uploaded_file_list)}). Processing first {MAX_IMAGES_PER_BATCH} images." | |
uploaded_file_list = uploaded_file_list[:MAX_IMAGES_PER_BATCH] | |
else: | |
status_message = f"Processing {len(uploaded_file_list)} images..." | |
progress(0, desc=status_message) | |
results_data = [] | |
# Temporary directory for this batch if needed by some metric that takes a folder path | |
# batch_temp_dir = tempfile.mkdtemp(prefix="eval_batch_") | |
for i, file_obj in enumerate(uploaded_file_list): | |
try: | |
# file_obj for gr.Files is a tempfile._TemporaryFileWrapper object | |
file_path = file_obj.name | |
base_filename = os.path.basename(file_path) | |
img_pil_rgb = Image.open(file_path).convert("RGB") | |
# 1. For PIQ BRISQUE (needs tensor) | |
img_tensor_chw_01 = pil_to_tensor_chw_01(img_pil_rgb) | |
brisque_val = get_brisque_score(img_tensor_chw_01) | |
# 2. For IQA-PyTorch NIQE & MUSIQ (needs PIL image) | |
niqe_val = get_niqe_score(img_pil_rgb) | |
musiq_nr_val = get_musiq_nr_score(img_pil_rgb) | |
# 3. Thumbnail for display | |
thumbnail_b64 = create_thumbnail_base64(img_pil_rgb) | |
preview_html = f'<img src="{thumbnail_b64}" alt="{base_filename}">' | |
results_data.append({ | |
"Preview": preview_html, | |
"Filename": base_filename, | |
"BRISQUE (PIQ) (β)": brisque_val, | |
"NIQE (IQA-PyTorch) (β)": niqe_val, | |
"MUSIQ-NR (IQA-PyTorch) (β)": musiq_nr_val, | |
}) | |
except Exception as e: | |
try: base_filename = os.path.basename(file_obj.name if hasattr(file_obj, 'name') else str(file_obj)) | |
except: base_filename = "Unknown File" | |
results_data.append({ | |
"Preview": "Error processing", "Filename": base_filename, | |
"BRISQUE (PIQ) (β)": f"Load Err", | |
"NIQE (IQA-PyTorch) (β)": "N/A", | |
"MUSIQ-NR (IQA-PyTorch) (β)": "N/A", | |
}) | |
progress((i + 1) / len(uploaded_file_list), desc=f"Processing {base_filename}") | |
df_results = pd.DataFrame(results_data) | |
status_message += f"\nPer-image metrics calculated for {len(results_data)} images." | |
# Batch metrics info (IS not implemented, FID separate) | |
is_text = "IS (PIQ): Not implemented in this version." | |
fid_text_batch_info = "FID (PIQ): Use the 'Calculate FID (Set vs Set)' tab for FID scores." | |
# Cleanup temp dir if created | |
# if os.path.exists(batch_temp_dir): shutil.rmtree(batch_temp_dir) | |
return df_results, status_message, is_text, fid_text_batch_info | |
def process_fid_for_two_sets(set1_file_list, set2_file_list, progress=gr.Progress(track_tqdm=True)): | |
"""Handles FID calculation between two sets of uploaded images.""" | |
if not set1_file_list or not set2_file_list: | |
return "Please upload files for both Set 1 and Set 2." | |
# Create temporary directories for Set 1 and Set 2 | |
# Suffix helps identify user folders if many users hit it, though Gradio handles sessions. | |
# Prefix is better for mkdtemp. | |
set1_dir = tempfile.mkdtemp(prefix="fid_set1_") | |
set2_dir = tempfile.mkdtemp(prefix="fid_set2_") | |
fid_result_text = "Starting FID calculation..." | |
progress(0.1, desc="Preparing image sets for FID...") | |
try: | |
# Copy uploaded files to these temporary directories | |
for i, file_obj in enumerate(set1_file_list): | |
shutil.copy(file_obj.name, os.path.join(set1_dir, os.path.basename(file_obj.name))) | |
progress(0.1 + 0.2 * (i / len(set1_file_list)), desc=f"Copying Set 1: {os.path.basename(file_obj.name)}") | |
for i, file_obj in enumerate(set2_file_list): | |
shutil.copy(file_obj.name, os.path.join(set2_dir, os.path.basename(file_obj.name))) | |
progress(0.3 + 0.2 * (i / len(set2_file_list)), desc=f"Copying Set 2: {os.path.basename(file_obj.name)}") | |
num_set1 = len(os.listdir(set1_dir)) | |
num_set2 = len(os.listdir(set2_dir)) | |
if num_set1 == 0 or num_set2 == 0: | |
return f"FID Error: One or both sets are empty after copying. Set 1: {num_set1}, Set 2: {num_set2}." | |
progress(0.5, desc=f"Calculating FID between Set 1 ({num_set1} images) and Set 2 ({num_set2} images)...") | |
fid_score = get_fid_score_piq_folders(set1_dir, set2_dir) | |
progress(1, desc="FID calculation complete.") | |
fid_result_text = f"FID (PIQ) between Set 1 ({num_set1} images) and Set 2 ({num_set2} images): {fid_score}" | |
except Exception as e: | |
fid_result_text = f"Error during FID preparation or calculation: {str(e)}" | |
finally: | |
# Cleanup temporary directories | |
if os.path.exists(set1_dir): shutil.rmtree(set1_dir) | |
if os.path.exists(set2_dir): shutil.rmtree(set2_dir) | |
return fid_result_text | |
# --- Gradio UI Definition --- | |
css_custom = """ | |
table {font-size: 0.8em !important; width: 100% !important;} | |
th, td {padding: 4px !important; text-align: left !important;} | |
img {max-width: 64px !important; max-height: 64px !important; object-fit: contain;} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), css=css_custom) as demo: | |
gr.Markdown(f""" | |
# Image Generation Model Evaluation Tool | |
**Objective:** Automated evaluation and comparison of image quality from different model versions. | |
Utilizes `PIQ` and `IQA-PyTorch` libraries. Runs on **{DEVICE}**. | |
(β) means lower is better, (β) means higher is better. | |
""") | |
with gr.Tabs(): | |
with gr.TabItem("Per-Image Quality Evaluation"): | |
gr.Markdown(f"Upload a batch of images (up to **{MAX_IMAGES_PER_BATCH}**) to get individual quality scores. Images are processed in the browser's session.") | |
image_upload_input = gr.Files( | |
label=f"Upload Images (max {MAX_IMAGES_PER_BATCH}, .png, .jpg, .jpeg, .bmp, .webp)", | |
file_count="multiple", | |
type="filepath" | |
) | |
evaluate_button_main = gr.Button("πΌοΈ Evaluate Uploaded Images", variant="primary") | |
gr.Markdown("---") | |
status_output_main = gr.Textbox(label="π Evaluation Status", interactive=False, lines=2) | |
gr.Markdown("### πΌοΈ Per-Image Evaluation Results") | |
gr.Markdown("Click column headers to sort. Previews are thumbnails.") | |
# MODIFIED LINE BELOW: | |
results_table_output = gr.DataFrame( | |
headers=["Preview", "Filename", "BRISQUE (PIQ) (β)", "NIQE (IQA-PyTorch) (β)", "MUSIQ-NR (IQA-PyTorch) (β)"], | |
datatype=["html", "str", "number", "number", "number"], | |
interactive=False, | |
wrap=True, | |
overflow_row_behaviour="paginate", # max_rows removed | |
# height=450 # Optional: Set a fixed height in pixels if you want ~15 rows visible before scrolling within the component | |
) | |
with gr.TabItem("βοΈ Calculate FID (Set vs. Set)"): | |
gr.Markdown(""" | |
Calculate FrΓ©chet Inception Distance (FID) between two sets of images. | |
FID measures the similarity of two image distributions (e.g., generated vs. real, or version A vs. version B). | |
**Lower FID scores are better**, indicating more similarity. | |
""") | |
with gr.Row(): | |
fid_set1_upload = gr.Files(label="Upload Images for Set 1 (.png, .jpg, .jpeg, .bmp, .webp)", file_count="multiple", type="filepath") | |
fid_set2_upload = gr.Files(label="Upload Images for Set 2 (.png, .jpg, .jpeg, .bmp, .webp)", file_count="multiple", type="filepath") | |
fid_calculate_button = gr.Button("π Calculate FID between Set 1 and Set 2", variant="primary") | |
fid_result_output = gr.Textbox(label="π FID Result", interactive=False, lines=2) | |
# Wire components | |
evaluate_button_main.click( | |
fn=process_images_for_individual_scores, | |
inputs=[image_upload_input], | |
outputs=[results_table_output, status_output_main] #, batch_is_output, batch_fid_output_info] | |
) | |
fid_calculate_button.click( | |
fn=process_fid_for_two_sets, | |
inputs=[fid_set1_upload, fid_set2_upload], | |
outputs=[fid_result_output] | |
) | |
# --- For Hugging Face Spaces: requirements.txt --- | |
# Ensure this content is in your 'requirements.txt' file in the HF Space: | |
""" | |
gradio | |
torch | |
torchvision | |
Pillow | |
numpy | |
piq>=0.8.0 # Specify version if known good, or just piq | |
iqa-pytorch>=0.2.1 # Specify version if known good | |
timm # A dependency for some iqa-pytorch models like MUSIQ | |
scikit-image # Often a transitive dependency, good to include | |
pandas | |
""" | |
if __name__ == "__main__": | |
if piq is None or IQA is None: | |
print("\n\nWARNING: One or more core libraries (PIQ, IQA-PyTorch) are missing.") | |
print("Please install them by creating a 'requirements.txt' file with the content above and running: pip install -r requirements.txt\n\n") | |
demo.launch(debug=True) # Set debug=False for production |