Clip-Model / app.py
ktllc's picture
Create app.py
b05aa6a
raw
history blame
2.85 kB
import clip
import numpy as np
import torch
import gradio as gr
# Load the CLIP model
model, preprocess = clip.load("ViT-B/32")
device = "cuda" if torch.cuda.is_available() else "cpu" # Check for GPU availability
model.to(device).eval()
# Define the Business Listing variable
Business_Listing = "Air Guide"
def find_similar_images(text_input):
# Directory where you want to load images
image_dir = "/content/sample_data/Tourism"
# Create an empty description dictionary
description = f"{Business_Listing} Logo"
# Set up the layout for displaying images
num_rows = 4
num_cols = 8
original_images = []
images = []
texts = []
# Load and preprocess images
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', '.ico', '.svg', '.eps', '.pdf']
for filename in [filename for filename in os.listdir(image_dir) if any(filename.endswith(ext) for ext in image_extensions)]:
# Get the image name (without extension)
image_name, _ = os.path.splitext(filename)
# Load the image
image = Image.open(os.path.join(image_dir, filename)).convert("RGB")
original_images.append(image)
images.append(preprocess(image))
texts.append(description)
# Prepare input text and images
image_input = torch.tensor(np.stack(images)).to(device)
text_tokens = clip.tokenize([f"This is {text_input}"])
text_tokens = text_tokens.to(device)
# Encode text and image features
with torch.no_grad():
image_features = model.encode_image(image_input).float()
text_features = model.encode_text(text_tokens).float()
# Normalize features and calculate similarity
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
# Find the maximum similarity value
max_similarity_value = similarity[0, :].max()
# Find all indices with the maximum similarity value
max_similarity_indices = np.where(similarity[0, :] == max_similarity_value)
# Get the filenames with the highest similarity
valid_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', '.ico', '.svg', '.eps', '.pdf')
image_files = [filename for filename in os.listdir(image_dir) if filename.endswith(valid_extensions)]
filenames_with_highest_similarity = [image_files[i] for i in max_similarity_indices[0]]
return filenames_with_highest_similarity, max_similarity_value
# Define a Gradio interface
iface = gr.Interface(
fn=find_similar_images,
inputs="text",
outputs=["text", "number"],
live=True,
interpretation="default",
title="CLIP Model Image Search",
)
iface.launch()