Spaces:
Build error
Build error
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| from torchvision import transforms | |
| from config import SAPIENS_LITE_MODELS_PATH | |
| def load_model(task, version): | |
| try: | |
| model_path = SAPIENS_LITE_MODELS_PATH[task][version] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| model = torch.jit.load(model_path) | |
| model.eval() | |
| model.to(device) | |
| return model, device | |
| except KeyError as e: | |
| print(f"Error: Tarea o versión inválida. {e}") | |
| return None, None | |
| def preprocess_image(image, input_shape): | |
| img = image.resize((input_shape[2], input_shape[1])) | |
| img = np.array(img).transpose(2, 0, 1) | |
| img = torch.from_numpy(img).float() | |
| img = img[[2, 1, 0], ...] # RGB to BGR | |
| mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1) | |
| std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1) | |
| img = (img - mean) / std | |
| return img.unsqueeze(0) | |
| def udp_decode(heatmap, img_size, heatmap_size): | |
| # This is a simplified version. You might need to implement the full UDP decode logic | |
| h, w = heatmap_size | |
| keypoints = np.zeros((heatmap.shape[0], 2)) | |
| keypoint_scores = np.zeros(heatmap.shape[0]) | |
| for i in range(heatmap.shape[0]): | |
| hm = heatmap[i] | |
| idx = np.unravel_index(np.argmax(hm), hm.shape) | |
| keypoints[i] = [idx[1] * img_size[1] / w, idx[0] * img_size[0] / h] | |
| keypoint_scores[i] = hm[idx] | |
| return keypoints, keypoint_scores | |
| def visualize_keypoints(image, keypoints, keypoint_scores, threshold=0.3): | |
| draw = ImageDraw.Draw(image) | |
| for (x, y), score in zip(keypoints, keypoint_scores): | |
| if score > threshold: | |
| draw.ellipse([(x-2, y-2), (x+2, y+2)], fill='red', outline='red') | |
| return image | |
| def process_image_or_video(input_data, task='pose', version='sapiens_1b'): | |
| model, device = load_model(task, version) | |
| if model is None or device is None: | |
| return None | |
| input_shape = (3, 1024, 768) | |
| def process_frame(frame): | |
| if isinstance(frame, np.ndarray): | |
| frame = Image.fromarray(frame) | |
| if frame.mode == 'RGBA': | |
| frame = frame.convert('RGB') | |
| img = preprocess_image(frame, input_shape) | |
| with torch.no_grad(): | |
| heatmap = model(img.to(device)) | |
| keypoints, keypoint_scores = udp_decode(heatmap[0].cpu().float().numpy(), | |
| input_shape[1:], | |
| (input_shape[1] // 4, input_shape[2] // 4)) | |
| scale_x = frame.width / input_shape[2] | |
| scale_y = frame.height / input_shape[1] | |
| keypoints[:, 0] *= scale_x | |
| keypoints[:, 1] *= scale_y | |
| pose_image = visualize_keypoints(frame, keypoints, keypoint_scores) | |
| return pose_image | |
| if isinstance(input_data, np.ndarray): # Video frame | |
| return process_frame(input_data) | |
| elif isinstance(input_data, Image.Image): # Imagen | |
| return process_frame(input_data) | |
| else: | |
| print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.") | |
| return None |