File size: 3,729 Bytes
56de2d4
 
b8466ce
 
 
2dc6183
10696ac
53189f9
 
 
2dc6183
8d1f721
2dc6183
 
a29b529
 
 
 
 
 
 
 
 
 
2dc6183
 
a29b529
 
 
 
 
a6c8793
56de2d4
 
 
 
 
 
 
 
 
 
 
2dc6183
 
56de2d4
 
 
 
 
 
2dc6183
 
 
 
 
 
 
56de2d4
 
 
 
 
53189f9
a29b529
 
 
53189f9
56de2d4
b8466ce
56de2d4
 
a29b529
56de2d4
 
2dc6183
2c5687c
56de2d4
 
 
a23243f
56de2d4
 
 
 
b8466ce
56de2d4
b8466ce
a29b529
 
b8466ce
 
56de2d4
b8466ce
a29b529
56de2d4
2dc6183
1bc2256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a23243f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import torch
import numpy as np
from transformers import AutoProcessor, AutoModel
from PIL import Image
import cv2

MODEL_NAME = "microsoft/xclip-base-patch16-zero-shot"
CLIP_LEN = 32

# Load model and processor once
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)


def get_video_length(file_path):
    cap = cv2.VideoCapture(file_path)
    length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    return length

def read_video_opencv(file_path, indices):
    cap = cv2.VideoCapture(file_path)
    frames = []
    for i in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame = cap.read()
        if ret:
            frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    cap.release()
    return frames

def sample_uniform_frame_indices(clip_len, seg_len):
    if seg_len < clip_len:
        repeat_factor = np.ceil(clip_len / seg_len).astype(int)
        indices = np.arange(seg_len).tolist() * repeat_factor
        indices = indices[:clip_len]
    else:
        spacing = seg_len // clip_len
        indices = [i * spacing for i in range(clip_len)]
    return np.array(indices).astype(np.int64)

def concatenate_frames(frames, clip_len):
    layout = { 32: (4, 8) }
    rows, cols = layout[clip_len]
    combined_image = Image.new('RGB', (frames[0].shape[1]*cols, frames[0].shape[0]*rows))
    frame_iter = iter(frames)
    y_offset = 0
    for i in range(rows):
        x_offset = 0
        for j in range(cols):
            img_array = next(frame_iter)
            
            # Handling rank-4 tensor
            if len(img_array.shape) == 4:
                img_array = img_array[0]
                
            img = Image.fromarray(img_array)
            combined_image.paste(img, (x_offset, y_offset))
            x_offset += frames[0].shape[1]
        y_offset += frames[0].shape[0]
    return combined_image

def model_interface(uploaded_video, activity):
    video_length = get_video_length(uploaded_video)
    indices = sample_uniform_frame_indices(CLIP_LEN, seg_len=video_length)
    video = read_video_opencv(uploaded_video, indices)
    concatenated_image = concatenate_frames(video, CLIP_LEN)

    activities_list = [activity, "other"]
    inputs = processor(
        text=activities_list,
        videos=list(video),
        return_tensors="pt",
        padding=True,
    )

    with torch.no_grad():
        outputs = model(**inputs)

    logits_per_video = outputs.logits_per_video
    probs = logits_per_video.softmax(dim=1)

    results_probs = []
    results_logits = []
    max_prob_index = torch.argmax(probs[0]).item()
    for i in range(len(activities_list)):
        current_activity = activities_list[i]
        prob = float(probs[0][i])
        logit = float(logits_per_video[0][i])
        results_probs.append((current_activity, f"Probability: {prob * 100:.2f}%"))
        results_logits.append((current_activity, f"Raw Score: {logit:.2f}"))

    likely_label = activities_list[max_prob_index]
    likely_probability = float(probs[0][max_prob_index]) * 100

    return concatenated_image, results_probs, results_logits, [likely_label, likely_probability]

iface = gr.Interface(
    fn=model_interface,
    inputs=[
        gr.components.Video(label="Upload a video file"),
        gr.components.Textbox(default="dancing", label="Desired Activity to Recognize"),
    ],
    outputs=[
        gr.components.Image(type="pil", label="Sampled Frames"),
        gr.components.Textbox(type="text", label="Probabilities"),
        gr.components.Textbox(type="text", label="Raw Scores"),
        gr.components.Textbox(type="text", label="Top Prediction")
    ],
    live=False
)

iface.launch()