File size: 7,052 Bytes
2773523
b8c788a
2773523
c57b6d0
2773523
cbdd0b8
 
2773523
3a227f4
800cbf3
2773523
2d29620
 
2773523
2d29620
 
05c2134
b8c788a
156d498
2d29620
 
 
05c2134
 
 
2773523
2a729e4
 
bd5ba96
b8c788a
 
 
 
 
336c80c
bd5ba96
 
 
c62a436
b8c788a
 
 
 
c57b6d0
a808d84
 
2d29620
 
 
156d498
05c2134
bd5ba96
b8c788a
2a729e4
a808d84
3072768
38ab1e3
3072768
 
 
2773523
a808d84
76c8f3a
 
d6b2a16
76c8f3a
 
2773523
 
 
 
c57b6d0
 
7e4d7af
 
c57b6d0
 
 
b8c788a
 
 
 
 
 
 
 
 
 
 
 
2773523
2d29620
 
 
05c2134
156d498
05c2134
2d29620
2773523
05c2134
2773523
bd5ba96
c62a436
b8c788a
c57b6d0
2a729e4
bd5ba96
b8c788a
 
 
336c80c
b8c788a
2773523
 
800cbf3
b8c788a
2773523
329d18e
b8c788a
6334863
2773523
329d18e
2773523
05c2134
2773523
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration, VisionEncoderDecoderModel, InstructBlipForConditionalGeneration
import torch
import open_clip

from huggingface_hub import hf_hub_download

torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')

# git_processor_base = AutoProcessor.from_pretrained("microsoft/git-base-coco")
# git_model_base = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")

git_processor_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
git_model_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")

# git_processor_large_textcaps = AutoProcessor.from_pretrained("microsoft/git-large-r-textcaps")
# git_model_large_textcaps = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps")

# blip_processor_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# blip_model_base = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")

# blip2_processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
# blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)

blip2_processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-6.7b")
blip2_model_4_bit = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b", device_map="auto", load_in_4bit=True, torch_dtype=torch.float16)

instructblip_processor = AutoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
instructblip_model_4_bit = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto", load_in_4bit=True, torch_dtype=torch.float16)

# vitgpt_processor = AutoImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
# vitgpt_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
# vitgpt_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

# coca_model, _, coca_transform = open_clip.create_model_and_transforms(
#  model_name="coca_ViT-L-14",
#  pretrained="mscoco_finetuned_laion2B-s13B-b90k"
# )

device = "cuda" if torch.cuda.is_available() else "cpu"

# git_model_base.to(device)
# blip_model_base.to(device)
git_model_large_coco.to(device)
# git_model_large_textcaps.to(device)
blip_model_large.to(device)
# vitgpt_model.to(device)
# coca_model.to(device)
# blip2_model.to(device)

def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
    inputs = processor(images=image, return_tensors="pt").to(device)

    if use_float_16:
        inputs = inputs.to(torch.float16)
    
    generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)

    if tokenizer is not None:
        generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    else:
        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
   
    return generated_caption


def generate_caption_coca(model, transform, image):
    im = transform(image).unsqueeze(0).to(device)
    with torch.no_grad(), torch.cuda.amp.autocast():
        generated = model.generate(im, seq_len=20)
    return open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")


def generate_caption_instructblip(processor, model, image):
    prompt = "Generate a caption for the image:"
    
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device=device, torch_dtype=torch.float16)

    generated_ids = model.generate(pixel_values=inputs.pixel_values,
                                   num_beams=5, max_length=256, min_length=1, top_p=0.9, repetition_penalty=1.5, length_penalty=1.0, temperature=1)
    generated_ids[generated_ids == 0] = 2 # TODO remove once https://github.com/huggingface/transformers/pull/24492 is merged
    
    return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]


def generate_captions(image):
    # caption_git_base = generate_caption(git_processor_base, git_model_base, image)

    caption_git_large_coco = generate_caption(git_processor_large_coco, git_model_large_coco, image)

    # caption_git_large_textcaps = generate_caption(git_processor_large_textcaps, git_model_large_textcaps, image)

    # caption_blip_base = generate_caption(blip_processor_base, blip_model_base, image)

    caption_blip_large = generate_caption(blip_processor_large, blip_model_large, image)

    # caption_vitgpt = generate_caption(vitgpt_processor, vitgpt_model, image, vitgpt_tokenizer)

    # caption_coca = generate_caption_coca(coca_model, coca_transform, image)

    # caption_blip2 = generate_caption(blip2_processor, blip2_model, image, use_float_16=True).strip()

    caption_blip2_8_bit = generate_caption(blip2_processor, blip2_model_8_bit, image, use_float_16=True).strip()

    caption_instructblip_4_bit = generate_caption_instructblip(instructblip_processor, instructblip_model_4_bit, image)

    return caption_git_large_coco, caption_blip_large, caption_blip2_8_bit, caption_instructblip_4_bit

   
examples = [["cats.jpg"], ["stop_sign.png"], ["astronaut.jpg"]]
outputs = [gr.outputs.Textbox(label="Caption generated by GIT-large fine-tuned on COCO"), gr.outputs.Textbox(label="Caption generated by BLIP-large"), gr.outputs.Textbox(label="Caption generated by BLIP-2 OPT 6.7b"), gr.outputs.Textbox(label="Caption generated by InstructBLIP"), ] 

title = "Interactive demo: comparing image captioning models"
description = "Gradio Demo to compare GIT, BLIP, BLIP-2 and InstructBLIP, 4 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"

interface = gr.Interface(fn=generate_captions, 
                         inputs=gr.inputs.Image(type="pil"),
                         outputs=outputs,
                         examples=examples, 
                         title=title,
                         description=description,
                         article=article, 
                         enable_queue=True)
interface.launch(debug=True)