Jsonwu's picture
Update app.py
68135d2 verified
raw
history blame
2.18 kB
import torch
import gradio as gr
import json
from torchvision import transforms
from torchvision.ops import nms
from PIL import Image, ImageDraw, ImageFont
TORCHSCRIPT_PATH = "res/screenrecognition-web350k-vins.torchscript"
LABELS_PATH = "res/class_map_vins_manual.json"
model = torch.jit.load(TORCHSCRIPT_PATH)
with open(LABELS_PATH, "r") as f:
idx2Label = json.load(f)["idx2Label"]
img_transforms = transforms.ToTensor()
# inter_class_nms implemented by GPT
def inter_class_nms(boxes, scores, iou_threshold=0.5):
# Perform non-maximum suppression
keep = nms(boxes, scores, iou_threshold)
# Filter boxes and scores
new_boxes = boxes[keep]
new_scores = scores[keep]
# Return the result in a dictionary
return {'boxes': new_boxes, 'scores': new_scores}
def predict(img, conf_thresh=0.4):
img_input = [img_transforms(img)]
_, pred = model(img_input)
pred = [inter_class_nms(pred[0]['boxes'], pred[0]['scores'])]
out_img = img.copy()
draw = ImageDraw.Draw(out_img)
font = ImageFont.truetype("res/Tuffy_Bold.ttf", 25)
for i in range(len(pred[0]['boxes'])):
conf_score = pred[0]['scores'][i]
if conf_score > conf_thresh:
x1, y1, x2, y2 = pred[0]['boxes'][i]
x1 = int(x1)
y1 = int(y1)
x2 = int(x2)
y2 = int(y2)
draw.rectangle([x1, y1, x2, y2], outline='red', width=3)
text = idx2Label[str(int(pred[0]['labels'][i]))] + " {:.2f}".format(float(conf_score))
bbox = draw.textbbox((x1, y1), text, font=font)
draw.rectangle(bbox, fill="red")
draw.text((x1, y1), text, font=font, fill="black")
return out_img
example_imgs = [
["res/example.jpg", 0.4],
["res/screenlane-snapchat-profile.jpg", 0.4],
["res/screenlane-snapchat-settings.jpg", 0.4],
["res/example_pair1.jpg", 0.4],
["res/example_pair2.jpg", 0.4],
]
interface = gr.Interface(fn=predict, inputs=[gr.Image(type="pil", label="Screenshot"), gr.Slider(0.0, 1.0, step=0.1, value=0.4)], outputs=gr.Image(type="pil", label="Annotated Screenshot").style(height=600), examples=example_imgs)
interface.launch()