Detection / app.py
SamiKhokhar's picture
Create app.py
94886ef verified
raw
history blame
2.89 kB
import torch
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from PIL import Image
import cv2
import numpy as np
import time
from flask import Flask, jsonify, request
# Initialize Flask app
app = Flask(__name__)
# Device setup (GPU or CPU)
device = 'cpu'
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
# Load pre-trained model and image processor from Hugging Face
ckpt = 'yainage90/fashion-object-detection'
image_processor = AutoImageProcessor.from_pretrained(ckpt)
model = AutoModelForObjectDetection.from_pretrained(ckpt).to(device)
def detect_objects(frame):
"""Detect objects in the video frame."""
# Convert the frame to PIL image
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# Prepare inputs for the model
with torch.no_grad():
inputs = image_processor(images=[image], return_tensors="pt")
outputs = model(**inputs.to(device))
target_sizes = torch.tensor([[image.size[1], image.size[0]]])
results = image_processor.post_process_object_detection(outputs, threshold=0.4, target_sizes=target_sizes)[0]
# Extract the detected items
items = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
score = score.item()
label = label.item()
box = [i.item() for i in box]
print(f"{model.config.id2label[label]}: {round(score, 3)} at {box}")
items.append((score, label, box))
return items
def save_data(frame, items):
"""Save image and extract plate number."""
filename = f"helmet_violation_{int(time.time())}.jpg"
cv2.imwrite(filename, frame)
# Here, you'd extract plate numbers or process further
plate_number = extract_plate_number(frame)
save_to_database(filename, plate_number, items)
def extract_plate_number(frame):
"""Extract license plate number (simplified)."""
plate_number = "XYZ 1234" # Replace with an actual license plate recognition method
return plate_number
def save_to_database(image_filename, plate_number, items):
"""Save the data (for simplicity, we just print it here)."""
print(f"Plate Number: {plate_number}, Image saved as {image_filename}")
print("Detected items:", items)
@app.route("/process_frame", methods=["POST"])
def process_frame():
"""Process incoming video frame via API."""
frame = request.files["frame"].read()
np_array = np.frombuffer(frame, np.uint8)
img = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
# Detect objects (e.g., helmets) in the frame
items = detect_objects(img)
if items: # If objects are detected, save the data
save_data(img, items)
return jsonify({"status": "processed"})
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=5000)