File size: 4,710 Bytes
2773523
b8c788a
2773523
c57b6d0
2773523
cbdd0b8
 
1b1e4db
 
2773523
3a227f4
800cbf3
2773523
2d29620
1b1e4db
05c2134
 
1b1e4db
bd5ba96
b8c788a
8b3c656
b8c788a
 
8b3c656
336c80c
3072768
38ab1e3
3072768
 
 
2773523
f61d812
76c8f3a
 
d6b2a16
76c8f3a
 
2773523
 
 
 
f61d812
b8c788a
 
375dc5c
b8c788a
 
f61d812
 
 
 
b8c788a
 
 
 
2773523
2d29620
05c2134
 
2773523
8b3c656
b8c788a
8b3c656
336c80c
8b3c656
2773523
 
800cbf3
b8c788a
2773523
329d18e
b8c788a
6334863
2773523
329d18e
2773523
05c2134
2773523
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration, VisionEncoderDecoderModel, InstructBlipForConditionalGeneration
import torch
import open_clip

from huggingface_hub import hf_hub_download

device = "cuda" if torch.cuda.is_available() else "cpu"

torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')

git_processor_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
git_model_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco").to(device)

blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)

blip2_processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-6.7b")
blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b", device_map="auto", torch_dtype=torch.float16)

instructblip_processor = AutoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
instructblip_model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto", torch_dtype=torch.float16)

def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
    inputs = processor(images=image, return_tensors="pt").to(device)

    if use_float_16:
        inputs = inputs.to(torch.float16)
    
    generated_ids = model.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5) 

    if tokenizer is not None:
        generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    else:
        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
   
    return generated_caption


def generate_caption_blip2(processor, model, image, replace_token=False):
    prompt = "Generate a caption for the image:"
    
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device=device, dtype=torch.float16)

    generated_ids = model.generate(pixel_values=inputs.pixel_values,
                                   num_beams=5, max_length=50, min_length=1, top_p=0.9, repetition_penalty=1.5, length_penalty=1.0, temperature=1)
    if replace_token:
        # TODO remove once https://github.com/huggingface/transformers/pull/24492 is merged
        generated_ids[generated_ids == 0] = 2 
    
    return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]


def generate_captions(image):
    caption_git_large_coco = generate_caption(git_processor_large_coco, git_model_large_coco, image)

    caption_blip_large = generate_caption(blip_processor_large, blip_model_large, image)

    caption_blip2 = generate_caption_blip2(blip2_processor, blip2_model, image).strip()

    caption_instructblip = generate_caption_blip2(instructblip_processor, instructblip_model, image)

    return caption_git_large_coco, caption_blip_large, caption_blip2, caption_instructblip

   
examples = [["cats.jpg"], ["stop_sign.png"], ["astronaut.jpg"]]
outputs = [gr.outputs.Textbox(label="Caption generated by GIT-large fine-tuned on COCO"), gr.outputs.Textbox(label="Caption generated by BLIP-large"), gr.outputs.Textbox(label="Caption generated by BLIP-2 OPT 6.7b"), gr.outputs.Textbox(label="Caption generated by InstructBLIP"), ] 

title = "Interactive demo: comparing image captioning models"
description = "Gradio Demo to compare GIT, BLIP, BLIP-2 and InstructBLIP, 4 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"

interface = gr.Interface(fn=generate_captions, 
                         inputs=gr.inputs.Image(type="pil"),
                         outputs=outputs,
                         examples=examples, 
                         title=title,
                         description=description,
                         article=article, 
                         enable_queue=True)
interface.launch(debug=True)