File size: 6,006 Bytes
2773523
bd5ba96
2773523
c57b6d0
2773523
cbdd0b8
 
2773523
3a227f4
800cbf3
2773523
2d29620
 
2773523
2d29620
 
05c2134
2d29620
 
 
 
 
05c2134
 
 
2773523
2a729e4
 
bd5ba96
336c80c
 
 
bd5ba96
 
 
c62a436
340acb8
 
 
 
c57b6d0
a808d84
 
2d29620
 
 
 
05c2134
bd5ba96
340acb8
2a729e4
a808d84
3072768
38ab1e3
3072768
 
 
2773523
a808d84
76c8f3a
 
d6b2a16
76c8f3a
 
2773523
 
 
 
c57b6d0
 
7e4d7af
 
c57b6d0
 
 
2773523
2d29620
 
 
05c2134
2d29620
05c2134
2d29620
2773523
05c2134
2773523
bd5ba96
c62a436
340acb8
c57b6d0
2a729e4
bd5ba96
336c80c
 
2a729e4
2773523
 
800cbf3
2a729e4
2773523
329d18e
bd5ba96
6334863
2773523
329d18e
2773523
05c2134
2773523
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration, VisionEncoderDecoderModel
import torch
import open_clip

from huggingface_hub import hf_hub_download

torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')

# git_processor_base = AutoProcessor.from_pretrained("microsoft/git-base-coco")
# git_model_base = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")

git_processor_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
git_model_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")

git_processor_large_textcaps = AutoProcessor.from_pretrained("microsoft/git-large-r-textcaps")
git_model_large_textcaps = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps")

# blip_processor_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# blip_model_base = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")

# blip2_processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
# blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)

blip2_processor_8_bit = AutoProcessor.from_pretrained("Salesforce/blip2-opt-6.7b")
blip2_model_8_bit = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b", device_map="auto", load_in_8bit=True)

# vitgpt_processor = AutoImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
# vitgpt_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
# vitgpt_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

coca_model, _, coca_transform = open_clip.create_model_and_transforms(
  model_name="coca_ViT-L-14",
  pretrained="mscoco_finetuned_laion2B-s13B-b90k"
)

device = "cuda" if torch.cuda.is_available() else "cpu"

# git_model_base.to(device)
# blip_model_base.to(device)
git_model_large_coco.to(device)
git_model_large_textcaps.to(device)
blip_model_large.to(device)
# vitgpt_model.to(device)
coca_model.to(device)
# blip2_model.to(device)

def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
    inputs = processor(images=image, return_tensors="pt").to(device)

    if use_float_16:
        inputs = inputs.to(torch.float16)
    
    generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)

    if tokenizer is not None:
        generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    else:
        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
   
    return generated_caption


def generate_caption_coca(model, transform, image):
    im = transform(image).unsqueeze(0).to(device)
    with torch.no_grad(), torch.cuda.amp.autocast():
        generated = model.generate(im, seq_len=20)
    return open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")


def generate_captions(image):
    # caption_git_base = generate_caption(git_processor_base, git_model_base, image)

    caption_git_large_coco = generate_caption(git_processor_large_coco, git_model_large_coco, image)

    caption_git_large_textcaps = generate_caption(git_processor_large_textcaps, git_model_large_textcaps, image)

    # caption_blip_base = generate_caption(blip_processor_base, blip_model_base, image)

    caption_blip_large = generate_caption(blip_processor_large, blip_model_large, image)

    # caption_vitgpt = generate_caption(vitgpt_processor, vitgpt_model, image, vitgpt_tokenizer)

    caption_coca = generate_caption_coca(coca_model, coca_transform, image)

    # caption_blip2 = generate_caption(blip2_processor, blip2_model, image, use_float_16=True).strip()

    caption_blip2_8_bit = generate_caption(blip2_processor_8_bit, blip2_model_8_bit, image, use_float_16=True).strip()

    return caption_git_large_coco, caption_git_large_textcaps, caption_blip_large, caption_coca, caption_blip2_8_bit

   
examples = [["cats.jpg"], ["stop_sign.png"], ["astronaut.jpg"]]
outputs = [gr.outputs.Textbox(label="Caption generated by GIT-large fine-tuned on COCO"), gr.outputs.Textbox(label="Caption generated by GIT-large fine-tuned on TextCaps"), gr.outputs.Textbox(label="Caption generated by BLIP-large"), gr.outputs.Textbox(label="Caption generated by CoCa"), gr.outputs.Textbox(label="Caption generated by BLIP-2 OPT 6.7b")] 

title = "Interactive demo: comparing image captioning models"
description = "Gradio Demo to compare GIT, BLIP, CoCa, and BLIP-2, 4 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"

interface = gr.Interface(fn=generate_captions, 
                         inputs=gr.inputs.Image(type="pil"),
                         outputs=outputs,
                         examples=examples, 
                         title=title,
                         description=description,
                         article=article, 
                         enable_queue=True)
interface.launch(debug=True)