Waste-Detector / utils.py
Hector Lopez
refactor: Using streamlit again
f890c24
raw
history blame
2.06 kB
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch
from classifier import CustomViT
from model import get_model
def plot_img_no_mask(image : np.ndarray, boxes : torch.Tensor, labels):
colors = {
0: (255,255,0),
1: (255, 0, 0),
2: (0, 0, 255),
3: (0,128,0),
4: (255,165,0),
5: (230,230,250),
6: (192,192,192)
}
texts = {
0: 'plastic',
1: 'dangerous',
2: 'carton',
3: 'glass',
4: 'organic',
5: 'rest',
6: 'other'
}
# Show image
boxes = boxes.cpu().detach().numpy().astype(np.int32)
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
for i, box in enumerate(boxes):
color = colors[labels[i]]
[x1, y1, x2, y2] = np.array(box).astype(int)
# Si no se hace la copia da error en cv2.rectangle
image = np.array(image).copy()
pt1 = (x1, y1)
pt2 = (x2, y2)
cv2.rectangle(image, pt1, pt2, color, thickness=5)
cv2.putText(image, texts[labels[i]], (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 4, thickness=5, color=color)
plt.axis('off')
ax.imshow(image)
fig.savefig("img.png", bbox_inches='tight')
def get_models(
detection_ckpt : str,
classifier_ckpt : str
) -> Tuple[torch.nn.Module, torch.nn.Module]:
"""
Get the detection and classifier models
Args:
detection_ckpt (str): Detection model checkpoint
classifier_ckpt (str): Classifier model checkpoint
Returns:
tuple: Tuple containing:
- (torch.nn.Module): Detection model
- (torch.nn.Module): Classifier model
"""
print('Loading the detection model')
det_model = get_model(detection_ckpt)
det_model.eval()
print('Loading the classifier model')
classifier = CustomViT(target_size=7, pretrained=False)
classifier.load_state_dict(torch.load(classifier_ckpt, map_location='cpu'))
classifier.eval()
return det_model, classifier