import streamlit as st
st.set_page_config(page_title='Image Search App', layout='wide')
import torch
from transformers import AutoTokenizer, AutoModel
import faiss
import numpy as np
import wget
from PIL import Image
# import io
from sentence_transformers import SentenceTransformer
import json
from zipfile import ZipFile
import zipfile
# from io import BytesIO
from PIL import Image
# from huggingface_hub import hf_hub_download
# hf_hub_download(repo_id="shivangibithel/Flickr8k", filename="Images.zip")

# Load the pre-trained sentence encoder
model_name = "sentence-transformers/all-distilroberta-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = SentenceTransformer(model_name)

# # Load the FAISS index
# index_name = 'index.faiss'
# index_url = './'
# wget.download(index_url, index_name)
# index = faiss.read_index(faiss_flickr8k.index)

# Define the path to the zip folder containing the images
zip_path = "Images.zip"

# Open the zip folder
zip_file = zipfile.ZipFile(zip_path)

# Iterate over the images in the zip folder and display them using Streamlit
for image_name in zip_file.namelist():
    image_data = zip_file.read(image_name)
    image = Image.open(io.BytesIO(image_data))
    st.image(image, caption=image_name)

vectors = np.load("./sbert_text_features.npy")
vector_dimension = vectors.shape[1]
index = faiss.IndexFlatL2(vector_dimension)
faiss.normalize_L2(vectors)
index.add(vectors)

# Map the image ids to the corresponding image URLs
image_map_name = 'captions.json'

with open(image_map_name, 'r') as f:
    caption_dict = json.load(f)

image_list = list(caption_dict.keys())
caption_list = list(caption_dict.values())

def search(query, k=5):
    # Encode the query
    query_embedding = model.encode(query)
    query_vector = np.array([query_embedding])
    faiss.normalize_L2(query_vector)
    index.nprobe = index.ntotal

    # Search for the nearest neighbors in the FAISS index
    D, I = index.search(query_vector, k)

    # Map the image ids to the corresponding image URLs
    image_urls = []
    for i in I[0]:
        text_id = i
        image_id = str(image_list[i])
        image_data = zip_file.read(image_id)
        image = Image.open(image_data)
        # image = Image.open(io.BytesIO(image_data))
        st.image(image, caption=image_name, width=200)
        # image_url = "./Images/" + image_id
        # image_urls.append(image_url)

    # return image_urls

st.title("Image Search App")

query = st.text_input("Enter your search query here:")
if st.button("Search"):
    if query:
        search(query)

        # Display the images
        # st.image(image_urls, width=200)

if __name__ == '__main__':    
    st.cache(allow_output_mutation=True)
    run_app()