FoodIdentifier / app.py
rdezwart's picture
Begin implementing object detection
ada211a
raw
history blame
3.24 kB
from threading import Thread
import gradio as gr
import torch
from transformers import PreTrainedModel # for type hint
from transformers import TextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer # Moondream
from transformers import YolosImageProcessor, YolosForObjectDetection # YOLOS-small-300
# --- Moondream --- #
# Moondream does not support the HuggingFace pipeline system, so we have to do it manually
moondream_id = "vikhyatk/moondream2"
moondream_revision = "2024-04-02"
moondream_tokenizer = AutoTokenizer.from_pretrained(moondream_id, revision=moondream_revision)
moondream_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
moondream_id, trust_remote_code=True, revision=moondream_revision
)
moondream_model.eval()
# --- YOLOS --- #
yolos_id = "hustvl/yolos-small-300"
yolos_processor: YolosImageProcessor = YolosImageProcessor.from_pretrained(yolos_id)
yolos_model: YolosForObjectDetection = YolosForObjectDetection.from_pretrained(yolos_id)
def answer_question(img, prompt):
"""
Submits an image and prompt to the Moondream model.
:param img:
:param prompt:
:return: yields the output buffer string
"""
image_embeds = moondream_model.encode_image(img)
streamer = TextIteratorStreamer(moondream_tokenizer, skip_special_tokens=True)
thread = Thread(
target=moondream_model.answer_question,
kwargs={
"image_embeds": image_embeds,
"question": prompt,
"tokenizer": moondream_tokenizer,
"streamer": streamer,
},
)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer.strip()
def detect_objects(img):
inputs = yolos_processor(images=img, return_tensors="pt")
outputs = yolos_model(**inputs)
target_sizes = torch.tensor([img.size[::-1]])
results = yolos_processor.post_process_object_detection(outputs, threshold=0.7, target_sizes=target_sizes)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {yolos_model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)
if __name__ == "__main__":
with gr.Blocks() as app:
gr.Markdown(
"""
# Food Identifier
Final project for IAT 481 at Simon Fraser University, Spring 2024.
"""
)
with gr.Tab("Object Detection"):
with gr.Row():
yolos_input = gr.Image()
yolos_output = gr.Image()
yolos_button = gr.Button("Submit")
with gr.Tab("Inference"):
with gr.Row():
moon_prompt = gr.Textbox(label="Input", value="Describe this image.")
moon_submit = gr.Button("Submit")
with gr.Row():
moon_img = gr.Image(label="Image", type="pil")
moon_output = gr.TextArea(label="Output")
moon_submit.click(answer_question, [moon_img, moon_prompt], moon_output)
yolos_button.click(detect_objects, [yolos_input], yolos_output)
app.queue().launch()