import streamlit as st | |
import yaml | |
import torch | |
from lib.IRRA.tokenizer import tokenize, SimpleTokenizer | |
from lib.IRRA.image import prepare_images | |
from lib.IRRA.model.build import build_model, IRRA | |
from easydict import EasyDict | |
def get_model(): | |
args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader) | |
args = EasyDict(args) | |
args['training'] = False | |
model = build_model(args) | |
return model | |
def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor: | |
tokenizer = SimpleTokenizer() | |
txt = tokenize(text, tokenizer) | |
imgs = prepare_images(images) | |
image_feats = model.encode_image(imgs) | |
text_feats = model.encode_text(txt.unsqueeze(0)) | |
return text_feats @ image_feats.t() | |