File size: 4,214 Bytes
3f01084
 
b14f3d4
3f01084
 
 
 
 
 
 
e6a9c05
3f01084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6a9c05
3f01084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import random
import torch
import spaces
import gradio as gr
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

model_path = "ibm-granite/granite-vision-3.1-2b-preview"
processor = LlavaNextProcessor.from_pretrained(model_path, use_fast=True)
model = LlavaNextForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")

@spaces.GPU()
def get_text_from_content(content):
    texts = []
    for item in content:
        if item["type"] == "text":
            texts.append(item["text"])
        elif item["type"] == "image":
            texts.append("[Image]")
    return " ".join(texts)

def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation):
    if conversation is None:
        conversation = []
        
    user_content = []
    if image is not None:
        user_content.append({"type": "image", "image": image})
    if text and text.strip():
        user_content.append({"type": "text", "text": text.strip()})
    if not user_content:
        return conversation_display(conversation), conversation

    conversation.append({
        "role": "user",
        "content": user_content
    })

    inputs = processor.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    ).to(model.device)

    torch.manual_seed(random.randint(0, 10000))

    generation_kwargs = {
        "max_new_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "do_sample": True,
    }

    output = model.generate(**inputs, **generation_kwargs)
    assistant_response = processor.decode(output[0], skip_special_tokens=True)

    conversation.append({
        "role": "assistant",
        "content": [{"type": "text", "text": assistant_response.strip()}]
    })
    
    return conversation_display(conversation), conversation

def conversation_display(conversation):
    chat_history = []
    for msg in conversation:
        if msg["role"] == "user":
            user_text = get_text_from_content(msg["content"])
        elif msg["role"] == "assistant":
            assistant_text = msg["content"][0]["text"].split("<|assistant|>")[-1].strip()
            chat_history.append({"role": "user", "content": user_text})
            chat_history.append({"role": "assistant", "content": assistant_text})
    return chat_history

def clear_chat():
    return [], [], "", None
    
with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo:
    gr.Markdown("# Granite Vision 3.1 2B")
    
    with gr.Row():
        with gr.Column(scale=2):
            image_input = gr.Image(type="pil", label="Upload Image (optional)")
            with gr.Column():
                temperature_input = gr.Slider(minimum=0.0, maximum=2.0, value=0.2, step=0.01, label="Temperature")
                top_p_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top p")
                top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top k")
                max_tokens_input = gr.Slider(minimum=10, maximum=300, value=128, step=1, label="Max Tokens")
            
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", type='messages')
            text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message")
            with gr.Row():
                send_button = gr.Button("Chat")
                clear_button = gr.Button("Clear Chat")
    

    state = gr.State([])

    send_button.click(
        chat_inference,
        inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input, state],
        outputs=[chatbot, state]
    )

    clear_button.click(
        clear_chat,
        inputs=None,
        outputs=[chatbot, state, text_input, image_input]
    )

    gr.Examples(
        examples=[
            ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is this?"]
        ],
        inputs=[image_input, text_input]
    )

if __name__ == "__main__":
    demo.launch()