Spaces:
Runtime error
Runtime error
File size: 2,333 Bytes
bd9b9f2 6b8d35c bca9f16 6b8d35c bca9f16 bc83b1f 52fd1d4 aa464d7 bca9f16 6b8d35c e1acbd5 bd9b9f2 86573d5 bd9b9f2 34bb9f0 e1acbd5 34bb9f0 e1acbd5 aa464d7 593f239 e1acbd5 bd9b9f2 d867e19 593f239 d867e19 e1acbd5 34bb9f0 e1acbd5 bd9b9f2 6b8d35c 86573d5 bd9b9f2 aa464d7 bd9b9f2 364461b bd9b9f2 34bb9f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
import gradio as gr
from transformers import pipeline
CAPTION_MODELS = {
'blip-base': 'Salesforce/blip-image-captioning-base',
'blip-large': 'Salesforce/blip-image-captioning-large',
'vit-gpt2-coco-en': 'ydshieh/vit-gpt2-coco-en',
'blip2-2.7b_8bit': 'Mediocreatmybest/blip2-opt-2.7b_8bit',
'blip2-2.7b-fp16': 'Mediocreatmybest/blip2-opt-2.7b-fp16-sharded',
'blip2-2.7b': 'Salesforce/blip2-opt-2.7b',
}
# Create a dictionary to store loaded models
loaded_models = {}
# Simple caption creation
def caption_image(model_choice, image_input, url_input, load_in_8bit, device):
if image_input is not None:
input_data = image_input
else:
input_data = url_input
model_key = (model_choice, load_in_8bit) # Create a tuple to represent the unique combination of model and 8bit loading
# Check if the model is already loaded
if model_key in loaded_models:
captioner = loaded_models[model_key]
else:
model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
dtype = torch.float16 if load_in_8bit else torch.float32 # Set dtype based on the value of load_in_8bit
captioner = pipeline(task="image-to-text",
model=CAPTION_MODELS[model_choice],
max_new_tokens=30,
device=device, # Set the device as selected
model_kwargs=model_kwargs,
torch_dtype=dtype, # Set the floating point
use_fast=True
)
# Store the loaded model
loaded_models[model_key] = captioner
caption = captioner(input_data)
return [str(c['generated_text']).strip() for c in caption]
model_dropdown = gr.Dropdown(choices=list(CAPTION_MODELS.keys()), label='Select Caption Model')
image_input = gr.Image(type="pil", label="Input Image", multiple=True) # Enable multiple inputs
url_input = gr.Text(label="Input URL")
load_in_8bit = gr.Checkbox(label="Load model in 8bit")
device = gr.Radio(['cpu', 'cuda'], label='Select device', default='cpu')
iface = gr.Interface(caption_image, inputs=[model_dropdown, image_input, url_input, load_in_8bit, device], outputs=gr.interfaces.outputs.Textbox(type="text", label="Caption"))
iface.launch() |