File size: 2,642 Bytes
bd9b9f2
6b8d35c
bca9f16
20a5e29
6b8d35c
bca9f16
 
 
 
bc83b1f
52fd1d4
aa464d7
bca9f16
6b8d35c
e1acbd5
 
 
bd9b9f2
20a5e29
bd9b9f2
20a5e29
bd9b9f2
20a5e29
86573d5
20a5e29
bd9b9f2
34bb9f0
e1acbd5
34bb9f0
 
e1acbd5
aa464d7
593f239
e1acbd5
 
 
20a5e29
d867e19
593f239
d867e19
e1acbd5
 
34bb9f0
e1acbd5
20a5e29
 
 
 
 
 
 
6b8d35c
86573d5
20a5e29
 
aa464d7
20a5e29
364461b
20a5e29
 
34bb9f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import gradio as gr
from transformers import pipeline
import ast

CAPTION_MODELS = {
    'blip-base': 'Salesforce/blip-image-captioning-base',
    'blip-large': 'Salesforce/blip-image-captioning-large',
    'vit-gpt2-coco-en': 'ydshieh/vit-gpt2-coco-en',
    'blip2-2.7b_8bit': 'Mediocreatmybest/blip2-opt-2.7b_8bit',
    'blip2-2.7b-fp16': 'Mediocreatmybest/blip2-opt-2.7b-fp16-sharded',
    'blip2-2.7b': 'Salesforce/blip2-opt-2.7b',
}

# Create a dictionary to store loaded models
loaded_models = {}

# Simple caption creation
def caption_image(model_choice, image_input, url_inputs, load_in_8bit):
    if image_input is not None:
        input_data = [image_input]
    else:
        input_data = ast.literal_eval(url_inputs)  # interpret the input string as a list

    captions = []
    model_key = (model_choice, load_in_8bit)  # Create a tuple to represent the unique combination of model and 8bit loading

    # Check if the model is already loaded
    if model_key in loaded_models:
        captioner = loaded_models[model_key]
    else:
        model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
        dtype = torch.float16 if load_in_8bit else torch.float32  # Set dtype based on the value of load_in_8bit
        captioner = pipeline(task="image-to-text",
                            model=CAPTION_MODELS[model_choice],
                            max_new_tokens=30,
                            device='cpu', # Set the device as CPU
                            model_kwargs=model_kwargs, 
                            torch_dtype=dtype,  # Set the floating point
                            use_fast=True
                            )
        # Store the loaded model
        loaded_models[model_key] = captioner

    for input_item in input_data:
        caption = captioner(input_item)[0]['generated_text']
        captions.append(str(caption).strip())
    return captions

def launch(model_choice, image_input, url_inputs, load_in_8bit, device):
    return caption_image(model_choice, image_input, url_inputs, load_in_8bit, device)

model_dropdown = gr.Dropdown(choices=list(CAPTION_MODELS.keys()), label='Select Caption Model')
image_input = gr.Image(type="pil", label="Input Image", multiple=True)  # Enable multiple inputs
url_inputs = gr.Textbox(label="Input URLs")
load_in_8bit = gr.Checkbox(label="Load model in 8bit")
device = gr.Radio(['cpu', 'cuda'], label='Select device', default='cpu')

iface = gr.Interface(launch, inputs=[model_dropdown, image_input, url_inputs, load_in_8bit, device], 
                     outputs=gr.outputs.Textbox(type="text", label="Caption"))
iface.launch()