Spaces:
Sleeping
Sleeping
File size: 14,977 Bytes
8c2a1e0 6a98506 8c2a1e0 6a98506 8c2a1e0 6a98506 8c2a1e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 |
import gradio as gr
import pandas as pd
import torch
import torchvision.transforms as T
from PIL import Image
import numpy as np
import io
import base64
import os
import shutil
import tempfile
# PIQ imports
try:
import piq
except ImportError:
print("Warning: PIQ library not found. Some metrics (BRISQUE, FID) will be unavailable.")
piq = None
# IQA-PyTorch imports
try:
from iqa_pytorch import IQA
# Available models in IQA-PyTorch (examples for NR):
# "MUSIQ-L2N-lessons", "MUSIQ-Koniq-NSR", "MUSIQ-SpAq-NSR"
# "BRISQUE-PyTorch", "NIQE-PyTorch"
# "NIMA-VGG16-estimate", "NIMA-MobileNet-estimate" (Aesthetic)
except ImportError:
print("Warning: IQA-PyTorch library not found. Some metrics (NIQE, MUSIQ-NR) will be unavailable.")
IQA = None
# --- Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_IMAGES_PER_BATCH = 100
THUMBNAIL_SIZE = (64, 64) # (width, height) for preview
# --- Metric Functions ---
def get_brisque_score(img_tensor_chw_01):
"""Calculates BRISQUE score using PIQ. Expects a (C, H, W) tensor, range [0, 1]."""
if piq is None: return "N/A (PIQ missing)"
try:
# Ensure tensor is (B, C, H, W) for piq.brisque
if img_tensor_chw_01.ndim == 3:
img_tensor_bchw_01 = img_tensor_chw_01.unsqueeze(0)
else: # Already has batch dim or incorrect dims
img_tensor_bchw_01 = img_tensor_chw_01
# Ensure 3 channels if it's grayscale by repeating
if img_tensor_bchw_01.shape[1] == 1:
img_tensor_bchw_01 = img_tensor_bchw_01.repeat(1, 3, 1, 1)
brisque_loss = piq.brisque(img_tensor_bchw_01.to(DEVICE), data_range=1.)
return round(brisque_loss.item(), 3)
except Exception as e:
# print(f"BRISQUE Error: {e} for tensor shape {img_tensor_chw_01.shape}")
return f"Error"
def get_niqe_score(img_pil_rgb):
"""Calculates NIQE score using IQA-PyTorch. Expects a PIL RGB image."""
if IQA is None: return "N/A (IQA missing)"
try:
niqe_metric = IQA(libs='NIQE-PyTorch', device=DEVICE) # NIQE is No-Reference
score = niqe_metric(img_pil_rgb)
return round(score.item(), 3)
except Exception as e:
# print(f"NIQE Error: {e}")
return f"Error"
def get_musiq_nr_score(img_pil_rgb):
"""Calculates No-Reference MUSIQ score using IQA-PyTorch. Expects a PIL RGB image."""
if IQA is None: return "N/A (IQA missing)"
try:
# Using MUSIQ-L2N-lessons as an example NR model from IQA-PyTorch
# Other options: "MUSIQ-Koniq-NSR", "MUSIQ-SpAq-NSR"
musiq_metric = IQA(libs='MUSIQ-L2N-lessons', device=DEVICE)
score = musiq_metric(img_pil_rgb)
return round(score.item(), 3)
except Exception as e:
# print(f"MUSIQ-NR Error: {e}")
return f"Error"
def get_fid_score_piq_folders(path_to_set1_folder, path_to_set2_folder):
"""Calculates FID between two folders of images using PIQ."""
if piq is None: return "N/A (PIQ missing)"
try:
# List image files in folders
set1_files = [os.path.join(path_to_set1_folder, f) for f in os.listdir(path_to_set1_folder)
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
set2_files = [os.path.join(path_to_set2_folder, f) for f in os.listdir(path_to_set2_folder)
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
if not set1_files or not set2_files:
return "One or both sets have no valid image files."
if len(set1_files) < 2 or len(set2_files) < 2: # FID usually needs more, but PIQ might handle small N. Min 2 to compute stats.
return f"FID needs at least 2 images per set. Found: Set1={len(set1_files)}, Set2={len(set2_files)}."
fid_metric = piq.FID()
# compute_feats expects a list of image paths
set1_features = fid_metric.compute_feats(set1_files, device=DEVICE)
set2_features = fid_metric.compute_feats(set2_files, device=DEVICE)
if set1_features is None or set2_features is None:
return "Could not extract features for one or both sets (check image validity and count)."
if set1_features.dim() == 0 or set2_features.dim() == 0 or set1_features.numel() == 0 or set2_features.numel() == 0: # Handle empty tensors
return "Feature extraction resulted in empty tensors."
fid_value = fid_metric(set1_features, set2_features) # Pass tensors directly
return round(fid_value.item(), 3)
except Exception as e:
print(f"FID calculation error: {e}")
return f"FID Error: {str(e)[:100]}"
# --- Helper Functions ---
def pil_to_tensor_chw_01(img_pil_rgb):
"""Converts PIL RGB image to PyTorch CHW tensor [0,1]."""
transform = T.Compose([T.ToTensor()]) # Converts PIL [0,255] to Tensor [0,1] C,H,W
return transform(img_pil_rgb)
def create_thumbnail_base64(img_pil_rgb, size=THUMBNAIL_SIZE):
"""Creates a base64 encoded PNG thumbnail string from a PIL image."""
img_copy = img_pil_rgb.copy()
img_copy.thumbnail(size)
buffered = io.BytesIO()
img_copy.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/png;base64,{img_str}"
# --- Main Processing Functions for Gradio ---
def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progress(track_tqdm=True)):
"""Processes uploaded images for individual quality scores and displays them."""
if not uploaded_file_list:
return pd.DataFrame(), "Please upload images first.", "IS: N/A (Not Implemented)", "FID: N/A (Use FID Tab)"
# Limit number of images
if len(uploaded_file_list) > MAX_IMAGES_PER_BATCH:
status_message = f"Too many images ({len(uploaded_file_list)}). Processing first {MAX_IMAGES_PER_BATCH} images."
uploaded_file_list = uploaded_file_list[:MAX_IMAGES_PER_BATCH]
else:
status_message = f"Processing {len(uploaded_file_list)} images..."
progress(0, desc=status_message)
results_data = []
# Temporary directory for this batch if needed by some metric that takes a folder path
# batch_temp_dir = tempfile.mkdtemp(prefix="eval_batch_")
for i, file_obj in enumerate(uploaded_file_list):
try:
# file_obj for gr.Files is a tempfile._TemporaryFileWrapper object
file_path = file_obj.name
base_filename = os.path.basename(file_path)
img_pil_rgb = Image.open(file_path).convert("RGB")
# 1. For PIQ BRISQUE (needs tensor)
img_tensor_chw_01 = pil_to_tensor_chw_01(img_pil_rgb)
brisque_val = get_brisque_score(img_tensor_chw_01)
# 2. For IQA-PyTorch NIQE & MUSIQ (needs PIL image)
niqe_val = get_niqe_score(img_pil_rgb)
musiq_nr_val = get_musiq_nr_score(img_pil_rgb)
# 3. Thumbnail for display
thumbnail_b64 = create_thumbnail_base64(img_pil_rgb)
preview_html = f'<img src="{thumbnail_b64}" alt="{base_filename}">'
results_data.append({
"Preview": preview_html,
"Filename": base_filename,
"BRISQUE (PIQ) (β)": brisque_val,
"NIQE (IQA-PyTorch) (β)": niqe_val,
"MUSIQ-NR (IQA-PyTorch) (β)": musiq_nr_val,
})
except Exception as e:
try: base_filename = os.path.basename(file_obj.name if hasattr(file_obj, 'name') else str(file_obj))
except: base_filename = "Unknown File"
results_data.append({
"Preview": "Error processing", "Filename": base_filename,
"BRISQUE (PIQ) (β)": f"Load Err",
"NIQE (IQA-PyTorch) (β)": "N/A",
"MUSIQ-NR (IQA-PyTorch) (β)": "N/A",
})
progress((i + 1) / len(uploaded_file_list), desc=f"Processing {base_filename}")
df_results = pd.DataFrame(results_data)
status_message += f"\nPer-image metrics calculated for {len(results_data)} images."
# Batch metrics info (IS not implemented, FID separate)
is_text = "IS (PIQ): Not implemented in this version."
fid_text_batch_info = "FID (PIQ): Use the 'Calculate FID (Set vs Set)' tab for FID scores."
# Cleanup temp dir if created
# if os.path.exists(batch_temp_dir): shutil.rmtree(batch_temp_dir)
return df_results, status_message, is_text, fid_text_batch_info
def process_fid_for_two_sets(set1_file_list, set2_file_list, progress=gr.Progress(track_tqdm=True)):
"""Handles FID calculation between two sets of uploaded images."""
if not set1_file_list or not set2_file_list:
return "Please upload files for both Set 1 and Set 2."
# Create temporary directories for Set 1 and Set 2
# Suffix helps identify user folders if many users hit it, though Gradio handles sessions.
# Prefix is better for mkdtemp.
set1_dir = tempfile.mkdtemp(prefix="fid_set1_")
set2_dir = tempfile.mkdtemp(prefix="fid_set2_")
fid_result_text = "Starting FID calculation..."
progress(0.1, desc="Preparing image sets for FID...")
try:
# Copy uploaded files to these temporary directories
for i, file_obj in enumerate(set1_file_list):
shutil.copy(file_obj.name, os.path.join(set1_dir, os.path.basename(file_obj.name)))
progress(0.1 + 0.2 * (i / len(set1_file_list)), desc=f"Copying Set 1: {os.path.basename(file_obj.name)}")
for i, file_obj in enumerate(set2_file_list):
shutil.copy(file_obj.name, os.path.join(set2_dir, os.path.basename(file_obj.name)))
progress(0.3 + 0.2 * (i / len(set2_file_list)), desc=f"Copying Set 2: {os.path.basename(file_obj.name)}")
num_set1 = len(os.listdir(set1_dir))
num_set2 = len(os.listdir(set2_dir))
if num_set1 == 0 or num_set2 == 0:
return f"FID Error: One or both sets are empty after copying. Set 1: {num_set1}, Set 2: {num_set2}."
progress(0.5, desc=f"Calculating FID between Set 1 ({num_set1} images) and Set 2 ({num_set2} images)...")
fid_score = get_fid_score_piq_folders(set1_dir, set2_dir)
progress(1, desc="FID calculation complete.")
fid_result_text = f"FID (PIQ) between Set 1 ({num_set1} images) and Set 2 ({num_set2} images): {fid_score}"
except Exception as e:
fid_result_text = f"Error during FID preparation or calculation: {str(e)}"
finally:
# Cleanup temporary directories
if os.path.exists(set1_dir): shutil.rmtree(set1_dir)
if os.path.exists(set2_dir): shutil.rmtree(set2_dir)
return fid_result_text
# --- Gradio UI Definition ---
css_custom = """
table {font-size: 0.8em !important; width: 100% !important;}
th, td {padding: 4px !important; text-align: left !important;}
img {max-width: 64px !important; max-height: 64px !important; object-fit: contain;}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css_custom) as demo:
gr.Markdown(f"""
# Image Generation Model Evaluation Tool
**Objective:** Automated evaluation and comparison of image quality from different model versions.
Utilizes `PIQ` and `IQA-PyTorch` libraries. Runs on **{DEVICE}**.
(β) means lower is better, (β) means higher is better.
""")
with gr.Tabs():
with gr.TabItem("Per-Image Quality Evaluation"):
gr.Markdown(f"Upload a batch of images (up to **{MAX_IMAGES_PER_BATCH}**) to get individual quality scores. Images are processed in the browser's session.")
image_upload_input = gr.Files(
label=f"Upload Images (max {MAX_IMAGES_PER_BATCH}, .png, .jpg, .jpeg, .bmp, .webp)",
file_count="multiple",
type="filepath"
)
evaluate_button_main = gr.Button("πΌοΈ Evaluate Uploaded Images", variant="primary")
gr.Markdown("---")
status_output_main = gr.Textbox(label="π Evaluation Status", interactive=False, lines=2)
gr.Markdown("### πΌοΈ Per-Image Evaluation Results")
gr.Markdown("Click column headers to sort. Previews are thumbnails.")
# MODIFIED LINE BELOW:
results_table_output = gr.DataFrame(
headers=["Preview", "Filename", "BRISQUE (PIQ) (β)", "NIQE (IQA-PyTorch) (β)", "MUSIQ-NR (IQA-PyTorch) (β)"],
datatype=["html", "str", "number", "number", "number"],
interactive=False,
wrap=True,
overflow_row_behaviour="paginate", # max_rows removed
# height=450 # Optional: Set a fixed height in pixels if you want ~15 rows visible before scrolling within the component
)
with gr.TabItem("βοΈ Calculate FID (Set vs. Set)"):
gr.Markdown("""
Calculate FrΓ©chet Inception Distance (FID) between two sets of images.
FID measures the similarity of two image distributions (e.g., generated vs. real, or version A vs. version B).
**Lower FID scores are better**, indicating more similarity.
""")
with gr.Row():
fid_set1_upload = gr.Files(label="Upload Images for Set 1 (.png, .jpg, .jpeg, .bmp, .webp)", file_count="multiple", type="filepath")
fid_set2_upload = gr.Files(label="Upload Images for Set 2 (.png, .jpg, .jpeg, .bmp, .webp)", file_count="multiple", type="filepath")
fid_calculate_button = gr.Button("π Calculate FID between Set 1 and Set 2", variant="primary")
fid_result_output = gr.Textbox(label="π FID Result", interactive=False, lines=2)
# Wire components
evaluate_button_main.click(
fn=process_images_for_individual_scores,
inputs=[image_upload_input],
outputs=[results_table_output, status_output_main] #, batch_is_output, batch_fid_output_info]
)
fid_calculate_button.click(
fn=process_fid_for_two_sets,
inputs=[fid_set1_upload, fid_set2_upload],
outputs=[fid_result_output]
)
# --- For Hugging Face Spaces: requirements.txt ---
# Ensure this content is in your 'requirements.txt' file in the HF Space:
"""
gradio
torch
torchvision
Pillow
numpy
piq>=0.8.0 # Specify version if known good, or just piq
iqa-pytorch>=0.2.1 # Specify version if known good
timm # A dependency for some iqa-pytorch models like MUSIQ
scikit-image # Often a transitive dependency, good to include
pandas
"""
if __name__ == "__main__":
if piq is None or IQA is None:
print("\n\nWARNING: One or more core libraries (PIQ, IQA-PyTorch) are missing.")
print("Please install them by creating a 'requirements.txt' file with the content above and running: pip install -r requirements.txt\n\n")
demo.launch(debug=True) # Set debug=False for production |