Spaces:
Runtime error
Runtime error
File size: 2,642 Bytes
bd9b9f2 6b8d35c bca9f16 20a5e29 6b8d35c bca9f16 bc83b1f 52fd1d4 aa464d7 bca9f16 6b8d35c e1acbd5 bd9b9f2 20a5e29 bd9b9f2 20a5e29 bd9b9f2 20a5e29 86573d5 20a5e29 bd9b9f2 34bb9f0 e1acbd5 34bb9f0 e1acbd5 aa464d7 593f239 e1acbd5 20a5e29 d867e19 593f239 d867e19 e1acbd5 34bb9f0 e1acbd5 20a5e29 6b8d35c 86573d5 20a5e29 aa464d7 20a5e29 364461b 20a5e29 34bb9f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch
import gradio as gr
from transformers import pipeline
import ast
CAPTION_MODELS = {
'blip-base': 'Salesforce/blip-image-captioning-base',
'blip-large': 'Salesforce/blip-image-captioning-large',
'vit-gpt2-coco-en': 'ydshieh/vit-gpt2-coco-en',
'blip2-2.7b_8bit': 'Mediocreatmybest/blip2-opt-2.7b_8bit',
'blip2-2.7b-fp16': 'Mediocreatmybest/blip2-opt-2.7b-fp16-sharded',
'blip2-2.7b': 'Salesforce/blip2-opt-2.7b',
}
# Create a dictionary to store loaded models
loaded_models = {}
# Simple caption creation
def caption_image(model_choice, image_input, url_inputs, load_in_8bit):
if image_input is not None:
input_data = [image_input]
else:
input_data = ast.literal_eval(url_inputs) # interpret the input string as a list
captions = []
model_key = (model_choice, load_in_8bit) # Create a tuple to represent the unique combination of model and 8bit loading
# Check if the model is already loaded
if model_key in loaded_models:
captioner = loaded_models[model_key]
else:
model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
dtype = torch.float16 if load_in_8bit else torch.float32 # Set dtype based on the value of load_in_8bit
captioner = pipeline(task="image-to-text",
model=CAPTION_MODELS[model_choice],
max_new_tokens=30,
device='cpu', # Set the device as CPU
model_kwargs=model_kwargs,
torch_dtype=dtype, # Set the floating point
use_fast=True
)
# Store the loaded model
loaded_models[model_key] = captioner
for input_item in input_data:
caption = captioner(input_item)[0]['generated_text']
captions.append(str(caption).strip())
return captions
def launch(model_choice, image_input, url_inputs, load_in_8bit, device):
return caption_image(model_choice, image_input, url_inputs, load_in_8bit, device)
model_dropdown = gr.Dropdown(choices=list(CAPTION_MODELS.keys()), label='Select Caption Model')
image_input = gr.Image(type="pil", label="Input Image", multiple=True) # Enable multiple inputs
url_inputs = gr.Textbox(label="Input URLs")
load_in_8bit = gr.Checkbox(label="Load model in 8bit")
device = gr.Radio(['cpu', 'cuda'], label='Select device', default='cpu')
iface = gr.Interface(launch, inputs=[model_dropdown, image_input, url_inputs, load_in_8bit, device],
outputs=gr.outputs.Textbox(type="text", label="Caption"))
iface.launch() |