DBDLD / sam2segment_structure.py
Evan73's picture
fresh start without image history
d5c53f9
import numpy as np
import sys,os
# sys.path.append("/home/yifan/sam2")
# sys.path.append("/data_sdf/yifan/miniconda3/envs/sam2/lib/python3.10/site-packages")
from huggingface_hub import hf_hub_download
sys.path.append(os.path.join(os.path.dirname(__file__), "sam2"))
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import torch
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import random
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
device = torch.device("cuda")
sam2_checkpoint = hf_hub_download(
repo_id="Evan73/sam2-models",
filename="sam2.1_hiera_large.pt"
)
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
# global sam2_model
predictor = SAM2ImagePredictor(sam2_model)
from ultralytics import YOLO
from diffusers.utils import load_image
import pickle
import os
import math
heatmap_zip = hf_hub_download(
repo_id="Evan73/attention-heatmaps",
filename="attention_heatmaps.zip"
)
import zipfile
import os
with zipfile.ZipFile(heatmap_zip, 'r') as zip_ref:
zip_ref.extractall("heatmaps_lda")
with open("heatmaps_lda/attention_heatmaps.pkl", "rb") as f:
heatmap_dict = pickle.load(f)
def load_yolov5_model():
# 使用YOLOv5官方模型加载器(需要安装yolov5)
# model = torch.hub.load('ultralytics/yolov11', 'yolov11s') # 可以根据需要选择不同大小的模型
model = YOLO("yolo11n.pt")
class_names = model.names # class index to name mapping
print("YOLOv11 Class Names:")
for idx, name in class_names.items():
print(f"{idx}: {name}")
return model
# 检查点是否在汽车区域内
def is_point_in_car_area(point, model, image):
"""
检查给定的点是否在车辆区域内
- point: 点的坐标 (x, y)
- model: YOLO模型
- image: 输入的图像
"""
# 使用YOLO模型进行物体检测
results = model(image) # 获取检测结果
# 获取汽车类别(根据模型调整类别ID)
# print("Detected classes:", results[0].boxes.cls.cpu().numpy())
car_class_id = [2, 5, 7] # COCO数据集中汽车类别通常为2,但需确认
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# 遍历每个检测结果(支持批量处理,这里假设单张图像)
for result in results:
# 提取检测框的xyxy坐标、置信度、类别
boxes = result.boxes.xyxy.cpu().numpy() # 转换为左上和右下坐标
confidences = result.boxes.conf.cpu().numpy()
class_ids = result.boxes.cls.cpu().numpy().astype(int)
# 遍历每个检测框
for box, cls in zip(boxes, class_ids):
if cls in car_class_id:
x_min, y_min, x_max, y_max = box[:4]
# 绘制检测框(可选)
cv2.rectangle(image_bgr, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 2)
# 检查点是否在框内
if (x_min <= point[0] <= x_max) and (y_min <= point[1] <= y_max):
cv2.imwrite("yolo_res.jpg", image_bgr)
return False
cv2.imwrite("yolo_res.jpg", image_bgr)
print(f"检测结果已保存至 yolo_res.jpg")
return True
def show_mask(mask, ax, image_path,random_color=False, borders=True, image=None, save_path=None):
"""
根据mask区域随机选择两个对角点并在原始图像上绘制矩形框。
参数:
- `mask`: 掩码区域
- `ax`: 用于绘制的matplotlib轴
- `random_color`: 是否使用随机颜色
- `borders`: 是否显示边界
- `image`: 原始图像,用于绘制矩形框
- `save_path`: 保存结果图像的路径
"""
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
cv2.imwrite("binary_mask.png", (mask * 255).astype(np.uint8))
print("原始二值掩码已保存为 binary_mask.png")
if borders:
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=5)
# print(f"Mask unique values: {np.unique(mask)}")
# print(f"Max value in mask: {mask.max()}, Min value in mask: {mask.min()}")
# 如果提供了原始图像,绘制矩形框
# size = 100
colors = [
(255, 0, 0), # 红色
(0, 255, 0), # 绿色
(0, 0, 255), # 蓝色
(255, 255, 0), # 黄色
(255, 0, 255), # 品红色
(0, 255, 255), # 青色
(255, 128, 0), # 橙色
(128, 0, 255), # 紫色
(128, 128, 128), # 灰色
(0, 128, 0) # 深绿色
]
for idx, contour in enumerate(contours):
x, y, w, h = cv2.boundingRect(contour)
print(f"轮廓{idx}: x={x}, y={y}, w={w}, h={h}")
color = colors[idx % len(colors)]
cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
middle_save_path = "contours_colored_result.png"
cv2.imwrite(middle_save_path, image)
print(f"带颜色的轮廓结果已保存至 {middle_save_path}")
if image is not None:
# 找到掩码的边界
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
# print(x, y, w, h)
if w > 50 and h > 50:
for size in range(90,40,-5):
for _ in range(100):
random_x1 = random.randint(x, x + w - 50)
random_y1 = random.randint(y, y + h - 50)
random_x2 = random_x1 - size
random_y2 = random_y1 - size
# print(random_x1, random_y1,random_x2,random_y2)
# 在原图上绘制矩形框
# 保存结果图像
try:
if save_path and mask[random_y1, random_x1] == 1 and mask[random_y2, random_x2] == 1:
cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
cv2.imwrite(save_path, image)
# generate_gt_mask_from_intersection([(random_x1, random_y1),(random_x2, random_y2)], yolo_boxes, image, sam2_model, threshold_iou=0.01)
print(f"Image with rectangle saved at {save_path}")
return (random_x1,random_y1),(random_x2,random_y2)
except:
pass
# cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
# cv2.imwrite(save_path, image)
# print(f"Image with rectangle saved at {save_path}")
# break
for _ in range(100):
random_x1 = random.randint(x, x + w - 50)
random_y1 = random.randint(y, y + h - 50)
random_x2 = random_x1 + size
random_y2 = random_y1 + size
# print(mask[random_y1, random_x1] == 1,mask[random_y2, random_x2] == 1)
# 在原图上绘制矩形框
# 保存结果图像
try:
if save_path and mask[random_y1, random_x1] == 1 and mask[random_y2, random_x2] == 1:
cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
cv2.imwrite(save_path, image)
print(f"Image with rectangle saved at {save_path}")
# generate_gt_mask_from_intersection([(random_x1, random_y1),(random_x2, random_y2)], yolo_boxes, image, sam2_model, threshold_iou=0.01)
return (random_x1,random_y1),(random_x2,random_y2)
except:
pass
ax.imshow(mask_image)
plt.axis('off')
def attention_mask(mask, ax, image_path,strategy="LOA",random_color=False, borders=True, image=None, save_path=None):
"""
根据mask区域随机选择两个对角点并在原始图像上绘制矩形框。
参数:
- `mask`: 掩码区域
- `ax`: 用于绘制的matplotlib轴
- `random_color`: 是否使用随机颜色
- `borders`: 是否显示边界
- `image`: 原始图像,用于绘制矩形框
- `save_path`: 保存结果图像的路径
"""
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
orig_w, orig_h = image.shape[1],image.shape[0]
# print(image.shape)
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
cv2.imwrite("binary_mask.png", (mask * 255).astype(np.uint8))
print("原始二值掩码已保存为 binary_mask.png")
# if borders:
# contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
# mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
# colors = [
# (255, 0, 0), # 红色
# (0, 255, 0), # 绿色
# (0, 0, 255), # 蓝色
# (255, 255, 0), # 黄色
# (255, 0, 255), # 品红色
# (0, 255, 255), # 青色
# (255, 128, 0), # 橙色
# (128, 0, 255), # 紫色
# (128, 128, 128), # 灰色
# (0, 128, 0) # 深绿色
# ]
# # print(mask.shape)
# for idx, contour in enumerate(contours):
# x, y, w, h = cv2.boundingRect(contour)
# print(f"轮廓{idx}: x={x}, y={y}, w={w}, h={h}")
# color = colors[idx % len(colors)]
# cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
# middle_save_path = "contours_colored_result.png"
# cv2.imwrite(middle_save_path, image)
# print(f"带颜色的轮廓结果已保存至 {middle_save_path}")
candidates = []
path = image_path
cls_heatmap = heatmap_dict[path]['cls_heatmap']
reg_heatmap = heatmap_dict[path]['reg_heatmap']
font = cv2.FONT_HERSHEY_SIMPLEX
if strategy == "LDA":
combined = cls_heatmap.astype(np.float32)
if strategy == "LOA" or strategy == "LRA":
combined = reg_heatmap.astype(np.float32)
print(mask.shape)
mask = cv2.resize(mask, (combined.shape[1], combined.shape[0]), interpolation=cv2.INTER_NEAREST)
mask = (mask > 0.5).astype(np.uint8)
cv2.imwrite("crop_binary_mask.png", (mask * 255).astype(np.uint8))
print("处理后的裁剪二值掩码已保存为 crop_binary_mask.png")
print(combined.shape)
vis_image = cv2.imread(image_path)
vis_image = cv2.resize(vis_image,(combined.shape[1],combined.shape[0]))
mask_image = cv2.resize(mask_image,(combined.shape[1],combined.shape[0]))
image = cv2.resize(image,(combined.shape[1],combined.shape[0]))
if borders:
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
colors = [
(255, 0, 0), # 红色
(0, 255, 0), # 绿色
(0, 0, 255), # 蓝色
(255, 255, 0), # 黄色
(255, 0, 255), # 品红色
(0, 255, 255), # 青色
(255, 128, 0), # 橙色
(128, 0, 255), # 紫色
(128, 128, 128), # 灰色
(0, 128, 0) # 深绿色
]
# print(mask.shape)
for idx, contour in enumerate(contours):
x, y, w, h = cv2.boundingRect(contour)
print(f"轮廓{idx}: x={x}, y={y}, w={w}, h={h}")
color = colors[idx % len(colors)]
cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
middle_save_path = "contours_colored_result.png"
cv2.imwrite(middle_save_path, image)
print(f"带颜色的轮廓结果已保存至 {middle_save_path}")
if image is not None:
# 找到掩码的边界
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# print(contours)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
print("the contour is:",x, y, w, h)
if w > 50 and h > 50:
for size in range(50,40,-5):
for y_step in range(y, y+h - size,5):
for x_step in range(x, x+w - size,5):
x1, y1, x2, y2 = x_step, y_step, x_step + size, y_step + size
# print(mask[y1:y2, x1:x2].sum())
if mask[y1:y2, x1:x2].sum() >= size * size: # 掩码区域必须都在内部
heat_value = combined[y1:y2, x1:x2].mean()
# print("the heat_value is:",heat_value,y1,y2, x1,x2,combined.shape)
if not math.isnan(heat_value):
candidates.append(((x1, y1, x2, y2), heat_value))
cv2.rectangle(vis_image, (x1, y1), (x2, y2), (0, 255, 0), 1)
cv2.putText(vis_image, f'{heat_value:.1f}', (x1, y1 - 2), font, 0.4, (0, 0, 255), 1)
if not candidates:
print("⚠️ 没有找到满足掩码内区域的候选框")
else:
break
cv2.imwrite("attention_vis.jpg", vis_image)
print(f"Attention 候选框可视化已保存 attention_vis.jpg")
# 从高到低排序,选择热值最高的
candidates.sort(key=lambda x: x[1], reverse=True)
print(save_path,candidates[0],candidates[-1])
for (x1, y1, x2, y2), _ in candidates:
try:
if mask[y1, x1] == 1 and mask[y2, x2] == 1:
# 可视化 + 保存
if save_path:
image = cv2.imread(image_path)
image = cv2.resize(image,(combined.shape[1],combined.shape[0]))
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# os.makedirs(os.path.dirname(save_path), exist_ok=True)
cv2.imwrite(save_path, image)
print(f"Image with rectangle saved at {save_path}")
resize_w, resize_h = combined.shape[1],combined.shape[0]
scale_x = orig_w / resize_w
scale_y = orig_h / resize_h
x1_orig = int(x1 * scale_x)
x2_orig = int(x2 * scale_x)
y1_orig = int(y1 * scale_y)
y2_orig = int(y2 * scale_y)
cx = (x1_orig + x2_orig) // 2
cy = (y1_orig + y2_orig) // 2
target_size = 90
half = target_size // 2
x1_exp = max(0, cx - half)
y1_exp = max(0, cy - half)
x2_exp = min(orig_w - 1, cx + half)
y2_exp = min(orig_h - 1, cy + half)
print(f"扩展后的原图坐标: ({x1_exp}, {y1_exp}), ({x2_exp}, {y2_exp})")
image_full = cv2.imread(image_path) # 原图大小读取
cv2.rectangle(image_full, (x1_exp, y1_exp), (x2_exp, y2_exp), (0, 0, 255), 2)
cv2.imwrite("expanded_bbox_on_original.jpg", image_full)
print("📌 扩大后的候选框已绘制到原图并保存为 expanded_bbox_on_original.jpg")
return (x1_exp, y1_exp), (x2_exp, y2_exp)
except Exception as e:
print("the error is:",e)
pass # 若越界等问题,继续下一个
# for _ in range(100):
# random_x1 = random.randint(x, x + w - 50)
# random_y1 = random.randint(y, y + h - 50)
# random_x2 = random_x1 + size
# random_y2 = random_y1 + size
# # print(mask[random_y1, random_x1] == 1,mask[random_y2, random_x2] == 1)
# try:
# if save_path and mask[random_y1, random_x1] == 1 and mask[random_y2, random_x2] == 1:
# cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
# cv2.imwrite(save_path, image)
# print(f"Image with rectangle saved at {save_path}")
# return (random_x1,random_y1),(random_x2,random_y2)
# except:
# pass
ax.imshow(mask_image)
plt.axis('off')
def generate_gt_mask_from_intersection(random_rectangle, yolo_boxes, image, mask_img,sam2_model, threshold_iou):
"""
判断随机生成的矩形与YOLO的框是否足够接近,
若满足条件则调用SAM获取精准掩码作为GT。
"""
image_np = np.array(image)
x1_rect, y1_rect = random_rectangle[0]
x2_rect, y2_rect = random_rectangle[1]
rect_mask = np.zeros(image_np.shape[:2], dtype=np.uint8)
cv2.rectangle(rect_mask, (x1_rect, y1_rect), (x2_rect, y2_rect), color=255, thickness=-1)
rect_box = [min(x1_rect, x2_rect), min(y1_rect, y2_rect), max(x1_rect, x2_rect), max(y1_rect, y2_rect)]
for box in yolo_boxes:
iou = calculate_iou(rect_box, box)
print(f"与YOLO box的IoU为: {iou}, 阈值: {threshold_iou}")
if iou >= threshold_iou:
# 在YOLO框内随机取两个点
x_min, y_min, x_max, y_max = box
input_point1 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max))
input_point2 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max))
input_point3 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max))
# 使用SAM生成精准掩码
gt_mask = get_gt_mask_from_sam(image, sam2_model, [input_point1, input_point2,input_point3], rect_mask)
mask_img[gt_mask > 0] = 0
# 保存gt掩码
cv2.imwrite('gt_mask_from_sam.png', gt_mask)
print(f"SAM生成的GT掩码已保存至 gt_mask_from_sam.png")
return gt_mask,mask_img
h, w = image_np.shape[:2]
black_mask = np.zeros((h, w), dtype=np.uint8)
no_match_save_path = 'gt_mask_from_sam.png'
cv2.imwrite(no_match_save_path, black_mask)
print("未找到满足阈值条件的YOLO box。")
print(f"未匹配成功,保存空掩码图至 {no_match_save_path}")
return None,mask_img
def calculate_iou(boxA, boxB):
"""计算两个box的IoU."""
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
inter_area = max(0, xB - xA + 1) * max(0, yB - yA + 1)
boxA_area = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
boxB_area = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
iou = inter_area / float(boxA_area + boxB_area - inter_area)
return iou
def get_gt_mask_from_sam(image, sam2_model, input_points, rect_mask):
"""使用SAM根据两个点生成掩码,并保存选取点和掩码图"""
predictor = SAM2ImagePredictor(sam2_model)
print("load sam2")
predictor.set_image(image)
input_point_np = np.array(input_points)
input_label = np.array([1, 1,1])
masks, _, _ = predictor.predict(
point_coords=input_point_np,
point_labels=input_label,
multimask_output=False,
)
mask_img = masks[0].astype(np.uint8) * 255
# mask_img[rect_mask == 255] = 0 # 将 `random_rectangle` 区域设为黑色
# 保存SAM生成的掩码图
mask_save_path = 'sam_gt_mask.jpg'
cv2.imwrite(mask_save_path, mask_img)
print(f"SAM生成的掩码已保存至 {mask_save_path}")
# 把选取的两个点画在原图上
image_with_points = np.array(image).copy()
for point in input_points:
cv2.circle(image_with_points, point, radius=5, color=(255, 0, 0), thickness=-1)
# 保存带有标记点的原图
point_marked_save_path = 'image_with_points.jpg'
image_bgr = cv2.cvtColor(image_with_points, cv2.COLOR_RGB2BGR)
cv2.imwrite(point_marked_save_path, image_bgr)
print(f"带点标记的原图已保存至 {point_marked_save_path}")
return mask_img
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
def display_mask(mask, ax, random_color=False, borders = True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if borders:
import cv2
contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
cv2.imwrite("check.jpg", mask_image)
ax.imshow(mask_image)
def random_points_below(point, radius, min_distance, model, image, max_attempts=100):
"""
在给定的point偏下方50像素的区域内,随机选择两个点直到满足条件。
参数:
- point: (x, y) 格式的坐标
- radius: 随机点的最大半径
- min_distance: 两个随机点之间的最小距离
- max_attempts: 最大尝试次数,避免死循环
返回:
- 两个随机点的坐标,如果没有找到合适的点则返回None
"""
for _ in range(max_attempts):
# 在点的偏下方50像素区域内随机选择两个点
x1 = random.randint(point[0] - radius, point[0] + radius)
y1 = random.randint(point[1] + 50, point[1] + 50 + radius) # 偏下50像素
x2 = random.randint(point[0] - radius, point[0] + radius)
y2 = random.randint(point[1] + 50, point[1] + 50 + radius) # 偏下50像素
# 计算两个点之间的欧几里得距离
distance = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
# 检查距离条件
if distance >= min_distance and is_point_in_car_area((x1, y1), model, image) and is_point_in_car_area((x2, y2), model, image) :
return [(x1, y1), (x2, y2)]
# 如果超过最大尝试次数还没有找到合适的点,返回None
return None
def show_masks(image, masks, scores, image_path, strategy,point_coords=None, box_coords=None, input_labels=None, borders=True, save_path=None):
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
display_mask(mask, plt.gca(), borders=borders)
if point_coords is not None:
assert input_labels is not None
show_points(point_coords, input_labels, plt.gca())
if box_coords is not None:
# boxes
show_box(box_coords, plt.gca())
plt.axis('off')
plt.savefig('check.jpg', bbox_inches='tight', pad_inches=0) # 保存图像
point1,point2 = attention_mask(mask, plt.gca(), image_path,strategy,borders=borders, image=image, save_path=save_path)
return point1,point2
def random_crop(image, target_width, target_height, mask_point1, mask_point2):
# global global_mask_point1_relative, global_mask_point2_relative
"""从两个对角点的中点裁剪指定宽度和高度的区域,避免超出图像边界"""
width, height = image.size
# 计算两个对角点的中点
center_x = (mask_point1[0] + mask_point2[0]) // 2
center_y = (mask_point1[1] + mask_point2[1]) // 2
# 计算裁剪区域的左上角和右下角
left = center_x - target_width // 2
top = center_y - target_height // 2
right = left + target_width
bottom = top + target_height
# 确保裁剪区域不会超出图像边界
if left < 0:
left = 0
right = target_width
if top < 0:
top = 0
bottom = target_height
if right > width:
right = width
left = width - target_width
if bottom > height:
bottom = height
top = height - target_height
# 计算 padding
top_padding = max(0, top)
left_padding = max(0, left)
# 裁剪图像
cropped_image = image.crop((left, top, right, bottom))
global_mask_point1_relative = (mask_point1[0] - left, mask_point1[1] - top)
global_mask_point2_relative = (mask_point2[0] - left, mask_point2[1] - top)
print("裁剪后点的相对位置为:")
print("mask_point1:", global_mask_point1_relative)
print("mask_point2:", global_mask_point2_relative)
return cropped_image, top_padding, left_padding,global_mask_point1_relative,global_mask_point2_relative
def get_left_right_points(lane_data,image_path):
lanes = lane_data["lanes"]
h_samples = lane_data["h_samples"]
model = load_yolov5_model()
# 找到h_samples的中间索引
mid_idx = len(h_samples) // 2
image = cv2.imread(image_path)
# 存储最左和最右的点
left_point = None
right_point = None
points = []
# 遍历每条车道线
for lane in lanes:
# 去掉值为-2的无效点
valid_points = [(x, y) for x, y in zip(lane, h_samples) if x != -2]
if valid_points:
if lane[mid_idx] != -2:
for i in range(mid_idx-2,0,-1):
left_point = lane[i]
print(left_point)
if lane[i] != -2:
point = (left_point,h_samples[i])
FLAG = is_point_in_car_area(point, model, image)
print(point,FLAG)
if FLAG:
points.append((left_point,h_samples[i]))
break
else:
point = (1540/2, 590/2+30) # 初始点坐标
radius = 50 # 随机点的最大半径
min_distance = 40 # 两个点之间的最小距离
points = random_points_below(point, radius, min_distance,model,image)
# first_non_minus_two = next((x for x in lane if x != -2), None)
# if first_non_minus_two:
# idx = lane.index(first_non_minus_two)
# for i in range(idx+5,idx,-1):
# left_point = lane[i]
# if lane[i] != -2:
# point = (left_point,h_samples[i])
# FLAG = is_point_in_car_area(point, model, image)
# if FLAG:
# points.append((left_point,h_samples[i]))
# break
# return left_point, right_point
return points
def sam2segment(image_path,points,strategy):
# print(points)
image = Image.open(image_path)
image = np.array(image.convert("RGB"))
predictor.set_image(image)
# print([points[0][0], points[0][1]])
input_point = np.array([(points[0][0], points[0][1])])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]
#mask
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
points_set = []
for point in points:
points_set.append((point[0], point[1]))
# print(points_set)
input_point = np.array(points_set)
input_label = np.array([1]*len(points_set))
masks, scores, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
# random_mask_selection(image, masks, mask_index=0,output_path="cropped_image.jpg")
point1,point2 = show_masks(image, masks, scores, image_path, strategy,save_path="masked_image.jpg")
return point1,point2
def draw_point(image_path,points):
image = cv2.imread(image_path)
if image is not None:
# 绘制点
for point in points:
cv2.circle(image, point, radius=5, color=(0, 255, 0), thickness=-1) # 绿色点
# 保存图像
output_path = "output_image_with_points.jpg"
cv2.imwrite(output_path, image)
print(f"Image saved with points at {output_path}")
else:
print("Error: Image could not be loaded.")
def generate_mask(original_img_path, point1, point2):
"""根据坐标生成掩码图像"""
# 读取原图
original_img = cv2.imread(original_img_path)
# 获取原图的尺寸
height, width, _ = original_img.shape
# 创建一个黑色的 mask 图像,尺寸与原图相同
mask = np.zeros((height, width), dtype=np.uint8)
# 计算3/4点
three_quarter_point = (
int(point1[0] + 0.95 * (point2[0] - point1[0])), # 计算 x 坐标
int(point1[1] + 0.95 * (point2[1] - point1[1])) # 计算 y 坐标
)
# 画出一个白色的矩形(将该区域填充为白色)
cv2.rectangle(mask, point1, three_quarter_point, color=255, thickness=-1)
# 保存生成的mask图像
mask_path = original_img_path.replace('test.jpg', 'mask_test.jpg')
cv2.imwrite(mask_path, mask)
print(mask_path)
return mask_path, point1, three_quarter_point
def extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max):
"""
过滤 TuSimple `lanes`,只保留 `crop` 内的部分
"""
cropped_lanes = []
for lane in lane_data["lanes"]:
cropped_lane = []
for x, y in zip(lane, lane_data["h_samples"]):
if x != -2 and crop_x_min <= x <= crop_x_max and crop_y_min <= y <= crop_y_max:
cropped_lane.append((x, y))
# new_x = x - crop_x_min
# new_y = y - crop_y_min
# cropped_lane.append((new_x, new_y))
if cropped_lane:
cropped_lanes.append(cropped_lane)
return cropped_lanes
def generate_trigger_crop(image_path: str, lane_data: dict):
"""
输入一张图像路径,返回处理后的 crop 图像和 crop mask 图像路径。
"""
# 1. 获取触发点
points = get_left_right_points(lane_data, image_path)
print(f"[INFO] 获取 trigger 点: {points}")
draw_point(image_path, points)
# 2. 使用 SAM2 获取 mask 点
image = load_image(image_path)
mask_point1, mask_point2 = sam2segment(image_path, points, "LDA")
# 3. Crop 原图
input_image, *_ = random_crop(image, 512, 512, mask_point1, mask_point2)
input_crop_path = "crop.jpg"
input_image.save(input_crop_path)
# 4. 生成 trigger mask
mask_path, point1, point2 = generate_mask(image_path, mask_point1, mask_point2)
mask_img = load_image(mask_path)
mask_img, *_ = random_crop(mask_img, 512, 512, mask_point1, mask_point2)
crop_mask_path = "crop_mask.jpg"
cv2.imwrite(crop_mask_path, np.array(mask_img))
return input_crop_path, crop_mask_path
if __name__ == "__main__":
lane_data = {"lanes": [[-2, -2, -2, -2, -2, -2, -2, 814, 751, 688, 625, 562, 500, 438, 373, 305, 234, 160, 88, 16, -64, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2], [-2, -2, -2, -2, -2, -2, -2, 818, 801, 784, 768, 751, 734, 717, 701, 685, 668, 651, 634, 618, 601, 585, 568, 551, 535, 518, 502, 484, 468, 451, 435, 418, 401, 385, 368, 351, 335, 318, 301, 287], [-2, -2, -2, -2, -2, -2, -2, 863, 872, 881, 890, 899, 908, 918, 927, 936, 945, 954, 964, 972, 982, 991, 1000, 1009, 1018, 1027, 1036, 1046, 1055, 1064, 1073, 1082, 1091, 1100, 1109, 1119, 1128, 1137, 1146, 1154]], "h_samples": [200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590], "raw_file": "driver_182_30frame/06010513_0036.MP4/00270.jpg"}
image_path = "driver_182_30frame/06010513_0036.MP4/00270.jpg"
points = get_left_right_points(lane_data,image_path)
print(points)
draw_point(image_path,points)
# left_point, right_point = get_left_right_points(lane_data)
# print(f"Left point: {left_point}, Right point: {right_point}")
# sam2segment(image_path,left_point, right_point)
image = load_image(image_path)
mask_point1,mask_point2 = sam2segment(image_path,points,"LDA")
input_image,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(image, 512, 512, mask_point1, mask_point2)
input_image.save("crop.jpg") # 直接用 PIL 的 `save()` 方法
print(f"Image saved with points at crop.jpg")
mask_path, point1, point2 = generate_mask('culane_test.jpg', mask_point1, mask_point2)
mask_img = load_image(mask_path)
mask_img,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(mask_img, 512, 512,mask_point1,mask_point2)
mask_img = np.array(mask_img)
# print(mask_img.shape)
model = load_yolov5_model()
yolo_results = model(input_image)
yolo_boxes = []
car_class_id = [2, 5, 7] # 汽车、巴士、卡车等类别ID,根据实际情况调整
for result in yolo_results:
boxes = result.boxes.xyxy.cpu().numpy()
class_ids = result.boxes.cls.cpu().numpy().astype(int)
for box, cls in zip(boxes, class_ids):
if cls in car_class_id:
x_min, y_min, x_max, y_max = box[:4]
yolo_boxes.append([int(x_min), int(y_min), int(x_max), int(y_max)])
_,mask_img=generate_gt_mask_from_intersection([global_mask_point1_relative,global_mask_point2_relative], yolo_boxes, input_image, mask_img,sam2_model, threshold_iou=0.01)
cv2.imwrite("crop_mask.jpg", mask_img)
print("Mask 已成功保存至 crop_mask.jpg")
crop_x_min = min(mask_point1[0], mask_point2[0])
crop_x_max = max(mask_point1[0], mask_point2[0])
crop_y_min = min(mask_point1[1], mask_point2[1])
crop_y_max = max(mask_point1[1], mask_point2[1])
def extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max):
"""
过滤 TuSimple `lanes`,只保留 `crop` 内的部分
"""
cropped_lanes = []
for lane in lane_data["lanes"]:
cropped_lane = []
for x, y in zip(lane, lane_data["h_samples"]):
if x != -2 and crop_x_min <= x <= crop_x_max and crop_y_min <= y <= crop_y_max:
cropped_lane.append((x, y))
# new_x = x - crop_x_min
# new_y = y - crop_y_min
# cropped_lane.append((new_x, new_y))
if cropped_lane:
cropped_lanes.append(cropped_lane)
return cropped_lanes
# **获取在 crop 范围内的 lane**
cropped_lanes = extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max)
# print(cropped_lanes)
# def draw_lane_mask(image, lanes):
# """
# 画出 `lane_mask` 只在 `crop` 图像中
# """
# height, width, _ = image.shape
# lane_mask = np.zeros((height, width), dtype=np.uint8)
# for lane in lanes:
# points = np.array(lane, dtype=np.int32)
# cv2.polylines(lane_mask, [points], isClosed=False, color=255, thickness=5)
# return lane_mask
# crop_image = load_image("crop.jpg").convert("RGB")
# crop_image = np.array(crop_image)
# lane_mask = draw_lane_mask(crop_image, cropped_lanes)
def draw_lane_mask_on_original(image, cropped_lanes):
"""
在原图上绘制 **仅包含 cropped_lanes** 的车道线
"""
height, width, _ = image.shape
lane_mask = np.zeros((height, width), dtype=np.uint8)
for lane in cropped_lanes:
points = np.array(lane, dtype=np.int32)
cv2.polylines(lane_mask, [points], isClosed=False, color=255, thickness=10)
return lane_mask
def random_crop_lane(image, target_width, target_height, mask_point1, mask_point2):
"""从两个对角点的中点裁剪指定宽度和高度的区域,避免超出图像边界"""
# **确保 image 是 NumPy 数组**
if isinstance(image, Image.Image):
image = np.array(image)
height, width = image.shape[:2] # 获取 NumPy 数组的大小
# 计算两个对角点的中点
center_x = (mask_point1[0] + mask_point2[0]) // 2
center_y = (mask_point1[1] + mask_point2[1]) // 2
# 计算裁剪区域的左上角和右下角
left = max(0, center_x - target_width // 2)
top = max(0, center_y - target_height // 2)
right = min(width, left + target_width)
bottom = min(height, top + target_height)
# 计算 padding(如果裁剪区域超出边界)
top_padding = max(0, target_height - (bottom - top))
left_padding = max(0, target_width - (right - left))
# **使用 NumPy 进行裁剪**
cropped_image = image[top:bottom, left:right]
return cropped_image, top_padding, left_padding
# **绘制 lane_mask 在原图上**
raw_image = np.array(load_image(image_path).convert("RGB"))
lane_mask = draw_lane_mask_on_original(raw_image, cropped_lanes)
lane_mask_pil = Image.fromarray(lane_mask)
crop_image,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(lane_mask_pil, 512, 512,mask_point1,mask_point2)
# **保存 lane_mask**
crop_image.save("lane_mask_crop.jpg")
print("✅ 车道 Mask 已保存为 lane_mask_crop.jpg")
crop_img = cv2.imread("crop.jpg") # 读取原图(BGR格式)
mask_img = cv2.imread("crop_mask.jpg", cv2.IMREAD_GRAYSCALE) # 读取掩码(灰度图)
if crop_img.shape[:2] != mask_img.shape:
print("⚠️ Resizing mask to match crop image size...")
mask_img = cv2.resize(mask_img, (crop_img.shape[1], crop_img.shape[0]))
white_overlay = np.ones_like(crop_img) * 255 # 生成全白图
masked_result = np.where(mask_img[:, :, None] == 255, white_overlay, crop_img) # 只替换白色部分
# **保存叠加后的图像**
cv2.imwrite("crop_with_mask.jpg", masked_result)
print("✅ 叠加后的 Mask 图像已保存至 crop_with_mask.jpg")