File size: 3,446 Bytes
c592663
e006d42
 
93e8a7e
99f44fe
93e8a7e
16756b8
99f44fe
e7b04ff
 
 
 
93e8a7e
 
e7b04ff
 
 
 
e006d42
16756b8
 
e006d42
 
 
 
7fcc53a
e006d42
 
 
7fcc53a
e006d42
 
 
 
93e8a7e
e006d42
 
93e8a7e
 
 
 
 
 
 
e7b04ff
93e8a7e
 
 
 
e7b04ff
 
 
 
 
93e8a7e
e7b04ff
e006d42
e7b04ff
 
93e8a7e
 
 
 
e006d42
 
 
 
 
 
 
 
 
93e8a7e
 
e006d42
 
 
 
 
 
 
93e8a7e
 
 
 
 
 
 
 
 
 
 
 
 
 
99f44fe
 
 
93e8a7e
 
99f44fe
93e8a7e
d2d1207
99f44fe
 
d2d1207
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
import torch
import cv2  # OpenCV for video processing

# Model ID for video classification (UCF101 subset)
model_id = "MCG-NJU/videomae-base"

# Parameters for frame extraction
TARGET_FRAME_COUNT = 16
FRAME_SIZE = (224, 224)  # Expected frame size for the model

def analyze_video(video):
    # Extract key frames from the video using OpenCV
    frames = extract_key_frames(video, TARGET_FRAME_COUNT)
    
    # Resize frames to the expected size
    frames = [cv2.resize(frame, FRAME_SIZE) for frame in frames]

    # Load model and feature extractor manually
    model = VideoMAEForVideoClassification.from_pretrained(model_id)
    processor = VideoMAEImageProcessor.from_pretrained(model_id)

    # Prepare frames for the model
    inputs = processor(images=frames, return_tensors="pt")

    # Make predictions
    with torch.no_grad():
        outputs = model(**inputs)
    
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    
    # Analyze predictions for insights related to the play
    results = []
    for prediction in predictions:
        result = analyze_predictions_ucf101(prediction.item())
        results.append(result)

    # Aggregate results across frames and provide a final analysis
    final_result = aggregate_results(results)

    return final_result

def extract_key_frames(video, target_frame_count):
    cap = cv2.VideoCapture(video)
    frames = []
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Calculate interval for frame extraction
    interval = max(1, frame_count // target_frame_count)
    
    for i in range(0, frame_count, interval):
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame = cap.read()
        if ret:
            frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))  # Convert to RGB
        if len(frames) >= target_frame_count:
            break
    
    cap.release()
    return frames

def analyze_predictions_ucf101(prediction):
    # Map prediction to action labels (this mapping is hypothetical)
    action_labels = {
        0: "running",
        1: "sliding",
        2: "jumping",
        # Add more labels as necessary
    }
    action = action_labels.get(prediction, "unknown")
    
    relevant_actions = ["running", "sliding", "jumping"]
    if action in relevant_actions:
        if action == "sliding":
            return "potentially safe"
        elif action == "running":
            return "potentially out"
        else:
            return "inconclusive"
    else:
        return "inconclusive"

def aggregate_results(results):
    # Combine insights from analyzing each frame (e.g., dominant action classes, confidence scores)
    safe_count = results.count("potentially safe")
    out_count = results.count("potentially out")
    
    if safe_count > out_count:
        return "Safe"
    elif out_count > safe_count:
        return "Out"
    else:
        return "Inconclusive"

# Gradio interface
interface = gr.Interface(
    fn=analyze_video,
    inputs="video",
    outputs="text",
    title="Baseball Play Analysis (UCF101 Subset Exploration)",
    description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays."
)

interface.launch(share=True)