Spaces:
Runtime error
Runtime error
File size: 2,485 Bytes
6b8d35c bca9f16 9f77cf7 6b8d35c bca9f16 bc83b1f 52fd1d4 aa464d7 bca9f16 6b8d35c e1acbd5 9f77cf7 86573d5 9f77cf7 34bb9f0 e1acbd5 34bb9f0 e1acbd5 aa464d7 593f239 e1acbd5 9f77cf7 d867e19 593f239 d867e19 e1acbd5 34bb9f0 e1acbd5 9f77cf7 6b8d35c 9f77cf7 6b8d35c 86573d5 9f77cf7 aa464d7 9f77cf7 364461b 9f77cf7 34bb9f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import gradio as gr
from transformers import pipeline
import torch
CAPTION_MODELS = {
'blip-base': 'Salesforce/blip-image-captioning-base',
'blip-large': 'Salesforce/blip-image-captioning-large',
'vit-gpt2-coco-en': 'ydshieh/vit-gpt2-coco-en',
'blip2-2.7b_8bit': 'Mediocreatmybest/blip2-opt-2.7b_8bit',
'blip2-2.7b-fp16': 'Mediocreatmybest/blip2-opt-2.7b-fp16-sharded',
'blip2-2.7b': 'Salesforce/blip2-opt-2.7b',
}
# Create a dictionary to store loaded models
loaded_models = {}
# Modify caption_image to accept and process lists of images
def caption_image(model_choice, images_input, urls_input, load_in_8bit, device):
input_data = images_input if all(i is not None for i in images_input) else urls_input
model_key = (model_choice, load_in_8bit, device) # Update the model key to include the device
# Check if the model is already loaded
if model_key in loaded_models:
captioner = loaded_models[model_key]
else:
model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
dtype = torch.float16 if load_in_8bit else torch.float32 # Set dtype based on the value of load_in_8bit
captioner = pipeline(task="image-to-text",
model=CAPTION_MODELS[model_choice],
max_new_tokens=30,
device=device, # Set the device as selected by the user
model_kwargs=model_kwargs,
torch_dtype=dtype, # Set the floating point
use_fast=True
)
# Store the loaded model
loaded_models[model_key] = captioner
captions = captioner(input_data) # Run the model on the batch of images
results = [str(caption['generated_text']).strip() for caption in captions] # Extract the captions from the outputs
return results
model_dropdown = gr.Dropdown(choices=list(CAPTION_MODELS.keys()), label='Select Caption Model')
image_input = gr.Image(type="pil", label="Input Image") # Now takes multiple images
url_input = gr.Text(label="Input URL") # Now takes multiple URLs
load_in_8bit = gr.Checkbox(label="Load model in 8bit")
device = gr.Radio(choices=['cpu', 'cuda'], label='Device') # Radio button for device selection
iface = gr.Interface(caption_image, inputs=[model_dropdown, image_input, url_input, load_in_8bit, device], outputs=gr.outputs.Textbox(type="auto", label="Caption"))
iface.launch() |