File size: 5,150 Bytes
83c71a6
 
5756dad
ada211a
698fd5a
 
ada211a
 
5756dad
ada211a
6dafc63
77b3326
 
83c71a6
ada211a
bf22d27
6dafc63
ada211a
77b3326
ada211a
 
 
 
2f92f19
6879c60
e881756
ada211a
 
 
 
 
 
 
 
 
 
83c71a6
 
ada211a
83c71a6
 
ada211a
83c71a6
 
 
 
 
 
 
 
 
 
2f92f19
 
927201f
22d7c64
ada211a
 
22d7c64
ada211a
 
da68fcd
ada211a
 
 
 
 
 
698fd5a
 
 
 
da68fcd
0a5c36c
ada211a
 
e881756
cbb61c3
0721dd9
 
 
cbb61c3
0721dd9
1d43352
cbb61c3
 
e881756
 
 
 
 
 
 
 
 
 
5756dad
60a2be3
 
 
 
 
 
 
 
ada211a
e881756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ada211a
cbb61c3
e881756
 
 
cbb61c3
 
 
60a2be3
83c71a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from threading import Thread

import gradio as gr
import torch
from PIL import Image
from transformers import PreTrainedModel  # for type hint
from transformers import TextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer  # Moondream
from transformers import YolosImageProcessor, YolosForObjectDetection  # YOLOS-small-300

# --- Moondream --- #
# Moondream does not support the HuggingFace pipeline system, so we have to do it manually
moondream_id = "vikhyatk/moondream2"
moondream_revision = "2024-04-02"
moondream_tokenizer = AutoTokenizer.from_pretrained(moondream_id, revision=moondream_revision)
moondream_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
    moondream_id, trust_remote_code=True, revision=moondream_revision
)
moondream_model.eval()

# --- YOLOS --- #
yolos_id = "hustvl/yolos-small-300"
yolos_processor: YolosImageProcessor = YolosImageProcessor.from_pretrained(yolos_id)
yolos_model: YolosForObjectDetection = YolosForObjectDetection.from_pretrained(yolos_id)

selected_image = 0


def answer_question(img, prompt):
    """
    Submits an image and prompt to the Moondream model.

    :param img:
    :param prompt:
    :return: yields the output buffer string
    """
    image_embeds = moondream_model.encode_image(img)
    streamer = TextIteratorStreamer(moondream_tokenizer, skip_special_tokens=True)
    thread = Thread(
        target=moondream_model.answer_question,
        kwargs={
            "image_embeds": image_embeds,
            "question": prompt,
            "tokenizer": moondream_tokenizer,
            "streamer": streamer,
        },
    )
    thread.start()

    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer.strip()


def detect_objects(img: Image.Image):
    inputs = yolos_processor(img, return_tensors="pt")
    outputs = yolos_model(**inputs)

    target_sizes = torch.tensor([tuple(reversed(img.size))])
    results = yolos_processor.post_process_object_detection(outputs, threshold=0.7, target_sizes=target_sizes)[0]

    box_images = []
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        box = [round(i, 2) for i in box.tolist()]
        print(
            f"Detected {yolos_model.config.id2label[label.item()]} with confidence "
            f"{round(score.item(), 3)} at location {box}"
        )
        box_images.append((
            img.crop((box[0], box[1], box[2], box[3])),
            f"{yolos_model.config.id2label[label.item()]} ({round(score.item(), 3)})")
        )

    return box_images


def gallery_selected(evt: gr.SelectData) -> tuple[int, gr.Button]:
    """
    Listener for the gallery selection event.

    :return: index of the currently selected image
    """
    print(f"Index: {evt.index}, Value: {evt.value}, Selected: {evt.selected}")
    return evt.index, gr.Button(interactive=evt.selected)


def to_moondream(images: list[tuple[Image.Image, str | None]]) -> tuple[gr.Tabs, Image.Image]:
    """
    Listener that sends selected gallery image to the moondream model.

    :param images: list of images from yolos_gallery
    :return: new selected tab and selected image (no caption)
    """
    return gr.Tabs(selected='moondream'), images[selected_image][0]


if __name__ == "__main__":
    with gr.Blocks() as app:
        gr.Markdown(
            """
            # Food Identifier

            Final project for IAT 481 at Simon Fraser University, Spring 2024.
            """
        )

        # Referenced: https://github.com/gradio-app/gradio/issues/7726#issuecomment-2028051431
        with gr.Tabs(selected='yolos') as tabs:
            with gr.Tab("Object Detection", id='yolos'):
                with gr.Row():
                    with gr.Column():
                        yolos_input = gr.Image(type="pil", scale=1)
                        yolos_submit = gr.Button("Submit", scale=0)

                    with gr.Column():
                        yolos_gallery = gr.Gallery(label="Detected Objects", object_fit="scale-down", columns=3,
                                                   scale=2, show_share_button=False, selected_index=None,
                                                   allow_preview=False, type="pil", interactive=False)
                        proceed_button = gr.Button("To Moondream", interactive=False)

            with gr.Tab("Inference", id='moondream'):
                with gr.Row():
                    moon_prompt = gr.Textbox(label="Input", value="Describe this image.")
                    moon_submit = gr.Button("Submit")
                with gr.Row():
                    moon_img = gr.Image(label="Image", type="pil", interactive=True)
                    moon_output = gr.TextArea(label="Output")

        # --- YOLOS --- #
        yolos_submit.click(detect_objects, [yolos_input], yolos_gallery)
        yolos_gallery.select(gallery_selected, None, [selected_image, proceed_button])
        proceed_button.click(to_moondream, yolos_gallery, [tabs, moon_img])

        # --- Moondream --- #
        moon_submit.click(answer_question, [moon_img, moon_prompt], moon_output)

    app.queue().launch()