Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import argparse | |
import pickle as pkl | |
import decord | |
from decord import VideoReader | |
import numpy as np | |
import yaml | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
from cover.datasets import UnifiedFrameSampler, spatial_temporal_view_decomposition | |
from cover.models import COVER | |
import pandas as pd | |
mean, std = ( | |
torch.FloatTensor([123.675, 116.28, 103.53]), | |
torch.FloatTensor([58.395, 57.12, 57.375]), | |
) | |
mean_clip, std_clip = ( | |
torch.FloatTensor([122.77, 116.75, 104.09]), | |
torch.FloatTensor([68.50, 66.63, 70.32]) | |
) | |
sample_interval = 30 | |
normalization_array = { | |
"semantic" : [-0.1477,-0.0181], | |
"technical": [-1.8762, 1.2428], | |
"aesthetic": [-1.2899, 0.5290], | |
"overall" : [-3.2538, 1.6728] | |
} | |
comparison_array = { | |
"semantic" : [], # 示例数组 | |
"technical": [], | |
"aesthetic": [], | |
"overall" : [] | |
} | |
def get_sampler_params(video_path): | |
vr = VideoReader(video_path) | |
total_frames = len(vr) | |
clip_len = (total_frames + sample_interval // 2) // sample_interval | |
if clip_len == 0: | |
clip_len = 1 | |
t_frag = clip_len | |
return total_frames, clip_len, t_frag | |
def fuse_results(results: list): | |
x = (results[0] + results[1] + results[2]) | |
return { | |
"semantic" : results[0], | |
"technical": results[1], | |
"aesthetic": results[2], | |
"overall" : x, | |
} | |
def normalize_score(score, min_score, max_score): | |
return (score - min_score) / (max_score - min_score) * 5 | |
def compare_score(score, score_list): | |
better_than = sum(1 for s in score_list if score > s) | |
percentage = better_than / len(score_list) * 100 | |
return f"Better than {percentage:.0f}% videos in YT-UGC" if percentage > 50 else f"Worse than {100-percentage:.0f}% videos in YT-UGC" | |
def create_bar_chart(scores, comparisons): | |
labels = ['Semantic', 'Technical', 'Aesthetic', 'Overall'] | |
base_colors = ['#d62728', '#1f77b4', '#ff7f0e', '#bcbd22'] | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
# Create vertical bars | |
bars = ax.bar(labels, scores, color=base_colors, edgecolor='black', width=0.6) | |
# Adding the text labels for scores | |
for bar, score in zip(bars, scores): | |
height = bar.get_height() | |
ax.annotate(f'{score:.1f}', | |
xy=(bar.get_x() + bar.get_width() / 2, height), | |
xytext=(0, 3), # 3 points vertical offset | |
textcoords="offset points", | |
ha='center', va='bottom', | |
color='black') | |
# Add comparison text | |
# for i, (bar, score) in enumerate(zip(bars, scores)): | |
# ax.annotate(comparisons[i], | |
# xy=(bar.get_x() + bar.get_width(), bar.get_height() / 2), | |
# xytext=(5, 0), # 5 points horizontal offset | |
# textcoords="offset points", | |
# ha='left', va='center', | |
# color=base_colors[i]) | |
ax.set_xlabel('Categories') | |
ax.set_ylabel('Scores') | |
ax.set_ylim(0, 5) | |
ax.set_title('Video Quality Scores') | |
plt.tight_layout() | |
image_path = "./scores_bar_chart.png" | |
plt.savefig(image_path) | |
plt.close(fig) | |
return image_path | |
def inference_one_video(input_video): | |
""" | |
BASIC SETTINGS | |
""" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
with open("./cover.yml", "r") as f: | |
opt = yaml.safe_load(f) | |
dopt = opt["data"]["val-ytugc"]["args"] | |
temporal_samplers = {} | |
# auto decision of parameters of sampler | |
total_frames, clip_len, t_frag = get_sampler_params(input_video) | |
for stype, sopt in dopt["sample_types"].items(): | |
sopt["clip_len"] = clip_len | |
sopt["t_frag"] = t_frag | |
if stype == 'technical' or stype == 'aesthetic': | |
if total_frames > 1: | |
sopt["clip_len"] = clip_len * 2 | |
if stype == 'technical': | |
sopt["aligned"] = sopt["clip_len"] | |
temporal_samplers[stype] = UnifiedFrameSampler( | |
sopt["clip_len"] // sopt["t_frag"], | |
sopt["t_frag"], | |
sopt["frame_interval"], | |
sopt["num_clips"], | |
) | |
""" | |
LOAD MODEL | |
""" | |
evaluator = COVER(**opt["model"]["args"]).to(device) | |
state_dict = torch.load(opt["test_load_path"], map_location=device) | |
# set strict=False here to avoid error of missing | |
# weight of prompt_learner in clip-iqa+, cross-gate | |
evaluator.load_state_dict(state_dict['state_dict'], strict=False) | |
""" | |
TESTING | |
""" | |
views, _ = spatial_temporal_view_decomposition( | |
input_video, dopt["sample_types"], temporal_samplers | |
) | |
for k, v in views.items(): | |
num_clips = dopt["sample_types"][k].get("num_clips", 1) | |
if k == 'technical' or k == 'aesthetic': | |
views[k] = ( | |
((v.permute(1, 2, 3, 0) - mean) / std) | |
.permute(3, 0, 1, 2) | |
.reshape(v.shape[0], num_clips, -1, *v.shape[2:]) | |
.transpose(0, 1) | |
.to(device) | |
) | |
elif k == 'semantic': | |
views[k] = ( | |
((v.permute(1, 2, 3, 0) - mean_clip) / std_clip) | |
.permute(3, 0, 1, 2) | |
.reshape(v.shape[0], num_clips, -1, *v.shape[2:]) | |
.transpose(0, 1) | |
.to(device) | |
) | |
results = [r.mean().item() for r in evaluator(views)] | |
pred_score = fuse_results(results) | |
comparison_array["semantic"] = pd.read_csv('./prediction_results/youtube_ugc/smos.csv')['Mos'] | |
comparison_array["technical"] = pd.read_csv('./prediction_results/youtube_ugc/tmos.csv')['Mos'] | |
comparison_array["aesthetic"] = pd.read_csv('./prediction_results/youtube_ugc/amos.csv')['Mos'] | |
comparison_array["overall"] = pd.read_csv('./prediction_results/youtube_ugc/overall.csv')['Mos'] | |
normalized_scores = [ | |
normalize_score(pred_score["semantic"] , comparison_array["semantic"].min() , comparison_array["semantic"].max() ), | |
normalize_score(pred_score["technical"], comparison_array["technical"].min(), comparison_array["technical"].max()), | |
normalize_score(pred_score["aesthetic"], comparison_array["aesthetic"].min(), comparison_array["aesthetic"].max()), | |
normalize_score(pred_score["overall"] , comparison_array["overall"].min() , comparison_array["overall"].max() ) | |
] | |
comparisons = [ | |
compare_score(pred_score["semantic"], comparison_array["semantic"]), | |
compare_score(pred_score["technical"], comparison_array["technical"]), | |
compare_score(pred_score["aesthetic"], comparison_array["aesthetic"]), | |
compare_score(pred_score["overall"], comparison_array["overall"]) | |
] | |
image_path = create_bar_chart(normalized_scores, comparisons) | |
return image_path | |
# Define the input and output types for Gradio using the new API | |
video_input = gr.Video(label="Input Video") | |
output_image = gr.Image(label="Scores") | |
# Create the Gradio interface | |
gradio_app = gr.Interface(fn=inference_one_video, inputs=video_input, outputs=output_image) | |
if __name__ == "__main__": | |
gradio_app.launch() |