|
import streamlit as st |
|
import yaml |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from lib.IRRA.tokenizer import tokenize, SimpleTokenizer |
|
from lib.IRRA.image import prepare_images |
|
from lib.IRRA.model.build import build_model, IRRA |
|
|
|
from easydict import EasyDict |
|
|
|
@st.cache_resource |
|
def get_model(): |
|
args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader) |
|
args = EasyDict(args) |
|
args['training'] = False |
|
|
|
model = build_model(args) |
|
|
|
return model |
|
|
|
def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor: |
|
tokenizer = SimpleTokenizer() |
|
|
|
txt = tokenize(text, tokenizer) |
|
imgs = prepare_images(images) |
|
|
|
image_feats = model.encode_image(imgs) |
|
text_feats = model.encode_text(txt.unsqueeze(0)) |
|
|
|
image_feats = F.normalize(image_feats, p=2, dim=1) |
|
text_feats = F.normalize(text_feats, p=2, dim=1) |
|
|
|
return text_feats @ image_feats.t() |
|
|