File size: 2,026 Bytes
789acc7
 
fd950ef
 
789acc7
fd950ef
 
 
 
 
 
 
cd44f8b
f8d9f18
fd950ef
 
 
4f9f0e6
fd950ef
 
 
f8d9f18
4f9f0e6
f8d9f18
4f9f0e6
fd950ef
 
4f9f0e6
 
fd950ef
 
 
 
 
 
 
 
 
 
 
 
f8d9f18
fd950ef
f8d9f18
fd950ef
cd44f8b
f8d9f18
f4d3338
cd44f8b
f4d3338
 
 
 
fd950ef
 
789acc7
 
5ee7893
fd950ef
 
 
 
 
 
 
 
789acc7
f8d9f18
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import warnings

# disable some warnings
transformers.logging.set_verbosity_error()
transformers.logging.disable_progress_bar()
warnings.filterwarnings('ignore')

# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = 'cognitivecomputations/dolphin-vision-7b'

# create model and load it to the specified device
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map='auto',
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True
)

def inference(prompt, image):
    messages = [
        {"role": "user", "content": f'<image>\n{prompt}'}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
    input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0).to(device)

    image_tensor = model.process_images([image], model.config).to(device)

    # generate
    with torch.cuda.amp.autocast():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            max_new_tokens=2048,
            use_cache=True
        )[0]

    return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", placeholder="Describe this image in detail")
            image_input = gr.Image(label="Image", type="pil")
            submit_button = gr.Button("Submit")
        with gr.Column():
            output_text = gr.Textbox(label="Output")

    submit_button.click(fn=inference, inputs=[prompt_input, image_input], outputs=output_text)

demo.launch(share=True)