Jyothirmai's picture
Update app.py
dd914ca verified
raw
history blame
3.57 kB
import gradio as gr
from PIL import Image
import clipGPT
import vitGPT
import skimage.io as io
import PIL.Image
import difflib
import ViTCoAtt
from build_vocab import Vocabulary
# Caption generation functions
def generate_caption_clipgpt(image, max_tokens, temperature):
caption = clipGPT.generate_caption_clipgpt(image, max_tokens, temperature)
return caption
def generate_caption_vitgpt(image, max_tokens, temperature):
caption = vitGPT.generate_caption(image, max_tokens, temperature)
return caption
def generate_caption_vitCoAtt(image):
caption = ViTCoAtt.CaptionSampler.main(image)
return caption
with gr.Blocks() as demo:
gr.HTML("<h1 style='text-align: center;'>MedViT: A Vision Transformer-Driven Method for Generating Medical Reports πŸ₯πŸ€–</h1>")
gr.HTML("<p style='text-align: center;'>You can generate captions by uploading an X-Ray and selecting a model of your choice below</p>")
sample_data = [
{'image': 'https://imgur.com/W1pIr9b', 'max token, temp': '75, 0.7', 'model supported': 'CLIP-GPT2, ViT-GPT2, ViT-CoAttention', 'ground truth': '...'},
{'image': 'https://imgur.com/MLJaWnf', 'max token, temp': '50, 0.8', 'model supported': 'CLIP-GPT2, ViT-CoAttention', 'ground truth': '...'},
]
with gr.Row():
image = gr.Image(label="Upload Chest X-ray", type="pil")
image_table = gr.Dataframe(sample_data, headers=['image', 'max token, temp', 'model supported', 'ground truth'], datatype=['picture', 'str', 'str', 'str'])
gr.HTML("<p style='text-align: center;'> Please select the Number of Max Tokens and Temperature setting, if you are testing CLIP GPT2 and VIT GPT2 Models</p>")
with gr.Row():
with gr.Column():
max_tokens = gr.Dropdown(list(range(50, 101)), label="Max Tokens", value=75)
temperature = gr.Slider(0.5, 0.9, step=0.1, label="Temperature", value=0.7)
model_choice = gr.Radio(["CLIP-GPT2", "ViT-GPT2", "ViT-CoAttention"], label="Select Model")
generate_button = gr.Button("Generate Caption")
caption = gr.Textbox(label="Generated Caption")
gr.HTML("<p style='text-align: center;'> Please select the Number of Max Tokens and Temperature setting, if you are testing CLIP GPT2 and VIT GPT2 Models</p>")
def predict(img, model_name, max_tokens, temperature):
if model_name == "CLIP-GPT2":
return generate_caption_clipgpt(img, max_tokens, temperature)
elif model_name == "ViT-GPT2":
return generate_caption_vitgpt(img, max_tokens, temperature)
elif model_name == "ViT-CoAttention":
return generate_caption_vitCoAtt(img)
else:
return "Caption generation for this model is not yet implemented."
def predict_from_table(row, model_name):
img_url = row['image']
img = Image.open(io.imread(img_url))
if model_name == "CLIP-GPT2":
return generate_caption_clipgpt(img, max_tokens, temperature)
elif model_name == "ViT-GPT2":
return generate_caption_vitgpt(img, max_tokens, temperature)
elif model_name == "ViT-CoAttention":
return generate_caption_vitCoAtt(img)
else:
return "Caption generation for this model is not yet implemented."
# Event handlers
generate_button.click(predict, [image, model_choice, max_tokens, temperature], caption)
image_table.click(predict_from_table, [image_table, model_choice], caption)
demo.launch()