Spaces:
Running
Running
import json | |
import os | |
import cv2 | |
import gradio as gr | |
import imagehash | |
import numpy as np | |
import plotly.graph_objects as go | |
from gradio_imageslider import ImageSlider | |
from PIL import Image | |
from scipy.stats import pearsonr | |
from skimage.metrics import mean_squared_error as mse_skimage | |
from skimage.metrics import peak_signal_noise_ratio as psnr_skimage | |
from skimage.metrics import structural_similarity as ssim | |
class FrameMetrics: | |
"""Class to compute and store frame-by-frame metrics""" | |
def __init__(self): | |
self.metrics = {} | |
def compute_ssim(self, frame1, frame2): | |
"""Compute SSIM between two frames""" | |
if frame1 is None or frame2 is None: | |
return None | |
try: | |
# Convert to grayscale for SSIM computation | |
gray1 = ( | |
cv2.cvtColor(frame1, cv2.COLOR_RGB2GRAY) | |
if len(frame1.shape) == 3 | |
else frame1 | |
) | |
gray2 = ( | |
cv2.cvtColor(frame2, cv2.COLOR_RGB2GRAY) | |
if len(frame2.shape) == 3 | |
else frame2 | |
) | |
# Ensure both frames have the same dimensions | |
if gray1.shape != gray2.shape: | |
# Resize to match the smaller dimension | |
h = min(gray1.shape[0], gray2.shape[0]) | |
w = min(gray1.shape[1], gray2.shape[1]) | |
gray1 = cv2.resize(gray1, (w, h)) | |
gray2 = cv2.resize(gray2, (w, h)) | |
# Compute SSIM | |
ssim_value = ssim(gray1, gray2, data_range=255) | |
return ssim_value | |
except Exception as e: | |
print(f"SSIM computation failed: {e}") | |
return None | |
def compute_ms_ssim(self, frame1, frame2): | |
"""Compute Multi-Scale SSIM between two frames""" | |
if frame1 is None or frame2 is None: | |
return None | |
try: | |
# Convert to grayscale for MS-SSIM computation | |
gray1 = ( | |
cv2.cvtColor(frame1, cv2.COLOR_RGB2GRAY) | |
if len(frame1.shape) == 3 | |
else frame1 | |
) | |
gray2 = ( | |
cv2.cvtColor(frame2, cv2.COLOR_RGB2GRAY) | |
if len(frame2.shape) == 3 | |
else frame2 | |
) | |
# Ensure both frames have the same dimensions | |
if gray1.shape != gray2.shape: | |
h = min(gray1.shape[0], gray2.shape[0]) | |
w = min(gray1.shape[1], gray2.shape[1]) | |
gray1 = cv2.resize(gray1, (w, h)) | |
gray2 = cv2.resize(gray2, (w, h)) | |
# Ensure minimum size for multi-scale analysis | |
min_size = 32 | |
if min(gray1.shape) < min_size: | |
return None | |
# Compute MS-SSIM using multiple scales | |
from skimage.metrics import structural_similarity | |
# Use win_size that works with image dimensions | |
win_size = min(7, min(gray1.shape) // 4) | |
if win_size < 3: | |
win_size = 3 | |
ms_ssim_val = structural_similarity( | |
gray1, gray2, data_range=255, win_size=win_size, multichannel=False | |
) | |
return ms_ssim_val | |
except Exception as e: | |
print(f"MS-SSIM computation failed: {e}") | |
return None | |
def compute_psnr(self, frame1, frame2): | |
"""Compute PSNR between two frames""" | |
if frame1 is None or frame2 is None: | |
return None | |
try: | |
# Ensure both frames have the same dimensions | |
if frame1.shape != frame2.shape: | |
h = min(frame1.shape[0], frame2.shape[0]) | |
w = min(frame1.shape[1], frame2.shape[1]) | |
c = ( | |
min(frame1.shape[2], frame2.shape[2]) | |
if len(frame1.shape) == 3 | |
else 1 | |
) | |
if len(frame1.shape) == 3: | |
frame1 = cv2.resize(frame1, (w, h))[:, :, :c] | |
frame2 = cv2.resize(frame2, (w, h))[:, :, :c] | |
else: | |
frame1 = cv2.resize(frame1, (w, h)) | |
frame2 = cv2.resize(frame2, (w, h)) | |
# Compute PSNR | |
return psnr_skimage(frame1, frame2, data_range=255) | |
except Exception as e: | |
print(f"PSNR computation failed: {e}") | |
return None | |
def compute_mse(self, frame1, frame2): | |
"""Compute MSE between two frames""" | |
if frame1 is None or frame2 is None: | |
return None | |
try: | |
# Ensure both frames have the same dimensions | |
if frame1.shape != frame2.shape: | |
h = min(frame1.shape[0], frame2.shape[0]) | |
w = min(frame1.shape[1], frame2.shape[1]) | |
c = ( | |
min(frame1.shape[2], frame2.shape[2]) | |
if len(frame1.shape) == 3 | |
else 1 | |
) | |
if len(frame1.shape) == 3: | |
frame1 = cv2.resize(frame1, (w, h))[:, :, :c] | |
frame2 = cv2.resize(frame2, (w, h))[:, :, :c] | |
else: | |
frame1 = cv2.resize(frame1, (w, h)) | |
frame2 = cv2.resize(frame2, (w, h)) | |
# Compute MSE | |
return mse_skimage(frame1, frame2) | |
except Exception as e: | |
print(f"MSE computation failed: {e}") | |
return None | |
def compute_phash(self, frame1, frame2): | |
"""Compute perceptual hash similarity between two frames""" | |
if frame1 is None or frame2 is None: | |
return None | |
try: | |
# Convert to PIL Images for imagehash | |
pil1 = Image.fromarray(frame1) | |
pil2 = Image.fromarray(frame2) | |
# Compute perceptual hashes | |
hash1 = imagehash.phash(pil1) | |
hash2 = imagehash.phash(pil2) | |
# Calculate similarity (lower hamming distance = more similar) | |
hamming_distance = hash1 - hash2 | |
# Convert to similarity score (0-1, where 1 is identical) | |
max_distance = len(str(hash1)) * 4 # 4 bits per hex char | |
similarity = 1 - (hamming_distance / max_distance) | |
return similarity | |
except Exception as e: | |
print(f"pHash computation failed: {e}") | |
return None | |
def compute_color_histogram_correlation(self, frame1, frame2): | |
"""Compute color histogram correlation between two frames""" | |
if frame1 is None or frame2 is None: | |
return None | |
try: | |
# Ensure both frames have the same dimensions | |
if frame1.shape != frame2.shape: | |
h = min(frame1.shape[0], frame2.shape[0]) | |
w = min(frame1.shape[1], frame2.shape[1]) | |
frame1 = cv2.resize(frame1, (w, h)) | |
frame2 = cv2.resize(frame2, (w, h)) | |
# Compute histograms for each channel | |
correlations = [] | |
if len(frame1.shape) == 3: # Color image | |
for i in range(3): # R, G, B channels | |
hist1 = cv2.calcHist([frame1], [i], None, [256], [0, 256]) | |
hist2 = cv2.calcHist([frame2], [i], None, [256], [0, 256]) | |
# Flatten histograms | |
hist1 = hist1.flatten() | |
hist2 = hist2.flatten() | |
# Compute correlation | |
if np.std(hist1) > 0 and np.std(hist2) > 0: | |
corr, _ = pearsonr(hist1, hist2) | |
correlations.append(corr) | |
# Return average correlation across channels | |
return np.mean(correlations) if correlations else 0.0 | |
else: # Grayscale | |
hist1 = cv2.calcHist([frame1], [0], None, [256], [0, 256]).flatten() | |
hist2 = cv2.calcHist([frame2], [0], None, [256], [0, 256]).flatten() | |
if np.std(hist1) > 0 and np.std(hist2) > 0: | |
corr, _ = pearsonr(hist1, hist2) | |
return corr | |
else: | |
return 0.0 | |
except Exception as e: | |
print(f"Color histogram correlation computation failed: {e}") | |
return None | |
def compute_sharpness(self, frame): | |
"""Compute sharpness using Laplacian variance method""" | |
if frame is None: | |
return None | |
# Convert to grayscale if needed | |
gray = ( | |
cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) if len(frame.shape) == 3 else frame | |
) | |
# Compute Laplacian variance (higher values = sharper) | |
laplacian = cv2.Laplacian(gray, cv2.CV_64F) | |
sharpness = laplacian.var() | |
return sharpness | |
def compute_frame_metrics(self, frame1, frame2, frame_idx): | |
"""Compute all metrics for a frame pair""" | |
metrics = { | |
"frame_index": frame_idx, | |
"ssim": self.compute_ssim(frame1, frame2), | |
"psnr": self.compute_psnr(frame1, frame2), | |
"mse": self.compute_mse(frame1, frame2), | |
"phash": self.compute_phash(frame1, frame2), | |
"color_hist_corr": self.compute_color_histogram_correlation(frame1, frame2), | |
"sharpness1": self.compute_sharpness(frame1), | |
"sharpness2": self.compute_sharpness(frame2), | |
} | |
# Compute average sharpness for the pair | |
if metrics["sharpness1"] is not None and metrics["sharpness2"] is not None: | |
metrics["sharpness_avg"] = ( | |
metrics["sharpness1"] + metrics["sharpness2"] | |
) / 2 | |
metrics["sharpness_diff"] = abs( | |
metrics["sharpness1"] - metrics["sharpness2"] | |
) | |
else: | |
metrics["sharpness_avg"] = None | |
metrics["sharpness_diff"] = None | |
return metrics | |
def compute_all_metrics(self, frames1, frames2): | |
"""Compute metrics for all frame pairs""" | |
all_metrics = [] | |
max_frames = max(len(frames1), len(frames2)) | |
for i in range(max_frames): | |
frame1 = frames1[i] if i < len(frames1) else None | |
frame2 = frames2[i] if i < len(frames2) else None | |
if frame1 is not None or frame2 is not None: | |
metrics = self.compute_frame_metrics(frame1, frame2, i) | |
all_metrics.append(metrics) | |
else: | |
# Handle cases where both frames are missing | |
all_metrics.append( | |
{ | |
"frame_index": i, | |
"ssim": None, | |
"ms_ssim": None, | |
"psnr": None, | |
"mse": None, | |
"phash": None, | |
"color_hist_corr": None, | |
"sharpness1": None, | |
"sharpness2": None, | |
"sharpness_avg": None, | |
"sharpness_diff": None, | |
} | |
) | |
return all_metrics | |
def get_metric_summary(self, metrics_list): | |
"""Compute summary statistics for all metrics""" | |
metric_names = [ | |
"ssim", | |
"psnr", | |
"mse", | |
"phash", | |
"color_hist_corr", | |
"sharpness1", | |
"sharpness2", | |
"sharpness_avg", | |
"sharpness_diff", | |
] | |
summary = { | |
"total_frames": len(metrics_list), | |
"valid_frames": len([m for m in metrics_list if m.get("ssim") is not None]), | |
} | |
# Compute statistics for each metric | |
for metric_name in metric_names: | |
valid_values = [ | |
m[metric_name] for m in metrics_list if m.get(metric_name) is not None | |
] | |
if valid_values: | |
summary.update( | |
{ | |
f"{metric_name}_mean": np.mean(valid_values), | |
f"{metric_name}_min": np.min(valid_values), | |
f"{metric_name}_max": np.max(valid_values), | |
f"{metric_name}_std": np.std(valid_values), | |
} | |
) | |
return summary | |
def create_individual_metric_plots(self, metrics_list, current_frame=0): | |
"""Create individual plots for each metric with frame on x-axis""" | |
if not metrics_list: | |
return None | |
# Extract frame indices | |
frame_indices = [m["frame_index"] for m in metrics_list] | |
# Helper function to get valid data | |
def get_valid_data(metric_name): | |
values = [m.get(metric_name) for m in metrics_list] | |
valid_indices = [i for i, v in enumerate(values) if v is not None] | |
valid_values = [values[i] for i in valid_indices] | |
valid_frames = [frame_indices[i] for i in valid_indices] | |
return valid_frames, valid_values | |
# Create individual plots for each metric | |
plots = {} | |
# 1. SSIM Plot | |
ssim_frames, ssim_values = get_valid_data("ssim") | |
if ssim_values: | |
# Calculate dynamic y-axis range for SSIM to highlight differences | |
min_ssim = min(ssim_values) | |
max_ssim = max(ssim_values) | |
ssim_range = max_ssim - min_ssim | |
# If there's very little variation, zoom in to show differences | |
if ssim_range < 0.05: | |
# For small variations, zoom in to show differences better | |
center = (min_ssim + max_ssim) / 2 | |
padding = max( | |
0.02, ssim_range * 2 | |
) # At least 0.02 range or 2x actual range | |
y_min = max(0, center - padding) | |
y_max = min(1, center + padding) | |
else: | |
# For larger variations, add some padding | |
padding = ssim_range * 0.15 # 15% padding | |
y_min = max(0, min_ssim - padding) | |
y_max = min(1, max_ssim + padding) | |
fig_ssim = go.Figure() | |
# Add area fill to emphasize the curve | |
fig_ssim.add_trace( | |
go.Scatter( | |
x=ssim_frames, | |
y=[y_min] * len(ssim_frames), | |
mode="lines", | |
line=dict( | |
color="rgba(0,0,255,0)" | |
), # Transparent line for area base | |
showlegend=False, | |
hoverinfo="skip", | |
) | |
) | |
fig_ssim.add_trace( | |
go.Scatter( | |
x=ssim_frames, | |
y=ssim_values, | |
mode="lines+markers", | |
name="SSIM", | |
line=dict(color="blue", width=3), | |
marker=dict( | |
size=6, color="blue", line=dict(color="darkblue", width=1) | |
), | |
hovertemplate="<b>Frame %{x}</b><br>SSIM: %{y:.5f}<extra></extra>", | |
fill="tonexty", | |
fillcolor="rgba(0,0,255,0.1)", # Light blue fill | |
) | |
) | |
if current_frame is not None: | |
fig_ssim.add_vline( | |
x=current_frame, | |
line_dash="dash", | |
line_color="red", | |
line_width=2, | |
) | |
fig_ssim.update_layout( | |
height=300, | |
margin=dict(t=20, b=40, l=60, r=20), | |
plot_bgcolor="rgba(0,0,0,0)", | |
paper_bgcolor="rgba(0,0,0,0)", | |
showlegend=False, | |
dragmode=False, | |
hovermode="x unified", | |
) | |
fig_ssim.update_xaxes( | |
title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True | |
) | |
fig_ssim.update_yaxes( | |
title_text="SSIM", | |
range=[y_min, y_max], | |
gridcolor="rgba(128,128,128,0.4)", | |
fixedrange=True, | |
) | |
plots["ssim"] = fig_ssim | |
# 2. PSNR Plot | |
psnr_frames, psnr_values = get_valid_data("psnr") | |
if psnr_values: | |
fig_psnr = go.Figure() | |
fig_psnr.add_trace( | |
go.Scatter( | |
x=psnr_frames, | |
y=psnr_values, | |
mode="lines+markers", | |
name="PSNR", | |
line=dict(color="green", width=3), | |
marker=dict(size=6), | |
hovertemplate="<b>Frame %{x}</b><br>PSNR: %{y:.2f} dB<extra></extra>", | |
) | |
) | |
if current_frame is not None: | |
fig_psnr.add_vline( | |
x=current_frame, | |
line_dash="dash", | |
line_color="red", | |
line_width=2, | |
) | |
fig_psnr.update_layout( | |
height=300, | |
margin=dict(t=20, b=40, l=60, r=20), | |
plot_bgcolor="rgba(0,0,0,0)", | |
paper_bgcolor="rgba(0,0,0,0)", | |
showlegend=False, | |
dragmode=False, | |
hovermode="x unified", | |
) | |
fig_psnr.update_xaxes( | |
title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True | |
) | |
fig_psnr.update_yaxes( | |
title_text="PSNR (dB)", | |
gridcolor="rgba(128,128,128,0.4)", | |
fixedrange=True, | |
) | |
plots["psnr"] = fig_psnr | |
# 3. MSE Plot | |
mse_frames, mse_values = get_valid_data("mse") | |
if mse_values: | |
fig_mse = go.Figure() | |
fig_mse.add_trace( | |
go.Scatter( | |
x=mse_frames, | |
y=mse_values, | |
mode="lines+markers", | |
name="MSE", | |
line=dict(color="red", width=3), | |
marker=dict(size=6), | |
hovertemplate="<b>Frame %{x}</b><br>MSE: %{y:.2f}<extra></extra>", | |
) | |
) | |
if current_frame is not None: | |
fig_mse.add_vline( | |
x=current_frame, | |
line_dash="dash", | |
line_color="red", | |
line_width=2, | |
) | |
fig_mse.update_layout( | |
height=300, | |
margin=dict(t=20, b=40, l=60, r=20), | |
plot_bgcolor="rgba(0,0,0,0)", | |
paper_bgcolor="rgba(0,0,0,0)", | |
showlegend=False, | |
dragmode=False, | |
hovermode="x unified", | |
) | |
fig_mse.update_xaxes( | |
title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True | |
) | |
fig_mse.update_yaxes( | |
title_text="MSE", gridcolor="rgba(128,128,128,0.4)", fixedrange=True | |
) | |
plots["mse"] = fig_mse | |
# 4. pHash Plot | |
phash_frames, phash_values = get_valid_data("phash") | |
if phash_values: | |
fig_phash = go.Figure() | |
fig_phash.add_trace( | |
go.Scatter( | |
x=phash_frames, | |
y=phash_values, | |
mode="lines+markers", | |
name="pHash", | |
line=dict(color="purple", width=3), | |
marker=dict(size=6), | |
hovertemplate="<b>Frame %{x}</b><br>pHash: %{y:.4f}<extra></extra>", | |
) | |
) | |
if current_frame is not None: | |
fig_phash.add_vline( | |
x=current_frame, | |
line_dash="dash", | |
line_color="red", | |
line_width=2, | |
) | |
fig_phash.update_layout( | |
height=300, | |
margin=dict(t=20, b=40, l=60, r=20), | |
plot_bgcolor="rgba(0,0,0,0)", | |
paper_bgcolor="rgba(0,0,0,0)", | |
showlegend=False, | |
dragmode=False, | |
hovermode="x unified", | |
) | |
fig_phash.update_xaxes( | |
title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True | |
) | |
fig_phash.update_yaxes( | |
title_text="pHash Similarity", | |
gridcolor="rgba(128,128,128,0.4)", | |
fixedrange=True, | |
) | |
plots["phash"] = fig_phash | |
# 5. Color Histogram Correlation Plot | |
hist_frames, hist_values = get_valid_data("color_hist_corr") | |
if hist_values: | |
fig_hist = go.Figure() | |
fig_hist.add_trace( | |
go.Scatter( | |
x=hist_frames, | |
y=hist_values, | |
mode="lines+markers", | |
name="Color Histogram", | |
line=dict(color="orange", width=3), | |
marker=dict(size=6), | |
hovertemplate="<b>Frame %{x}</b><br>Color Histogram: %{y:.4f}<extra></extra>", | |
) | |
) | |
if current_frame is not None: | |
fig_hist.add_vline( | |
x=current_frame, | |
line_dash="dash", | |
line_color="red", | |
line_width=2, | |
) | |
fig_hist.update_layout( | |
height=300, | |
margin=dict(t=20, b=40, l=60, r=20), | |
plot_bgcolor="rgba(0,0,0,0)", | |
paper_bgcolor="rgba(0,0,0,0)", | |
showlegend=False, | |
dragmode=False, | |
hovermode="x unified", | |
) | |
fig_hist.update_xaxes( | |
title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True | |
) | |
fig_hist.update_yaxes( | |
title_text="Color Histogram Correlation", | |
gridcolor="rgba(128,128,128,0.4)", | |
fixedrange=True, | |
) | |
plots["color_hist"] = fig_hist | |
# 6. Sharpness Comparison Plot | |
sharp1_frames, sharp1_values = get_valid_data("sharpness1") | |
sharp2_frames, sharp2_values = get_valid_data("sharpness2") | |
if sharp1_values or sharp2_values: | |
fig_sharp = go.Figure() | |
if sharp1_values: | |
fig_sharp.add_trace( | |
go.Scatter( | |
x=sharp1_frames, | |
y=sharp1_values, | |
mode="lines+markers", | |
name="Video 1", | |
line=dict(color="darkgreen", width=3), | |
marker=dict(size=6), | |
hovertemplate="<b>Frame %{x}</b><br>Video 1 Sharpness: %{y:.1f}<extra></extra>", | |
) | |
) | |
if sharp2_values: | |
fig_sharp.add_trace( | |
go.Scatter( | |
x=sharp2_frames, | |
y=sharp2_values, | |
mode="lines+markers", | |
name="Video 2", | |
line=dict(color="darkblue", width=3), | |
marker=dict(size=6), | |
hovertemplate="<b>Frame %{x}</b><br>Video 2 Sharpness: %{y:.1f}<extra></extra>", | |
) | |
) | |
if current_frame is not None: | |
fig_sharp.add_vline( | |
x=current_frame, | |
line_dash="dash", | |
line_color="red", | |
line_width=2, | |
) | |
fig_sharp.update_layout( | |
height=300, | |
margin=dict(t=20, b=40, l=60, r=20), | |
plot_bgcolor="rgba(0,0,0,0)", | |
paper_bgcolor="rgba(0,0,0,0)", | |
showlegend=True, | |
legend=dict( | |
orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5 | |
), | |
dragmode=False, | |
hovermode="x unified", | |
) | |
fig_sharp.update_xaxes( | |
title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True | |
) | |
fig_sharp.update_yaxes( | |
title_text="Sharpness", | |
gridcolor="rgba(128,128,128,0.4)", | |
fixedrange=True, | |
) | |
plots["sharpness"] = fig_sharp | |
# 7. Overall Quality Score Plot (Combination of metrics) | |
# Calculate overall quality score by combining normalized metrics | |
if ssim_values and psnr_values and len(ssim_values) == len(psnr_values): | |
# Get data for metrics that contribute to overall score | |
phash_frames_overall, phash_values_overall = get_valid_data("phash") | |
# Ensure we have the same frames for all metrics | |
common_frames = set(ssim_frames) & set(psnr_frames) | |
if phash_values_overall: | |
common_frames = common_frames & set(phash_frames_overall) | |
common_frames = sorted(list(common_frames)) | |
if common_frames: | |
# Extract values for common frames | |
ssim_common = [ | |
ssim_values[ssim_frames.index(f)] | |
for f in common_frames | |
if f in ssim_frames | |
] | |
psnr_common = [ | |
psnr_values[psnr_frames.index(f)] | |
for f in common_frames | |
if f in psnr_frames | |
] | |
# Normalize PSNR to 0-1 scale using min-max normalization | |
if psnr_common: | |
psnr_min = min(psnr_common) | |
psnr_max = max(psnr_common) | |
if psnr_max > psnr_min: | |
psnr_normalized = [ | |
(p - psnr_min) / (psnr_max - psnr_min) for p in psnr_common | |
] | |
else: | |
psnr_normalized = [0.0 for _ in psnr_common] | |
else: | |
psnr_normalized = [] | |
# Start with SSIM and normalized PSNR | |
quality_components = [ssim_common, psnr_normalized] | |
component_names = ["SSIM", "PSNR"] | |
# Add pHash if available | |
if phash_values_overall: | |
phash_common = [ | |
phash_values_overall[phash_frames_overall.index(f)] | |
for f in common_frames | |
if f in phash_frames_overall | |
] | |
if len(phash_common) == len(ssim_common): | |
quality_components.append(phash_common) | |
component_names.append("pHash") | |
# Calculate average across all components | |
overall_quality = [] | |
for i in range(len(common_frames)): | |
frame_scores = [ | |
component[i] | |
for component in quality_components | |
if i < len(component) | |
] | |
overall_quality.append(sum(frame_scores) / len(frame_scores)) | |
# Calculate dynamic y-axis range to emphasize differences | |
min_quality = min(overall_quality) | |
max_quality = max(overall_quality) | |
quality_range = max_quality - min_quality | |
# If there's very little variation, use a smaller range to emphasize small differences | |
if quality_range < 0.08: | |
# For small variations, zoom in to show differences better | |
center = (min_quality + max_quality) / 2 | |
padding = max( | |
0.04, quality_range * 2 | |
) # At least 0.04 range or 2x the actual range | |
y_min = max(0, center - padding) | |
y_max = min(1, center + padding) | |
else: | |
# For larger variations, add some padding | |
padding = quality_range * 0.15 # 15% padding | |
y_min = max(0, min_quality - padding) | |
y_max = min(1, max_quality + padding) | |
fig_overall = go.Figure() | |
# Add area fill to emphasize the quality curve | |
fig_overall.add_trace( | |
go.Scatter( | |
x=common_frames, | |
y=[y_min] * len(common_frames), | |
mode="lines", | |
line=dict( | |
color="rgba(255,215,0,0)" | |
), # Transparent line for area base | |
showlegend=False, | |
hoverinfo="skip", | |
) | |
) | |
fig_overall.add_trace( | |
go.Scatter( | |
x=common_frames, | |
y=overall_quality, | |
mode="lines+markers", | |
name="Overall Quality", | |
line=dict(color="gold", width=4), | |
marker=dict( | |
size=8, color="gold", line=dict(color="orange", width=2) | |
), | |
hovertemplate="<b>Frame %{x}</b><br>Overall Quality: %{y:.5f}<br><i>Combined from: " | |
+ ", ".join(component_names) | |
+ "</i><extra></extra>", | |
fill="tonexty", | |
fillcolor="rgba(255,215,0,0.15)", # Semi-transparent gold fill | |
) | |
) | |
# Add quality threshold indicators if there are significant variations | |
if current_frame is not None: | |
fig_overall.add_vline( | |
x=current_frame, | |
line_dash="dash", | |
line_color="red", | |
line_width=2, | |
) | |
fig_overall.update_layout( | |
height=300, | |
margin=dict(t=20, b=40, l=60, r=20), | |
plot_bgcolor="rgba(0,0,0,0)", | |
paper_bgcolor="rgba(0,0,0,0)", | |
showlegend=False, | |
dragmode=False, | |
hovermode="x unified", | |
) | |
fig_overall.update_xaxes( | |
title_text="Frame", | |
gridcolor="rgba(128,128,128,0.4)", | |
fixedrange=True, | |
) | |
fig_overall.update_yaxes( | |
title_text="Overall Quality Score", | |
range=[y_min, y_max], | |
gridcolor="rgba(128,128,128,0.4)", | |
fixedrange=True, | |
) | |
plots["overall"] = fig_overall | |
return plots | |
def create_modern_plot(self, metrics_list, current_frame=0): | |
"""Create individual metric plots instead of combined dashboard""" | |
return self.create_individual_metric_plots(metrics_list, current_frame) | |
class VideoFrameComparator: | |
def __init__(self): | |
self.video1_frames = [] | |
self.video2_frames = [] | |
self.max_frames = 0 | |
self.frame_metrics = FrameMetrics() | |
self.computed_metrics = [] | |
self.metrics_summary = {} | |
def extract_frames(self, video_path): | |
"""Extract all frames from a video file or URL""" | |
if not video_path: | |
return [] | |
# Check if it's a URL or local file | |
is_url = video_path.startswith(("http://", "https://")) | |
if not is_url and not os.path.exists(video_path): | |
print(f"Warning: Local video file not found: {video_path}") | |
return [] | |
frames = [] | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print( | |
f"Error: Could not open video {'URL' if is_url else 'file'}: {video_path}" | |
) | |
return [] | |
try: | |
frame_count = 0 | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Convert BGR to RGB for display | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames.append(frame_rgb) | |
frame_count += 1 | |
# Add progress feedback for URLs (which might be slower) | |
if is_url and frame_count % 30 == 0: | |
print(f"Processed {frame_count} frames from URL...") | |
except Exception as e: | |
print(f"Error processing video: {e}") | |
finally: | |
cap.release() | |
print( | |
f"Successfully extracted {len(frames)} frames from {'URL' if is_url else 'file'}: {video_path}" | |
) | |
return frames | |
def is_comparison_in_data_json( | |
self, video1_path, video2_path, json_file_path="data.json" | |
): | |
"""Check if this video comparison exists in data.json""" | |
try: | |
with open(json_file_path, "r") as f: | |
data = json.load(f) | |
for comparison in data.get("comparisons", []): | |
videos = comparison.get("videos", []) | |
if len(videos) == 2: | |
# Check both orders (works for both local files and URLs) | |
if (videos[0] == video1_path and videos[1] == video2_path) or ( | |
videos[0] == video2_path and videos[1] == video1_path | |
): | |
return True | |
return False | |
except Exception: | |
return False | |
def load_videos(self, video1_path, video2_path): | |
"""Load both videos and extract frames""" | |
if not video1_path and not video2_path: | |
return "Please upload at least one video.", 0, None, None, "", None | |
# Extract frames from both videos | |
self.video1_frames = self.extract_frames(video1_path) if video1_path else [] | |
self.video2_frames = self.extract_frames(video2_path) if video2_path else [] | |
# Determine maximum number of frames | |
self.max_frames = max(len(self.video1_frames), len(self.video2_frames)) | |
if self.max_frames == 0: | |
return ( | |
"No valid frames found in the uploaded videos.", | |
0, | |
None, | |
None, | |
"", | |
None, | |
) | |
# Compute metrics if both videos are present and not in data.json | |
metrics_info = "" | |
plots = None | |
if ( | |
video1_path | |
and video2_path | |
and not self.is_comparison_in_data_json(video1_path, video2_path) | |
): | |
print("Computing comprehensive frame-by-frame metrics...") | |
self.computed_metrics = self.frame_metrics.compute_all_metrics( | |
self.video1_frames, self.video2_frames | |
) | |
self.metrics_summary = self.frame_metrics.get_metric_summary( | |
self.computed_metrics | |
) | |
# Build metrics info string | |
metrics_info = "\n\n📊 Computed Metrics Summary:\n" | |
metric_display = { | |
"ssim": ("SSIM", ".4f", "", "↑ Higher=Better"), | |
"psnr": ("PSNR", ".2f", " dB", "↑ Higher=Better"), | |
"mse": ("MSE", ".2f", "", "↓ Lower=Better"), | |
"phash": ("pHash", ".4f", "", "↑ Higher=Better"), | |
"color_hist_corr": ("Color Hist", ".4f", "", "↑ Higher=Better"), | |
"sharpness_avg": ("Sharpness", ".1f", "", "↑ Higher=Better"), | |
} | |
for metric_key, ( | |
display_name, | |
format_str, | |
unit, | |
direction, | |
) in metric_display.items(): | |
if self.metrics_summary.get(f"{metric_key}_mean") is not None: | |
mean_val = self.metrics_summary[f"{metric_key}_mean"] | |
std_val = self.metrics_summary[f"{metric_key}_std"] | |
metrics_info += f"{display_name}: μ={mean_val:{format_str}}{unit}, σ={std_val:{format_str}}{unit} ({direction})\n" | |
metrics_info += f"Valid Frames: {self.metrics_summary['valid_frames']}/{self.metrics_summary['total_frames']}" | |
# Generate initial plot | |
plots = self.frame_metrics.create_individual_metric_plots( | |
self.computed_metrics, 0 | |
) | |
else: | |
self.computed_metrics = [] | |
self.metrics_summary = {} | |
if video1_path and video2_path: | |
metrics_info = "\n\n📋 Note: This comparison is predefined in data.json (metrics not computed)" | |
# Get initial frames | |
frame1 = ( | |
self.video1_frames[0] | |
if self.video1_frames | |
else np.zeros((480, 640, 3), dtype=np.uint8) | |
) | |
frame2 = ( | |
self.video2_frames[0] | |
if self.video2_frames | |
else np.zeros((480, 640, 3), dtype=np.uint8) | |
) | |
status_msg = "Videos loaded successfully!\n" | |
status_msg += f"Video 1: {len(self.video1_frames)} frames\n" | |
status_msg += f"Video 2: {len(self.video2_frames)} frames\n" | |
status_msg += ( | |
f"Use the slider to navigate through frames (0-{self.max_frames - 1})" | |
) | |
status_msg += metrics_info | |
return ( | |
status_msg, | |
self.max_frames - 1, | |
frame1, | |
frame2, | |
self.get_current_frame_info(0), | |
plots, | |
) | |
def get_frames_at_index(self, frame_index): | |
"""Get frames at specific index from both videos""" | |
frame_index = int(frame_index) | |
# Get frame from video 1 | |
if frame_index < len(self.video1_frames): | |
frame1 = self.video1_frames[frame_index] | |
else: | |
# Create a placeholder if frame doesn't exist | |
frame1 = np.zeros((480, 640, 3), dtype=np.uint8) | |
cv2.putText( | |
frame1, | |
f"Frame {frame_index} not available", | |
(50, 240), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1, | |
(255, 255, 255), | |
2, | |
) | |
# Get frame from video 2 | |
if frame_index < len(self.video2_frames): | |
frame2 = self.video2_frames[frame_index] | |
else: | |
# Create a placeholder if frame doesn't exist | |
frame2 = np.zeros((480, 640, 3), dtype=np.uint8) | |
cv2.putText( | |
frame2, | |
f"Frame {frame_index} not available", | |
(50, 240), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1, | |
(255, 255, 255), | |
2, | |
) | |
return frame1, frame2 | |
def get_current_frame_info(self, frame_index): | |
"""Get information about the current frame including metrics""" | |
frame_index = int(frame_index) | |
info = f"Current Frame: {frame_index} / {self.max_frames - 1}" | |
# Add metrics info if available | |
if self.computed_metrics and frame_index < len(self.computed_metrics): | |
metrics = self.computed_metrics[frame_index] | |
# === COMPARISON METRICS (Between Videos) === | |
comparison_metrics = [] | |
# SSIM with quality assessment | |
if metrics.get("ssim") is not None: | |
ssim_val = metrics["ssim"] | |
if ssim_val >= 0.9: | |
quality = "🟢 Excellent" | |
elif ssim_val >= 0.8: | |
quality = "🔵 Good" | |
elif ssim_val >= 0.6: | |
quality = "🟡 Fair" | |
else: | |
quality = "🔴 Poor" | |
comparison_metrics.append( | |
f"SSIM: {ssim_val:.4f} ({quality} similarity)" | |
) | |
# PSNR with quality indicator | |
if metrics.get("psnr") is not None: | |
psnr_val = metrics["psnr"] | |
if psnr_val >= 40: | |
psnr_quality = "🟢 Excellent" | |
elif psnr_val >= 30: | |
psnr_quality = "🔵 Good" | |
elif psnr_val >= 20: | |
psnr_quality = "🟡 Fair" | |
else: | |
psnr_quality = "🔴 Poor" | |
comparison_metrics.append( | |
f"PSNR: {psnr_val:.1f}dB ({psnr_quality} signal quality)" | |
) | |
# MSE with quality indicator (lower is better) | |
if metrics.get("mse") is not None: | |
mse_val = metrics["mse"] | |
if mse_val <= 50: | |
mse_quality = "🟢 Very Similar" | |
elif mse_val <= 100: | |
mse_quality = "🔵 Similar" | |
elif mse_val <= 200: | |
mse_quality = "🟡 Moderately Different" | |
else: | |
mse_quality = "🔴 Very Different" | |
comparison_metrics.append(f"MSE: {mse_val:.1f} ({mse_quality})") | |
# pHash with quality indicator | |
if metrics.get("phash") is not None: | |
phash_val = metrics["phash"] | |
if phash_val >= 0.95: | |
phash_quality = "🟢 Nearly Identical" | |
elif phash_val >= 0.9: | |
phash_quality = "🔵 Very Similar" | |
elif phash_val >= 0.8: | |
phash_quality = "🟡 Somewhat Similar" | |
else: | |
phash_quality = "🔴 Different" | |
comparison_metrics.append( | |
f"pHash: {phash_val:.3f} ({phash_quality} perceptually)" | |
) | |
# Color Histogram Correlation | |
if metrics.get("color_hist_corr") is not None: | |
color_val = metrics["color_hist_corr"] | |
if color_val >= 0.9: | |
color_quality = "🟢 Very Similar Colors" | |
elif color_val >= 0.8: | |
color_quality = "🔵 Similar Colors" | |
elif color_val >= 0.6: | |
color_quality = "🟡 Moderate Color Diff" | |
else: | |
color_quality = "🔴 Different Colors" | |
comparison_metrics.append(f"Color: {color_val:.3f} ({color_quality})") | |
# Add comparison metrics to info | |
if comparison_metrics: | |
info += "\n📊 Comparison Analysis: " + " | ".join(comparison_metrics) | |
# === INDIVIDUAL VIDEO QUALITY === | |
individual_metrics = [] | |
# Individual Sharpness for each video | |
if metrics.get("sharpness1") is not None: | |
sharp1 = metrics["sharpness1"] | |
if sharp1 >= 200: | |
sharp1_quality = "🟢 Sharp" | |
elif sharp1 >= 100: | |
sharp1_quality = "🔵 Moderate" | |
elif sharp1 >= 50: | |
sharp1_quality = "🟡 Soft" | |
else: | |
sharp1_quality = "🔴 Blurry" | |
individual_metrics.append( | |
f"V1 Sharpness: {sharp1:.0f} ({sharp1_quality})" | |
) | |
if metrics.get("sharpness2") is not None: | |
sharp2 = metrics["sharpness2"] | |
if sharp2 >= 200: | |
sharp2_quality = "🟢 Sharp" | |
elif sharp2 >= 100: | |
sharp2_quality = "🔵 Moderate" | |
elif sharp2 >= 50: | |
sharp2_quality = "🟡 Soft" | |
else: | |
sharp2_quality = "🔴 Blurry" | |
individual_metrics.append( | |
f"V2 Sharpness: {sharp2:.0f} ({sharp2_quality})" | |
) | |
# Sharpness comparison | |
if ( | |
metrics.get("sharpness1") is not None | |
and metrics.get("sharpness2") is not None | |
): | |
sharp1 = metrics["sharpness1"] | |
sharp2 = metrics["sharpness2"] | |
# Calculate difference percentage | |
diff_pct = abs(sharp1 - sharp2) / max(sharp1, sharp2) * 100 | |
# Determine significance with clearer labels | |
if diff_pct > 20: | |
significance = "🔴 MAJOR difference" | |
elif diff_pct > 10: | |
significance = "🟡 MODERATE difference" | |
elif diff_pct > 5: | |
significance = "🔵 MINOR difference" | |
else: | |
significance = "🟢 NEGLIGIBLE difference" | |
# Determine which is sharper | |
if sharp1 > sharp2: | |
comparison = "V1 is sharper" | |
elif sharp2 > sharp1: | |
comparison = "V2 is sharper" | |
else: | |
comparison = "Equal sharpness" | |
individual_metrics.append(f"Sharpness: {comparison} ({significance})") | |
# Add individual metrics to info | |
if individual_metrics: | |
info += "\n🎯 Individual Quality: " + " | ".join(individual_metrics) | |
# === OVERALL QUALITY ASSESSMENT === | |
# Calculate combined quality score from multiple metrics | |
quality_score = 0 | |
quality_count = 0 | |
metric_contributions = [] | |
# SSIM contribution | |
if metrics.get("ssim") is not None: | |
quality_score += metrics["ssim"] | |
quality_count += 1 | |
metric_contributions.append(f"SSIM({metrics['ssim']:.3f})") | |
# PSNR contribution (normalized to 0-1 scale) | |
if metrics.get("psnr") is not None: | |
psnr_norm = min(metrics["psnr"] / 50, 1.0) | |
quality_score += psnr_norm | |
quality_count += 1 | |
metric_contributions.append(f"PSNR({psnr_norm:.3f})") | |
# pHash contribution | |
if metrics.get("phash") is not None: | |
quality_score += metrics["phash"] | |
quality_count += 1 | |
metric_contributions.append(f"pHash({metrics['phash']:.3f})") | |
if quality_count > 0: | |
avg_quality = quality_score / quality_count | |
# Add overall assessment with formula explanation | |
if avg_quality >= 0.9: | |
overall = "✨ Excellent Overall" | |
quality_indicator = "🟢" | |
elif avg_quality >= 0.8: | |
overall = "✅ Good Overall" | |
quality_indicator = "🔵" | |
elif avg_quality >= 0.6: | |
overall = "⚠️ Fair Overall" | |
quality_indicator = "🟡" | |
else: | |
overall = "❌ Poor Overall" | |
quality_indicator = "🔴" | |
# Calculate quality variation across all frames to show differences | |
quality_variation = "" | |
if self.computed_metrics and len(self.computed_metrics) > 1: | |
# Calculate overall quality for all frames to show variation | |
all_quality_scores = [] | |
for metric in self.computed_metrics: | |
frame_quality = 0 | |
frame_quality_count = 0 | |
if metric.get("ssim") is not None: | |
frame_quality += metric["ssim"] | |
frame_quality_count += 1 | |
if metric.get("psnr") is not None: | |
frame_quality += min(metric["psnr"] / 50, 1.0) | |
frame_quality_count += 1 | |
if metric.get("phash") is not None: | |
frame_quality += metric["phash"] | |
frame_quality_count += 1 | |
if frame_quality_count > 0: | |
all_quality_scores.append( | |
frame_quality / frame_quality_count | |
) | |
if len(all_quality_scores) > 1: | |
min_qual = min(all_quality_scores) | |
max_qual = max(all_quality_scores) | |
variation = max_qual - min_qual | |
if variation > 0.08: | |
quality_variation = ( | |
f" | 📊 High Variation (Δ{variation:.4f})" | |
) | |
elif variation > 0.04: | |
quality_variation = ( | |
f" | 📊 Moderate Variation (Δ{variation:.4f})" | |
) | |
elif variation > 0.02: | |
quality_variation = ( | |
f" | 📊 Low Variation (Δ{variation:.4f})" | |
) | |
else: | |
quality_variation = ( | |
f" | 📊 Stable Quality (Δ{variation:.4f})" | |
) | |
info += f"\n🎯 Overall Quality: {quality_indicator} {avg_quality:.5f} ({overall}){quality_variation}" | |
info += f"\n 💡 Formula: Average of {' + '.join(metric_contributions)} = {avg_quality:.5f}" | |
return info | |
def get_updated_plot(self, frame_index): | |
"""Get updated plot with current frame highlighted""" | |
if self.computed_metrics: | |
return self.frame_metrics.create_individual_metric_plots( | |
self.computed_metrics, int(frame_index) | |
) | |
return None | |
def load_examples_from_json(json_file_path="data.json"): | |
"""Load example video pairs from JSON configuration file""" | |
try: | |
with open(json_file_path, "r") as f: | |
data = json.load(f) | |
examples = [] | |
# Extract video pairs from the comparisons | |
for comparison in data.get("comparisons", []): | |
videos = comparison.get("videos", []) | |
# Validate that video files/URLs exist or are accessible | |
valid_videos = [] | |
for video_path in videos: | |
if video_path: # Check if not empty/None | |
# Check if it's a URL | |
if video_path.startswith(("http://", "https://")): | |
# For URLs, we'll assume they're valid (can't easily check without downloading) | |
# OpenCV will handle the validation during actual loading | |
valid_videos.append(video_path) | |
print(f"Added video URL: {video_path}") | |
else: | |
# Convert to absolute path for local files | |
abs_path = os.path.abspath(video_path) | |
if os.path.exists(abs_path): | |
valid_videos.append(abs_path) | |
print(f"Added local video file: {abs_path}") | |
elif os.path.exists(video_path): | |
# Try relative path as fallback | |
valid_videos.append(video_path) | |
print(f"Added local video file: {video_path}") | |
else: | |
print( | |
f"Warning: Local video file not found: {video_path} (abs: {abs_path})" | |
) | |
# Add to examples if we have valid videos | |
if len(valid_videos) == 2: | |
examples.append(valid_videos) | |
elif len(valid_videos) == 1: | |
# Single video example (compare with None) | |
examples.append([valid_videos[0], None]) | |
return examples | |
except FileNotFoundError: | |
print(f"Warning: {json_file_path} not found. No examples will be loaded.") | |
return [] | |
except json.JSONDecodeError as e: | |
print(f"Error parsing {json_file_path}: {e}") | |
return [] | |
except Exception as e: | |
print(f"Error loading examples: {e}") | |
return [] | |
def get_all_videos_from_json(json_file_path="data.json"): | |
"""Get list of all unique videos mentioned in the JSON file""" | |
try: | |
with open(json_file_path, "r") as f: | |
data = json.load(f) | |
all_videos = set() | |
# Extract all unique video paths/URLs from comparisons | |
for comparison in data.get("comparisons", []): | |
videos = comparison.get("videos", []) | |
for video_path in videos: | |
if video_path: # Only add non-empty paths | |
# Check if it's a URL or local file | |
if video_path.startswith(("http://", "https://")): | |
# For URLs, add them directly | |
all_videos.add(video_path) | |
elif os.path.exists(video_path): | |
# For local files, check existence before adding | |
all_videos.add(video_path) | |
return sorted(list(all_videos)) | |
except FileNotFoundError: | |
print(f"Warning: {json_file_path} not found.") | |
return [] | |
except json.JSONDecodeError as e: | |
print(f"Error parsing {json_file_path}: {e}") | |
return [] | |
except Exception as e: | |
print(f"Error loading videos: {e}") | |
return [] | |
def create_app(): | |
comparator = VideoFrameComparator() | |
example_pairs = load_examples_from_json() | |
print(f"DEBUG: Loaded {len(example_pairs)} example pairs") | |
for i, pair in enumerate(example_pairs): | |
print(f" Example {i + 1}: {pair}") | |
with gr.Blocks( | |
title="Frame Arena - Video Frame Comparator", | |
# theme=gr.themes.Soft(), | |
fill_width=True, | |
css=""" | |
/* Ensure plots adapt to theme */ | |
.plotly .main-svg { | |
color: var(--body-text-color, #000) !important; | |
} | |
/* Grid visibility for both themes */ | |
.plotly .gridlayer .xgrid, .plotly .gridlayer .ygrid { | |
stroke-opacity: 0.4 !important; | |
} | |
/* Axis text color adaptation */ | |
.plotly .xtick text, .plotly .ytick text { | |
fill: var(--body-text-color, #000) !important; | |
} | |
/* Axis title color adaptation - multiple selectors for better coverage */ | |
.plotly .g-xtitle, .plotly .g-ytitle, | |
.plotly .xtitle, .plotly .ytitle, | |
.plotly text[class*="xtitle"], .plotly text[class*="ytitle"], | |
.plotly .infolayer .g-xtitle, .plotly .infolayer .g-ytitle { | |
fill: var(--body-text-color, #000) !important; | |
} | |
/* Additional axis title selectors */ | |
.plotly .subplot .xtitle, .plotly .subplot .ytitle, | |
.plotly .cartesianlayer .xtitle, .plotly .cartesianlayer .ytitle { | |
fill: var(--body-text-color, #000) !important; | |
} | |
/* SVG text elements in plots */ | |
.plotly svg text { | |
fill: var(--body-text-color, #000) !important; | |
} | |
/* Legend text color */ | |
.plotly .legendtext, .plotly .legend text { | |
fill: var(--body-text-color, #000) !important; | |
} | |
/* Hover label adaptation */ | |
.plotly .hoverlayer .hovertext, .plotly .hovertext { | |
fill: var(--body-text-color, #000) !important; | |
color: var(--body-text-color, #000) !important; | |
} | |
/* Annotation text */ | |
.plotly .annotation-text, .plotly .annotation { | |
fill: var(--body-text-color, #000) !important; | |
} | |
/* Disable plot interactions except hover */ | |
.plotly .modebar { | |
display: none !important; | |
} | |
.plotly .plot-container .plotly { | |
pointer-events: none !important; | |
} | |
.plotly .plot-container .plotly .hoverlayer { | |
pointer-events: auto !important; | |
} | |
.plotly .plot-container .plotly .hovertext { | |
pointer-events: auto !important; | |
} | |
""", | |
) as app: | |
gr.Markdown(""" | |
# 🎬 Frame Arena: Frame by frame comparisons of any videos | |
> 🎉 This tool has been created to celebrate our Wan 2.2 [text-to-video](https://replicate.com/wan-video/wan-2.2-t2v-480p-fast) and [image-to-video](https://replicate.com/wan-video/wan-2.2-i2v-a14b) endpoints on Replicate. Want to know more? Check out [our blog](https://www.wan22.com/blog/video-optimization-on-replicate)! | |
- Upload videos in common formats with the same number of frames (MP4, AVI, MOV, etc.) or use URLs | |
- **7 Quality Metrics**: SSIM, PSNR, MSE, pHash, Color Histogram, Sharpness + Overall Quality | |
- **Individual Visualization**: Each metric gets its own dedicated plot | |
- **Real-time Analysis**: Navigate frames with live metric updates | |
- **Smart Comparisons**: Understand differences between videos per metric | |
**Perfect for**: Analyzing compression effects, processing artifacts, visual quality assessment, and compression algorithm comparisons. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Video 1") | |
video1_input = gr.File( | |
label="Upload Video 1", | |
file_types=[ | |
".mp4", | |
".avi", | |
".mov", | |
".mkv", | |
".wmv", | |
".flv", | |
".webm", | |
], | |
type="filepath", | |
) | |
with gr.Column(): | |
gr.Markdown("### Video 2") | |
video2_input = gr.File( | |
label="Upload Video 2", | |
file_types=[ | |
".mp4", | |
".avi", | |
".mov", | |
".mkv", | |
".wmv", | |
".flv", | |
".webm", | |
], | |
type="filepath", | |
) | |
# Add examples at the top for better UX | |
if example_pairs: | |
gr.Markdown("### 📁 Example Video Comparisons") | |
gr.Examples( | |
examples=example_pairs, | |
inputs=[video1_input, video2_input], | |
label="Click any example to load video pairs:", | |
examples_per_page=10, | |
run_on_click=False, # We'll handle this manually | |
) | |
load_btn = gr.Button("🔄 Load Videos", variant="primary", size="lg") | |
# Frame comparison section (initially hidden) | |
frame_display = gr.Row(visible=True) | |
with frame_display: | |
with gr.Column(): | |
gr.Markdown("### Video 1 - Current Frame") | |
frame1_output = gr.Image( | |
label="Video 1 Frame", | |
type="numpy", | |
interactive=False, | |
# height=400, | |
) | |
with gr.Column(): | |
gr.Markdown("### Frame Slider (Left: Video 1, Right: Video 2)") | |
image_slider = ImageSlider( | |
label="Drag to compare frames", | |
type="numpy", | |
interactive=True, | |
# height=400, | |
) | |
with gr.Column(): | |
gr.Markdown("### Video 2 - Current Frame") | |
frame2_output = gr.Image( | |
label="Video 2 Frame", | |
type="numpy", | |
interactive=False, | |
# height=400, | |
) | |
# Frame navigation (initially hidden) - moved underneath frames | |
frame_controls = gr.Row(visible=True) | |
with frame_controls: | |
frame_slider = gr.Slider( | |
minimum=0, | |
maximum=0, | |
step=1, | |
value=0, | |
label="Frame Number", | |
interactive=True, | |
) | |
# Comprehensive metrics visualization (initially hidden) | |
metrics_section = gr.Row(visible=True) | |
with metrics_section: | |
with gr.Column(): | |
gr.Markdown("### 📊 Metric Analysis") | |
# Overall quality plot | |
with gr.Row(): | |
overall_plot = gr.Plot( | |
label="Overall Quality (Combined Metric [SSIM + normalized_PSNR + pHash])", | |
show_label=True, | |
) | |
# Frame info moved below overall quality plot | |
frame_info = gr.Textbox( | |
label="Frame Information & Metrics", | |
interactive=False, | |
value="", | |
lines=3, | |
) | |
# Add comprehensive usage guide underneath frame information & metrics | |
with gr.Accordion("📖 Usage Guide & Metrics Reference", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(""" | |
### 📊 Metrics Explained | |
- **SSIM**: Structural Similarity (1.0 = identical structure, 0.0 = completely different) | |
- **PSNR**: Peak Signal-to-Noise Ratio in dB (higher = better quality, less noise) | |
- **MSE**: Mean Squared Error (lower = more similar pixel values) | |
- **pHash**: Perceptual Hash similarity (1.0 = visually identical) | |
- **Color Histogram**: Color distribution correlation (1.0 = identical color patterns) | |
- **Sharpness**: Laplacian variance per video (higher = sharper/more detailed images) | |
- **Overall Quality**: Combined metric averaging SSIM, min-max normalized PSNR, and pHash | |
""") | |
with gr.Column() as info_section: | |
status_output = gr.Textbox( | |
label="Status", interactive=False, lines=16 | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(""" | |
### 🎯 Quality Assessment Scale (Research-Based Thresholds) | |
**SSIM Scale** (based on human perception studies): | |
- 🟢 **Excellent (≥0.9)**: Visually indistinguishable differences | |
- 🔵 **Good (≥0.8)**: Minor visible differences, still high quality | |
- 🟡 **Fair (≥0.6)**: Noticeable differences, acceptable quality | |
- 🔴 **Poor (<0.6)**: Significant visible artifacts and differences | |
**PSNR Scale** (standard video quality benchmarks): | |
- 🟢 **Excellent (≥40dB)**: Professional broadcast quality | |
- 🔵 **Good (≥30dB)**: High consumer video quality | |
- 🟡 **Fair (≥20dB)**: Acceptable for web streaming | |
- 🔴 **Poor (<20dB)**: Low quality with visible compression artifacts | |
**MSE Scale** (pixel difference thresholds): | |
- 🟢 **Very Similar (≤50)**: Minimal pixel-level differences | |
- 🔵 **Similar (≤100)**: Small differences, good quality preservation | |
- 🟡 **Moderately Different (≤200)**: Noticeable but acceptable differences | |
- 🔴 **Very Different (>200)**: Significant pixel-level changes | |
""") | |
with gr.Column(): | |
gr.Markdown(""" | |
### 🔍 Understanding Comparisons | |
**Comparison Analysis**: Shows how similar/different the videos are | |
- Most metrics indicate **similarity** - not which video "wins" | |
- Higher SSIM/PSNR/pHash/Color = more similar videos | |
- Lower MSE = more similar videos | |
**Individual Quality**: Shows the quality of each video separately | |
- Sharpness comparison shows which video has more detail | |
- Significance levels: 🔴 MAJOR (>20%), 🟡 MODERATE (10-20%), 🔵 MINOR (5-10%), 🟢 NEGLIGIBLE (<5%) | |
**Overall Quality**: Combines multiple metrics to provide a single similarity score | |
- **Formula**: Average of [SSIM + normalized_PSNR + pHash] | |
- **PSNR Normalization**: PSNR values are divided by 50dB and capped at 1.0 | |
- **Range**: 0.0 to 1.0 (higher = more similar videos overall) | |
- **Purpose**: Provides a single metric when you need one overall assessment | |
- **Limitation**: Different metrics may disagree; check individual metrics for details | |
""") | |
# Individual metric plots | |
with gr.Row(): | |
ssim_plot = gr.Plot(label="SSIM", show_label=True) | |
psnr_plot = gr.Plot(label="PSNR", show_label=True) | |
with gr.Row(): | |
mse_plot = gr.Plot(label="MSE", show_label=True) | |
phash_plot = gr.Plot(label="pHash", show_label=True) | |
with gr.Row(): | |
color_plot = gr.Plot(label="Color Histogram", show_label=True) | |
sharpness_plot = gr.Plot(label="Sharpness", show_label=True) | |
# Connect examples to auto-loading | |
if example_pairs: | |
# Use a manual approach to handle examples click | |
def examples_manual_handler(video1, video2): | |
print("DEBUG: Examples clicked manually!") | |
return load_videos_handler(video1, video2) | |
# Since we can't directly attach to examples, we'll use the change events | |
# Event handlers | |
def load_videos_handler(video1, video2): | |
print( | |
f"DEBUG: load_videos_handler called with video1={video1}, video2={video2}" | |
) | |
status, max_frames, frame1, frame2, info, plots = comparator.load_videos( | |
video1, video2 | |
) | |
# Update slider | |
slider_update = gr.Slider( | |
minimum=0, | |
maximum=max_frames, | |
step=1, | |
value=0, | |
interactive=True if max_frames > 0 else False, | |
) | |
# Show/hide sections based on whether videos were loaded successfully | |
videos_loaded = max_frames > 0 | |
# Extract individual plots from the plots dictionary | |
ssim_fig = plots.get("ssim") if plots else None | |
psnr_fig = plots.get("psnr") if plots else None | |
mse_fig = plots.get("mse") if plots else None | |
phash_fig = plots.get("phash") if plots else None | |
color_fig = plots.get("color_hist") if plots else None | |
sharpness_fig = plots.get("sharpness") if plots else None | |
overall_fig = plots.get("overall") if plots else None | |
return ( | |
status, # status_output | |
slider_update, # frame_slider | |
frame1, # frame1_output | |
(frame1, frame2), # image_slider | |
frame2, # frame2_output | |
info, # frame_info | |
ssim_fig, # ssim_plot | |
psnr_fig, # psnr_plot | |
mse_fig, # mse_plot | |
phash_fig, # phash_plot | |
color_fig, # color_plot | |
sharpness_fig, # sharpness_plot | |
overall_fig, # overall_plot | |
gr.Row(visible=videos_loaded), # frame_controls | |
gr.Row(visible=videos_loaded), # frame_display | |
gr.Row(visible=videos_loaded and plots is not None), # metrics_section | |
gr.Row(visible=videos_loaded), # info_section | |
) | |
def update_frames(frame_index): | |
if comparator.max_frames == 0: | |
return ( | |
None, | |
None, | |
None, | |
"No videos loaded", | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
) | |
frame1, frame2 = comparator.get_frames_at_index(frame_index) | |
info = comparator.get_current_frame_info(frame_index) | |
plots = comparator.get_updated_plot(frame_index) | |
# Extract individual plots from the plots dictionary | |
ssim_fig = plots.get("ssim") if plots else None | |
psnr_fig = plots.get("psnr") if plots else None | |
mse_fig = plots.get("mse") if plots else None | |
phash_fig = plots.get("phash") if plots else None | |
color_fig = plots.get("color_hist") if plots else None | |
sharpness_fig = plots.get("sharpness") if plots else None | |
overall_fig = plots.get("overall") if plots else None | |
return ( | |
frame1, | |
(frame1, frame2), | |
frame2, | |
info, | |
ssim_fig, | |
psnr_fig, | |
mse_fig, | |
phash_fig, | |
color_fig, | |
sharpness_fig, | |
overall_fig, | |
) | |
# Auto-load when examples populate the inputs | |
def auto_load_when_examples_change(video1, video2): | |
print( | |
f"DEBUG: auto_load_when_examples_change called with video1={video1}, video2={video2}" | |
) | |
# Only auto-load if both inputs are provided (from examples) | |
if video1 and video2: | |
print("DEBUG: Both videos present, calling load_videos_handler") | |
return load_videos_handler(video1, video2) | |
# If only one or no videos, return default empty state | |
print("DEBUG: Not both videos present, returning default state") | |
return ( | |
"Please upload videos or select an example", # status_output | |
gr.Slider( | |
minimum=0, maximum=0, step=1, value=0, interactive=False | |
), # frame_slider | |
None, # frame1_output | |
(None, None), # image_slider | |
None, # frame2_output | |
"", # frame_info | |
None, # ssim_plot | |
None, # psnr_plot | |
None, # mse_plot | |
None, # phash_plot | |
None, # color_plot | |
None, # sharpness_plot | |
None, # overall_plot | |
gr.Row(visible=True), # frame_controls | |
gr.Row(visible=True), # frame_display | |
gr.Row(visible=True), # metrics_section | |
gr.Row(visible=True), # info_section | |
) | |
# Enhanced auto-load function with debouncing to prevent multiple rapid calls | |
last_processed_pair = {"video1": None, "video2": None} | |
def enhanced_auto_load(video1, video2): | |
print(f"DEBUG: Input change detected! video1={video1}, video2={video2}") | |
# Simple debouncing: skip if same video pair was just processed | |
if ( | |
last_processed_pair["video1"] == video1 | |
and last_processed_pair["video2"] == video2 | |
): | |
print("DEBUG: Same video pair already processed, skipping...") | |
# Return current state without recomputing | |
return ( | |
gr.update(), # status_output | |
gr.update(), # frame_slider | |
gr.update(), # frame1_output | |
gr.update(), # image_slider | |
gr.update(), # frame2_output | |
gr.update(), # frame_info | |
gr.update(), # ssim_plot | |
gr.update(), # psnr_plot | |
gr.update(), # mse_plot | |
gr.update(), # phash_plot | |
gr.update(), # color_plot | |
gr.update(), # sharpness_plot | |
gr.update(), # overall_plot | |
gr.update(), # frame_controls | |
gr.update(), # frame_display | |
gr.update(), # metrics_section | |
gr.update(), # info_section | |
) | |
last_processed_pair["video1"] = video1 | |
last_processed_pair["video2"] = video2 | |
return auto_load_when_examples_change(video1, video2) | |
# Auto-load when both video inputs change (triggered by examples) | |
video1_input.change( | |
fn=enhanced_auto_load, | |
inputs=[video1_input, video2_input], | |
outputs=[ | |
status_output, | |
frame_slider, | |
frame1_output, | |
image_slider, | |
frame2_output, | |
frame_info, | |
ssim_plot, | |
psnr_plot, | |
mse_plot, | |
phash_plot, | |
color_plot, | |
sharpness_plot, | |
overall_plot, | |
frame_controls, | |
frame_display, | |
metrics_section, | |
info_section, | |
], | |
) | |
video2_input.change( | |
fn=enhanced_auto_load, | |
inputs=[video1_input, video2_input], | |
outputs=[ | |
status_output, | |
frame_slider, | |
frame1_output, | |
image_slider, | |
frame2_output, | |
frame_info, | |
ssim_plot, | |
psnr_plot, | |
mse_plot, | |
phash_plot, | |
color_plot, | |
sharpness_plot, | |
overall_plot, | |
frame_controls, | |
frame_display, | |
metrics_section, | |
info_section, | |
], | |
) | |
# Manual load button event handler with debug | |
def debug_load_videos_handler(video1, video2): | |
print(f"DEBUG: Load button clicked! video1={video1}, video2={video2}") | |
return load_videos_handler(video1, video2) | |
load_btn.click( | |
fn=debug_load_videos_handler, | |
inputs=[video1_input, video2_input], | |
outputs=[ | |
status_output, | |
frame_slider, | |
frame1_output, | |
image_slider, | |
frame2_output, | |
frame_info, | |
ssim_plot, | |
psnr_plot, | |
mse_plot, | |
phash_plot, | |
color_plot, | |
sharpness_plot, | |
overall_plot, | |
frame_controls, | |
frame_display, | |
metrics_section, | |
info_section, | |
], | |
) | |
frame_slider.change( | |
fn=update_frames, | |
inputs=[frame_slider], | |
outputs=[ | |
frame1_output, | |
image_slider, | |
frame2_output, | |
frame_info, | |
ssim_plot, | |
psnr_plot, | |
mse_plot, | |
phash_plot, | |
color_plot, | |
sharpness_plot, | |
overall_plot, | |
], | |
) | |
return app | |
def main(): | |
app = create_app() | |
app.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
) | |
if __name__ == "__main__": | |
main() | |