Detector / app.py
SamiKhokhar's picture
Update app.py
e2fbe95 verified
raw
history blame
2.93 kB
import torch
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from PIL import Image
import cv2
import numpy as np
import time
import gradio as gr
# Device setup (GPU or CPU)
device = 'cpu'
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
# Load pre-trained model and image processor from Hugging Face
ckpt = 'yainage90/fashion-object-detection'
image_processor = AutoImageProcessor.from_pretrained(ckpt)
model = AutoModelForObjectDetection.from_pretrained(ckpt).to(device)
def detect_objects(frame):
"""Detect objects in the video frame."""
# Convert the frame to PIL image
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# Prepare inputs for the model
with torch.no_grad():
inputs = image_processor(images=[image], return_tensors="pt")
outputs = model(**inputs.to(device))
target_sizes = torch.tensor([[image.size[1], image.size[0]]])
results = image_processor.post_process_object_detection(outputs, threshold=0.4, target_sizes=target_sizes)[0]
# Extract the detected items
items = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
score = score.item()
label = label.item()
box = [i.item() for i in box]
print(f"{model.config.id2label[label]}: {round(score, 3)} at {box}")
items.append((score, label, box))
return items
def process_image(image):
"""Process the image uploaded via Gradio and return the result."""
# Convert the image to numpy array
frame = np.array(image)
# Detect objects (e.g., helmets) in the frame
items = detect_objects(frame)
if items: # If objects are detected, save the data
save_data(frame, items)
return {"items_detected": items}
def save_data(frame, items):
"""Save image and extract plate number."""
filename = f"helmet_violation_{int(time.time())}.jpg"
cv2.imwrite(filename, frame)
# Here, you'd extract plate numbers or process further
plate_number = extract_plate_number(frame)
save_to_database(filename, plate_number, items)
def extract_plate_number(frame):
"""Extract license plate number (simplified)."""
plate_number = "XYZ 1234" # Replace with an actual license plate recognition method
return plate_number
def save_to_database(image_filename, plate_number, items):
"""Save the data (for simplicity, we just print it here)."""
print(f"Plate Number: {plate_number}, Image saved as {image_filename}")
print("Detected items:", items)
# Define the Gradio interface using updated syntax
interface = gr.Interface(fn=process_image,
inputs=gr.Image(type="pil"),
outputs=gr.JSON(),
live=True)
# Launch the Gradio app
interface.launch(debug=True)