File size: 2,269 Bytes
bca9f16
6b8d35c
bca9f16
6b8d35c
bca9f16
 
 
 
bc83b1f
52fd1d4
aa464d7
bca9f16
6b8d35c
e1acbd5
 
 
bca9f16
aa464d7
86573d5
 
 
 
 
34bb9f0
 
e1acbd5
34bb9f0
 
e1acbd5
aa464d7
593f239
e1acbd5
 
 
593f239
d867e19
593f239
d867e19
e1acbd5
 
34bb9f0
e1acbd5
86573d5
bca9f16
6b8d35c
aa464d7
 
6b8d35c
86573d5
 
 
aa464d7
364461b
aa464d7
34bb9f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import gradio as gr
from transformers import pipeline

CAPTION_MODELS = {
    'blip-base': 'Salesforce/blip-image-captioning-base',
    'blip-large': 'Salesforce/blip-image-captioning-large',
    'vit-gpt2-coco-en': 'ydshieh/vit-gpt2-coco-en',
    'blip2-2.7b_8bit': 'Mediocreatmybest/blip2-opt-2.7b_8bit',
    'blip2-2.7b-fp16': 'Mediocreatmybest/blip2-opt-2.7b-fp16-sharded',
    'blip2-2.7b': 'Salesforce/blip2-opt-2.7b',
}

# Create a dictionary to store loaded models
loaded_models = {}

# Simple caption creation
def caption_image(model_choice, image_input, url_input, load_in_8bit):
    if image_input is not None:
        input_data = image_input
    else:
        input_data = url_input

    model_key = (model_choice, load_in_8bit)  # Create a tuple to represent the unique combination of model and 8bit loading

    # Check if the model is already loaded
    if model_key in loaded_models:
        captioner = loaded_models[model_key]
    else:
        model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
        dtype = torch.float16 if load_in_8bit else torch.float32  # Set dtype based on the value of load_in_8bit
        captioner = pipeline(task="image-to-text",
                            model=CAPTION_MODELS[model_choice],
                            max_new_tokens=30,
                            device='cpu', # Set the device as CPU
                            model_kwargs=model_kwargs, 
                            torch_dtype=dtype,  # Set the floating point
                            use_fast=True
                            )
        # Store the loaded model
        loaded_models[model_key] = captioner

    caption = captioner(input_data)[0]['generated_text']
    return str(caption).strip()

def launch(model_choice, image_input, url_input, load_in_8bit):
    return caption_image(model_choice, image_input, url_input, load_in_8bit)

model_dropdown = gr.Dropdown(choices=list(CAPTION_MODELS.keys()), label='Select Caption Model')
image_input = gr.Image(type="pil", label="Input Image")
url_input = gr.Text(label="Input URL")
load_in_8bit = gr.Checkbox(label="Load model in 8bit")

iface = gr.Interface(launch, inputs=[model_dropdown, image_input, url_input, load_in_8bit], outputs="text")
iface.launch()