File size: 1,637 Bytes
d5c53f9
 
5668d1d
 
 
 
 
7067e59
d5c53f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gradio as gr
import os
import time
# if not os.path.exists("sam2"):
print("add sam2")
os.system("git clone https://github.com/facebookresearch/sam2.git")  # 修改为你的 sam2 仓库地址
time.sleep(3)
from sam2segment_structure import generate_trigger_crop
# 模拟 lane_data(后期你可以动态读取 JSON 或用户上传)
dummy_lane_data = {
    "lanes": [[-2, -2, -2, 814, 751, 688, 625, 562, 500, 438, 373, 305, 234, 160, 88, 16, -64, -2, -2, -2]],
    "h_samples": [200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390],
    "raw_file": "driver_182_30frame/06010513_0036.MP4/00270.jpg"
}

def process_trigger_with_path(input_image, save_path):
    # 确保目录存在
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    # 保存图片到指定路径
    input_image.save(save_path)

    # 设置 dummy_lane_data 中 raw_file 为当前路径
    dummy_lane_data["raw_file"] = save_path

    # 调用主处理函数
    crop_path, mask_path = generate_trigger_crop(save_path, dummy_lane_data)
    return crop_path, mask_path

demo = gr.Interface(
    fn=process_trigger_with_path,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(label="Path to Save Image (e.g. driver_182_30frame/06010513_0036.MP4/00270.jpg)")
    ],
    outputs=[
        gr.Image(type="filepath", label="Cropped Image"),
        gr.Image(type="filepath", label="Cropped Mask")
    ],
    title="DBDLD Trigger Demo",
    description="Upload an image and specify the target save path. The crop and mask will be generated accordingly."
)

demo.launch()