File size: 2,365 Bytes
2c5aba6 a634e56 2c5aba6 c0be566 2c5aba6 c0be566 5392e1d c0be566 a634e56 2c5aba6 a634e56 2c5aba6 a634e56 2c5aba6 a634e56 2c5aba6 a634e56 2c5aba6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
import torch.nn as nn
import cv2
import gradio as gr
import numpy as np
from PIL import Image
import transformers
from transformers import RobertaModel, RobertaTokenizer
import timm
import pandas as pd
import matplotlib.pyplot as plt
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from model import Model
from output import visualize_output
# Use GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Initialize used pretrained models
vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0, global_pool='').to(device)
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True)
roberta = RobertaModel.from_pretrained("roberta-base")
model = Model(vit, roberta, tokenizer, device).to(device)
model.eval()
# Initialize trained model
state = torch.load('saved_model', map_location=torch.device('cpu'))
model.load_state_dict(state['val_model_dict'])
# Create transform for input image
config = resolve_data_config({}, model=vit)
config['no_aug'] = True
config['interpolation'] = 'bilinear'
transform = create_transform(**config)
# Inference function
def query_image(input_img, query, binarize, eval_threshold):
PIL_image = Image.fromarray(input_img, "RGB")
img = transform(PIL_image)
img = torch.unsqueeze(img,0).to(device)
with torch.no_grad():
output = model(img, query)
img = visualize_output(img, output, binarize, eval_threshold)
return img
# Gradio interface
description = """
Gradio demo for an object detection architecture,
introduced in <a href="https://www.google.com/">my bachelor thesis (link will be added)</a>.
\n\nLorem ipsum ....
*"image of a shoe"*. Refer to the <a href="https://arxiv.org/abs/2103.00020">CLIP</a> paper to see the full list of text templates used to augment the training data.
"""
demo = gr.Interface(
query_image,
inputs=[gr.Image(), "text", "checkbox", gr.Slider(0, 1, value=0.25)],
outputs="image",
title="Object Detection Using Textual Queries",
description=description,
examples=[
["examples/img1.jpeg", "Find a person.", True, 0.25],
],
allow_flagging = "never",
cache_examples=False,
css = """
body {background-color : grey}
""",
)
demo.launch(debug=True)
|