import torch import torch.nn as nn import cv2 import gradio as gr import numpy as np from PIL import Image import transformers from transformers import RobertaModel, RobertaTokenizer import timm import pandas as pd import matplotlib.pyplot as plt from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform from model import Model from output import visualize_output # Use GPU if available device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Initialize used pretrained models vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0, global_pool='').to(device) tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True) roberta = RobertaModel.from_pretrained("roberta-base") model = Model(vit, roberta, tokenizer, device).to(device) model.eval() # Initialize trained model state = torch.load('saved_model', map_location=torch.device('cpu')) model.load_state_dict(state['val_model_dict']) # Create transform for input image config = resolve_data_config({}, model=vit) config['no_aug'] = True config['interpolation'] = 'bilinear' # Inference function def query_image(input_img, query, binarize, eval_threshold, crop_mode, crop_pct): if crop_mode == 'center': crop_mode = None config['crop_pct'] = crop_pct config['crop_mode'] = crop_mode transform = create_transform(**config) PIL_image = Image.fromarray(input_img, "RGB") img = transform(PIL_image) img = torch.unsqueeze(img,0).to(device) with torch.no_grad(): output = model(img, query) img = visualize_output(img, output, binarize, eval_threshold) return img # Gradio interface description = """ Gradio demo for an object detection architecture, introduced in my bachelor thesis (link will be added). \n\n You can use this architecture to detect objects using textual queries. To use it, simply upload an image and enter any query you want. It can be a single word or a sentence. The model is trained to recognize only 80 categories (classes) from the COCO Detection 2017 dataset. Refer to this website or the original COCO paper to see the full list of categories. \n\n Best results are obtained using one of these sentences, which were used during training:
\n\n When the binarize option is turned off, model will output propabilities of requested {class} for each patch. When the binarize option is turned on the model will binarize each propability based on set eval_threshold. \n\n Each input image is transformed to size 224x224 so it can be processed by ViT. During this transformation, different crop_modes and crop_percentages can be selected. No image is lost if crop_pct = 1.0 and crop_mode='squash' or 'border'. The model was trained using crop_mode='center' and crop_pct = 0.9. For explanation of different crop_modes, please refer to this website, lines 155-172. """ demo = gr.Interface( query_image, #inputs=[gr.Image(), "text", "checkbox", gr.Slider(0, 1, value=0.25)], #inputs=[gr.Image(type='numpy', label='input_img').style(height=250, width=600), "text", "checkbox", gr.Slider(0, 1, value=0.25), # gr.Radio(["center", "squash", "border"], value='center', label='crop_mode'), gr.Slider(0.7, 1, value=0.9, step=0.01)], inputs=["image", "text", "checkbox", gr.Slider(0, 1, value=0.25), gr.Radio(["center", "squash", "border"], value='squash', label='crop_mode'), gr.Slider(0.7, 1, value=1, step=0.01)], outputs="image", #outputs=gr.Image(type='numpy', label='output').style(height=600, width=600), title="Text-Based Object Detection", description=description, examples=[ ["examples/imga.jpeg", "Find a person.", True, 0.45], ["examples/imgb.jpeg", "Could you mark a horse?", False, 0.25], ["examples/imgc.jpeg", "There should be a cat in this picture, where?", True, 0.25], ["examples/imgd.jpeg", "Mark a tv in this image.", False, 0.1], ["examples/imge.jpeg", "Is there a zebra in this picture?", True, 0.4], ["examples/imgf.jpeg", "Look for a stop sign.", True, 0.5], ], cache_examples=False, allow_flagging = "never", css = """ .column { float: left; padding: 10px; } .left { width: 25%; } .right { width: 75%; } """ ) demo.launch()