Marcus Vinicius Zerbini Canhaço
feat: atualização do detector com otimizações para GPU T4
577120c
raw
history blame
5.46 kB
import torch
import torch.nn.functional as F
import logging
import os
import gc
import numpy as np
import cv2
from PIL import Image
from transformers import Owlv2Processor, Owlv2ForObjectDetection
from .base import BaseDetector
logger = logging.getLogger(__name__)
class WeaponDetectorGPU(BaseDetector):
"""Detector de armas otimizado para GPU."""
def __init__(self):
"""Inicializa o detector."""
super().__init__()
self.default_resolution = 640
self.device = None # Será configurado em _initialize
self._initialize()
def _initialize(self):
"""Inicializa o modelo."""
try:
# Configurar device
if not torch.cuda.is_available():
raise RuntimeError("CUDA não está disponível!")
# Configurar device corretamente
self.device = 0 # Usar índice inteiro para GPU
# Carregar modelo e processador
logger.info("Carregando modelo e processador...")
model_name = "google/owlv2-base-patch16"
self.owlv2_processor = Owlv2Processor.from_pretrained(model_name)
self.owlv2_model = Owlv2ForObjectDetection.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map={"": self.device} # Mapear todo o modelo para GPU 0
)
# Otimizar modelo
self.owlv2_model.eval()
# Processar queries
self.text_queries = self._get_detection_queries()
self.processed_text = self.owlv2_processor(
text=self.text_queries,
return_tensors="pt",
padding=True
)
self.processed_text = {
key: val.to(self.device)
for key, val in self.processed_text.items()
}
logger.info("Inicialização GPU completa!")
self._initialized = True
except Exception as e:
logger.error(f"Erro na inicialização GPU: {str(e)}")
raise
def detect_objects(self, image: Image.Image, threshold: float = 0.3) -> list:
"""Detecta objetos em uma imagem."""
try:
# Pré-processar imagem
if image.mode != 'RGB':
image = image.convert('RGB')
# Processar imagem
image_inputs = self.owlv2_processor(
images=image,
return_tensors="pt"
)
image_inputs = {
key: val.to(self.device)
for key, val in image_inputs.items()
}
# Inferência
with torch.no_grad():
inputs = {**image_inputs, **self.processed_text}
outputs = self.owlv2_model(**inputs)
target_sizes = torch.tensor([image.size[::-1]], device=self.device)
results = self.owlv2_processor.post_process_grounded_object_detection(
outputs=outputs,
target_sizes=target_sizes,
threshold=threshold
)[0]
# Processar detecções
detections = []
if len(results["scores"]) > 0:
scores = results["scores"]
boxes = results["boxes"]
labels = results["labels"]
for score, box, label in zip(scores, boxes, labels):
if score.item() >= threshold:
detections.append({
"confidence": score.item(),
"box": [int(x) for x in box.tolist()],
"label": self.text_queries[label]
})
return detections
except Exception as e:
logger.error(f"Erro em detect_objects: {str(e)}")
return []
def _get_best_device(self):
"""Retorna o melhor dispositivo disponível."""
return 0 # Usar índice inteiro para GPU
def _clear_gpu_memory(self):
"""Limpa memória GPU."""
torch.cuda.empty_cache()
gc.collect()
def process_video(self, video_path: str, fps: int = None, threshold: float = 0.3, resolution: int = 640) -> tuple:
"""Processa um vídeo."""
metrics = {
"total_time": 0,
"frames_analyzed": 0,
"detections": []
}
try:
frames = self.extract_frames(video_path, fps or 2, resolution)
metrics["frames_analyzed"] = len(frames)
for i, frame in enumerate(frames):
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_pil = Image.fromarray(frame_rgb)
detections = self.detect_objects(frame_pil, threshold)
if detections:
metrics["detections"].append({
"frame": i,
"detections": detections
})
return video_path, metrics
return video_path, metrics
except Exception as e:
logger.error(f"Erro ao processar vídeo: {str(e)}")
return video_path, metrics