Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import torch
|
|
| 2 |
from PIL import Image
|
| 3 |
import gradio as gr
|
| 4 |
import spaces
|
| 5 |
-
from transformers import AutoProcessor,
|
| 6 |
import torch.nn.functional as F
|
| 7 |
|
| 8 |
#---------------------------------
|
|
@@ -13,16 +13,18 @@ def load_biomedclip_model():
|
|
| 13 |
"""Loads the BiomedCLIP model and tokenizer."""
|
| 14 |
biomedclip_model_name = 'microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
|
| 15 |
processor = AutoProcessor.from_pretrained(biomedclip_model_name)
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
def compute_similarity(image, text,
|
| 20 |
"""Computes similarity scores using BiomedCLIP."""
|
| 21 |
with torch.no_grad():
|
| 22 |
-
inputs = biomedclip_processor(text=text, images=image, return_tensors="pt", padding=True).to(
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
image_embeds = F.normalize(image_embeds, dim=-1)
|
| 27 |
text_embeds = F.normalize(text_embeds, dim=-1)
|
| 28 |
similarity = (text_embeds @ image_embeds.transpose(-1, -2)).squeeze()
|
|
@@ -55,12 +57,12 @@ def gradio_ask(user_message, chatbot, chat_state):
|
|
| 55 |
return '', chatbot, chat_state
|
| 56 |
|
| 57 |
@spaces.GPU
|
| 58 |
-
def gradio_answer(chatbot, chat_state, img_list,
|
| 59 |
"""Computes and displays similarity scores."""
|
| 60 |
if not img_list:
|
| 61 |
return chatbot, chat_state, img_list, similarity_output
|
| 62 |
|
| 63 |
-
similarity_score = compute_similarity(img_list[0], chatbot[-1][0],
|
| 64 |
print(f'Similarity Score is: {similarity_score}')
|
| 65 |
|
| 66 |
similarity_text = f"Similarity Score: {similarity_score:.3f}"
|
|
@@ -77,7 +79,7 @@ examples_list=[
|
|
| 77 |
]
|
| 78 |
|
| 79 |
# Load models and related resources outside of the Gradio block for loading on startup
|
| 80 |
-
|
| 81 |
|
| 82 |
with gr.Blocks() as demo:
|
| 83 |
gr.Markdown(title)
|
|
@@ -100,7 +102,7 @@ with gr.Blocks() as demo:
|
|
| 100 |
upload_button.click(upload_img, [image, text_input, chat_state, similarity_output], [image, text_input, upload_button, chat_state, img_list, similarity_output])
|
| 101 |
|
| 102 |
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
| 103 |
-
gradio_answer, [chatbot, chat_state, img_list,
|
| 104 |
)
|
| 105 |
clear.click(gradio_reset, [chat_state, img_list, similarity_output], [chatbot, image, text_input, upload_button, chat_state, img_list, similarity_output], queue=False)
|
| 106 |
|
|
|
|
| 2 |
from PIL import Image
|
| 3 |
import gradio as gr
|
| 4 |
import spaces
|
| 5 |
+
from transformers import AutoProcessor, AutoModel, CLIPVisionModel
|
| 6 |
import torch.nn.functional as F
|
| 7 |
|
| 8 |
#---------------------------------
|
|
|
|
| 13 |
"""Loads the BiomedCLIP model and tokenizer."""
|
| 14 |
biomedclip_model_name = 'microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
|
| 15 |
processor = AutoProcessor.from_pretrained(biomedclip_model_name)
|
| 16 |
+
config = AutoModel.from_pretrained(biomedclip_model_name).config
|
| 17 |
+
vision_model = CLIPVisionModel.from_pretrained(config.vision_config._name_or_path, torch_dtype=torch.float16).cuda().eval()
|
| 18 |
+
text_model = AutoModel.from_pretrained(config.text_config._name_or_path).cuda().eval()
|
| 19 |
+
return vision_model, text_model, processor
|
| 20 |
|
| 21 |
+
def compute_similarity(image, text, vision_model, text_model, biomedclip_processor):
|
| 22 |
"""Computes similarity scores using BiomedCLIP."""
|
| 23 |
with torch.no_grad():
|
| 24 |
+
inputs = biomedclip_processor(text=text, images=image, return_tensors="pt", padding=True).to(text_model.device)
|
| 25 |
+
text_embeds = text_model(**inputs).last_hidden_state[:,0,:] # Extract the [CLS] token
|
| 26 |
+
image_inputs = biomedclip_processor(images=image, return_tensors="pt").to(vision_model.device)
|
| 27 |
+
image_embeds = vision_model(**image_inputs).last_hidden_state[:,0,:] # Extract the image embedding
|
| 28 |
image_embeds = F.normalize(image_embeds, dim=-1)
|
| 29 |
text_embeds = F.normalize(text_embeds, dim=-1)
|
| 30 |
similarity = (text_embeds @ image_embeds.transpose(-1, -2)).squeeze()
|
|
|
|
| 57 |
return '', chatbot, chat_state
|
| 58 |
|
| 59 |
@spaces.GPU
|
| 60 |
+
def gradio_answer(chatbot, chat_state, img_list, vision_model, text_model, biomedclip_processor, similarity_output):
|
| 61 |
"""Computes and displays similarity scores."""
|
| 62 |
if not img_list:
|
| 63 |
return chatbot, chat_state, img_list, similarity_output
|
| 64 |
|
| 65 |
+
similarity_score = compute_similarity(img_list[0], chatbot[-1][0], vision_model, text_model, biomedclip_processor)
|
| 66 |
print(f'Similarity Score is: {similarity_score}')
|
| 67 |
|
| 68 |
similarity_text = f"Similarity Score: {similarity_score:.3f}"
|
|
|
|
| 79 |
]
|
| 80 |
|
| 81 |
# Load models and related resources outside of the Gradio block for loading on startup
|
| 82 |
+
vision_model, text_model, biomedclip_processor = load_biomedclip_model()
|
| 83 |
|
| 84 |
with gr.Blocks() as demo:
|
| 85 |
gr.Markdown(title)
|
|
|
|
| 102 |
upload_button.click(upload_img, [image, text_input, chat_state, similarity_output], [image, text_input, upload_button, chat_state, img_list, similarity_output])
|
| 103 |
|
| 104 |
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
| 105 |
+
gradio_answer, [chatbot, chat_state, img_list, vision_model, text_model, biomedclip_processor, similarity_output], [chatbot, chat_state, img_list, similarity_output]
|
| 106 |
)
|
| 107 |
clear.click(gradio_reset, [chat_state, img_list, similarity_output], [chatbot, image, text_input, upload_button, chat_state, img_list, similarity_output], queue=False)
|
| 108 |
|