File size: 3,446 Bytes
c592663 e006d42 93e8a7e 99f44fe 93e8a7e 16756b8 99f44fe e7b04ff 93e8a7e e7b04ff e006d42 16756b8 e006d42 7fcc53a e006d42 7fcc53a e006d42 93e8a7e e006d42 93e8a7e e7b04ff 93e8a7e e7b04ff 93e8a7e e7b04ff e006d42 e7b04ff 93e8a7e e006d42 93e8a7e e006d42 93e8a7e 99f44fe 93e8a7e 99f44fe 93e8a7e d2d1207 99f44fe d2d1207 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import gradio as gr
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
import torch
import cv2 # OpenCV for video processing
# Model ID for video classification (UCF101 subset)
model_id = "MCG-NJU/videomae-base"
# Parameters for frame extraction
TARGET_FRAME_COUNT = 16
FRAME_SIZE = (224, 224) # Expected frame size for the model
def analyze_video(video):
# Extract key frames from the video using OpenCV
frames = extract_key_frames(video, TARGET_FRAME_COUNT)
# Resize frames to the expected size
frames = [cv2.resize(frame, FRAME_SIZE) for frame in frames]
# Load model and feature extractor manually
model = VideoMAEForVideoClassification.from_pretrained(model_id)
processor = VideoMAEImageProcessor.from_pretrained(model_id)
# Prepare frames for the model
inputs = processor(images=frames, return_tensors="pt")
# Make predictions
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
# Analyze predictions for insights related to the play
results = []
for prediction in predictions:
result = analyze_predictions_ucf101(prediction.item())
results.append(result)
# Aggregate results across frames and provide a final analysis
final_result = aggregate_results(results)
return final_result
def extract_key_frames(video, target_frame_count):
cap = cv2.VideoCapture(video)
frames = []
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Calculate interval for frame extraction
interval = max(1, frame_count // target_frame_count)
for i in range(0, frame_count, interval):
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if ret:
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Convert to RGB
if len(frames) >= target_frame_count:
break
cap.release()
return frames
def analyze_predictions_ucf101(prediction):
# Map prediction to action labels (this mapping is hypothetical)
action_labels = {
0: "running",
1: "sliding",
2: "jumping",
# Add more labels as necessary
}
action = action_labels.get(prediction, "unknown")
relevant_actions = ["running", "sliding", "jumping"]
if action in relevant_actions:
if action == "sliding":
return "potentially safe"
elif action == "running":
return "potentially out"
else:
return "inconclusive"
else:
return "inconclusive"
def aggregate_results(results):
# Combine insights from analyzing each frame (e.g., dominant action classes, confidence scores)
safe_count = results.count("potentially safe")
out_count = results.count("potentially out")
if safe_count > out_count:
return "Safe"
elif out_count > safe_count:
return "Out"
else:
return "Inconclusive"
# Gradio interface
interface = gr.Interface(
fn=analyze_video,
inputs="video",
outputs="text",
title="Baseball Play Analysis (UCF101 Subset Exploration)",
description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays."
)
interface.launch(share=True)
|