Spaces:
Runtime error
Runtime error
File size: 1,943 Bytes
bca9f16 6b8d35c bca9f16 6b8d35c bca9f16 52fd1d4 aa464d7 bca9f16 6b8d35c e1acbd5 bca9f16 aa464d7 86573d5 34bb9f0 e1acbd5 34bb9f0 e1acbd5 aa464d7 e1acbd5 aa464d7 e1acbd5 34bb9f0 e1acbd5 86573d5 bca9f16 6b8d35c aa464d7 6b8d35c 86573d5 aa464d7 364461b aa464d7 34bb9f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch
import gradio as gr
from transformers import pipeline
CAPTION_MODELS = {
'blip-base': 'Salesforce/blip-image-captioning-base',
'blip-large': 'Salesforce/blip-image-captioning-large',
'vit-gpt2-coco-en': 'ydshieh/vit-gpt2-coco-en',
'blip2-2.7b-fp16': 'Mediocreatmybest/blip2-opt-2.7b-fp16-sharded',
'blip2-2.7b': 'Salesforce/blip2-opt-2.7b',
}
# Create a dictionary to store loaded models
loaded_models = {}
# Simple caption creation
def caption_image(model_choice, image_input, url_input, load_in_8bit):
if image_input is not None:
input_data = image_input
else:
input_data = url_input
model_key = (model_choice, load_in_8bit) # Create a tuple to represent the unique combination of model and 8bit loading
# Check if the model is already loaded
if model_key in loaded_models:
captioner = loaded_models[model_key]
else:
model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
captioner = pipeline(task="image-to-text",
model=CAPTION_MODELS[model_choice],
max_new_tokens=30,
device_map="cpu", model_kwargs=model_kwargs, use_fast=True
)
# Store the loaded model
loaded_models[model_key] = captioner
caption = captioner(input_data)[0]['generated_text']
return str(caption).strip()
def launch(model_choice, image_input, url_input, load_in_8bit):
return caption_image(model_choice, image_input, url_input, load_in_8bit)
model_dropdown = gr.Dropdown(choices=list(CAPTION_MODELS.keys()), label='Select Caption Model')
image_input = gr.Image(type="pil", label="Input Image")
url_input = gr.Text(label="Input URL")
load_in_8bit = gr.Checkbox(label="Load model in 8bit")
iface = gr.Interface(launch, inputs=[model_dropdown, image_input, url_input, load_in_8bit], outputs="text")
iface.launch() |