Trent
Navigation
6e03e5d
raw
history blame
466 Bytes
import streamlit as st
from sentence_transformers import SentenceTransformer
from .config import MODELS_ID
@st.cache(allow_output_mutation=True)
def load_model(model_name):
assert model_name in MODELS_ID.keys()
# Lazy downloading
models = MODELS_ID[model_name]
if models is str:
output = SentenceTransformer(models)
elif hasattr(models, '__iter__') :
output = [SentenceTransformer(model) for model in models]
return output