import nmslib
import numpy as np
import streamlit as st
from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor

from koclip import FlaxHybridCLIP


@st.cache(allow_output_mutation=True)
def load_index(img_file):
    filenames, embeddings = [], []
    with open(img_file, "r") as f:
        for line in f:
            cols = line.strip().split("\t")
            filename = cols[0]
            embedding = [float(x) for x in cols[1].split(",")]
            filenames.append(filename)
            embeddings.append(embedding)
    embeddings = np.array(embeddings)
    index = nmslib.init(method="hnsw", space="cosinesimil")
    index.addDataPointBatch(embeddings)
    index.createIndex({"post": 2}, print_progress=True)
    return filenames, index


@st.cache(allow_output_mutation=True)
def load_model(model_name="koclip/koclip-base"):
    assert model_name in {"koclip/koclip-base", "koclip/koclip-large"}
    model = FlaxHybridCLIP.from_pretrained(model_name)
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
    if model_name == "koclip/koclip-large":
        processor.feature_extractor = ViTFeatureExtractor.from_pretrained(
            "google/vit-large-patch16-224"
        )
    return model, processor