|
import gradio as gr |
|
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor |
|
import torch |
|
import cv2 |
|
|
|
|
|
model_id = "MCG-NJU/videomae-base" |
|
|
|
|
|
TARGET_FRAME_COUNT = 16 |
|
FRAME_SIZE = (224, 224) |
|
|
|
def analyze_video(video): |
|
|
|
frames = extract_key_frames(video, TARGET_FRAME_COUNT) |
|
|
|
|
|
frames = [cv2.resize(frame, FRAME_SIZE) for frame in frames] |
|
|
|
|
|
model = VideoMAEForVideoClassification.from_pretrained(model_id) |
|
processor = VideoMAEImageProcessor.from_pretrained(model_id) |
|
|
|
|
|
inputs = processor(images=frames, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
logits = outputs.logits |
|
predictions = torch.argmax(logits, dim=-1) |
|
|
|
|
|
results = [] |
|
for prediction in predictions: |
|
result = analyze_predictions_ucf101(prediction.item()) |
|
results.append(result) |
|
|
|
|
|
final_result = aggregate_results(results) |
|
|
|
return final_result |
|
|
|
def extract_key_frames(video, target_frame_count): |
|
cap = cv2.VideoCapture(video) |
|
frames = [] |
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
interval = max(1, frame_count // target_frame_count) |
|
|
|
for i in range(0, frame_count, interval): |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, i) |
|
ret, frame = cap.read() |
|
if ret: |
|
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
if len(frames) >= target_frame_count: |
|
break |
|
|
|
cap.release() |
|
return frames |
|
|
|
def analyze_predictions_ucf101(prediction): |
|
|
|
action_labels = { |
|
0: "running", |
|
1: "sliding", |
|
2: "jumping", |
|
|
|
} |
|
action = action_labels.get(prediction, "unknown") |
|
|
|
relevant_actions = ["running", "sliding", "jumping"] |
|
if action in relevant_actions: |
|
if action == "sliding": |
|
return "potentially safe" |
|
elif action == "running": |
|
return "potentially out" |
|
else: |
|
return "inconclusive" |
|
else: |
|
return "inconclusive" |
|
|
|
def aggregate_results(results): |
|
|
|
safe_count = results.count("potentially safe") |
|
out_count = results.count("potentially out") |
|
|
|
if safe_count > out_count: |
|
return "Safe" |
|
elif out_count > safe_count: |
|
return "Out" |
|
else: |
|
return "Inconclusive" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=analyze_video, |
|
inputs="video", |
|
outputs="text", |
|
title="Baseball Play Analysis (UCF101 Subset Exploration)", |
|
description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays." |
|
) |
|
|
|
interface.launch(share=True) |
|
|