File size: 3,166 Bytes
cd39c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import streamlit as st
import clip
import torch
import numpy as np
import os
import glob
from pathlib import Path
from PIL import Image
import chromadb
import boto3
import botocore
from io import BytesIO

st.set_page_config(
    page_title="Поиск изображений Google Open Images по текстовому запросу",
    page_icon="🤖",
    layout="wide",
    initial_sidebar_state="expanded"
)

BUCKET_NAME = 'open-images-dataset'
DOWNLOAD_FOLDER='ds_download'
SPLIT='validation'

device = "cuda" if torch.cuda.is_available() else "cpu"
model= None
preprocess = None
try:
    model, preprocess = clip.load("ViT-B/32", device=device)
except: 
    st.write("Exception loading model")

collection_path = "image_embeddings_collection.chroma"
chroma_client = chromadb.PersistentClient(path=collection_path)
image_collection= None
try:
    image_collection = chroma_client.get_or_create_collection("image" , metadata={"hnsw:space": "cosine"})
except: 
    st.write("Exception loading collection")
    
#if os.path.isdir(DOWNLOAD_FOLDER)==False:
  #os.mkdir(DOWNLOAD_FOLDER)

num_embeddings = image_collection.count()

# Main page heading
st.title("Поиск изображений Google Open Images по текстовому запросу")

# Sidebar
st.sidebar.header("Настройки поиска")


st.sidebar.write(f"Количество изображений в БД: {num_embeddings}")
text_input = st.sidebar.text_input(label='Введите запрос:', value='kite in the sky')
search_files_cnt = int(st.sidebar.slider(label="Количество изображений", min_value=1, max_value=10, value=2)) 
searchStarted = st.sidebar.button('Искать')

col1, col2 = st.columns(2)
if searchStarted==True:
    text_embedding = clip.tokenize(text_input).to(device)
    text_features = model.encode_text(text_embedding).detach().cpu().numpy()
    result = image_collection.query(text_features, n_results=search_files_cnt)
    bucket = boto3.resource('s3', 
                            config=botocore.config.Config(
                            signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME)
    cnt=0                        
    for i in result['metadatas'][0]:
        try:
            filename= Path(i['name'])
            image_id= filename.with_suffix('')
            filepath= i['path']
            #down_file_path= os.path.join(DOWNLOAD_FOLDER, f'{image_id}.jpg')
            #bucket.download_file(f'{SPLIT}/{image_id}.jpg',  down_file_path)
            #img = Image.open(down_file_path)
            object_key= f'{SPLIT}/{image_id}.jpg'
            image_data = BytesIO()
            bucket.download_fileobj(object_key, image_data)
            image_data.seek(0)
            img = Image.open(image_data)
            col_ref= col1
            if ((cnt+1) % 2) == 0:
                col_ref= col2
            with col_ref:
                st.write('image_id:', image_id)
                st.write('distance:', result['distances'][0][cnt])
                st.image(img, use_column_width=True)
        except botocore.exceptions.ClientError as exception:
            st.write(str(exception))
        cnt=cnt+1