File size: 4,781 Bytes
a9d25c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
065b69d
a9d25c7
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import cv2
import spaces
import gradio as gr
from PIL import Image
from omegaconf import OmegaConf

# set up environment
from utils.env_utils import set_random_seed, use_lower_vram
from utils.timer_utils import Timer

set_random_seed(1024)
timer = Timer()
timer.start()
# use_lower_vram()

# import functions
from utils.labels_utils import Labels
from utils.ram_utils import ram_inference
from utils.blip2_utils import blip2_caption
from utils.llms_utils import pre_refinement, make_prompt, init_model
from utils.grounded_sam_utils import run_grounded_sam


# hardcode parameters for G-SAM
box_threshold  = 0.18
text_threshold = 0.15
iou_threshold  = 0.8

global current_config, L, llm, system_prompt

# load Llama-3 here to avoid loading it during the inference.
llm = init_model("Meta-Llama-3-8B-Instruct")
current_config = ""
L = None
system_prompt = None

def load_config(config_type):
    config = OmegaConf.load(os.path.join(os.path.dirname(__file__), f"configs/{config_type}.yaml"))
    L = Labels(config=config)
    # init labels and llm prompt, only Meta-Llama-3-8B-Instruct is supported for online demo, but you can use any model in your local environment using our released code
    system_prompt = make_prompt(", ".join(L.LABELS))
    return L, system_prompt

@spaces.GPU(duration=120)
def process(image_ori, config_type):
    global current_config, L, llm, system_prompt
    if current_config != config_type:
        L, system_prompt = load_config(config_type)
        current_config = config_type
    else:
        pass
    image_ori = cv2.cvtColor(image_ori, cv2.COLOR_BGR2RGB)
    image_pil = Image.fromarray(image_ori)
    labels_ram = ram_inference(image_pil) + ": " + blip2_caption(image_pil)
    converted_labels, llm_output = pre_refinement([labels_ram], system_prompt, llm=llm)
    labels_llm = L.check_labels(converted_labels)[0]
    print("labels_ram: ", labels_ram)
    print("llm_output: ", llm_output)
    print("labels_llm: ", labels_llm)

    # run sam
    label_res, bboxes, output_labels, output_prob_maps, output_points = run_grounded_sam(
        input_image    = {"image": image_pil, "mask": None},
        text_prompt    = labels_llm,
        box_threshold  = box_threshold,
        text_threshold = text_threshold,
        iou_threshold  = iou_threshold,
        LABELS         = L.LABELS,
        IDS            = L.IDS,
        llm            = llm,
        timer          = timer,
    )

    # draw mask and save image
    ours = L.draw_mask(label_res, image_ori, print_label=True, tag="Ours")
    return cv2.cvtColor(ours, cv2.COLOR_BGR2RGB)


if __name__ == "__main__":
    # options for different settings
    dropdown_options = ["COCO-81", "Cityscapes", "DRAM", "VOC2012"]
    default_option = "COCO-81"

with gr.Blocks() as demo:
    gr.HTML(
        """
            <h1 style="text-align: center; font-size: 32px; font-family: 'Times New Roman', Times, serif;">
                Training-Free Zero-Shot Semantic Segmentation with LLM Refinement
            </h1>
            <p style="text-align: center; font-size: 20px; font-family: 'Times New Roman', Times, serif;">
                <a style="text-align: center; display:inline-block"
                    href="https://sky24h.github.io/websites/bmvc2024_training-free-semseg-with-LLM/">
                    <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/paper-page-sm.svg#center"
                    alt="Paper Page">
                </a>
                <a style="text-align: center; display:inline-block" href="https://huggingface.co/spaces/sky24h/Training-Free_Zero-Shot_Semantic_Segmentation_with_LLM_Refinement?duplicate=true">
                    <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg#center" alt="Duplicate Space">
                </a>
            </p>
            """
    )
    gr.Interface(
        fn=process,
        inputs=[gr.Image(type="numpy", height="384"), gr.Dropdown(choices=dropdown_options, label="Refinement Type", value=default_option)],
        outputs="image",
        description="""<html>
        <p style="text-align:center;"> This is an online demo for the paper "Training-Free Zero-Shot Semantic Segmentation with LLM Refinement" (BMVC 2024). </p>
        <p style="text-align:center;"> Uasge: Please select or upload an image and choose a dataset setting for semantic segmentation refinement.</p>
        </html>""",
        allow_flagging='never',
        examples=[
            ["examples/Cityscapes_eg.jpg", "Cityscapes"],
            ["examples/DRAM_eg.jpg", "DRAM"],
            ["examples/COCO-81_eg.jpg", "COCO-81"],
            ["examples/VOC2012_eg.jpg", "VOC2012"],
        ],
        cache_examples=True,
    )

    demo.queue(max_size=10).launch()