Abdo-Alshoki's picture
creating the app.py with the initial setup and loading
d620482 verified
history blame
3.11 kB
import torch
import gradio as gr
import torch.nn as nn
import torchvision
import cv2
import numpy as np
import tempfile
class MyModel(nn.Module):
def __init__(self, num_classes=1):
super(MyModel, self).__init__() # Initialize nn.Module
self.model = torchvision.models.video.r3d_18(pretrained=True)
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
def preprocess_video(self, video_path, num_frames=40):
"""Preprocess video: sample frames, resize, normalize, and return tensor."""
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int)
sampled_frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if not ret:
frame = cv2.resize(frame, (112, 112)) # Resize to 112x112 for R3D-18
frame = np.transpose(frame, (2, 0, 1)) # Channels-first
if len(sampled_frames) < num_frames:
padding = np.zeros((num_frames - len(sampled_frames), 3, 112, 112))
sampled_frames = np.concatenate([sampled_frames, padding], axis=0)
# Convert to tensor and rearrange dimensions to (3, num_frames, 112, 112)
return torch.tensor(sampled_frames).float().permute(1, 0, 2, 3).unsqueeze(0)
def forward(self, x):
return self.model(x)
def test(self, video_paths):
"""Test the model on the given videos and compute accuracy."""
predictions = []
with torch.no_grad():
for i, video_path in enumerate(video_paths):
X = self.preprocess_video(video_path)
output = self.model(X)
pred = torch.sigmoid(output) # Apply sigmoid for binary classification
# Track predictions
return predictions
def save_model(self, filepath):
'model_state_dict': self.state_dict(),
}, filepath)
def load_model(filepath, num_classes=1):
model = MyModel(num_classes)
checkpoint = torch.load(filepath, weights_only=True)
return model
model = MyModel().load_model('pre_3D_model.h5')
def classify_video(video):
prob = model.test(video)
label = "Non-violent" if prob >= 0.5 else "Violent"
return label, prob
# Set up the Gradio interface
interface = gr.Interface(
inputs=gr.Video(), # Allows video upload
outputs=[gr.Text(), gr.Number()], # Outputs classification and probability
title="Violence Detection in Videos",
description="Upload a video to classify it as violent or non-violent with a probability score."
interface.launch(share=True, debug=True)