File size: 2,292 Bytes
9ff0cd2
ba2ab36
 
7e39033
ba2ab36
9ff0cd2
 
ba2ab36
 
 
9ff0cd2
 
ba2ab36
 
 
9ff0cd2
 
ba2ab36
 
 
 
 
 
 
9ff0cd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba2ab36
 
 
 
 
 
 
 
 
 
7e39033
 
9ff0cd2
ba2ab36
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import streamlit as st
import yaml
import torch
import torch.nn.functional as F

from transformers import DetrImageProcessor, DetrForObjectDetection

from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
from lib.IRRA.image import prepare_images
from lib.IRRA.model.build import build_model, IRRA
from PIL import Image
from pathlib import Path

from easydict import EasyDict


@st.cache_resource
def get_model():
    args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader)
    args = EasyDict(args)
    args['training'] = False

    model = build_model(args)

    return model


@st.cache_resource
def get_detr():
    processor = DetrImageProcessor.from_pretrained(
        "facebook/detr-resnet-50", revision="no_timm")

    model = DetrForObjectDetection.from_pretrained(
        "facebook/detr-resnet-50", revision="no_timm")

    return model, processor


def segment_images(model, processor, images: list[str]):
    segments = []
    id = 0

    p = Path('segments')
    p.mkdir(exist_ok=True)

    for image in images:
        image = Image.open(image)

        inputs = processor(images=image, return_tensors="pt")
        outputs = model(**inputs)

        target_sizes = torch.tensor([image.size[::-1]])
        results = processor.post_process_object_detection(
            outputs, target_sizes=target_sizes, threshold=0.9)[0]

        for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
            box = [round(i, 2) for i in box.tolist()]
            label = model.config.id2label[label.item()]

            if box[2] - box[0] > 70 and box[3] - box[1] > 70:
                if label == 'person':
                    file = p / f'img_{id}.jpg'
                    image.crop(box).save(file)
                    segments.append(file.as_posix())

                    id += 1

    return segments


def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
    tokenizer = SimpleTokenizer()

    txt = tokenize(text, tokenizer)
    imgs = prepare_images(images)

    image_feats = model.encode_image(imgs)
    text_feats = model.encode_text(txt.unsqueeze(0))

    image_feats = F.normalize(image_feats, p=2, dim=1)
    text_feats = F.normalize(text_feats, p=2, dim=1)

    return text_feats @ image_feats.t()