File size: 2,292 Bytes
9ff0cd2 ba2ab36 7e39033 ba2ab36 9ff0cd2 ba2ab36 9ff0cd2 ba2ab36 9ff0cd2 ba2ab36 9ff0cd2 ba2ab36 7e39033 9ff0cd2 ba2ab36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import streamlit as st
import yaml
import torch
import torch.nn.functional as F
from transformers import DetrImageProcessor, DetrForObjectDetection
from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
from lib.IRRA.image import prepare_images
from lib.IRRA.model.build import build_model, IRRA
from PIL import Image
from pathlib import Path
from easydict import EasyDict
@st.cache_resource
def get_model():
args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader)
args = EasyDict(args)
args['training'] = False
model = build_model(args)
return model
@st.cache_resource
def get_detr():
processor = DetrImageProcessor.from_pretrained(
"facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained(
"facebook/detr-resnet-50", revision="no_timm")
return model, processor
def segment_images(model, processor, images: list[str]):
segments = []
id = 0
p = Path('segments')
p.mkdir(exist_ok=True)
for image in images:
image = Image.open(image)
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=0.9)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
label = model.config.id2label[label.item()]
if box[2] - box[0] > 70 and box[3] - box[1] > 70:
if label == 'person':
file = p / f'img_{id}.jpg'
image.crop(box).save(file)
segments.append(file.as_posix())
id += 1
return segments
def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
tokenizer = SimpleTokenizer()
txt = tokenize(text, tokenizer)
imgs = prepare_images(images)
image_feats = model.encode_image(imgs)
text_feats = model.encode_text(txt.unsqueeze(0))
image_feats = F.normalize(image_feats, p=2, dim=1)
text_feats = F.normalize(text_feats, p=2, dim=1)
return text_feats @ image_feats.t()
|