File size: 917 Bytes
ba2ab36
 
 
7e39033
ba2ab36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e39033
 
 
ba2ab36
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import streamlit as st 
import yaml
import torch
import torch.nn.functional as F

from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
from lib.IRRA.image import prepare_images
from lib.IRRA.model.build import build_model, IRRA

from easydict import EasyDict

@st.cache_resource 
def get_model():
    args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader)
    args = EasyDict(args)
    args['training'] = False

    model = build_model(args)

    return model 

def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
    tokenizer = SimpleTokenizer()

    txt = tokenize(text, tokenizer)
    imgs = prepare_images(images)

    image_feats = model.encode_image(imgs)
    text_feats = model.encode_text(txt.unsqueeze(0))

    image_feats = F.normalize(image_feats, p=2, dim=1)
    text_feats = F.normalize(text_feats, p=2, dim=1)
    
    return text_feats @ image_feats.t()