koclip / utils.py
Trent
Global lock to avoid concurrent caching
98e7562
raw
history blame
1.84 kB
import nmslib
import numpy as np
import streamlit as st
from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
from config import MODEL_LIST
from koclip import FlaxHybridCLIP
from global_session import GlobalState
from threading import Lock
@st.cache(allow_output_mutation=True)
def load_index(img_file):
filenames, embeddings = [], []
with open(img_file, "r") as f:
for line in f:
cols = line.strip().split("\t")
filename = cols[0]
embedding = [float(x) for x in cols[1].split(",")]
filenames.append(filename)
embeddings.append(embedding)
embeddings = np.array(embeddings)
index = nmslib.init(method="hnsw", space="cosinesimil")
index.addDataPointBatch(embeddings)
index.createIndex({"post": 2}, print_progress=True)
return filenames, index
def load_model(model_name="koclip/koclip-base"):
state = GlobalState(model_name)
if not hasattr(state, '_lock'):
state._lock = Lock()
print(f"Locking loading of model : {model_name} to avoid concurrent caching.")
with state._lock:
cached_model = load_model_cached(model_name)
print(f"Unlocking loading of model : {model_name} to avoid concurrent caching.")
return cached_model
@st.cache(allow_output_mutation=True)
def load_model_cached(model_name):
assert model_name in {f"koclip/{model}" for model in MODEL_LIST}
model = FlaxHybridCLIP.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
if model_name == "koclip/koclip-large":
processor.feature_extractor = ViTFeatureExtractor.from_pretrained(
"google/vit-large-patch16-224"
)
return model, processor