Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import torch | |
| import torchvision.transforms.functional as F | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import streamlit as st | |
| # Define dictionaries to map class indices to their corresponding names | |
| object_dict = { | |
| 0: 'background', | |
| 1: 'task', | |
| 2: 'exclusiveGateway', | |
| 3: 'event', | |
| 4: 'parallelGateway', | |
| 5: 'messageEvent', | |
| 6: 'pool', | |
| 7: 'lane', | |
| 8: 'dataObject', | |
| 9: 'dataStore', | |
| 10: 'subProcess', | |
| 11: 'eventBasedGateway', | |
| 12: 'timerEvent', | |
| } | |
| arrow_dict = { | |
| 0: 'background', | |
| 1: 'sequenceFlow', | |
| 2: 'dataAssociation', | |
| 3: 'messageFlow', | |
| } | |
| class_dict = { | |
| 0: 'background', | |
| 1: 'task', | |
| 2: 'exclusiveGateway', | |
| 3: 'event', | |
| 4: 'parallelGateway', | |
| 5: 'messageEvent', | |
| 6: 'pool', | |
| 7: 'lane', | |
| 8: 'dataObject', | |
| 9: 'dataStore', | |
| 10: 'subProcess', | |
| 11: 'eventBasedGateway', | |
| 12: 'timerEvent', | |
| 13: 'sequenceFlow', | |
| 14: 'dataAssociation', | |
| 15: 'messageFlow', | |
| } | |
| def is_inside(box1, box2): | |
| """Check if the center of box1 is inside box2.""" | |
| x_center = (box1[0] + box1[2]) / 2 | |
| y_center = (box1[1] + box1[3]) / 2 | |
| return box2[0] <= x_center <= box2[2] and box2[1] <= y_center <= box2[3] | |
| def is_vertical(box): | |
| """Determine if the text in the bounding box is vertically aligned.""" | |
| width = box[2] - box[0] | |
| height = box[3] - box[1] | |
| return (height > 2 * width) | |
| def rescale_boxes(scale, boxes): | |
| """Rescale the bounding boxes by a given scale factor.""" | |
| for i in range(len(boxes)): | |
| boxes[i] = [boxes[i][0] * scale, boxes[i][1] * scale, boxes[i][2] * scale, boxes[i][3] * scale] | |
| return boxes | |
| def iou(box1, box2): | |
| """Calculate the Intersection over Union (IoU) of two bounding boxes.""" | |
| inter_box = [max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])] | |
| inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1]) | |
| box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
| box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
| union_area = box1_area + box2_area - inter_area | |
| return inter_area / union_area | |
| def proportion_inside(box1, box2): | |
| """Calculate the proportion of the smaller box inside the larger box.""" | |
| box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
| box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
| big_box, small_box = (box1, box2) if box1_area > box2_area else (box2, box1) | |
| inter_box = [max(small_box[0], big_box[0]), max(small_box[1], big_box[1]), min(small_box[2], big_box[2]), min(small_box[3], big_box[3])] | |
| inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1]) | |
| proportion = inter_area / ((small_box[2] - small_box[0]) * (small_box[3] - small_box[1])) | |
| return min(proportion, 1.0) | |
| def resize_boxes(boxes, original_size, target_size): | |
| """ | |
| Resizes bounding boxes according to a new image size. | |
| Parameters: | |
| - boxes (np.array): The original bounding boxes as a numpy array of shape [N, 4]. | |
| - original_size (tuple): The original size of the image as (width, height). | |
| - target_size (tuple): The desired size to resize the image to as (width, height). | |
| Returns: | |
| - np.array: The resized bounding boxes as a numpy array of shape [N, 4]. | |
| """ | |
| orig_width, orig_height = original_size | |
| target_width, target_height = target_size | |
| width_ratio = target_width / orig_width | |
| height_ratio = target_height / orig_height | |
| boxes[:, 0] *= width_ratio | |
| boxes[:, 1] *= height_ratio | |
| boxes[:, 2] *= width_ratio | |
| boxes[:, 3] *= height_ratio | |
| return boxes | |
| def resize_keypoints(keypoints, original_size, target_size): | |
| """ | |
| Resize keypoints based on the original and target dimensions of an image. | |
| Parameters: | |
| - keypoints (np.ndarray): The array of keypoints, where each keypoint is represented by its (x, y) coordinates. | |
| - original_size (tuple): The width and height of the original image (width, height). | |
| - target_size (tuple): The width and height of the target image (width, height). | |
| Returns: | |
| - np.ndarray: The resized keypoints. | |
| """ | |
| orig_width, orig_height = original_size | |
| target_width, target_height = target_size | |
| width_ratio = target_width / orig_width | |
| height_ratio = target_height / orig_height | |
| keypoints[:, 0] *= width_ratio | |
| keypoints[:, 1] *= height_ratio | |
| return keypoints | |
| def write_results(name_model, metrics_list, start_epoch): | |
| """Write training results to a text file.""" | |
| with open('./results/' + name_model + '.txt', 'w') as f: | |
| for i in range(len(metrics_list[0])): | |
| f.write(f"{i + 1 + start_epoch},{metrics_list[0][i]},{metrics_list[1][i]},{metrics_list[2][i]},{metrics_list[3][i]},{metrics_list[4][i]},{metrics_list[5][i]},{metrics_list[6][i]},{metrics_list[7][i]},{metrics_list[8][i]},{metrics_list[9][i]} \n") | |
| def find_other_keypoint(idx, keypoints, boxes): | |
| """ | |
| Find the opposite keypoint to the center of the box. | |
| Parameters: | |
| - idx (int): The index of the box and keypoints. | |
| - keypoints (np.ndarray): The array of keypoints. | |
| - boxes (np.ndarray): The array of bounding boxes. | |
| Returns: | |
| - tuple: The coordinates of the new keypoint and the average keypoint. | |
| """ | |
| box = boxes[idx] | |
| key1, key2 = keypoints[idx] | |
| x1, y1, x2, y2 = box | |
| center = ((x1 + x2) // 2, (y1 + y2) // 2) | |
| average_keypoint = (key1 + key2) // 2 | |
| if average_keypoint[0] < center[0]: | |
| x = center[0] + abs(center[0] - average_keypoint[0]) | |
| else: | |
| x = center[0] - abs(center[0] - average_keypoint[0]) | |
| if average_keypoint[1] < center[1]: | |
| y = center[1] + abs(center[1] - average_keypoint[1]) | |
| else: | |
| y = center[1] - abs(center[1] - average_keypoint[1]) | |
| return x, y, average_keypoint[0], average_keypoint[1] | |
| def filter_overlap_boxes(boxes, scores, labels, keypoints, iou_threshold=0.5): | |
| """ | |
| Filters overlapping boxes based on the Intersection over Union (IoU) metric, keeping only the boxes with the highest scores. | |
| Parameters: | |
| - boxes (np.ndarray): Array of bounding boxes with shape (N, 4), where each row contains [x_min, y_min, x_max, y_max]. | |
| - scores (np.ndarray): Array of scores for each box, reflecting the confidence of detection. | |
| - labels (np.ndarray): Array of labels corresponding to each box. | |
| - keypoints (np.ndarray): Array of keypoints associated with each box. | |
| - iou_threshold (float): Threshold for IoU above which a box is considered overlapping. | |
| Returns: | |
| - tuple: Filtered boxes, scores, labels, and keypoints. | |
| """ | |
| areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |
| order = scores.argsort()[::-1] | |
| keep = [] | |
| while order.size > 0: | |
| i = order[0] | |
| keep.append(i) | |
| xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0]) | |
| yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1]) | |
| xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2]) | |
| yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3]) | |
| w = np.maximum(0.0, xx2 - xx1) | |
| h = np.maximum(0.0, yy2 - yy1) | |
| inter = w * h | |
| iou = inter / (areas[i] + areas[order[1:]] - inter) | |
| inds = np.where(iou <= iou_threshold)[0] | |
| order = order[inds + 1] | |
| boxes = boxes[keep] | |
| scores = scores[keep] | |
| labels = labels[keep] | |
| keypoints = keypoints[keep] | |
| return boxes, scores, labels, keypoints | |
| def draw_annotations(image, | |
| target=None, | |
| prediction=None, | |
| full_prediction=None, | |
| text_predictions=None, | |
| model_dict=class_dict, | |
| draw_keypoints=False, | |
| draw_boxes=False, | |
| draw_text=False, | |
| draw_links=False, | |
| draw_twins=False, | |
| write_class=False, | |
| write_score=False, | |
| write_text=False, | |
| write_idx=False, | |
| score_threshold=0.4, | |
| keypoints_correction=False, | |
| only_print=None, | |
| axis=False, | |
| return_image=False, | |
| new_size=(1333, 800), | |
| resize=False): | |
| """ | |
| Draws annotations on images including bounding boxes, keypoints, links, and text. | |
| Parameters: | |
| - image (np.array): The image on which annotations will be drawn. | |
| - target (dict): Ground truth data containing boxes, labels, etc. | |
| - prediction (dict): Prediction data from a model. | |
| - full_prediction (dict): Additional detailed prediction data, potentially including relationships. | |
| - text_predictions (tuple): OCR text predictions containing bounding boxes and texts. | |
| - model_dict (dict): Mapping from class IDs to class names. | |
| - draw_keypoints (bool): Flag to draw keypoints. | |
| - draw_boxes (bool): Flag to draw bounding boxes. | |
| - draw_text (bool): Flag to draw text annotations. | |
| - draw_links (bool): Flag to draw links between annotations. | |
| - draw_twins (bool): Flag to draw twin keypoints. | |
| - write_class (bool): Flag to write class names near the annotations. | |
| - write_score (bool): Flag to write scores near the annotations. | |
| - write_text (bool): Flag to write OCR recognized text. | |
| - score_threshold (float): Threshold for scores above which annotations will be drawn. | |
| - only_print (str): Specific class name to filter annotations by. | |
| - resize (bool): Whether to resize annotations to fit the image size. | |
| """ | |
| # Convert image to RGB (if not already in that format) | |
| if prediction is None: | |
| image = image.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image_copy = image.copy() | |
| scale = max(image.shape[0], image.shape[1]) / 1000 | |
| # Helper function to draw annotations based on provided data | |
| def draw(data, is_prediction=False): | |
| for i in range(len(data['boxes'])): | |
| box = data['boxes'][i].tolist() | |
| x1, y1, x2, y2 = box | |
| if resize: | |
| x1, y1, x2, y2 = resize_boxes(np.array([box]), new_size, (image_copy.shape[1], image_copy.shape[0]))[0] | |
| if is_prediction: | |
| score = data['scores'][i].item() | |
| if score < score_threshold: | |
| continue | |
| if draw_boxes: | |
| if only_print is not None: | |
| if data['labels'][i] != list(model_dict.values()).index(only_print): | |
| continue | |
| cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0) if is_prediction else (0, 0, 0), int(2 * scale)) | |
| if is_prediction and write_score: | |
| cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15 * scale)), cv2.FONT_HERSHEY_SIMPLEX, scale / 2, (100, 100, 255), 2) | |
| if write_class and 'labels' in data: | |
| class_id = data['labels'][i].item() | |
| cv2.putText(image_copy, model_dict[class_id], (int(x1), int(y1) - int(2 * scale)), cv2.FONT_HERSHEY_SIMPLEX, scale / 2, (255, 100, 100), 2) | |
| if write_idx: | |
| cv2.putText(image_copy, str(i), (int(x1) + int(15 * scale), int(y1) + int(15 * scale)), cv2.FONT_HERSHEY_SIMPLEX, 2 * scale, (0, 0, 0), 2) | |
| # Draw keypoints if available | |
| if draw_keypoints and 'keypoints' in data: | |
| if is_prediction and keypoints_correction: | |
| for idx, (key1, key2) in enumerate(data['keypoints']): | |
| if data['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'), | |
| list(model_dict.values()).index('messageFlow'), | |
| list(model_dict.values()).index('dataAssociation')]: | |
| continue | |
| distance = np.linalg.norm(key1[:2] - key2[:2]) | |
| if distance < 5: | |
| x_new, y_new, x, y = find_other_keypoint(idx, data['keypoints'], data['boxes']) | |
| data['keypoints'][idx][0] = torch.tensor([x_new, y_new, 1]) | |
| data['keypoints'][idx][1] = torch.tensor([x, y, 1]) | |
| print("keypoint has been changed") | |
| for i in range(len(data['keypoints'])): | |
| kp = data['keypoints'][i] | |
| for j in range(kp.shape[0]): | |
| if is_prediction and data['labels'][i] not in [list(model_dict.values()).index('sequenceFlow'), | |
| list(model_dict.values()).index('messageFlow'), | |
| list(model_dict.values()).index('dataAssociation')]: | |
| continue | |
| if is_prediction: | |
| score = data['scores'][i] | |
| if score < score_threshold: | |
| continue | |
| x, y, v = np.array(kp[j]) | |
| if resize: | |
| x, y, v = resize_keypoints(np.array([kp[j]]), new_size, (image_copy.shape[1], image_copy.shape[0]))[0] | |
| if j == 0: | |
| cv2.circle(image_copy, (int(x), int(y)), int(5 * scale), (0, 0, 255), -1) | |
| else: | |
| cv2.circle(image_copy, (int(x), int(y)), int(5 * scale), (255, 0, 0), -1) | |
| # Draw text predictions if available | |
| if (draw_text or write_text) and text_predictions is not None: | |
| for i in range(len(text_predictions[0])): | |
| x1, y1, x2, y2 = text_predictions[0][i] | |
| text = text_predictions[1][i] | |
| if resize: | |
| x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), new_size, (image_copy.shape[1], image_copy.shape[0]))[0] | |
| if draw_text: | |
| cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2 * scale)) | |
| if write_text: | |
| cv2.putText(image_copy, text, (int(x1 + int(2 * scale)), int((y1 + y2) / 2)), cv2.FONT_HERSHEY_SIMPLEX, scale / 2, (0, 0, 0), 2) | |
| def draw_with_links(full_prediction): | |
| """Draws links between objects based on the full prediction data.""" | |
| if draw_twins and full_prediction is not None: | |
| circle_color = (0, 255, 0) | |
| circle_radius = int(10 * scale) | |
| for idx, (key1, key2) in enumerate(full_prediction['keypoints']): | |
| if full_prediction['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'), | |
| list(model_dict.values()).index('messageFlow'), | |
| list(model_dict.values()).index('dataAssociation')]: | |
| continue | |
| distance = np.linalg.norm(key1[:2] - key2[:2]) | |
| if distance < 10: | |
| x_new, y_new, x, y = find_other_keypoint(idx, full_prediction['keypoints'], full_prediction['boxes']) | |
| cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1) | |
| cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0, 0, 0), -1) | |
| if draw_links and full_prediction is not None: | |
| for i, (start_idx, end_idx) in enumerate(full_prediction['links']): | |
| if start_idx is None or end_idx is None: | |
| continue | |
| start_box = full_prediction['boxes'][start_idx] | |
| end_box = full_prediction['boxes'][end_idx] | |
| current_box = full_prediction['boxes'][i] | |
| start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2) | |
| end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2) | |
| current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2) | |
| cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2 * scale)) | |
| cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2 * scale)) | |
| i += 1 | |
| if target is not None: | |
| draw(target, is_prediction=False) | |
| if prediction is not None: | |
| draw(prediction, is_prediction=True) | |
| if full_prediction is not None: | |
| draw_with_links(full_prediction) | |
| image_copy = cv2.cvtColor(image_copy, cv2.COLOR_BGR2RGB) | |
| plt.figure(figsize=(12, 12)) | |
| plt.imshow(image_copy) | |
| if not axis: | |
| plt.axis('off') | |
| plt.show() | |
| if return_image: | |
| return image_copy | |
| def find_closest_object(keypoint, boxes, labels): | |
| """ | |
| Find the closest object to a keypoint based on their proximity. | |
| Parameters: | |
| - keypoint (numpy.ndarray): The coordinates of the keypoint. | |
| - boxes (numpy.ndarray): The bounding boxes of the objects. | |
| Returns: | |
| - int or None: The index of the closest object to the keypoint, or None if no object is found. | |
| """ | |
| closest_object_idx = None | |
| best_point = None | |
| min_distance = float('inf') | |
| for i, box in enumerate(boxes): | |
| if labels[i] in [list(class_dict.values()).index('sequenceFlow'), | |
| list(class_dict.values()).index('messageFlow'), | |
| list(class_dict.values()).index('dataAssociation'), | |
| list(class_dict.values()).index('lane')]: | |
| continue | |
| x1, y1, x2, y2 = box | |
| top = ((x1 + x2) / 2, y1) | |
| bottom = ((x1 + x2) / 2, y2) | |
| left = (x1, (y1 + y2) / 2) | |
| right = (x2, (y1 + y2) / 2) | |
| points = [left, top, right, bottom] | |
| pos_dict = {0: 'left', 1: 'top', 2: 'right', 3: 'bottom'} | |
| for pos, point in enumerate(points): | |
| distance = np.linalg.norm(keypoint[:2] - point) | |
| if distance < min_distance: | |
| min_distance = distance | |
| closest_object_idx = i | |
| best_point = pos_dict[pos] | |
| return closest_object_idx, best_point | |
| def error(text='There is an error in the detection'): | |
| """Display an error message using Streamlit.""" | |
| st.error(text, icon="🚨") | |
| def warning(text='Some element are maybe not detected, verify the results, try to modify the parameters or try to add it in the method and style step.'): | |
| """Display a warning message using Streamlit.""" | |
| st.warning(text, icon="⚠️") | |