File size: 2,333 Bytes
bd9b9f2
6b8d35c
bca9f16
6b8d35c
bca9f16
 
 
 
bc83b1f
52fd1d4
aa464d7
bca9f16
6b8d35c
e1acbd5
 
 
bd9b9f2
 
 
 
 
 
86573d5
bd9b9f2
34bb9f0
e1acbd5
34bb9f0
 
e1acbd5
aa464d7
593f239
e1acbd5
 
 
bd9b9f2
d867e19
593f239
d867e19
e1acbd5
 
34bb9f0
e1acbd5
bd9b9f2
 
6b8d35c
86573d5
bd9b9f2
 
aa464d7
bd9b9f2
364461b
bd9b9f2
34bb9f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import gradio as gr
from transformers import pipeline

CAPTION_MODELS = {
    'blip-base': 'Salesforce/blip-image-captioning-base',
    'blip-large': 'Salesforce/blip-image-captioning-large',
    'vit-gpt2-coco-en': 'ydshieh/vit-gpt2-coco-en',
    'blip2-2.7b_8bit': 'Mediocreatmybest/blip2-opt-2.7b_8bit',
    'blip2-2.7b-fp16': 'Mediocreatmybest/blip2-opt-2.7b-fp16-sharded',
    'blip2-2.7b': 'Salesforce/blip2-opt-2.7b',
}

# Create a dictionary to store loaded models
loaded_models = {}

# Simple caption creation
def caption_image(model_choice, image_input, url_input, load_in_8bit, device):
    if image_input is not None:
        input_data = image_input
    else:
        input_data = url_input

    model_key = (model_choice, load_in_8bit)  # Create a tuple to represent the unique combination of model and 8bit loading

    # Check if the model is already loaded
    if model_key in loaded_models:
        captioner = loaded_models[model_key]
    else:
        model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
        dtype = torch.float16 if load_in_8bit else torch.float32  # Set dtype based on the value of load_in_8bit
        captioner = pipeline(task="image-to-text",
                            model=CAPTION_MODELS[model_choice],
                            max_new_tokens=30,
                            device=device, # Set the device as selected
                            model_kwargs=model_kwargs, 
                            torch_dtype=dtype,  # Set the floating point
                            use_fast=True
                            )
        # Store the loaded model
        loaded_models[model_key] = captioner

    caption = captioner(input_data)
    return [str(c['generated_text']).strip() for c in caption]

model_dropdown = gr.Dropdown(choices=list(CAPTION_MODELS.keys()), label='Select Caption Model')
image_input = gr.Image(type="pil", label="Input Image", multiple=True)  # Enable multiple inputs
url_input = gr.Text(label="Input URL")
load_in_8bit = gr.Checkbox(label="Load model in 8bit")
device = gr.Radio(['cpu', 'cuda'], label='Select device', default='cpu')

iface = gr.Interface(caption_image, inputs=[model_dropdown, image_input, url_input, load_in_8bit, device], outputs=gr.interfaces.outputs.Textbox(type="text", label="Caption"))
iface.launch()