File size: 3,467 Bytes
83c71a6
 
22d7c64
5756dad
ada211a
22d7c64
ada211a
 
5756dad
ada211a
6dafc63
77b3326
 
83c71a6
ada211a
bf22d27
6dafc63
ada211a
77b3326
ada211a
 
 
 
2f92f19
ada211a
 
 
 
 
 
 
 
 
 
83c71a6
 
ada211a
83c71a6
 
ada211a
83c71a6
 
 
 
 
 
 
 
 
 
2f92f19
 
927201f
22d7c64
ada211a
 
22d7c64
ada211a
 
da68fcd
ada211a
 
 
 
 
 
1dc8ee5
da68fcd
0a5c36c
ada211a
 
5756dad
60a2be3
 
 
 
 
 
 
 
ada211a
 
 
22d7c64
0a5c36c
ada211a
 
 
 
 
 
 
 
 
 
 
da68fcd
60a2be3
83c71a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from threading import Thread

from PIL import Image
import gradio as gr
import torch
from transformers import PreTrainedModel, AutoImageProcessor  # for type hint
from transformers import TextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer  # Moondream
from transformers import YolosImageProcessor, YolosForObjectDetection  # YOLOS-small-300

# --- Moondream --- #
# Moondream does not support the HuggingFace pipeline system, so we have to do it manually
moondream_id = "vikhyatk/moondream2"
moondream_revision = "2024-04-02"
moondream_tokenizer = AutoTokenizer.from_pretrained(moondream_id, revision=moondream_revision)
moondream_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
    moondream_id, trust_remote_code=True, revision=moondream_revision
)
moondream_model.eval()

# --- YOLOS --- #
yolos_id = "hustvl/yolos-small-300"
yolos_processor: YolosImageProcessor = YolosImageProcessor.from_pretrained(yolos_id)
yolos_model: YolosForObjectDetection = YolosForObjectDetection.from_pretrained(yolos_id)


def answer_question(img, prompt):
    """
    Submits an image and prompt to the Moondream model.

    :param img:
    :param prompt:
    :return: yields the output buffer string
    """
    image_embeds = moondream_model.encode_image(img)
    streamer = TextIteratorStreamer(moondream_tokenizer, skip_special_tokens=True)
    thread = Thread(
        target=moondream_model.answer_question,
        kwargs={
            "image_embeds": image_embeds,
            "question": prompt,
            "tokenizer": moondream_tokenizer,
            "streamer": streamer,
        },
    )
    thread.start()

    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer.strip()


def detect_objects(img: Image.Image):
    inputs = yolos_processor(img, return_tensors="pt")
    outputs = yolos_model(**inputs)

    target_sizes = torch.tensor([tuple(reversed(img.size))])
    results = yolos_processor.post_process_object_detection(outputs, threshold=0.7, target_sizes=target_sizes)[0]

    box_images = []
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        box = [round(i, 2) for i in box.tolist()]
        print(
            f"Detected {yolos_model.config.id2label[label.item()]} with confidence "
            f"{round(score.item(), 3)} at location {box}"
        )
        box_images.append(img.crop((box[0], box[1], box[2], box[3])))

    return box_images


if __name__ == "__main__":
    with gr.Blocks() as app:
        gr.Markdown(
            """
            # Food Identifier

            Final project for IAT 481 at Simon Fraser University, Spring 2024.
            """
        )

        with gr.Tab("Object Detection"):
            with gr.Row():
                yolos_input = gr.Image(type="pil")
                yolos_output = gr.Gallery(label="Detected Objects", show_label=True)
            yolos_button = gr.Button("Submit")

        with gr.Tab("Inference"):
            with gr.Row():
                moon_prompt = gr.Textbox(label="Input", value="Describe this image.")
                moon_submit = gr.Button("Submit")
            with gr.Row():
                moon_img = gr.Image(label="Image", type="pil")
                moon_output = gr.TextArea(label="Output")

        moon_submit.click(answer_question, [moon_img, moon_prompt], moon_output)
        yolos_button.click(detect_objects, [yolos_input], yolos_output)

    app.queue().launch()