MBase / app.py
MNGames's picture
Update app.py
e7b04ff verified
raw
history blame
3.45 kB
import gradio as gr
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
import torch
import cv2 # OpenCV for video processing
# Model ID for video classification (UCF101 subset)
model_id = "MCG-NJU/videomae-base"
# Parameters for frame extraction
TARGET_FRAME_COUNT = 16
FRAME_SIZE = (224, 224) # Expected frame size for the model
def analyze_video(video):
# Extract key frames from the video using OpenCV
frames = extract_key_frames(video, TARGET_FRAME_COUNT)
# Resize frames to the expected size
frames = [cv2.resize(frame, FRAME_SIZE) for frame in frames]
# Load model and feature extractor manually
model = VideoMAEForVideoClassification.from_pretrained(model_id)
processor = VideoMAEImageProcessor.from_pretrained(model_id)
# Prepare frames for the model
inputs = processor(images=frames, return_tensors="pt")
# Make predictions
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
# Analyze predictions for insights related to the play
results = []
for prediction in predictions:
result = analyze_predictions_ucf101(prediction.item())
results.append(result)
# Aggregate results across frames and provide a final analysis
final_result = aggregate_results(results)
return final_result
def extract_key_frames(video, target_frame_count):
cap = cv2.VideoCapture(video)
frames = []
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Calculate interval for frame extraction
interval = max(1, frame_count // target_frame_count)
for i in range(0, frame_count, interval):
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if ret:
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Convert to RGB
if len(frames) >= target_frame_count:
break
cap.release()
return frames
def analyze_predictions_ucf101(prediction):
# Map prediction to action labels (this mapping is hypothetical)
action_labels = {
0: "running",
1: "sliding",
2: "jumping",
# Add more labels as necessary
}
action = action_labels.get(prediction, "unknown")
relevant_actions = ["running", "sliding", "jumping"]
if action in relevant_actions:
if action == "sliding":
return "potentially safe"
elif action == "running":
return "potentially out"
else:
return "inconclusive"
else:
return "inconclusive"
def aggregate_results(results):
# Combine insights from analyzing each frame (e.g., dominant action classes, confidence scores)
safe_count = results.count("potentially safe")
out_count = results.count("potentially out")
if safe_count > out_count:
return "Safe"
elif out_count > safe_count:
return "Out"
else:
return "Inconclusive"
# Gradio interface
interface = gr.Interface(
fn=analyze_video,
inputs="video",
outputs="text",
title="Baseball Play Analysis (UCF101 Subset Exploration)",
description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays."
)
interface.launch(share=True)