File size: 917 Bytes
ba2ab36 7e39033 ba2ab36 7e39033 ba2ab36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import streamlit as st
import yaml
import torch
import torch.nn.functional as F
from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
from lib.IRRA.image import prepare_images
from lib.IRRA.model.build import build_model, IRRA
from easydict import EasyDict
@st.cache_resource
def get_model():
args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader)
args = EasyDict(args)
args['training'] = False
model = build_model(args)
return model
def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
tokenizer = SimpleTokenizer()
txt = tokenize(text, tokenizer)
imgs = prepare_images(images)
image_feats = model.encode_image(imgs)
text_feats = model.encode_text(txt.unsqueeze(0))
image_feats = F.normalize(image_feats, p=2, dim=1)
text_feats = F.normalize(text_feats, p=2, dim=1)
return text_feats @ image_feats.t()
|