File size: 3,566 Bytes
7ebfeb9
afda258
 
cf05f8b
22ed06b
 
623b4fb
931c795
7fcb6d2
623b4fb
 
7ebfeb9
afda258
aafac25
2bfbbec
7ebfeb9
 
aafac25
2bfbbec
cf05f8b
afda258
dee2758
6ae0110
dee2758
fc6f52f
7ebfeb9
623b4fb
22ed06b
 
 
 
dd914ca
 
 
 
aafac25
22ed06b
dd914ca
 
f63a88c
e6a32cf
dd914ca
e6a32cf
dd914ca
e6a32cf
 
dd914ca
 
f63a88c
dd914ca
22ed06b
 
dd914ca
 
 
 
a4fcc26
afda258
aafac25
b51c75c
aafac25
dee2758
 
afda258
dd914ca
623b4fb
dd914ca
 
 
 
 
 
 
 
 
 
 
afda258
dd914ca
aafac25
 
dd914ca
 
623b4fb
 
22ed06b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
from PIL import Image
import clipGPT
import vitGPT
import skimage.io as io
import PIL.Image
import difflib
import ViTCoAtt
from build_vocab import Vocabulary



# Caption generation functions
def generate_caption_clipgpt(image, max_tokens, temperature):
    caption = clipGPT.generate_caption_clipgpt(image, max_tokens, temperature)
    return caption

def generate_caption_vitgpt(image, max_tokens, temperature):
    caption = vitGPT.generate_caption(image, max_tokens, temperature)
    return caption

def generate_caption_vitCoAtt(image):
    caption = ViTCoAtt.CaptionSampler.main(image)
    return caption


with gr.Blocks() as demo:
    
    gr.HTML("<h1 style='text-align: center;'>MedViT: A Vision Transformer-Driven Method for Generating Medical Reports πŸ₯πŸ€–</h1>")
    gr.HTML("<p style='text-align: center;'>You can generate captions by uploading an X-Ray and selecting a model of your choice below</p>")

    sample_data = [
        {'image': 'https://imgur.com/W1pIr9b', 'max token, temp': '75, 0.7', 'model supported': 'CLIP-GPT2, ViT-GPT2, ViT-CoAttention', 'ground truth': '...'},
        {'image': 'https://imgur.com/MLJaWnf', 'max token, temp': '50, 0.8', 'model supported': 'CLIP-GPT2, ViT-CoAttention', 'ground truth': '...'},
    ]

    with gr.Row():
        image = gr.Image(label="Upload Chest X-ray", type="pil")  
        image_table = gr.Dataframe(sample_data, headers=['image', 'max token, temp', 'model supported', 'ground truth'], datatype=['picture', 'str', 'str', 'str'])

    gr.HTML("<p style='text-align: center;'> Please select the Number of Max Tokens and Temperature setting, if you are testing CLIP GPT2 and VIT GPT2 Models</p>")
    
    with gr.Row():
        with gr.Column(): 
            max_tokens = gr.Dropdown(list(range(50, 101)), label="Max Tokens", value=75)
            temperature = gr.Slider(0.5, 0.9, step=0.1, label="Temperature", value=0.7) 
        
        model_choice = gr.Radio(["CLIP-GPT2", "ViT-GPT2", "ViT-CoAttention"], label="Select Model")      
    
    generate_button = gr.Button("Generate Caption")  
    caption = gr.Textbox(label="Generated Caption") 


    gr.HTML("<p style='text-align: center;'> Please select the Number of Max Tokens and Temperature setting, if you are testing CLIP GPT2 and VIT GPT2 Models</p>")
        

    def predict(img,  model_name, max_tokens, temperature):
        if model_name == "CLIP-GPT2":
            return generate_caption_clipgpt(img, max_tokens, temperature)
        elif model_name == "ViT-GPT2":
            return generate_caption_vitgpt(img, max_tokens, temperature)
        elif model_name == "ViT-CoAttention":
            return generate_caption_vitCoAtt(img)
        else:
            return "Caption generation for this model is not yet implemented."  

    def predict_from_table(row, model_name):  
        img_url = row['image']
        img = Image.open(io.imread(img_url))  
        if model_name == "CLIP-GPT2":
            return generate_caption_clipgpt(img, max_tokens, temperature)
        elif model_name == "ViT-GPT2":
            return generate_caption_vitgpt(img, max_tokens, temperature)
        elif model_name == "ViT-CoAttention":
            return generate_caption_vitCoAtt(img)
        else:
            return "Caption generation for this model is not yet implemented."  

        
    # Event handlers 
    generate_button.click(predict, [image, model_choice, max_tokens, temperature], caption) 
    image_table.click(predict_from_table, [image_table, model_choice], caption)  



demo.launch()