user-agent's picture
Update app.py
a156838 verified
raw
history blame
1.43 kB
import requests
from PIL import Image
from io import BytesIO
import torch
from transformers import CLIPProcessor, CLIPModel
import gradio as gr
# Initialize the model and processor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def get_embedding(image_or_text):
if image_or_text.startswith(('http:', 'https:')):
# Image URL
response = requests.get(image_or_text)
image = Image.open(BytesIO(response.content))
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
features = model.get_image_features(**inputs).cpu().numpy()
else:
# Text input
inputs = processor(text=[image_or_text], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
features = model.get_text_features(**inputs).cpu().numpy()
return features.flatten().tolist()
# Define the Gradio interface
interface = gr.Interface(fn=get_embedding,
inputs="text",
outputs="json",
title="CLIP Model Embeddings",
description="Enter an Image URL or text to get embeddings from CLIP.")
if __name__ == "__main__":
interface.launch(share=True)