File size: 3,833 Bytes
42fea26
 
 
2005ef8
42fea26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2005ef8
42fea26
 
 
2005ef8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import spaces
import os

import gradio as gr
import torch
from transformers import AutoModelForCausalLM

model_name = 'AIDC-AI/Ovis1.6-Gemma2-9B'

# load model
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             torch_dtype=torch.bfloat16,
                                             multimodal_max_length=8192,
                                             trust_remote_code=True).to(device='cuda')
text_tokenizer = model.get_text_tokenizer()
visual_tokenizer = model.get_visual_tokenizer()
image_placeholder = '<image>'


@spaces.GPU
def ovis_chat(chatbot, image_input, text_input):
    # preprocess inputs
    conversations = []
    for query, response in chatbot:
        conversations.append({
            "from": "human",
            "value": query
        })
        conversations.append({
            "from": "gpt",
            "value": response
        })
    text_input = text_input.replace(image_placeholder, '')
    conversations.append({
        "from": "human",
        "value": text_input
    })
    if image_input is not None:
        conversations[0]["value"] = image_placeholder + '\n' + conversations[0]["value"]
    prompt, input_ids, pixel_values = model.preprocess_inputs(conversations, [image_input])
    attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
    input_ids = input_ids.unsqueeze(0).to(device=model.device)
    attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
    if image_input is None:
        pixel_values = [None]
    else:
        pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]

    # generate output
    with torch.inference_mode():
        gen_kwargs = dict(
            max_new_tokens=512,
            do_sample=False,
            top_p=None,
            top_k=None,
            temperature=None,
            repetition_penalty=None,
            eos_token_id=model.generation_config.eos_token_id,
            pad_token_id=text_tokenizer.pad_token_id,
            use_cache=True
        )
    output_ids = model.generate(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **gen_kwargs)[0]
    output = text_tokenizer.decode(output_ids, skip_special_tokens=True)
    chatbot.append((text_input, output))

    return chatbot, ""


def clear_chat():
    return [], None, ""

md = f'''# <center>{model_name.split('/')[-1]}</center>
###
Ovis has been open-sourced on [GitHub](https://github.com/AIDC-AI/Ovis) and [Huggingface](https://huggingface.co/{model_name}). If you find Ovis useful, a star or a like would be appreciated.
'''

text_input = gr.Textbox(label="prompt", placeholder="Enter your text here...", lines=1, container=False)
with gr.Blocks(title=model_name.split('/')[-1]) as demo:
    gr.Markdown(md)
    cur_dir = os.path.dirname(os.path.abspath(__file__))
    with gr.Row():
        with gr.Column(scale=3):
            image_input = gr.Image(label="image", height=350, type="pil")
            gr.Examples(
                examples=[
                    [f"{cur_dir}/examples/rs-1.png", "What shape should come as the fourth shape?"]],
                inputs=[image_input, text_input]
            )
        with gr.Column(scale=7):
            chatbot = gr.Chatbot(label="Ovis", layout="panel", height=470, show_copy_button=True)
            text_input.render()
            with gr.Row():
                send_btn = gr.Button("Send", variant="primary")
                clear_btn = gr.Button("Clear", variant="secondary")

    send_click_event = send_btn.click(ovis_chat, [chatbot, image_input, text_input], [chatbot, text_input])
    submit_event = text_input.submit(ovis_chat, [chatbot, image_input, text_input], [chatbot, text_input])
    clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input])

demo.launch()