import streamlit as st import clip import torch import numpy as np import os import glob from pathlib import Path from PIL import Image import chromadb import boto3 import botocore from io import BytesIO st.set_page_config( page_title="Поиск изображений Google Open Images по текстовому запросу", page_icon="🤖", layout="wide", initial_sidebar_state="expanded" ) BUCKET_NAME = 'open-images-dataset' DOWNLOAD_FOLDER='ds_download' SPLIT='validation' device = "cuda" if torch.cuda.is_available() else "cpu" model= None preprocess = None try: model, preprocess = clip.load("ViT-B/32", device=device) except: st.write("Exception loading model") collection_path = "image_embeddings_collection.chroma" chroma_client = chromadb.PersistentClient(path=collection_path) image_collection= None try: image_collection = chroma_client.get_or_create_collection("image" , metadata={"hnsw:space": "cosine"}) except: st.write("Exception loading collection") #if os.path.isdir(DOWNLOAD_FOLDER)==False: #os.mkdir(DOWNLOAD_FOLDER) num_embeddings = image_collection.count() # Main page heading st.title("Поиск изображений Google Open Images по текстовому запросу") # Sidebar st.sidebar.header("Настройки поиска") st.sidebar.write(f"Количество изображений в БД: {num_embeddings}") text_input = st.sidebar.text_input(label='Введите запрос:', value='kite in the sky') search_files_cnt = int(st.sidebar.slider(label="Количество изображений", min_value=1, max_value=10, value=2)) searchStarted = st.sidebar.button('Искать') col1, col2 = st.columns(2) if searchStarted==True: text_embedding = clip.tokenize(text_input).to(device) text_features = model.encode_text(text_embedding).detach().cpu().numpy() result = image_collection.query(text_features, n_results=search_files_cnt) bucket = boto3.resource('s3', config=botocore.config.Config( signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME) cnt=0 for i in result['metadatas'][0]: try: filename= Path(i['name']) image_id= filename.with_suffix('') filepath= i['path'] #down_file_path= os.path.join(DOWNLOAD_FOLDER, f'{image_id}.jpg') #bucket.download_file(f'{SPLIT}/{image_id}.jpg', down_file_path) #img = Image.open(down_file_path) object_key= f'{SPLIT}/{image_id}.jpg' image_data = BytesIO() bucket.download_fileobj(object_key, image_data) image_data.seek(0) img = Image.open(image_data) col_ref= col1 if ((cnt+1) % 2) == 0: col_ref= col2 with col_ref: st.write('image_id:', image_id) st.write('distance:', result['distances'][0][cnt]) st.image(img, use_column_width=True) except botocore.exceptions.ClientError as exception: st.write(str(exception)) cnt=cnt+1