Spaces:
Sleeping
Sleeping
import streamlit as st | |
import clip | |
import torch | |
import numpy as np | |
import os | |
import glob | |
from pathlib import Path | |
from PIL import Image | |
import chromadb | |
import boto3 | |
import botocore | |
from io import BytesIO | |
st.set_page_config( | |
page_title="Поиск изображений Google Open Images по текстовому запросу", | |
page_icon="🤖", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
BUCKET_NAME = 'open-images-dataset' | |
DOWNLOAD_FOLDER='ds_download' | |
SPLIT='validation' | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model= None | |
preprocess = None | |
try: | |
model, preprocess = clip.load("ViT-B/32", device=device) | |
except: | |
st.write("Exception loading model") | |
collection_path = "image_embeddings_collection.chroma" | |
chroma_client = chromadb.PersistentClient(path=collection_path) | |
image_collection= None | |
try: | |
image_collection = chroma_client.get_or_create_collection("image" , metadata={"hnsw:space": "cosine"}) | |
except: | |
st.write("Exception loading collection") | |
#if os.path.isdir(DOWNLOAD_FOLDER)==False: | |
#os.mkdir(DOWNLOAD_FOLDER) | |
num_embeddings = image_collection.count() | |
# Main page heading | |
st.title("Поиск изображений Google Open Images по текстовому запросу") | |
# Sidebar | |
st.sidebar.header("Настройки поиска") | |
st.sidebar.write(f"Количество изображений в БД: {num_embeddings}") | |
text_input = st.sidebar.text_input(label='Введите запрос:', value='kite in the sky') | |
search_files_cnt = int(st.sidebar.slider(label="Количество изображений", min_value=1, max_value=10, value=2)) | |
searchStarted = st.sidebar.button('Искать') | |
col1, col2 = st.columns(2) | |
if searchStarted==True: | |
text_embedding = clip.tokenize(text_input).to(device) | |
text_features = model.encode_text(text_embedding).detach().cpu().numpy() | |
result = image_collection.query(text_features, n_results=search_files_cnt) | |
bucket = boto3.resource('s3', | |
config=botocore.config.Config( | |
signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME) | |
cnt=0 | |
for i in result['metadatas'][0]: | |
try: | |
filename= Path(i['name']) | |
image_id= filename.with_suffix('') | |
filepath= i['path'] | |
#down_file_path= os.path.join(DOWNLOAD_FOLDER, f'{image_id}.jpg') | |
#bucket.download_file(f'{SPLIT}/{image_id}.jpg', down_file_path) | |
#img = Image.open(down_file_path) | |
object_key= f'{SPLIT}/{image_id}.jpg' | |
image_data = BytesIO() | |
bucket.download_fileobj(object_key, image_data) | |
image_data.seek(0) | |
img = Image.open(image_data) | |
col_ref= col1 | |
if ((cnt+1) % 2) == 0: | |
col_ref= col2 | |
with col_ref: | |
st.write('image_id:', image_id) | |
st.write('distance:', result['distances'][0][cnt]) | |
st.image(img, use_column_width=True) | |
except botocore.exceptions.ClientError as exception: | |
st.write(str(exception)) | |
cnt=cnt+1 | |