File size: 2,748 Bytes
7ebfeb9 afda258 cf05f8b 22ed06b 623b4fb 931c795 7fcb6d2 623b4fb eba7622 7ebfeb9 afda258 8875dbc 7ebfeb9 8875dbc afda258 dee2758 8875dbc fc6f52f 7ebfeb9 623b4fb 8875dbc 22ed06b 8875dbc eba7622 9511ac2 478c334 96fc972 8875dbc 22ed06b 8875dbc 2e77581 8875dbc dd914ca eba7622 8875dbc afda258 8875dbc b51c75c 8875dbc dee2758 8875dbc afda258 8875dbc 478c334 8875dbc 623b4fb 8875dbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import gradio as gr
from PIL import Image
import clipGPT
import vitGPT
import skimage.io as io
import PIL.Image
import difflib
import ViTCoAtt
from build_vocab import Vocabulary
# Caption generation functions
def generate_caption_clipgpt(image, max_tokens, temperature):
caption = clipGPT.generate_caption_clipgpt(image, max_tokens, temperature)
return caption
def generate_caption_vitgpt(image, max_tokens, temperature):
caption = vitGPT.generate_caption(image, max_tokens, temperature)
return caption
def generate_caption_vitCoAtt(image):
caption = ViTCoAtt.CaptionSampler.main(image)
return caption
with gr.Blocks() as demo:
gr.HTML("<h1 style='text-align: center;'>MedViT: A Vision Transformer-Driven Method for Generating Medical Reports π₯π€</h1>")
gr.HTML("<p style='text-align: center;'>You can generate captions by uploading an X-Ray and selecting a model of your choice below</p>")
with gr.Row():
image = gr.Image(label="Upload Chest X-ray", type="pil")
sample_images_gallery = gr.Gallery(value = [
"https://imgur.com/W1pIr9b",
"https://imgur.com/MLJaWnf",
"https://imgur.com/6XymFW1",
"https://imgur.com/zdPjZZ1",
"https://imgur.com/DKUlZbF"], label="Sample Images", columns = 5)
gr.HTML("<p style='text-align: center;'> Please select the Number of Max Tokens and Temperature setting, if you are testing CLIP GPT2 and VIT GPT2 Models</p>")
with gr.Row():
with gr.Column(): # Column for dropdowns and model choice
max_tokens = gr.Dropdown(list(range(50, 101)), label="Max Tokens", value=75)
temperature = gr.Slider(0.5, 0.9, step=0.1, label="Temperature", value=0.7)
model_choice = gr.Radio(["CLIP-GPT2", "ViT-GPT2", "ViT-CoAttention"], label="Select Model")
generate_button = gr.Button("Generate Caption")
caption = gr.Textbox(label="Generated Caption")
def predict(img, model_name, max_tokens, temperature):
if model_name == "CLIP-GPT2":
return generate_caption_clipgpt(img, max_tokens, temperature)
elif model_name == "ViT-GPT2":
return generate_caption_vitgpt(img, max_tokens, temperature)
elif model_name == "ViT-CoAttention":
return generate_caption_vitCoAtt(img)
else:
return "Caption generation for this model is not yet implemented."
# Event handlers
generate_button.click(predict, [image, model_choice, max_tokens, temperature], caption)
sample_images_gallery.select(predict, [sample_images_gallery, model_choice, max_tokens, temperature], caption)
demo.launch() |