dmibor's picture
закинул публикацию
cd39c4a
raw
history blame
3.17 kB
import streamlit as st
import clip
import torch
import numpy as np
import os
import glob
from pathlib import Path
from PIL import Image
import chromadb
import boto3
import botocore
from io import BytesIO
st.set_page_config(
page_title="Поиск изображений Google Open Images по текстовому запросу",
page_icon="🤖",
layout="wide",
initial_sidebar_state="expanded"
)
BUCKET_NAME = 'open-images-dataset'
DOWNLOAD_FOLDER='ds_download'
SPLIT='validation'
device = "cuda" if torch.cuda.is_available() else "cpu"
model= None
preprocess = None
try:
model, preprocess = clip.load("ViT-B/32", device=device)
except:
st.write("Exception loading model")
collection_path = "image_embeddings_collection.chroma"
chroma_client = chromadb.PersistentClient(path=collection_path)
image_collection= None
try:
image_collection = chroma_client.get_or_create_collection("image" , metadata={"hnsw:space": "cosine"})
except:
st.write("Exception loading collection")
#if os.path.isdir(DOWNLOAD_FOLDER)==False:
#os.mkdir(DOWNLOAD_FOLDER)
num_embeddings = image_collection.count()
# Main page heading
st.title("Поиск изображений Google Open Images по текстовому запросу")
# Sidebar
st.sidebar.header("Настройки поиска")
st.sidebar.write(f"Количество изображений в БД: {num_embeddings}")
text_input = st.sidebar.text_input(label='Введите запрос:', value='kite in the sky')
search_files_cnt = int(st.sidebar.slider(label="Количество изображений", min_value=1, max_value=10, value=2))
searchStarted = st.sidebar.button('Искать')
col1, col2 = st.columns(2)
if searchStarted==True:
text_embedding = clip.tokenize(text_input).to(device)
text_features = model.encode_text(text_embedding).detach().cpu().numpy()
result = image_collection.query(text_features, n_results=search_files_cnt)
bucket = boto3.resource('s3',
config=botocore.config.Config(
signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME)
cnt=0
for i in result['metadatas'][0]:
try:
filename= Path(i['name'])
image_id= filename.with_suffix('')
filepath= i['path']
#down_file_path= os.path.join(DOWNLOAD_FOLDER, f'{image_id}.jpg')
#bucket.download_file(f'{SPLIT}/{image_id}.jpg', down_file_path)
#img = Image.open(down_file_path)
object_key= f'{SPLIT}/{image_id}.jpg'
image_data = BytesIO()
bucket.download_fileobj(object_key, image_data)
image_data.seek(0)
img = Image.open(image_data)
col_ref= col1
if ((cnt+1) % 2) == 0:
col_ref= col2
with col_ref:
st.write('image_id:', image_id)
st.write('distance:', result['distances'][0][cnt])
st.image(img, use_column_width=True)
except botocore.exceptions.ClientError as exception:
st.write(str(exception))
cnt=cnt+1