import gradio as gr from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor import torch import cv2 # OpenCV for video processing # Model ID for video classification (UCF101 subset) model_id = "MCG-NJU/videomae-base" # Parameters for frame extraction TARGET_FRAME_COUNT = 16 FRAME_SIZE = (224, 224) # Expected frame size for the model def analyze_video(video): # Extract key frames from the video using OpenCV frames = extract_key_frames(video, TARGET_FRAME_COUNT) # Resize frames to the expected size frames = [cv2.resize(frame, FRAME_SIZE) for frame in frames] # Load model and feature extractor manually model = VideoMAEForVideoClassification.from_pretrained(model_id) processor = VideoMAEImageProcessor.from_pretrained(model_id) # Prepare frames for the model inputs = processor(images=frames, return_tensors="pt") # Make predictions with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predictions = torch.argmax(logits, dim=-1) # Analyze predictions for insights related to the play results = [] for prediction in predictions: result = analyze_predictions_ucf101(prediction.item()) results.append(result) # Aggregate results across frames and provide a final analysis final_result = aggregate_results(results) return final_result def extract_key_frames(video, target_frame_count): cap = cv2.VideoCapture(video) frames = [] frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Calculate interval for frame extraction interval = max(1, frame_count // target_frame_count) for i in range(0, frame_count, interval): cap.set(cv2.CAP_PROP_POS_FRAMES, i) ret, frame = cap.read() if ret: frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Convert to RGB if len(frames) >= target_frame_count: break cap.release() return frames def analyze_predictions_ucf101(prediction): # Map prediction to action labels (this mapping is hypothetical) action_labels = { 0: "running", 1: "sliding", 2: "jumping", # Add more labels as necessary } action = action_labels.get(prediction, "unknown") relevant_actions = ["running", "sliding", "jumping"] if action in relevant_actions: if action == "sliding": return "potentially safe" elif action == "running": return "potentially out" else: return "inconclusive" else: return "inconclusive" def aggregate_results(results): # Combine insights from analyzing each frame (e.g., dominant action classes, confidence scores) safe_count = results.count("potentially safe") out_count = results.count("potentially out") if safe_count > out_count: return "Safe" elif out_count > safe_count: return "Out" else: return "Inconclusive" # Gradio interface interface = gr.Interface( fn=analyze_video, inputs="video", outputs="text", title="Baseball Play Analysis (UCF101 Subset Exploration)", description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays." ) interface.launch(share=True)