|
import numpy as np |
|
import sys,os |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
sys.path.append(os.path.join(os.path.dirname(__file__), "sam2")) |
|
from sam2.build_sam import build_sam2 |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
import torch |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import cv2 |
|
import random |
|
import warnings |
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
device = torch.device("cuda") |
|
sam2_checkpoint = hf_hub_download( |
|
repo_id="Evan73/sam2-models", |
|
filename="sam2.1_hiera_large.pt" |
|
) |
|
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" |
|
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device) |
|
|
|
predictor = SAM2ImagePredictor(sam2_model) |
|
from ultralytics import YOLO |
|
from diffusers.utils import load_image |
|
import pickle |
|
import os |
|
import math |
|
heatmap_zip = hf_hub_download( |
|
repo_id="Evan73/attention-heatmaps", |
|
filename="attention_heatmaps.zip" |
|
) |
|
import zipfile |
|
import os |
|
|
|
with zipfile.ZipFile(heatmap_zip, 'r') as zip_ref: |
|
zip_ref.extractall("heatmaps_lda") |
|
|
|
with open("heatmaps_lda/attention_heatmaps.pkl", "rb") as f: |
|
heatmap_dict = pickle.load(f) |
|
|
|
def load_yolov5_model(): |
|
|
|
|
|
model = YOLO("yolo11n.pt") |
|
class_names = model.names |
|
print("YOLOv11 Class Names:") |
|
for idx, name in class_names.items(): |
|
print(f"{idx}: {name}") |
|
return model |
|
|
|
|
|
def is_point_in_car_area(point, model, image): |
|
""" |
|
检查给定的点是否在车辆区域内 |
|
- point: 点的坐标 (x, y) |
|
- model: YOLO模型 |
|
- image: 输入的图像 |
|
""" |
|
|
|
results = model(image) |
|
|
|
|
|
|
|
car_class_id = [2, 5, 7] |
|
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
for result in results: |
|
|
|
boxes = result.boxes.xyxy.cpu().numpy() |
|
confidences = result.boxes.conf.cpu().numpy() |
|
class_ids = result.boxes.cls.cpu().numpy().astype(int) |
|
|
|
|
|
for box, cls in zip(boxes, class_ids): |
|
if cls in car_class_id: |
|
x_min, y_min, x_max, y_max = box[:4] |
|
|
|
cv2.rectangle(image_bgr, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 2) |
|
|
|
if (x_min <= point[0] <= x_max) and (y_min <= point[1] <= y_max): |
|
cv2.imwrite("yolo_res.jpg", image_bgr) |
|
return False |
|
cv2.imwrite("yolo_res.jpg", image_bgr) |
|
print(f"检测结果已保存至 yolo_res.jpg") |
|
return True |
|
|
|
|
|
def show_mask(mask, ax, image_path,random_color=False, borders=True, image=None, save_path=None): |
|
""" |
|
根据mask区域随机选择两个对角点并在原始图像上绘制矩形框。 |
|
|
|
参数: |
|
- `mask`: 掩码区域 |
|
- `ax`: 用于绘制的matplotlib轴 |
|
- `random_color`: 是否使用随机颜色 |
|
- `borders`: 是否显示边界 |
|
- `image`: 原始图像,用于绘制矩形框 |
|
- `save_path`: 保存结果图像的路径 |
|
""" |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
color = np.array([30/255, 144/255, 255/255, 0.6]) |
|
|
|
h, w = mask.shape[-2:] |
|
mask = mask.astype(np.uint8) |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
cv2.imwrite("binary_mask.png", (mask * 255).astype(np.uint8)) |
|
print("原始二值掩码已保存为 binary_mask.png") |
|
if borders: |
|
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] |
|
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=5) |
|
|
|
|
|
|
|
|
|
colors = [ |
|
(255, 0, 0), |
|
(0, 255, 0), |
|
(0, 0, 255), |
|
(255, 255, 0), |
|
(255, 0, 255), |
|
(0, 255, 255), |
|
(255, 128, 0), |
|
(128, 0, 255), |
|
(128, 128, 128), |
|
(0, 128, 0) |
|
] |
|
|
|
for idx, contour in enumerate(contours): |
|
x, y, w, h = cv2.boundingRect(contour) |
|
print(f"轮廓{idx}: x={x}, y={y}, w={w}, h={h}") |
|
color = colors[idx % len(colors)] |
|
cv2.rectangle(image, (x, y), (x + w, y + h), color, 2) |
|
middle_save_path = "contours_colored_result.png" |
|
cv2.imwrite(middle_save_path, image) |
|
print(f"带颜色的轮廓结果已保存至 {middle_save_path}") |
|
if image is not None: |
|
|
|
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
for contour in contours: |
|
x, y, w, h = cv2.boundingRect(contour) |
|
|
|
if w > 50 and h > 50: |
|
for size in range(90,40,-5): |
|
for _ in range(100): |
|
random_x1 = random.randint(x, x + w - 50) |
|
random_y1 = random.randint(y, y + h - 50) |
|
random_x2 = random_x1 - size |
|
random_y2 = random_y1 - size |
|
|
|
|
|
|
|
try: |
|
if save_path and mask[random_y1, random_x1] == 1 and mask[random_y2, random_x2] == 1: |
|
cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2) |
|
cv2.imwrite(save_path, image) |
|
|
|
print(f"Image with rectangle saved at {save_path}") |
|
return (random_x1,random_y1),(random_x2,random_y2) |
|
except: |
|
pass |
|
|
|
|
|
|
|
|
|
for _ in range(100): |
|
random_x1 = random.randint(x, x + w - 50) |
|
random_y1 = random.randint(y, y + h - 50) |
|
random_x2 = random_x1 + size |
|
random_y2 = random_y1 + size |
|
|
|
|
|
|
|
try: |
|
if save_path and mask[random_y1, random_x1] == 1 and mask[random_y2, random_x2] == 1: |
|
cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2) |
|
cv2.imwrite(save_path, image) |
|
print(f"Image with rectangle saved at {save_path}") |
|
|
|
return (random_x1,random_y1),(random_x2,random_y2) |
|
except: |
|
pass |
|
|
|
ax.imshow(mask_image) |
|
plt.axis('off') |
|
|
|
|
|
|
|
def attention_mask(mask, ax, image_path,strategy="LOA",random_color=False, borders=True, image=None, save_path=None): |
|
""" |
|
根据mask区域随机选择两个对角点并在原始图像上绘制矩形框。 |
|
|
|
参数: |
|
- `mask`: 掩码区域 |
|
- `ax`: 用于绘制的matplotlib轴 |
|
- `random_color`: 是否使用随机颜色 |
|
- `borders`: 是否显示边界 |
|
- `image`: 原始图像,用于绘制矩形框 |
|
- `save_path`: 保存结果图像的路径 |
|
""" |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
color = np.array([30/255, 144/255, 255/255, 0.6]) |
|
orig_w, orig_h = image.shape[1],image.shape[0] |
|
|
|
h, w = mask.shape[-2:] |
|
mask = mask.astype(np.uint8) |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
cv2.imwrite("binary_mask.png", (mask * 255).astype(np.uint8)) |
|
print("原始二值掩码已保存为 binary_mask.png") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
candidates = [] |
|
path = image_path |
|
cls_heatmap = heatmap_dict[path]['cls_heatmap'] |
|
reg_heatmap = heatmap_dict[path]['reg_heatmap'] |
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
if strategy == "LDA": |
|
combined = cls_heatmap.astype(np.float32) |
|
if strategy == "LOA" or strategy == "LRA": |
|
combined = reg_heatmap.astype(np.float32) |
|
print(mask.shape) |
|
mask = cv2.resize(mask, (combined.shape[1], combined.shape[0]), interpolation=cv2.INTER_NEAREST) |
|
mask = (mask > 0.5).astype(np.uint8) |
|
cv2.imwrite("crop_binary_mask.png", (mask * 255).astype(np.uint8)) |
|
print("处理后的裁剪二值掩码已保存为 crop_binary_mask.png") |
|
print(combined.shape) |
|
vis_image = cv2.imread(image_path) |
|
vis_image = cv2.resize(vis_image,(combined.shape[1],combined.shape[0])) |
|
mask_image = cv2.resize(mask_image,(combined.shape[1],combined.shape[0])) |
|
image = cv2.resize(image,(combined.shape[1],combined.shape[0])) |
|
if borders: |
|
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] |
|
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) |
|
colors = [ |
|
(255, 0, 0), |
|
(0, 255, 0), |
|
(0, 0, 255), |
|
(255, 255, 0), |
|
(255, 0, 255), |
|
(0, 255, 255), |
|
(255, 128, 0), |
|
(128, 0, 255), |
|
(128, 128, 128), |
|
(0, 128, 0) |
|
] |
|
|
|
for idx, contour in enumerate(contours): |
|
x, y, w, h = cv2.boundingRect(contour) |
|
print(f"轮廓{idx}: x={x}, y={y}, w={w}, h={h}") |
|
color = colors[idx % len(colors)] |
|
cv2.rectangle(image, (x, y), (x + w, y + h), color, 2) |
|
middle_save_path = "contours_colored_result.png" |
|
cv2.imwrite(middle_save_path, image) |
|
print(f"带颜色的轮廓结果已保存至 {middle_save_path}") |
|
if image is not None: |
|
|
|
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
for contour in contours: |
|
x, y, w, h = cv2.boundingRect(contour) |
|
print("the contour is:",x, y, w, h) |
|
if w > 50 and h > 50: |
|
for size in range(50,40,-5): |
|
for y_step in range(y, y+h - size,5): |
|
for x_step in range(x, x+w - size,5): |
|
x1, y1, x2, y2 = x_step, y_step, x_step + size, y_step + size |
|
|
|
if mask[y1:y2, x1:x2].sum() >= size * size: |
|
heat_value = combined[y1:y2, x1:x2].mean() |
|
|
|
if not math.isnan(heat_value): |
|
candidates.append(((x1, y1, x2, y2), heat_value)) |
|
cv2.rectangle(vis_image, (x1, y1), (x2, y2), (0, 255, 0), 1) |
|
cv2.putText(vis_image, f'{heat_value:.1f}', (x1, y1 - 2), font, 0.4, (0, 0, 255), 1) |
|
if not candidates: |
|
print("⚠️ 没有找到满足掩码内区域的候选框") |
|
else: |
|
break |
|
cv2.imwrite("attention_vis.jpg", vis_image) |
|
print(f"Attention 候选框可视化已保存 attention_vis.jpg") |
|
|
|
candidates.sort(key=lambda x: x[1], reverse=True) |
|
print(save_path,candidates[0],candidates[-1]) |
|
for (x1, y1, x2, y2), _ in candidates: |
|
try: |
|
if mask[y1, x1] == 1 and mask[y2, x2] == 1: |
|
|
|
if save_path: |
|
image = cv2.imread(image_path) |
|
image = cv2.resize(image,(combined.shape[1],combined.shape[0])) |
|
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2) |
|
|
|
cv2.imwrite(save_path, image) |
|
print(f"Image with rectangle saved at {save_path}") |
|
resize_w, resize_h = combined.shape[1],combined.shape[0] |
|
scale_x = orig_w / resize_w |
|
scale_y = orig_h / resize_h |
|
x1_orig = int(x1 * scale_x) |
|
x2_orig = int(x2 * scale_x) |
|
y1_orig = int(y1 * scale_y) |
|
y2_orig = int(y2 * scale_y) |
|
cx = (x1_orig + x2_orig) // 2 |
|
cy = (y1_orig + y2_orig) // 2 |
|
target_size = 90 |
|
half = target_size // 2 |
|
x1_exp = max(0, cx - half) |
|
y1_exp = max(0, cy - half) |
|
x2_exp = min(orig_w - 1, cx + half) |
|
y2_exp = min(orig_h - 1, cy + half) |
|
print(f"扩展后的原图坐标: ({x1_exp}, {y1_exp}), ({x2_exp}, {y2_exp})") |
|
image_full = cv2.imread(image_path) |
|
cv2.rectangle(image_full, (x1_exp, y1_exp), (x2_exp, y2_exp), (0, 0, 255), 2) |
|
cv2.imwrite("expanded_bbox_on_original.jpg", image_full) |
|
print("📌 扩大后的候选框已绘制到原图并保存为 expanded_bbox_on_original.jpg") |
|
return (x1_exp, y1_exp), (x2_exp, y2_exp) |
|
except Exception as e: |
|
print("the error is:",e) |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax.imshow(mask_image) |
|
plt.axis('off') |
|
|
|
def generate_gt_mask_from_intersection(random_rectangle, yolo_boxes, image, mask_img,sam2_model, threshold_iou): |
|
""" |
|
判断随机生成的矩形与YOLO的框是否足够接近, |
|
若满足条件则调用SAM获取精准掩码作为GT。 |
|
""" |
|
image_np = np.array(image) |
|
x1_rect, y1_rect = random_rectangle[0] |
|
x2_rect, y2_rect = random_rectangle[1] |
|
rect_mask = np.zeros(image_np.shape[:2], dtype=np.uint8) |
|
cv2.rectangle(rect_mask, (x1_rect, y1_rect), (x2_rect, y2_rect), color=255, thickness=-1) |
|
|
|
rect_box = [min(x1_rect, x2_rect), min(y1_rect, y2_rect), max(x1_rect, x2_rect), max(y1_rect, y2_rect)] |
|
|
|
for box in yolo_boxes: |
|
iou = calculate_iou(rect_box, box) |
|
print(f"与YOLO box的IoU为: {iou}, 阈值: {threshold_iou}") |
|
|
|
if iou >= threshold_iou: |
|
|
|
x_min, y_min, x_max, y_max = box |
|
input_point1 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max)) |
|
input_point2 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max)) |
|
input_point3 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max)) |
|
|
|
|
|
gt_mask = get_gt_mask_from_sam(image, sam2_model, [input_point1, input_point2,input_point3], rect_mask) |
|
mask_img[gt_mask > 0] = 0 |
|
|
|
cv2.imwrite('gt_mask_from_sam.png', gt_mask) |
|
print(f"SAM生成的GT掩码已保存至 gt_mask_from_sam.png") |
|
|
|
return gt_mask,mask_img |
|
h, w = image_np.shape[:2] |
|
black_mask = np.zeros((h, w), dtype=np.uint8) |
|
no_match_save_path = 'gt_mask_from_sam.png' |
|
cv2.imwrite(no_match_save_path, black_mask) |
|
print("未找到满足阈值条件的YOLO box。") |
|
print(f"未匹配成功,保存空掩码图至 {no_match_save_path}") |
|
return None,mask_img |
|
|
|
def calculate_iou(boxA, boxB): |
|
"""计算两个box的IoU.""" |
|
xA = max(boxA[0], boxB[0]) |
|
yA = max(boxA[1], boxB[1]) |
|
xB = min(boxA[2], boxB[2]) |
|
yB = min(boxA[3], boxB[3]) |
|
|
|
inter_area = max(0, xB - xA + 1) * max(0, yB - yA + 1) |
|
|
|
boxA_area = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) |
|
boxB_area = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) |
|
|
|
iou = inter_area / float(boxA_area + boxB_area - inter_area) |
|
return iou |
|
|
|
def get_gt_mask_from_sam(image, sam2_model, input_points, rect_mask): |
|
"""使用SAM根据两个点生成掩码,并保存选取点和掩码图""" |
|
predictor = SAM2ImagePredictor(sam2_model) |
|
print("load sam2") |
|
predictor.set_image(image) |
|
|
|
input_point_np = np.array(input_points) |
|
input_label = np.array([1, 1,1]) |
|
|
|
masks, _, _ = predictor.predict( |
|
point_coords=input_point_np, |
|
point_labels=input_label, |
|
multimask_output=False, |
|
) |
|
|
|
mask_img = masks[0].astype(np.uint8) * 255 |
|
|
|
|
|
|
|
mask_save_path = 'sam_gt_mask.jpg' |
|
cv2.imwrite(mask_save_path, mask_img) |
|
print(f"SAM生成的掩码已保存至 {mask_save_path}") |
|
|
|
|
|
image_with_points = np.array(image).copy() |
|
for point in input_points: |
|
cv2.circle(image_with_points, point, radius=5, color=(255, 0, 0), thickness=-1) |
|
|
|
|
|
point_marked_save_path = 'image_with_points.jpg' |
|
image_bgr = cv2.cvtColor(image_with_points, cv2.COLOR_RGB2BGR) |
|
cv2.imwrite(point_marked_save_path, image_bgr) |
|
print(f"带点标记的原图已保存至 {point_marked_save_path}") |
|
|
|
return mask_img |
|
|
|
def show_points(coords, labels, ax, marker_size=375): |
|
pos_points = coords[labels==1] |
|
neg_points = coords[labels==0] |
|
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) |
|
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) |
|
|
|
def show_box(box, ax): |
|
x0, y0 = box[0], box[1] |
|
w, h = box[2] - box[0], box[3] - box[1] |
|
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2)) |
|
|
|
def display_mask(mask, ax, random_color=False, borders = True): |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
color = np.array([30/255, 144/255, 255/255, 0.6]) |
|
h, w = mask.shape[-2:] |
|
mask = mask.astype(np.uint8) |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
if borders: |
|
import cv2 |
|
contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
|
|
|
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] |
|
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) |
|
cv2.imwrite("check.jpg", mask_image) |
|
ax.imshow(mask_image) |
|
|
|
def random_points_below(point, radius, min_distance, model, image, max_attempts=100): |
|
""" |
|
在给定的point偏下方50像素的区域内,随机选择两个点直到满足条件。 |
|
|
|
参数: |
|
- point: (x, y) 格式的坐标 |
|
- radius: 随机点的最大半径 |
|
- min_distance: 两个随机点之间的最小距离 |
|
- max_attempts: 最大尝试次数,避免死循环 |
|
|
|
返回: |
|
- 两个随机点的坐标,如果没有找到合适的点则返回None |
|
""" |
|
for _ in range(max_attempts): |
|
|
|
x1 = random.randint(point[0] - radius, point[0] + radius) |
|
y1 = random.randint(point[1] + 50, point[1] + 50 + radius) |
|
|
|
x2 = random.randint(point[0] - radius, point[0] + radius) |
|
y2 = random.randint(point[1] + 50, point[1] + 50 + radius) |
|
|
|
|
|
distance = np.sqrt((x2 - x1)**2 + (y2 - y1)**2) |
|
|
|
|
|
if distance >= min_distance and is_point_in_car_area((x1, y1), model, image) and is_point_in_car_area((x2, y2), model, image) : |
|
return [(x1, y1), (x2, y2)] |
|
|
|
|
|
return None |
|
|
|
|
|
def show_masks(image, masks, scores, image_path, strategy,point_coords=None, box_coords=None, input_labels=None, borders=True, save_path=None): |
|
for i, (mask, score) in enumerate(zip(masks, scores)): |
|
plt.figure(figsize=(10, 10)) |
|
plt.imshow(image) |
|
display_mask(mask, plt.gca(), borders=borders) |
|
if point_coords is not None: |
|
assert input_labels is not None |
|
show_points(point_coords, input_labels, plt.gca()) |
|
if box_coords is not None: |
|
|
|
show_box(box_coords, plt.gca()) |
|
plt.axis('off') |
|
plt.savefig('check.jpg', bbox_inches='tight', pad_inches=0) |
|
point1,point2 = attention_mask(mask, plt.gca(), image_path,strategy,borders=borders, image=image, save_path=save_path) |
|
return point1,point2 |
|
|
|
def random_crop(image, target_width, target_height, mask_point1, mask_point2): |
|
|
|
"""从两个对角点的中点裁剪指定宽度和高度的区域,避免超出图像边界""" |
|
width, height = image.size |
|
|
|
center_x = (mask_point1[0] + mask_point2[0]) // 2 |
|
center_y = (mask_point1[1] + mask_point2[1]) // 2 |
|
|
|
|
|
left = center_x - target_width // 2 |
|
top = center_y - target_height // 2 |
|
right = left + target_width |
|
bottom = top + target_height |
|
|
|
|
|
if left < 0: |
|
left = 0 |
|
right = target_width |
|
if top < 0: |
|
top = 0 |
|
bottom = target_height |
|
if right > width: |
|
right = width |
|
left = width - target_width |
|
if bottom > height: |
|
bottom = height |
|
top = height - target_height |
|
|
|
|
|
top_padding = max(0, top) |
|
left_padding = max(0, left) |
|
|
|
|
|
cropped_image = image.crop((left, top, right, bottom)) |
|
|
|
global_mask_point1_relative = (mask_point1[0] - left, mask_point1[1] - top) |
|
global_mask_point2_relative = (mask_point2[0] - left, mask_point2[1] - top) |
|
print("裁剪后点的相对位置为:") |
|
print("mask_point1:", global_mask_point1_relative) |
|
print("mask_point2:", global_mask_point2_relative) |
|
return cropped_image, top_padding, left_padding,global_mask_point1_relative,global_mask_point2_relative |
|
|
|
def get_left_right_points(lane_data,image_path): |
|
lanes = lane_data["lanes"] |
|
h_samples = lane_data["h_samples"] |
|
model = load_yolov5_model() |
|
|
|
mid_idx = len(h_samples) // 2 |
|
image = cv2.imread(image_path) |
|
|
|
left_point = None |
|
right_point = None |
|
points = [] |
|
|
|
for lane in lanes: |
|
|
|
valid_points = [(x, y) for x, y in zip(lane, h_samples) if x != -2] |
|
|
|
if valid_points: |
|
if lane[mid_idx] != -2: |
|
for i in range(mid_idx-2,0,-1): |
|
left_point = lane[i] |
|
print(left_point) |
|
if lane[i] != -2: |
|
point = (left_point,h_samples[i]) |
|
FLAG = is_point_in_car_area(point, model, image) |
|
print(point,FLAG) |
|
if FLAG: |
|
points.append((left_point,h_samples[i])) |
|
break |
|
else: |
|
point = (1540/2, 590/2+30) |
|
radius = 50 |
|
min_distance = 40 |
|
points = random_points_below(point, radius, min_distance,model,image) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return points |
|
|
|
def sam2segment(image_path,points,strategy): |
|
|
|
image = Image.open(image_path) |
|
image = np.array(image.convert("RGB")) |
|
predictor.set_image(image) |
|
|
|
input_point = np.array([(points[0][0], points[0][1])]) |
|
input_label = np.array([1]) |
|
masks, scores, logits = predictor.predict( |
|
point_coords=input_point, |
|
point_labels=input_label, |
|
multimask_output=True, |
|
) |
|
sorted_ind = np.argsort(scores)[::-1] |
|
masks = masks[sorted_ind] |
|
scores = scores[sorted_ind] |
|
logits = logits[sorted_ind] |
|
|
|
mask_input = logits[np.argmax(scores), :, :] |
|
points_set = [] |
|
for point in points: |
|
points_set.append((point[0], point[1])) |
|
|
|
input_point = np.array(points_set) |
|
input_label = np.array([1]*len(points_set)) |
|
masks, scores, _ = predictor.predict( |
|
point_coords=input_point, |
|
point_labels=input_label, |
|
mask_input=mask_input[None, :, :], |
|
multimask_output=False, |
|
) |
|
|
|
point1,point2 = show_masks(image, masks, scores, image_path, strategy,save_path="masked_image.jpg") |
|
return point1,point2 |
|
|
|
def draw_point(image_path,points): |
|
image = cv2.imread(image_path) |
|
if image is not None: |
|
|
|
for point in points: |
|
cv2.circle(image, point, radius=5, color=(0, 255, 0), thickness=-1) |
|
|
|
|
|
output_path = "output_image_with_points.jpg" |
|
cv2.imwrite(output_path, image) |
|
print(f"Image saved with points at {output_path}") |
|
else: |
|
print("Error: Image could not be loaded.") |
|
|
|
def generate_mask(original_img_path, point1, point2): |
|
"""根据坐标生成掩码图像""" |
|
|
|
original_img = cv2.imread(original_img_path) |
|
|
|
|
|
height, width, _ = original_img.shape |
|
|
|
|
|
mask = np.zeros((height, width), dtype=np.uint8) |
|
|
|
|
|
three_quarter_point = ( |
|
int(point1[0] + 0.95 * (point2[0] - point1[0])), |
|
int(point1[1] + 0.95 * (point2[1] - point1[1])) |
|
) |
|
|
|
|
|
cv2.rectangle(mask, point1, three_quarter_point, color=255, thickness=-1) |
|
|
|
|
|
mask_path = original_img_path.replace('test.jpg', 'mask_test.jpg') |
|
cv2.imwrite(mask_path, mask) |
|
print(mask_path) |
|
return mask_path, point1, three_quarter_point |
|
|
|
def extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max): |
|
""" |
|
过滤 TuSimple `lanes`,只保留 `crop` 内的部分 |
|
""" |
|
cropped_lanes = [] |
|
for lane in lane_data["lanes"]: |
|
cropped_lane = [] |
|
for x, y in zip(lane, lane_data["h_samples"]): |
|
if x != -2 and crop_x_min <= x <= crop_x_max and crop_y_min <= y <= crop_y_max: |
|
cropped_lane.append((x, y)) |
|
|
|
|
|
|
|
if cropped_lane: |
|
cropped_lanes.append(cropped_lane) |
|
|
|
return cropped_lanes |
|
|
|
|
|
def generate_trigger_crop(image_path: str, lane_data: dict): |
|
""" |
|
输入一张图像路径,返回处理后的 crop 图像和 crop mask 图像路径。 |
|
""" |
|
|
|
points = get_left_right_points(lane_data, image_path) |
|
print(f"[INFO] 获取 trigger 点: {points}") |
|
draw_point(image_path, points) |
|
|
|
|
|
image = load_image(image_path) |
|
mask_point1, mask_point2 = sam2segment(image_path, points, "LDA") |
|
|
|
|
|
input_image, *_ = random_crop(image, 512, 512, mask_point1, mask_point2) |
|
input_crop_path = "crop.jpg" |
|
input_image.save(input_crop_path) |
|
|
|
|
|
mask_path, point1, point2 = generate_mask(image_path, mask_point1, mask_point2) |
|
mask_img = load_image(mask_path) |
|
mask_img, *_ = random_crop(mask_img, 512, 512, mask_point1, mask_point2) |
|
crop_mask_path = "crop_mask.jpg" |
|
cv2.imwrite(crop_mask_path, np.array(mask_img)) |
|
|
|
return input_crop_path, crop_mask_path |
|
|
|
if __name__ == "__main__": |
|
lane_data = {"lanes": [[-2, -2, -2, -2, -2, -2, -2, 814, 751, 688, 625, 562, 500, 438, 373, 305, 234, 160, 88, 16, -64, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2], [-2, -2, -2, -2, -2, -2, -2, 818, 801, 784, 768, 751, 734, 717, 701, 685, 668, 651, 634, 618, 601, 585, 568, 551, 535, 518, 502, 484, 468, 451, 435, 418, 401, 385, 368, 351, 335, 318, 301, 287], [-2, -2, -2, -2, -2, -2, -2, 863, 872, 881, 890, 899, 908, 918, 927, 936, 945, 954, 964, 972, 982, 991, 1000, 1009, 1018, 1027, 1036, 1046, 1055, 1064, 1073, 1082, 1091, 1100, 1109, 1119, 1128, 1137, 1146, 1154]], "h_samples": [200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590], "raw_file": "driver_182_30frame/06010513_0036.MP4/00270.jpg"} |
|
|
|
image_path = "driver_182_30frame/06010513_0036.MP4/00270.jpg" |
|
points = get_left_right_points(lane_data,image_path) |
|
print(points) |
|
draw_point(image_path,points) |
|
|
|
|
|
|
|
image = load_image(image_path) |
|
mask_point1,mask_point2 = sam2segment(image_path,points,"LDA") |
|
input_image,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(image, 512, 512, mask_point1, mask_point2) |
|
input_image.save("crop.jpg") |
|
print(f"Image saved with points at crop.jpg") |
|
mask_path, point1, point2 = generate_mask('culane_test.jpg', mask_point1, mask_point2) |
|
mask_img = load_image(mask_path) |
|
mask_img,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(mask_img, 512, 512,mask_point1,mask_point2) |
|
|
|
mask_img = np.array(mask_img) |
|
|
|
model = load_yolov5_model() |
|
yolo_results = model(input_image) |
|
yolo_boxes = [] |
|
car_class_id = [2, 5, 7] |
|
|
|
for result in yolo_results: |
|
boxes = result.boxes.xyxy.cpu().numpy() |
|
class_ids = result.boxes.cls.cpu().numpy().astype(int) |
|
|
|
for box, cls in zip(boxes, class_ids): |
|
if cls in car_class_id: |
|
x_min, y_min, x_max, y_max = box[:4] |
|
yolo_boxes.append([int(x_min), int(y_min), int(x_max), int(y_max)]) |
|
_,mask_img=generate_gt_mask_from_intersection([global_mask_point1_relative,global_mask_point2_relative], yolo_boxes, input_image, mask_img,sam2_model, threshold_iou=0.01) |
|
cv2.imwrite("crop_mask.jpg", mask_img) |
|
|
|
print("Mask 已成功保存至 crop_mask.jpg") |
|
crop_x_min = min(mask_point1[0], mask_point2[0]) |
|
crop_x_max = max(mask_point1[0], mask_point2[0]) |
|
crop_y_min = min(mask_point1[1], mask_point2[1]) |
|
crop_y_max = max(mask_point1[1], mask_point2[1]) |
|
|
|
|
|
def extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max): |
|
""" |
|
过滤 TuSimple `lanes`,只保留 `crop` 内的部分 |
|
""" |
|
cropped_lanes = [] |
|
for lane in lane_data["lanes"]: |
|
cropped_lane = [] |
|
for x, y in zip(lane, lane_data["h_samples"]): |
|
if x != -2 and crop_x_min <= x <= crop_x_max and crop_y_min <= y <= crop_y_max: |
|
cropped_lane.append((x, y)) |
|
|
|
|
|
|
|
if cropped_lane: |
|
cropped_lanes.append(cropped_lane) |
|
|
|
return cropped_lanes |
|
|
|
|
|
cropped_lanes = extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def draw_lane_mask_on_original(image, cropped_lanes): |
|
""" |
|
在原图上绘制 **仅包含 cropped_lanes** 的车道线 |
|
""" |
|
height, width, _ = image.shape |
|
lane_mask = np.zeros((height, width), dtype=np.uint8) |
|
|
|
for lane in cropped_lanes: |
|
points = np.array(lane, dtype=np.int32) |
|
cv2.polylines(lane_mask, [points], isClosed=False, color=255, thickness=10) |
|
|
|
return lane_mask |
|
|
|
def random_crop_lane(image, target_width, target_height, mask_point1, mask_point2): |
|
"""从两个对角点的中点裁剪指定宽度和高度的区域,避免超出图像边界""" |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
image = np.array(image) |
|
|
|
height, width = image.shape[:2] |
|
|
|
|
|
center_x = (mask_point1[0] + mask_point2[0]) // 2 |
|
center_y = (mask_point1[1] + mask_point2[1]) // 2 |
|
|
|
|
|
left = max(0, center_x - target_width // 2) |
|
top = max(0, center_y - target_height // 2) |
|
right = min(width, left + target_width) |
|
bottom = min(height, top + target_height) |
|
|
|
|
|
top_padding = max(0, target_height - (bottom - top)) |
|
left_padding = max(0, target_width - (right - left)) |
|
|
|
|
|
cropped_image = image[top:bottom, left:right] |
|
|
|
return cropped_image, top_padding, left_padding |
|
|
|
raw_image = np.array(load_image(image_path).convert("RGB")) |
|
lane_mask = draw_lane_mask_on_original(raw_image, cropped_lanes) |
|
lane_mask_pil = Image.fromarray(lane_mask) |
|
crop_image,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(lane_mask_pil, 512, 512,mask_point1,mask_point2) |
|
|
|
|
|
crop_image.save("lane_mask_crop.jpg") |
|
print("✅ 车道 Mask 已保存为 lane_mask_crop.jpg") |
|
|
|
crop_img = cv2.imread("crop.jpg") |
|
mask_img = cv2.imread("crop_mask.jpg", cv2.IMREAD_GRAYSCALE) |
|
if crop_img.shape[:2] != mask_img.shape: |
|
print("⚠️ Resizing mask to match crop image size...") |
|
mask_img = cv2.resize(mask_img, (crop_img.shape[1], crop_img.shape[0])) |
|
white_overlay = np.ones_like(crop_img) * 255 |
|
masked_result = np.where(mask_img[:, :, None] == 255, white_overlay, crop_img) |
|
|
|
|
|
cv2.imwrite("crop_with_mask.jpg", masked_result) |
|
print("✅ 叠加后的 Mask 图像已保存至 crop_with_mask.jpg") |