Spaces:
Sleeping
Sleeping
import chromadb | |
from chromadb.config import Settings | |
import torchvision.models as models | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import logging | |
import streamlit as st | |
import requests | |
import json | |
import uuid | |
import os | |
try: | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def load_mobilenet_model(): | |
device = 'cpu' | |
model = models.mobilenet_v3_small(pretrained=False) | |
model.classifier[3] = torch.nn.Linear(1024, 768) | |
model.load_state_dict(torch.load( | |
'mobilenet_v3_small_distilled_new_state_dict.pth', map_location=device)) | |
model.eval().to(device) | |
return model | |
def load_chromadb(): | |
chroma_client = chromadb.PersistentClient( | |
path='data', settings=Settings(anonymized_telemetry=False)) | |
collection = chroma_client.get_collection(name='images') | |
return collection | |
model = load_mobilenet_model() | |
logger.info("MobileNet loaded") | |
collection = load_chromadb() | |
logger.info("ChromaDB loaded") | |
logger.info( | |
f"Connected to ChromaDB collection images with {collection.count()} items") | |
preprocess = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ | |
0.229, 0.224, 0.225]) | |
]) | |
def get_image_embedding(image): | |
if isinstance(image, str): | |
img = Image.open(image).convert('RGB') | |
else: | |
img = Image.open(image).convert('RGB') | |
input_tensor = preprocess(img).unsqueeze(0).to('cpu') | |
with torch.no_grad(): | |
student_embedding = model(input_tensor) | |
return torch.nn.functional.normalize(student_embedding, p=2, dim=1).squeeze(0).tolist() | |
def save_image(image_file): | |
unique_filename = f"{image_file.name}" | |
save_path = os.path.join('images', unique_filename) | |
with open(save_path, "wb") as f: | |
f.write(image_file.getbuffer()) | |
return save_path | |
def resize_image(image_path, size=(224, 224)): | |
if isinstance(image_path, str): | |
img = Image.open(image_path).convert("RGB") | |
else: | |
# Handle uploaded file | |
img = Image.open(image_path).convert("RGB") | |
img_resized = img.resize(size, Image.LANCZOS) # High-quality resizing | |
return img_resized | |
st.sidebar.header("Upload Images") | |
image_files = st.sidebar.file_uploader( | |
"Upload images (Please do not upload personal data.)", type=["png", "jpg", "jpeg"], accept_multiple_files=True) | |
num_images = st.sidebar.slider( | |
"Number of results to return", min_value=1, max_value=10, value=3) | |
if image_files: | |
st.sidebar.subheader( | |
"Add Images to collection") | |
if st.sidebar.button("Add uploaded images"): | |
for idx, image_file in enumerate(image_files): | |
image_embedding = get_image_embedding(image_file) | |
saved_path = save_image(image_file) | |
unique_id = str(uuid.uuid4()) | |
metadata = { | |
'path': f'images/{image_file.name}', "type": "photo" | |
} | |
collection.add( | |
embeddings=[image_embedding], | |
ids=[unique_id], | |
metadatas=[metadata] | |
) | |
st.sidebar.success( | |
f"Image {image_file.name} added to the collection") | |
st.title('Image Search Using Text') | |
st.write( | |
"The images stored in this database are sourced from the [COCO 2017 Validation Dataset](https://cocodataset.org/#download).") | |
st.write('Enter the text to search for images with matching description') | |
text_input = st.text_input("Description", "Road") | |
if st.button("Search"): | |
if text_input.strip(): | |
params = {'text': text_input} | |
response = requests.get( | |
'https://ashish-001-text-embedding-api.hf.space/embedding', params=params) | |
if response.status_code == 200: | |
logger.info("Embedding returned by API successfully") | |
data = json.loads(response.content) | |
embedding = data['embedding'] | |
results = collection.query( | |
query_embeddings=[embedding], | |
n_results=num_images | |
) | |
images = [results['metadatas'][0][i]['path'] | |
for i in range(len(results['metadatas'][0]))] | |
distances = [results['distances'][0][i] | |
for i in range(len(results['metadatas'][0]))] | |
if images: | |
cols_per_row = 3 | |
rows = (len(images)+cols_per_row-1)//cols_per_row | |
for row in range(rows): | |
cols = st.columns(cols_per_row) | |
for col_idx, col in enumerate(cols): | |
img_idx = row*cols_per_row+col_idx | |
if img_idx < len(images): | |
resized_img = resize_image( | |
images[img_idx], size=(224, 224)) | |
col.image(resized_img, | |
caption=f"Image {img_idx+1}\ndistance {distances[img_idx]}", use_container_width=True) | |
else: | |
st.write("No image found") | |
else: | |
st.write("Please try again later") | |
logger.info(f"status code {response.status_code} returned") | |
else: | |
st.write("Please enter text in the text area") | |
except Exception as e: | |
logger.info(f"Exception occured: {e}") | |