EmbedDior / app.py
HadrienCr's picture
Add examples and changed requirements
b32144d
raw
history blame
3.74 kB
import os
import sys
import gradio as gr
import numpy as np
from PIL import Image
from embedding import get_device, get_model_and_processor, to_embedding
from datasets import Dataset, load_dataset
def initialize_model():
device = get_device()
model, processor = get_model_and_processor("patrickjohncyh/fashion-clip", device)
ref_dataset = load_dataset("HadrienCr/embeddeDior", split="train")
ref_dataset = ref_dataset.add_faiss_index("embeddings")
return model, processor, ref_dataset, device
def search(
image: np.ndarray,
reference_dataset: Dataset,
model,
processor,
device: str,
num_neighbors: int = 4,
):
"""a function that embeds a new image and returns the most probable results"""
scores, retrieved_examples = reference_dataset.get_nearest_examples(
"embeddings",
to_embedding(np.expand_dims(image, 0), processor, model, device),
k=num_neighbors,
)
return retrieved_examples
def process_image(image, num_results, remove_bg, model, processor, ref_dataset, device):
if image is None:
return [] # Return an empty list if no image is provided
# Ensure the input image is a numpy array
if isinstance(image, Image.Image):
image = np.array(image)
# Handle background removal
if remove_bg:
from rembg import remove
image = remove(image)[:,:,0:3]
# Perform the search
results = search(
image,
ref_dataset,
model,
processor,
device,
num_neighbors=num_results
)
images = results['image']
paths = results['path']
# Prepare the output
output_images = []
for img, path in zip(images, paths):
output_images.append((np.array(img), os.path.basename(path)))
return output_images # Return the list of tuples
def main():
print("Initializing model and loading reference dataset...")
model, processor, ref_dataset, device = initialize_model()
print("Initialization complete!")
# Path to the examples folder
examples_folder = "examples/"
example_files = [
os.path.join(examples_folder, fname)
for fname in os.listdir(examples_folder)
if fname.lower().endswith(('png', 'jpg', 'jpeg'))
]
with gr.Blocks() as demo:
gr.Markdown("# Image Retrieval System")
gr.Markdown("Upload an image to find similar images in the reference dataset.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="Upload Image",
type="pil" # Changed to PIL format
)
num_results = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="Number of results"
)
remove_bg = gr.Checkbox(label="Remove Background")
submit_btn = gr.Button("Search")
with gr.Column(scale=2):
gallery = gr.Gallery(
label="Retrieved Images",
show_label=True,
columns=3,
object_fit="contain"
)
# Add the Examples component
gr.Examples(
examples=example_files,
inputs=input_image,
label="Example Images"
)
submit_btn.click(
fn=lambda img, num, bg: process_image(img, num, bg, model, processor, ref_dataset, device),
inputs=[input_image, num_results, remove_bg],
outputs=gallery
)
demo.launch(share=True)
if __name__ == "__main__":
main()