Spaces:
Runtime error
Runtime error
File size: 6,006 Bytes
2773523 bd5ba96 2773523 c57b6d0 2773523 cbdd0b8 2773523 3a227f4 800cbf3 2773523 2d29620 2773523 2d29620 05c2134 2d29620 05c2134 2773523 2a729e4 bd5ba96 336c80c bd5ba96 c62a436 340acb8 c57b6d0 a808d84 2d29620 05c2134 bd5ba96 340acb8 2a729e4 a808d84 3072768 38ab1e3 3072768 2773523 a808d84 76c8f3a d6b2a16 76c8f3a 2773523 c57b6d0 7e4d7af c57b6d0 2773523 2d29620 05c2134 2d29620 05c2134 2d29620 2773523 05c2134 2773523 bd5ba96 c62a436 340acb8 c57b6d0 2a729e4 bd5ba96 336c80c 2a729e4 2773523 800cbf3 2a729e4 2773523 329d18e bd5ba96 6334863 2773523 329d18e 2773523 05c2134 2773523 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration, VisionEncoderDecoderModel
import torch
import open_clip
from huggingface_hub import hf_hub_download
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')
# git_processor_base = AutoProcessor.from_pretrained("microsoft/git-base-coco")
# git_model_base = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
git_processor_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
git_model_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
git_processor_large_textcaps = AutoProcessor.from_pretrained("microsoft/git-large-r-textcaps")
git_model_large_textcaps = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps")
# blip_processor_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# blip_model_base = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
# blip2_processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
# blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
blip2_processor_8_bit = AutoProcessor.from_pretrained("Salesforce/blip2-opt-6.7b")
blip2_model_8_bit = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b", device_map="auto", load_in_8bit=True)
# vitgpt_processor = AutoImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
# vitgpt_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
# vitgpt_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
coca_model, _, coca_transform = open_clip.create_model_and_transforms(
model_name="coca_ViT-L-14",
pretrained="mscoco_finetuned_laion2B-s13B-b90k"
)
device = "cuda" if torch.cuda.is_available() else "cpu"
# git_model_base.to(device)
# blip_model_base.to(device)
git_model_large_coco.to(device)
git_model_large_textcaps.to(device)
blip_model_large.to(device)
# vitgpt_model.to(device)
coca_model.to(device)
# blip2_model.to(device)
def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
inputs = processor(images=image, return_tensors="pt").to(device)
if use_float_16:
inputs = inputs.to(torch.float16)
generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
if tokenizer is not None:
generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_caption
def generate_caption_coca(model, transform, image):
im = transform(image).unsqueeze(0).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
generated = model.generate(im, seq_len=20)
return open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
def generate_captions(image):
# caption_git_base = generate_caption(git_processor_base, git_model_base, image)
caption_git_large_coco = generate_caption(git_processor_large_coco, git_model_large_coco, image)
caption_git_large_textcaps = generate_caption(git_processor_large_textcaps, git_model_large_textcaps, image)
# caption_blip_base = generate_caption(blip_processor_base, blip_model_base, image)
caption_blip_large = generate_caption(blip_processor_large, blip_model_large, image)
# caption_vitgpt = generate_caption(vitgpt_processor, vitgpt_model, image, vitgpt_tokenizer)
caption_coca = generate_caption_coca(coca_model, coca_transform, image)
# caption_blip2 = generate_caption(blip2_processor, blip2_model, image, use_float_16=True).strip()
caption_blip2_8_bit = generate_caption(blip2_processor_8_bit, blip2_model_8_bit, image, use_float_16=True).strip()
return caption_git_large_coco, caption_git_large_textcaps, caption_blip_large, caption_coca, caption_blip2_8_bit
examples = [["cats.jpg"], ["stop_sign.png"], ["astronaut.jpg"]]
outputs = [gr.outputs.Textbox(label="Caption generated by GIT-large fine-tuned on COCO"), gr.outputs.Textbox(label="Caption generated by GIT-large fine-tuned on TextCaps"), gr.outputs.Textbox(label="Caption generated by BLIP-large"), gr.outputs.Textbox(label="Caption generated by CoCa"), gr.outputs.Textbox(label="Caption generated by BLIP-2 OPT 6.7b")]
title = "Interactive demo: comparing image captioning models"
description = "Gradio Demo to compare GIT, BLIP, CoCa, and BLIP-2, 4 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"
interface = gr.Interface(fn=generate_captions,
inputs=gr.inputs.Image(type="pil"),
outputs=outputs,
examples=examples,
title=title,
description=description,
article=article,
enable_queue=True)
interface.launch(debug=True) |