File size: 4,464 Bytes
9dc5c64
3f01084
 
 
 
 
d24d100
 
 
3f01084
 
 
 
 
 
 
5b69bef
3f01084
 
9dc5c64
3f01084
 
 
 
 
 
d24d100
 
3f01084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dc5c64
3f01084
 
 
 
 
 
 
 
 
 
dc4de1a
 
 
 
3f01084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d24d100
3f01084
 
 
 
 
d24d100
 
 
 
3f01084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc4de1a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import spaces
import random
import torch
import gradio as gr
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

model_id = "ibm-granite/granite-vision-3.1-2b-preview"
processor = LlavaNextProcessor.from_pretrained(model_id, use_fast=True)
model = LlavaNextForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="auto")

def get_text_from_content(content):
    texts = []
    for item in content:
        if item["type"] == "text":
            texts.append(item["text"])
        elif item["type"] == "image":
            texts.append("<image>")
    return " ".join(texts)

@spaces.GPU
def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation):
    if conversation is None:
        conversation = []
        
    user_content = []
    if image is not None:
        if image.width > 512 or image.height > 512:
            image.thumbnail((512, 512))
        user_content.append({"type": "image", "image": image})
    if text and text.strip():
        user_content.append({"type": "text", "text": text.strip()})
    if not user_content:
        return conversation_display(conversation), conversation

    conversation.append({
        "role": "user",
        "content": user_content
    })

    inputs = processor.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    ).to("cuda")

    torch.manual_seed(random.randint(0, 10000))

    generation_kwargs = {
        "max_new_tokens": max_tokens,
        "top_p": top_p,
        "top_k": top_k,
        "do_sample": True,
    }

    if temperature > 0:
        generation_kwargs["temperature"] = temperature
        generation_kwargs["do_sample"] = True

    output = model.generate(**inputs, **generation_kwargs)
    assistant_response = processor.decode(output[0], skip_special_tokens=True)

    conversation.append({
        "role": "assistant",
        "content": [{"type": "text", "text": assistant_response.strip()}]
    })
    
    return conversation_display(conversation), conversation

def conversation_display(conversation):
    chat_history = []
    for msg in conversation:
        if msg["role"] == "user":
            user_text = get_text_from_content(msg["content"])
        elif msg["role"] == "assistant":
            assistant_text = msg["content"][0]["text"].split("<|assistant|>")[-1].strip()
            chat_history.append({"role": "user", "content": user_text})
            chat_history.append({"role": "assistant", "content": assistant_text})
    return chat_history

def clear_chat():
    return [], [], "", None
    
with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo:
    gr.Markdown("# [Granite Vision 3.1 2B](https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview)")
    
    with gr.Row():
        with gr.Column(scale=2):
            image_input = gr.Image(type="pil", label="Upload Image (optional)")
            with gr.Column():
                temperature_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.01, label="Temperature")
                top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.01, label="Top p")
                top_k_input = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top k")
                max_tokens_input = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max Tokens")
            
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", type='messages')
            text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message")
            with gr.Row():
                send_button = gr.Button("Chat")
                clear_button = gr.Button("Clear Chat")
    

    state = gr.State([])

    send_button.click(
        chat_inference,
        inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input, state],
        outputs=[chatbot, state]
    )

    clear_button.click(
        clear_chat,
        inputs=None,
        outputs=[chatbot, state, text_input, image_input]
    )

    gr.Examples(
        examples=[
            ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is this?"]
        ],
        inputs=[image_input, text_input]
    )

if __name__ == "__main__":
    demo.launch(show_api=False)