File size: 6,659 Bytes
83c71a6
 
5756dad
ada211a
698fd5a
 
ada211a
 
5756dad
26352f5
 
 
 
 
ada211a
6dafc63
77b3326
 
83c71a6
ada211a
bf22d27
6dafc63
ada211a
77b3326
ada211a
 
 
 
 
 
 
 
 
 
83c71a6
 
ada211a
83c71a6
 
ada211a
83c71a6
 
 
 
 
 
 
 
 
 
2f92f19
 
927201f
17317f8
 
 
 
 
22d7c64
ada211a
 
22d7c64
ada211a
 
da68fcd
ada211a
 
 
 
 
 
698fd5a
 
 
 
da68fcd
33f25f1
0a5c36c
ada211a
 
8db35f3
cbb61c3
0721dd9
 
 
cbb61c3
8db35f3
cbb61c3
 
8db35f3
e881756
 
 
 
29ebbdc
8db35f3
e881756
8db35f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e881756
 
5756dad
60a2be3
 
 
 
 
 
26352f5
 
 
 
 
60a2be3
 
29ebbdc
ada211a
e881756
331994a
e881756
331994a
ef6805a
331994a
ef6805a
 
26352f5
e881756
331994a
17317f8
e881756
26352f5
9c795ec
17317f8
 
 
 
8db35f3
ef6805a
17317f8
ada211a
cbb61c3
8db35f3
 
17317f8
 
8db35f3
 
 
 
cbb61c3
 
8db35f3
 
cbb61c3
60a2be3
83c71a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from threading import Thread

import gradio as gr
import torch
from PIL import Image
from transformers import PreTrainedModel  # for type hint
from transformers import TextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer  # Moondream
from transformers import YolosImageProcessor, YolosForObjectDetection  # YOLOS-small-300

# --- YOLOS --- #
yolos_id = "hustvl/yolos-small-300"
yolos_processor: YolosImageProcessor = YolosImageProcessor.from_pretrained(yolos_id)
yolos_model: YolosForObjectDetection = YolosForObjectDetection.from_pretrained(yolos_id)

# --- Moondream --- #
# Moondream does not support the HuggingFace pipeline system, so we have to do it manually
moondream_id = "vikhyatk/moondream2"
moondream_revision = "2024-04-02"
moondream_tokenizer = AutoTokenizer.from_pretrained(moondream_id, revision=moondream_revision)
moondream_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
    moondream_id, trust_remote_code=True, revision=moondream_revision
)
moondream_model.eval()


def answer_question(img, prompt):
    """
    Submits an image and prompt to the Moondream model.

    :param img:
    :param prompt:
    :return: yields the output buffer string
    """
    image_embeds = moondream_model.encode_image(img)
    streamer = TextIteratorStreamer(moondream_tokenizer, skip_special_tokens=True)
    thread = Thread(
        target=moondream_model.answer_question,
        kwargs={
            "image_embeds": image_embeds,
            "question": prompt,
            "tokenizer": moondream_tokenizer,
            "streamer": streamer,
        },
    )
    thread.start()

    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer.strip()


def detect_objects(img: Image.Image):
    """
    Submits an image to the YOLOS-Small-300 model for object detection.
    :param img:
    :return:
    """
    inputs = yolos_processor(img, return_tensors="pt")
    outputs = yolos_model(**inputs)

    target_sizes = torch.tensor([tuple(reversed(img.size))])
    results = yolos_processor.post_process_object_detection(outputs, threshold=0.7, target_sizes=target_sizes)[0]

    box_images = []
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        box = [round(i, 2) for i in box.tolist()]
        print(
            f"Detected {yolos_model.config.id2label[label.item()]} with confidence "
            f"{round(score.item(), 3)} at location {box}"
        )
        box_images.append((
            img.crop((box[0], box[1], box[2], box[3])),
            f"{yolos_model.config.id2label[label.item()]} ({round(score.item(), 3)})")
        )

    box_images.append((img, f"original"))
    return box_images


def get_selected_index(evt: gr.SelectData) -> int:
    """
    Listener for the gallery selection event.

    :return: index of the currently selected image
    """
    return evt.index


def to_moondream(images: list[tuple[Image.Image, str | None]], index: int) -> tuple[gr.Tabs, Image.Image]:
    """
    Listener that sends selected gallery image to the moondream model.

    :param images: list of images from yolos_gallery
    :param index: index of selected gallery image
    :return: selected tab and selected image (no caption)
    """
    return gr.Tabs(selected='moondream'), images[index][0]


def enable_button() -> gr.Button:
    """
    Helper function for Gradio event listeners.

    :return: a button with ``interactive=True`` and ``variant="primary"``
    """
    return gr.Button(interactive=True, variant="primary")


def disable_button() -> gr.Button:
    """
    Helper function for Gradio event listeners.

    :return: a button with ``interactive=False`` and ``variant="secondary"``
    """
    return gr.Button(interactive=False, variant="secondary")


if __name__ == "__main__":
    with gr.Blocks() as app:
        gr.Markdown(
            """
            # Food Identifier

            Final project for IAT 481 at Simon Fraser University, Spring 2024.
            
            **Models used:**
            
            - [hustvl/yolos-small-300](https://huggingface.co/hustvl/yolos-small-300)
            - [vikhyatk/moondream2](https://huggingface.co/vikhyatk/moondream2)
            """
        )
        selected_image = gr.Number(visible=False, precision=0)

        # Referenced: https://github.com/gradio-app/gradio/issues/7726#issuecomment-2028051431
        with gr.Tabs() as tabs:
            with gr.Tab("Object Detection", id='yolos'):
                with gr.Row(equal_height=False):
                    with gr.Column():
                        yolos_submit = gr.Button("Detect Objects", interactive=False)
                        yolos_input = gr.Image(label="Input Image", type="pil", interactive=True, mirror_webcam=False)
                    with gr.Column():
                        proceed_button = gr.Button("Select for Captioning", interactive=False)
                        yolos_gallery = gr.Gallery(label="Detected Objects", object_fit="scale-down", columns=3,
                                                   show_share_button=False, selected_index=None, allow_preview=False,
                                                   type="pil", interactive=False)

            with gr.Tab("Captioning", id='moondream'):
                with gr.Row(equal_height=False):
                    with gr.Column():
                        with gr.Group():
                            moon_prompt = gr.Textbox(label="Ask a question about the image:",
                                                     value="What is this food item? Include any text on labels.")
                            moon_submit = gr.Button("Submit", interactive=False)
                        moon_img = gr.Image(label="Image", type="pil", interactive=True, mirror_webcam=False)
                    moon_output = gr.TextArea(label="Answer", interactive=False)

        # --- YOLOS --- #
        yolos_input.upload(enable_button, None, yolos_submit)
        yolos_input.clear(disable_button, None, yolos_submit)
        yolos_submit.click(detect_objects, yolos_input, yolos_gallery)

        yolos_gallery.select(get_selected_index, None, selected_image)
        yolos_gallery.select(enable_button, None, proceed_button)
        proceed_button.click(to_moondream, [yolos_gallery, selected_image], [tabs, moon_img])
        proceed_button.click(enable_button, None, moon_submit)

        # --- Moondream --- #
        moon_img.upload(enable_button, None, moon_submit)
        moon_img.clear(disable_button, None, moon_submit)
        moon_submit.click(answer_question, [moon_img, moon_prompt], moon_output)

    app.queue().launch()