Mediocreatmybest's picture
Update app.py
9f77cf7
raw
history blame
2.49 kB
import gradio as gr
from transformers import pipeline
import torch
CAPTION_MODELS = {
'blip-base': 'Salesforce/blip-image-captioning-base',
'blip-large': 'Salesforce/blip-image-captioning-large',
'vit-gpt2-coco-en': 'ydshieh/vit-gpt2-coco-en',
'blip2-2.7b_8bit': 'Mediocreatmybest/blip2-opt-2.7b_8bit',
'blip2-2.7b-fp16': 'Mediocreatmybest/blip2-opt-2.7b-fp16-sharded',
'blip2-2.7b': 'Salesforce/blip2-opt-2.7b',
}
# Create a dictionary to store loaded models
loaded_models = {}
# Modify caption_image to accept and process lists of images
def caption_image(model_choice, images_input, urls_input, load_in_8bit, device):
input_data = images_input if all(i is not None for i in images_input) else urls_input
model_key = (model_choice, load_in_8bit, device) # Update the model key to include the device
# Check if the model is already loaded
if model_key in loaded_models:
captioner = loaded_models[model_key]
else:
model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
dtype = torch.float16 if load_in_8bit else torch.float32 # Set dtype based on the value of load_in_8bit
captioner = pipeline(task="image-to-text",
model=CAPTION_MODELS[model_choice],
max_new_tokens=30,
device=device, # Set the device as selected by the user
model_kwargs=model_kwargs,
torch_dtype=dtype, # Set the floating point
use_fast=True
)
# Store the loaded model
loaded_models[model_key] = captioner
captions = captioner(input_data) # Run the model on the batch of images
results = [str(caption['generated_text']).strip() for caption in captions] # Extract the captions from the outputs
return results
model_dropdown = gr.Dropdown(choices=list(CAPTION_MODELS.keys()), label='Select Caption Model')
image_input = gr.Image(type="pil", label="Input Image") # Now takes multiple images
url_input = gr.Text(label="Input URL") # Now takes multiple URLs
load_in_8bit = gr.Checkbox(label="Load model in 8bit")
device = gr.Radio(choices=['cpu', 'cuda'], label='Device') # Radio button for device selection
iface = gr.Interface(caption_image, inputs=[model_dropdown, image_input, url_input, load_in_8bit, device], outputs=gr.outputs.Textbox(type="auto", label="Caption"))
iface.launch()