Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import streamlit as st | |
| import pandas as pd | |
| from plip_support import embed_text | |
| import numpy as np | |
| from PIL import Image | |
| import requests | |
| import tokenizers | |
| from io import BytesIO | |
| import torch | |
| from transformers import ( | |
| VisionTextDualEncoderModel, | |
| AutoFeatureExtractor, | |
| AutoTokenizer, | |
| CLIPModel, | |
| AutoProcessor | |
| ) | |
| import streamlit.components.v1 as components | |
| def embed_texts(model, texts, processor): | |
| inputs = processor(text=texts, padding="longest") | |
| input_ids = torch.tensor(inputs["input_ids"]) | |
| attention_mask = torch.tensor(inputs["attention_mask"]) | |
| with torch.no_grad(): | |
| embeddings = model.get_text_features( | |
| input_ids=input_ids, attention_mask=attention_mask | |
| ) | |
| return embeddings | |
| def load_embeddings(embeddings_path): | |
| print("loading embeddings") | |
| return np.load(embeddings_path) | |
| def load_path_clip(): | |
| model = CLIPModel.from_pretrained("vinid/plip") | |
| processor = AutoProcessor.from_pretrained("vinid/plip") | |
| return model, processor | |
| def app(): | |
| st.title('PLIP Image Search') | |
| plip_imgURL = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t") | |
| plip_weblink = pd.read_csv("tweet_eval_retrieval_twlnk.tsv", sep="\t") | |
| model, processor = load_path_clip() | |
| image_embedding = load_embeddings("tweet_eval_embeddings.npy") | |
| query = st.text_input('Search Query', '') | |
| if query: | |
| text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy() | |
| text_embedding = text_embedding/np.linalg.norm(text_embedding) | |
| # Sort IDs by cosine-similarity from high to low | |
| similarity_scores = text_embedding.dot(image_embedding.T) | |
| id_sorted = np.argsort(similarity_scores)[::-1] | |
| best_id = id_sorted[0] | |
| score = similarity_scores[best_id] | |
| target_url = plip_imgURL.iloc[best_id]["imageURL"] | |
| target_weblink = plip_weblink.iloc[best_id]["weblink"] | |
| st.caption('Most relevant image (similarity = %.4f)' % score) | |
| #response = requests.get(target_url) | |
| #img = Image.open(BytesIO(response.content)) | |
| #st.image(img) | |
| components.html(''' | |
| <blockquote class="twitter-tweet"> | |
| <a href="%s"></a> | |
| </blockquote> | |
| <script async src="https://platform.twitter.com/widgets.js" charset="utf-8"> | |
| </script> | |
| ''' % target_weblink, | |
| height=600) | |
