File size: 2,485 Bytes
6b8d35c
bca9f16
9f77cf7
6b8d35c
bca9f16
 
 
 
bc83b1f
52fd1d4
aa464d7
bca9f16
6b8d35c
e1acbd5
 
 
9f77cf7
 
 
86573d5
9f77cf7
34bb9f0
e1acbd5
34bb9f0
 
e1acbd5
aa464d7
593f239
e1acbd5
 
 
9f77cf7
d867e19
593f239
d867e19
e1acbd5
 
34bb9f0
e1acbd5
9f77cf7
 
6b8d35c
9f77cf7
6b8d35c
86573d5
9f77cf7
 
aa464d7
9f77cf7
364461b
9f77cf7
34bb9f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr
from transformers import pipeline
import torch

CAPTION_MODELS = {
    'blip-base': 'Salesforce/blip-image-captioning-base',
    'blip-large': 'Salesforce/blip-image-captioning-large',
    'vit-gpt2-coco-en': 'ydshieh/vit-gpt2-coco-en',
    'blip2-2.7b_8bit': 'Mediocreatmybest/blip2-opt-2.7b_8bit',
    'blip2-2.7b-fp16': 'Mediocreatmybest/blip2-opt-2.7b-fp16-sharded',
    'blip2-2.7b': 'Salesforce/blip2-opt-2.7b',
}

# Create a dictionary to store loaded models
loaded_models = {}

# Modify caption_image to accept and process lists of images
def caption_image(model_choice, images_input, urls_input, load_in_8bit, device):
    input_data = images_input if all(i is not None for i in images_input) else urls_input

    model_key = (model_choice, load_in_8bit, device)  # Update the model key to include the device

    # Check if the model is already loaded
    if model_key in loaded_models:
        captioner = loaded_models[model_key]
    else:
        model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
        dtype = torch.float16 if load_in_8bit else torch.float32  # Set dtype based on the value of load_in_8bit
        captioner = pipeline(task="image-to-text",
                            model=CAPTION_MODELS[model_choice],
                            max_new_tokens=30,
                            device=device, # Set the device as selected by the user
                            model_kwargs=model_kwargs, 
                            torch_dtype=dtype,  # Set the floating point
                            use_fast=True
                            )
        # Store the loaded model
        loaded_models[model_key] = captioner

    captions = captioner(input_data)  # Run the model on the batch of images
    results = [str(caption['generated_text']).strip() for caption in captions]  # Extract the captions from the outputs

    return results

model_dropdown = gr.Dropdown(choices=list(CAPTION_MODELS.keys()), label='Select Caption Model')
image_input = gr.Image(type="pil", label="Input Image")  # Now takes multiple images
url_input = gr.Text(label="Input URL")  # Now takes multiple URLs
load_in_8bit = gr.Checkbox(label="Load model in 8bit")
device = gr.Radio(choices=['cpu', 'cuda'], label='Device')  # Radio button for device selection

iface = gr.Interface(caption_image, inputs=[model_dropdown, image_input, url_input, load_in_8bit, device], outputs=gr.outputs.Textbox(type="auto", label="Caption"))
iface.launch()