image / app.py
mgbam's picture
Update app.py
48ac3b6 verified
raw
history blame
5.67 kB
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoProcessor, AutoModel, CLIPVisionModel
import torch.nn.functional as F
#---------------------------------
#++++++++ Model ++++++++++
#---------------------------------
def load_biomedclip_model():
"""Loads the BiomedCLIP model and tokenizer."""
biomedclip_model_name = 'microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
processor = AutoProcessor.from_pretrained(biomedclip_model_name)
config = AutoModel.from_pretrained(biomedclip_model_name).config
vision_model = CLIPVisionModel.from_pretrained(config.vision_config._name_or_path, torch_dtype=torch.float16).cuda().eval()
text_model = AutoModel.from_pretrained(config.text_config._name_or_path).cuda().eval()
return vision_model, text_model, processor
def compute_similarity(image, text, vision_model, text_model, biomedclip_processor):
"""Computes similarity scores using BiomedCLIP."""
with torch.no_grad():
inputs = biomedclip_processor(text=text, images=image, return_tensors="pt", padding=True).to(text_model.device)
text_embeds = text_model(**inputs).last_hidden_state[:,0,:] # Extract the [CLS] token
image_inputs = biomedclip_processor(images=image, return_tensors="pt").to(vision_model.device)
image_embeds = vision_model(**image_inputs).last_hidden_state[:,0,:] # Extract the image embedding
image_embeds = F.normalize(image_embeds, dim=-1)
text_embeds = F.normalize(text_embeds, dim=-1)
similarity = (text_embeds @ image_embeds.transpose(-1, -2)).squeeze()
return similarity
#---------------------------------
#++++++++ Gradio ++++++++++
#---------------------------------
def gradio_reset(chat_state, img_list, similarity_output):
"""Resets the chat state and image list."""
if chat_state is not None:
chat_state.messages = []
if img_list is not None:
img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your medical image first', interactive=False), gr.update(value="Upload & Start Analysis", interactive=True), chat_state, img_list, gr.update(value="", visible=False)
def upload_img(gr_img, text_input, chat_state, similarity_output):
"""Handles image upload."""
if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None, gr.update(visible=False)
img_list = [gr_img]
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Analysis", interactive=False), chat_state, img_list, gr.update(visible=True)
def gradio_ask(user_message, chatbot, chat_state):
"""Handles user input."""
if not user_message:
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
chatbot = chatbot + [[user_message, None]]
return '', chatbot, chat_state
@spaces.GPU
def gradio_answer(chatbot, chat_state, img_list, vision_model, text_model, biomedclip_processor, similarity_output):
"""Computes and displays similarity scores."""
if not img_list:
return chatbot, chat_state, img_list, similarity_output
similarity_score = compute_similarity(img_list[0], chatbot[-1][0], vision_model, text_model, biomedclip_processor)
print(f'Similarity Score is: {similarity_score}')
similarity_text = f"Similarity Score: {similarity_score:.3f}"
chatbot[-1][1] = similarity_text
return chatbot, chat_state, img_list, gr.update(value=similarity_text, visible=True)
title = """<h1 align="center">Medical Image Analysis Tool</h1>"""
description = """<h3>Upload medical images, ask questions, and receive a similarity score.</h3>"""
examples_list=[
["./case1.png", "Analyze the X-ray for any abnormalities."],
["./case2.jpg", "What type of disease may be present?"],
["./case1.png","What is the anatomical structure shown here?"]
]
# Load models and related resources outside of the Gradio block for loading on startup
vision_model, text_model, biomedclip_processor = load_biomedclip_model()
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=0.5):
image = gr.Image(type="pil", label="Medical Image")
upload_button = gr.Button(value="Upload & Start Analysis", interactive=True, variant="primary")
clear = gr.Button("Restart")
with gr.Column():
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(label='Medical Analysis')
text_input = gr.Textbox(label='Analysis Query', placeholder='Please upload your medical image first', interactive=False)
similarity_output = gr.Textbox(label="Similarity Score", visible=False, interactive=False)
gr.Examples(examples=examples_list, inputs=[image, text_input])
upload_button.click(upload_img, [image, text_input, chat_state, similarity_output], [image, text_input, upload_button, chat_state, img_list, similarity_output])
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
gradio_answer, [chatbot, chat_state, img_list, vision_model, text_model, biomedclip_processor, similarity_output], [chatbot, chat_state, img_list, similarity_output]
)
clear.click(gradio_reset, [chat_state, img_list, similarity_output], [chatbot, image, text_input, upload_button, chat_state, img_list, similarity_output], queue=False)
demo.launch()