IbrahimHasani's picture
Update app.py
ba82304
raw
history blame
4.64 kB
import gradio as gr
import torch
import numpy as np
from transformers import AutoProcessor, AutoModel
from PIL import Image
import cv2
from concurrent.futures import ThreadPoolExecutor
import PyNvCodec as nvc
MODEL_NAME = "microsoft/xclip-base-patch16-zero-shot"
CLIP_LEN = 32
# Check if GPU is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)
# Load model and processor once and move them to the device
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(device)
def get_video_length(file_path):
decoder = nvc.PyNvDecoder(file_path, 0) # 0 indicates GPU ID
return decoder.FramesCount()
def read_video_nvcodec(file_path, indices):
frames = []
decoder = nvc.PyNvDecoder(file_path, 0) # 0 indicates GPU ID
nv12_surf_plane = nvc.PySurface()
for i in range(max(indices) + 1):
success = decoder.DecodeSingleSurface(nv12_surf_plane)
if not success:
break
if i in indices:
rgb_surf = nv12_surf_plane.ToColor(nvc.PixelFormat.RGB)
h, w, c = rgb_surf.HostShape()
frame = np.ndarray(shape=(h, w, c), dtype=np.uint8, order='C')
rgb_surf.Download(frame)
frames.append(frame)
return frames
def get_frame(file_path, index):
cap = cv2.VideoCapture(file_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, index)
ret, frame = cap.read()
cap.release()
if ret:
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return None
def sample_uniform_frame_indices(clip_len, seg_len):
if seg_len < clip_len:
repeat_factor = np.ceil(clip_len / seg_len).astype(int)
indices = np.arange(seg_len).tolist() * repeat_factor
indices = indices[:clip_len]
else:
spacing = seg_len // clip_len
indices = [i * spacing for i in range(clip_len)]
return np.array(indices).astype(np.int64)
def concatenate_frames(frames, clip_len):
layout = { 32: (4, 8) }
rows, cols = layout[clip_len]
combined_image = Image.new('RGB', (frames[0].shape[1]*cols, frames[0].shape[0]*rows))
frame_iter = iter(frames)
y_offset = 0
for i in range(rows):
x_offset = 0
for j in range(cols):
img = Image.fromarray(next(frame_iter))
combined_image.paste(img, (x_offset, y_offset))
x_offset += frames[0].shape[1]
y_offset += frames[0].shape[0]
return combined_image
def model_interface(uploaded_video, activity):
video_length = get_video_length(uploaded_video)
indices = sample_uniform_frame_indices(CLIP_LEN, seg_len=video_length)
video = read_video_nvcodec(uploaded_video, indices)
concatenated_image = concatenate_frames(video, CLIP_LEN)
activities_list = [activity, "other"]
inputs = processor(
text=activities_list,
videos=list(video),
return_tensors="pt",
padding=True,
)
# Move the tensors to the same device as the model
for key, value in inputs.items():
if isinstance(value, torch.Tensor):
inputs[key] = value.to(device)
with torch.no_grad():
outputs = model(**inputs)
logits_per_video = outputs.logits_per_video
probs = logits_per_video.softmax(dim=1)
results_probs = []
results_logits = []
max_prob_index = torch.argmax(probs[0]).item()
for i in range(len(activities_list)):
current_activity = activities_list[i]
prob = float(probs[0][i].cpu()) # Move tensor data to CPU for further processing
logit = float(logits_per_video[0][i].cpu()) # Move tensor data to CPU for further processing
results_probs.append((current_activity, f"Probability: {prob * 100:.2f}%"))
results_logits.append((current_activity, f"Raw Score: {logit:.2f}"))
likely_label = activities_list[max_prob_index]
likely_probability = float(probs[0][max_prob_index].cpu()) * 100 # Move tensor data to CPU
return concatenated_image, results_probs, results_logits, [likely_label, likely_probability]
iface = gr.Interface(
fn=model_interface,
inputs=[
gr.components.Video(label="Upload a video file"),
gr.components.Textbox(default="dancing", label="Desired Activity to Recognize"),
],
outputs=[
gr.components.Image(type="pil", label="Sampled Frames"),
gr.components.Textbox(type="text", label="Probabilities"),
gr.components.Textbox(type="text", label="Raw Scores"),
gr.components.Textbox(type="text", label="Top Prediction")
],
live=False
)
iface.launch()